2222
2323# pylint: disable=too-few-public-methods
2424
25+ from collections .abc import Sequence
2526from enum import StrEnum
2627
27- from xdsl .ir import EnumAttribute , SpacedOpaqueSyntaxAttribute
28+ from xdsl .dialects .builtin import I64 , ArrayAttr , IntegerAttr , i64
29+ from xdsl .ir import Attribute , EnumAttribute , ParametrizedAttribute , SpacedOpaqueSyntaxAttribute
2830from xdsl .irdl import irdl_attr_definition
31+ from xdsl .parser import AttrParser
32+ from xdsl .printer import Printer
33+
34+
35+ # Utility functions for dimension array parsing/printing
36+ def parse_dims (parser : AttrParser ) -> ArrayAttr [IntegerAttr [I64 ]]:
37+ """Parse dimension array in [1, 2, 3] format"""
38+ value = parser .parse_comma_separated_list (
39+ AttrParser .Delimiter .SQUARE ,
40+ lambda : IntegerAttr (parser .parse_integer (), i64 ),
41+ )
42+ return ArrayAttr (value )
43+
44+
45+ def print_dims (printer : Printer , dims : ArrayAttr [IntegerAttr [I64 ]]):
46+ """Print dimension array in [1, 2, 3] format"""
47+ printer .print_string ("[" )
48+ printer .print_list (
49+ dims .data ,
50+ lambda dim : printer .print_string (f"{ dim .value .data } " ),
51+ )
52+ printer .print_string ("]" )
2953
3054
3155class ResultAccuracyMode (StrEnum ):
@@ -47,3 +71,219 @@ class ResultAccuracyModeAttr(EnumAttribute[ResultAccuracyMode], SpacedOpaqueSynt
4771 """
4872
4973 name = "stablehlo.result_accuracy_mode"
74+
75+
76+ @irdl_attr_definition
77+ class GatherDimensionNumbers (ParametrizedAttribute ):
78+ """
79+ XLA gather dimension numbers.
80+
81+ This attribute models the dimension information for gather operations.
82+ See external [documentation](https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloAttrs.td#L42).
83+ """
84+
85+ name = "stablehlo.gather"
86+
87+ offset_dims : ArrayAttr [IntegerAttr [I64 ]]
88+ collapsed_slice_dims : ArrayAttr [IntegerAttr [I64 ]]
89+ operand_batching_dims : ArrayAttr [IntegerAttr [I64 ]]
90+ start_indices_batching_dims : ArrayAttr [IntegerAttr [I64 ]]
91+ start_index_map : ArrayAttr [IntegerAttr [I64 ]]
92+ index_vector_dim : IntegerAttr [I64 ]
93+
94+ def print_parameters (self , printer : Printer ) -> None :
95+ """Print gather dimension numbers in structured format"""
96+ with printer .in_angle_brackets ():
97+ with printer .indented ():
98+ # Print offset_dims
99+ printer .print_string ("\n offset_dims = " )
100+ print_dims (printer , self .offset_dims )
101+ printer .print_string ("," )
102+
103+ # Print collapsed_slice_dims
104+ printer .print_string ("\n collapsed_slice_dims = " )
105+ print_dims (printer , self .collapsed_slice_dims )
106+ printer .print_string ("," )
107+
108+ # Print operand_batching_dims
109+ printer .print_string ("\n operand_batching_dims = " )
110+ print_dims (printer , self .operand_batching_dims )
111+ printer .print_string ("," )
112+
113+ # Print start_indices_batching_dims
114+ printer .print_string ("\n start_indices_batching_dims = " )
115+ print_dims (printer , self .start_indices_batching_dims )
116+ printer .print_string ("," )
117+
118+ # Print start_index_map
119+ printer .print_string ("\n start_index_map = " )
120+ print_dims (printer , self .start_index_map )
121+ printer .print_string ("," )
122+
123+ # Print index_vector_dim
124+ printer .print_string (f"\n index_vector_dim = { self .index_vector_dim .value .data } " )
125+ printer .print_string ("\n " )
126+
127+ @classmethod
128+ def parse_parameters (cls , parser : AttrParser ) -> Sequence [Attribute ]:
129+ """Parse gather dimension numbers from structured format"""
130+ with parser .in_angle_brackets ():
131+ # Initialize default values for all fields
132+ offset_dims = ArrayAttr ([])
133+ collapsed_slice_dims = ArrayAttr ([])
134+ operand_batching_dims = ArrayAttr ([])
135+ start_indices_batching_dims = ArrayAttr ([])
136+ start_index_map = ArrayAttr ([])
137+ index_vector_dim = IntegerAttr (0 , i64 )
138+
139+ # Try to parse offset_dims
140+ if parser .parse_optional_characters ("offset_dims" ) is not None :
141+ parser .parse_punctuation ("=" )
142+ offset_dims = parse_dims (parser )
143+ parser .parse_optional_punctuation ("," )
144+
145+ # Try to parse collapsed_slice_dims
146+ if parser .parse_optional_characters ("collapsed_slice_dims" ) is not None :
147+ parser .parse_punctuation ("=" )
148+ collapsed_slice_dims = parse_dims (parser )
149+ parser .parse_optional_punctuation ("," )
150+
151+ # Try to parse operand_batching_dims
152+ if parser .parse_optional_characters ("operand_batching_dims" ) is not None :
153+ parser .parse_punctuation ("=" )
154+ operand_batching_dims = parse_dims (parser )
155+ parser .parse_optional_punctuation ("," )
156+
157+ # Try to parse start_indices_batching_dims
158+ if parser .parse_optional_characters ("start_indices_batching_dims" ) is not None :
159+ parser .parse_punctuation ("=" )
160+ start_indices_batching_dims = parse_dims (parser )
161+ parser .parse_optional_punctuation ("," )
162+
163+ # Try to parse start_index_map
164+ if parser .parse_optional_characters ("start_index_map" ) is not None :
165+ parser .parse_punctuation ("=" )
166+ start_index_map = parse_dims (parser )
167+ parser .parse_optional_punctuation ("," )
168+
169+ # Try to parse index_vector_dim
170+ if parser .parse_optional_characters ("index_vector_dim" ) is not None :
171+ parser .parse_punctuation ("=" )
172+ index_vector_dim = IntegerAttr (parser .parse_integer (), i64 )
173+
174+ return (
175+ offset_dims ,
176+ collapsed_slice_dims ,
177+ operand_batching_dims ,
178+ start_indices_batching_dims ,
179+ start_index_map ,
180+ index_vector_dim ,
181+ )
182+
183+
184+ @irdl_attr_definition
185+ class ScatterDimensionNumbers (ParametrizedAttribute ):
186+ """
187+ XLA scatter dimension numbers.
188+
189+ This attribute models the dimension information for scatter operations.
190+ See external [documentation](https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloAttrs.td#L28).
191+ """
192+
193+ name = "stablehlo.scatter"
194+
195+ update_window_dims : ArrayAttr [IntegerAttr [I64 ]]
196+ inserted_window_dims : ArrayAttr [IntegerAttr [I64 ]]
197+ input_batching_dims : ArrayAttr [IntegerAttr [I64 ]]
198+ scatter_indices_batching_dims : ArrayAttr [IntegerAttr [I64 ]]
199+ scatter_dims_to_operand_dims : ArrayAttr [IntegerAttr [I64 ]]
200+ index_vector_dim : IntegerAttr [I64 ]
201+
202+ def print_parameters (self , printer : Printer ) -> None :
203+ """Print scatter dimension numbers in structured format"""
204+ with printer .in_angle_brackets ():
205+ with printer .indented ():
206+ # Print update_window_dims
207+ printer .print_string ("\n update_window_dims = " )
208+ print_dims (printer , self .update_window_dims )
209+ printer .print_string ("," )
210+
211+ # Print inserted_window_dims
212+ printer .print_string ("\n inserted_window_dims = " )
213+ print_dims (printer , self .inserted_window_dims )
214+ printer .print_string ("," )
215+
216+ # Print input_batching_dims
217+ printer .print_string ("\n input_batching_dims = " )
218+ print_dims (printer , self .input_batching_dims )
219+ printer .print_string ("," )
220+
221+ # Print scatter_indices_batching_dims
222+ printer .print_string ("\n scatter_indices_batching_dims = " )
223+ print_dims (printer , self .scatter_indices_batching_dims )
224+ printer .print_string ("," )
225+
226+ # Print scatter_dims_to_operand_dims
227+ printer .print_string ("\n scatter_dims_to_operand_dims = " )
228+ print_dims (printer , self .scatter_dims_to_operand_dims )
229+ printer .print_string ("," )
230+
231+ # Print index_vector_dim
232+ printer .print_string (f"\n index_vector_dim = { self .index_vector_dim .value .data } " )
233+ printer .print_string ("\n " )
234+
235+ @classmethod
236+ def parse_parameters (cls , parser : AttrParser ) -> Sequence [Attribute ]:
237+ """Parse scatter dimension numbers from structured format"""
238+ with parser .in_angle_brackets ():
239+ # Initialize default values for all fields
240+ update_window_dims = ArrayAttr ([])
241+ inserted_window_dims = ArrayAttr ([])
242+ input_batching_dims = ArrayAttr ([])
243+ scatter_indices_batching_dims = ArrayAttr ([])
244+ scatter_dims_to_operand_dims = ArrayAttr ([])
245+ index_vector_dim = IntegerAttr (0 , i64 )
246+
247+ # Try to parse update_window_dims
248+ if parser .parse_optional_characters ("update_window_dims" ) is not None :
249+ parser .parse_punctuation ("=" )
250+ update_window_dims = parse_dims (parser )
251+ parser .parse_optional_punctuation ("," )
252+
253+ # Try to parse inserted_window_dims
254+ if parser .parse_optional_characters ("inserted_window_dims" ) is not None :
255+ parser .parse_punctuation ("=" )
256+ inserted_window_dims = parse_dims (parser )
257+ parser .parse_optional_punctuation ("," )
258+
259+ # Try to parse input_batching_dims
260+ if parser .parse_optional_characters ("input_batching_dims" ) is not None :
261+ parser .parse_punctuation ("=" )
262+ input_batching_dims = parse_dims (parser )
263+ parser .parse_optional_punctuation ("," )
264+
265+ # Try to parse scatter_indices_batching_dims
266+ if parser .parse_optional_characters ("scatter_indices_batching_dims" ) is not None :
267+ parser .parse_punctuation ("=" )
268+ scatter_indices_batching_dims = parse_dims (parser )
269+ parser .parse_optional_punctuation ("," )
270+
271+ # Try to parse scatter_dims_to_operand_dims
272+ if parser .parse_optional_characters ("scatter_dims_to_operand_dims" ) is not None :
273+ parser .parse_punctuation ("=" )
274+ scatter_dims_to_operand_dims = parse_dims (parser )
275+ parser .parse_optional_punctuation ("," )
276+
277+ # Try to parse index_vector_dim
278+ if parser .parse_optional_characters ("index_vector_dim" ) is not None :
279+ parser .parse_punctuation ("=" )
280+ index_vector_dim = IntegerAttr (parser .parse_integer (), i64 )
281+
282+ return (
283+ update_window_dims ,
284+ inserted_window_dims ,
285+ input_batching_dims ,
286+ scatter_indices_batching_dims ,
287+ scatter_dims_to_operand_dims ,
288+ index_vector_dim ,
289+ )
0 commit comments