Skip to content

Commit 104b5e8

Browse files
mehrdad2mmudit2812
andauthored
Add data movement operations to StableHLO dialect (#8217)
**Context:** Adding some of data movement operations to StableHLO dialect. **Description of the Change:** Operations: - [x] Add BroadcastInDimOp - [x] Add ConcatenateOp - [x] Add GatherOp - [x] Add ReshapeOp - [x] Add ScatterOp - [x] Add SliceOp traits: - [x] Add SliceArraysSameSizeTrait Attributes: - [x] ScatterDimensionNumbers - [x] GatherDimensionNumbers **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** [sc-98427] --------- Co-authored-by: Mudit Pandey <18223836+mudit2812@users.noreply.github.com>
1 parent c2f3714 commit 104b5e8

File tree

11 files changed

+1214
-15
lines changed

11 files changed

+1214
-15
lines changed

‎pennylane/compiler/python_compiler/dialects/stablehlo/__init__.py‎

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,22 @@
6666
OptimizationBarrierOp,
6767
)
6868

69+
from .data_movement import (
70+
BroadcastInDimOp,
71+
ConcatenateOp,
72+
DynamicSliceOp,
73+
GatherOp,
74+
ReshapeOp,
75+
ScatterOp,
76+
SliceOp,
77+
)
78+
79+
from .attributes import (
80+
GatherDimensionNumbers,
81+
ResultAccuracyModeAttr,
82+
ScatterDimensionNumbers,
83+
)
84+
6985
# Import the main StableHLO dialect
7086
from .dialect import StableHLO
7187

@@ -111,4 +127,16 @@
111127
"IfOp",
112128
"WhileOp",
113129
"OptimizationBarrierOp",
130+
# Data movement operations
131+
"BroadcastInDimOp",
132+
"ConcatenateOp",
133+
"DynamicSliceOp",
134+
"GatherOp",
135+
"ReshapeOp",
136+
"ScatterOp",
137+
"SliceOp",
138+
# Attributes
139+
"GatherDimensionNumbers",
140+
"ResultAccuracyModeAttr",
141+
"ScatterDimensionNumbers",
114142
]

‎pennylane/compiler/python_compiler/dialects/stablehlo/attributes.py‎

Lines changed: 241 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,34 @@
2222

2323
# pylint: disable=too-few-public-methods
2424

25+
from collections.abc import Sequence
2526
from enum import StrEnum
2627

27-
from xdsl.ir import EnumAttribute, SpacedOpaqueSyntaxAttribute
28+
from xdsl.dialects.builtin import I64, ArrayAttr, IntegerAttr, i64
29+
from xdsl.ir import Attribute, EnumAttribute, ParametrizedAttribute, SpacedOpaqueSyntaxAttribute
2830
from xdsl.irdl import irdl_attr_definition
31+
from xdsl.parser import AttrParser
32+
from xdsl.printer import Printer
33+
34+
35+
# Utility functions for dimension array parsing/printing
36+
def parse_dims(parser: AttrParser) -> ArrayAttr[IntegerAttr[I64]]:
37+
"""Parse dimension array in [1, 2, 3] format"""
38+
value = parser.parse_comma_separated_list(
39+
AttrParser.Delimiter.SQUARE,
40+
lambda: IntegerAttr(parser.parse_integer(), i64),
41+
)
42+
return ArrayAttr(value)
43+
44+
45+
def print_dims(printer: Printer, dims: ArrayAttr[IntegerAttr[I64]]):
46+
"""Print dimension array in [1, 2, 3] format"""
47+
printer.print_string("[")
48+
printer.print_list(
49+
dims.data,
50+
lambda dim: printer.print_string(f"{dim.value.data}"),
51+
)
52+
printer.print_string("]")
2953

3054

3155
class ResultAccuracyMode(StrEnum):
@@ -47,3 +71,219 @@ class ResultAccuracyModeAttr(EnumAttribute[ResultAccuracyMode], SpacedOpaqueSynt
4771
"""
4872

4973
name = "stablehlo.result_accuracy_mode"
74+
75+
76+
@irdl_attr_definition
77+
class GatherDimensionNumbers(ParametrizedAttribute):
78+
"""
79+
XLA gather dimension numbers.
80+
81+
This attribute models the dimension information for gather operations.
82+
See external [documentation](https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloAttrs.td#L42).
83+
"""
84+
85+
name = "stablehlo.gather"
86+
87+
offset_dims: ArrayAttr[IntegerAttr[I64]]
88+
collapsed_slice_dims: ArrayAttr[IntegerAttr[I64]]
89+
operand_batching_dims: ArrayAttr[IntegerAttr[I64]]
90+
start_indices_batching_dims: ArrayAttr[IntegerAttr[I64]]
91+
start_index_map: ArrayAttr[IntegerAttr[I64]]
92+
index_vector_dim: IntegerAttr[I64]
93+
94+
def print_parameters(self, printer: Printer) -> None:
95+
"""Print gather dimension numbers in structured format"""
96+
with printer.in_angle_brackets():
97+
with printer.indented():
98+
# Print offset_dims
99+
printer.print_string("\noffset_dims = ")
100+
print_dims(printer, self.offset_dims)
101+
printer.print_string(",")
102+
103+
# Print collapsed_slice_dims
104+
printer.print_string("\ncollapsed_slice_dims = ")
105+
print_dims(printer, self.collapsed_slice_dims)
106+
printer.print_string(",")
107+
108+
# Print operand_batching_dims
109+
printer.print_string("\noperand_batching_dims = ")
110+
print_dims(printer, self.operand_batching_dims)
111+
printer.print_string(",")
112+
113+
# Print start_indices_batching_dims
114+
printer.print_string("\nstart_indices_batching_dims = ")
115+
print_dims(printer, self.start_indices_batching_dims)
116+
printer.print_string(",")
117+
118+
# Print start_index_map
119+
printer.print_string("\nstart_index_map = ")
120+
print_dims(printer, self.start_index_map)
121+
printer.print_string(",")
122+
123+
# Print index_vector_dim
124+
printer.print_string(f"\nindex_vector_dim = {self.index_vector_dim.value.data}")
125+
printer.print_string("\n")
126+
127+
@classmethod
128+
def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
129+
"""Parse gather dimension numbers from structured format"""
130+
with parser.in_angle_brackets():
131+
# Initialize default values for all fields
132+
offset_dims = ArrayAttr([])
133+
collapsed_slice_dims = ArrayAttr([])
134+
operand_batching_dims = ArrayAttr([])
135+
start_indices_batching_dims = ArrayAttr([])
136+
start_index_map = ArrayAttr([])
137+
index_vector_dim = IntegerAttr(0, i64)
138+
139+
# Try to parse offset_dims
140+
if parser.parse_optional_characters("offset_dims") is not None:
141+
parser.parse_punctuation("=")
142+
offset_dims = parse_dims(parser)
143+
parser.parse_optional_punctuation(",")
144+
145+
# Try to parse collapsed_slice_dims
146+
if parser.parse_optional_characters("collapsed_slice_dims") is not None:
147+
parser.parse_punctuation("=")
148+
collapsed_slice_dims = parse_dims(parser)
149+
parser.parse_optional_punctuation(",")
150+
151+
# Try to parse operand_batching_dims
152+
if parser.parse_optional_characters("operand_batching_dims") is not None:
153+
parser.parse_punctuation("=")
154+
operand_batching_dims = parse_dims(parser)
155+
parser.parse_optional_punctuation(",")
156+
157+
# Try to parse start_indices_batching_dims
158+
if parser.parse_optional_characters("start_indices_batching_dims") is not None:
159+
parser.parse_punctuation("=")
160+
start_indices_batching_dims = parse_dims(parser)
161+
parser.parse_optional_punctuation(",")
162+
163+
# Try to parse start_index_map
164+
if parser.parse_optional_characters("start_index_map") is not None:
165+
parser.parse_punctuation("=")
166+
start_index_map = parse_dims(parser)
167+
parser.parse_optional_punctuation(",")
168+
169+
# Try to parse index_vector_dim
170+
if parser.parse_optional_characters("index_vector_dim") is not None:
171+
parser.parse_punctuation("=")
172+
index_vector_dim = IntegerAttr(parser.parse_integer(), i64)
173+
174+
return (
175+
offset_dims,
176+
collapsed_slice_dims,
177+
operand_batching_dims,
178+
start_indices_batching_dims,
179+
start_index_map,
180+
index_vector_dim,
181+
)
182+
183+
184+
@irdl_attr_definition
185+
class ScatterDimensionNumbers(ParametrizedAttribute):
186+
"""
187+
XLA scatter dimension numbers.
188+
189+
This attribute models the dimension information for scatter operations.
190+
See external [documentation](https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloAttrs.td#L28).
191+
"""
192+
193+
name = "stablehlo.scatter"
194+
195+
update_window_dims: ArrayAttr[IntegerAttr[I64]]
196+
inserted_window_dims: ArrayAttr[IntegerAttr[I64]]
197+
input_batching_dims: ArrayAttr[IntegerAttr[I64]]
198+
scatter_indices_batching_dims: ArrayAttr[IntegerAttr[I64]]
199+
scatter_dims_to_operand_dims: ArrayAttr[IntegerAttr[I64]]
200+
index_vector_dim: IntegerAttr[I64]
201+
202+
def print_parameters(self, printer: Printer) -> None:
203+
"""Print scatter dimension numbers in structured format"""
204+
with printer.in_angle_brackets():
205+
with printer.indented():
206+
# Print update_window_dims
207+
printer.print_string("\nupdate_window_dims = ")
208+
print_dims(printer, self.update_window_dims)
209+
printer.print_string(",")
210+
211+
# Print inserted_window_dims
212+
printer.print_string("\ninserted_window_dims = ")
213+
print_dims(printer, self.inserted_window_dims)
214+
printer.print_string(",")
215+
216+
# Print input_batching_dims
217+
printer.print_string("\ninput_batching_dims = ")
218+
print_dims(printer, self.input_batching_dims)
219+
printer.print_string(",")
220+
221+
# Print scatter_indices_batching_dims
222+
printer.print_string("\nscatter_indices_batching_dims = ")
223+
print_dims(printer, self.scatter_indices_batching_dims)
224+
printer.print_string(",")
225+
226+
# Print scatter_dims_to_operand_dims
227+
printer.print_string("\nscatter_dims_to_operand_dims = ")
228+
print_dims(printer, self.scatter_dims_to_operand_dims)
229+
printer.print_string(",")
230+
231+
# Print index_vector_dim
232+
printer.print_string(f"\nindex_vector_dim = {self.index_vector_dim.value.data}")
233+
printer.print_string("\n")
234+
235+
@classmethod
236+
def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
237+
"""Parse scatter dimension numbers from structured format"""
238+
with parser.in_angle_brackets():
239+
# Initialize default values for all fields
240+
update_window_dims = ArrayAttr([])
241+
inserted_window_dims = ArrayAttr([])
242+
input_batching_dims = ArrayAttr([])
243+
scatter_indices_batching_dims = ArrayAttr([])
244+
scatter_dims_to_operand_dims = ArrayAttr([])
245+
index_vector_dim = IntegerAttr(0, i64)
246+
247+
# Try to parse update_window_dims
248+
if parser.parse_optional_characters("update_window_dims") is not None:
249+
parser.parse_punctuation("=")
250+
update_window_dims = parse_dims(parser)
251+
parser.parse_optional_punctuation(",")
252+
253+
# Try to parse inserted_window_dims
254+
if parser.parse_optional_characters("inserted_window_dims") is not None:
255+
parser.parse_punctuation("=")
256+
inserted_window_dims = parse_dims(parser)
257+
parser.parse_optional_punctuation(",")
258+
259+
# Try to parse input_batching_dims
260+
if parser.parse_optional_characters("input_batching_dims") is not None:
261+
parser.parse_punctuation("=")
262+
input_batching_dims = parse_dims(parser)
263+
parser.parse_optional_punctuation(",")
264+
265+
# Try to parse scatter_indices_batching_dims
266+
if parser.parse_optional_characters("scatter_indices_batching_dims") is not None:
267+
parser.parse_punctuation("=")
268+
scatter_indices_batching_dims = parse_dims(parser)
269+
parser.parse_optional_punctuation(",")
270+
271+
# Try to parse scatter_dims_to_operand_dims
272+
if parser.parse_optional_characters("scatter_dims_to_operand_dims") is not None:
273+
parser.parse_punctuation("=")
274+
scatter_dims_to_operand_dims = parse_dims(parser)
275+
parser.parse_optional_punctuation(",")
276+
277+
# Try to parse index_vector_dim
278+
if parser.parse_optional_characters("index_vector_dim") is not None:
279+
parser.parse_punctuation("=")
280+
index_vector_dim = IntegerAttr(parser.parse_integer(), i64)
281+
282+
return (
283+
update_window_dims,
284+
inserted_window_dims,
285+
input_batching_dims,
286+
scatter_indices_batching_dims,
287+
scatter_dims_to_operand_dims,
288+
index_vector_dim,
289+
)

‎pennylane/compiler/python_compiler/dialects/stablehlo/control_flow.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class IfOp(IRDLOperation):
6767

6868
pred = operand_def(HLO_PredTensor)
6969

70-
results = var_result_def(HLO_TensorOrPerAxisQuantizedTensorOrToken)
70+
res = var_result_def(HLO_TensorOrPerAxisQuantizedTensorOrToken)
7171

7272
true_branch = region_def("single_block")
7373

@@ -110,7 +110,7 @@ class WhileOp(IRDLOperation):
110110

111111
operand = var_operand_def(HLO_TensorOrPerAxisQuantizedTensorOrToken)
112112

113-
results = var_result_def(HLO_TensorOrPerAxisQuantizedTensorOrToken)
113+
res = var_result_def(HLO_TensorOrPerAxisQuantizedTensorOrToken)
114114

115115
cond = region_def("single_block")
116116

@@ -146,7 +146,7 @@ class OptimizationBarrierOp(IRDLOperation):
146146

147147
operand = var_operand_def(HLO_TensorOrToken)
148148

149-
results = var_result_def(HLO_TensorOrToken)
149+
res = var_result_def(HLO_TensorOrToken)
150150

151151
traits = traits_def(
152152
Pure(),

0 commit comments

Comments
 (0)