[Mlir-commits] [mlir] [mlir][python] implement GenericOp bindings (PR #124496)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 26 17:03:30 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/124496.diff
2 Files Affected:
- (modified) mlir/python/mlir/dialects/linalg/__init__.py (+44)
- (modified) mlir/test/python/dialects/linalg/ops.py (+60)
``````````diff
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 8fb1227ee80ff5..946094e2e9f691 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -10,6 +10,7 @@
# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
from .._linalg_ops_gen import *
from .._linalg_enum_gen import *
+from .._linalg_enum_gen import _iteratortypeenum
# These are the ground truth functions defined as:
# ```
@@ -58,6 +59,7 @@
from ...ir import *
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from ...extras.meta import region_op
def transpose(
@@ -102,3 +104,45 @@ def broadcast(
)
fill_builtin_region(op.operation)
return op
+
+
+ at register_attribute_builder("IteratorTypeArrayAttr")
+def _IteratorTypeArrayAttr(x, context):
+ return ArrayAttr.get([_iteratortypeenum(v, context) for v in x])
+
+
+class GenericOp(GenericOp):
+ def __init__(
+ self,
+ inputs,
+ outputs,
+ indexing_maps,
+ iterator_types,
+ *,
+ doc=None,
+ library_call=None,
+ loc=None,
+ ip=None,
+ ):
+ result_types = []
+ if isinstance(outputs[0].type, RankedTensorType):
+ result_types = [o.type for o in outputs]
+
+ super().__init__(
+ result_types,
+ inputs,
+ outputs,
+ indexing_maps,
+ iterator_types,
+ doc=doc,
+ library_call=library_call,
+ loc=loc,
+ ip=ip,
+ )
+ element_types = [i.type.element_type for i in inputs] + [
+ o.type.element_type for o in outputs
+ ]
+ self.regions[0].blocks.append(*element_types)
+
+
+generic = region_op(GenericOp, terminator=YieldOp)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 72045a07b2da80..b7e0f2884bb249 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -84,6 +84,7 @@ def named_form(lhs, rhs):
print(module)
+
# CHECK-LABEL: TEST: testIdentityRegionOps
@run
def testIdentityRegionOps():
@@ -161,3 +162,62 @@ def broadcast_op(op1, op2, op3):
op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
print(module)
+
+
+# CHECK-LABEL: TEST: testGenericOp
+ at run
+def testGenericOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ id_map = AffineMap.get_identity(2)
+ # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32>
+ # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32>
+ x = tensor.empty((16, 16), f32)
+ y = tensor.empty((16, 16), f32)
+
+ # CHECK: %[[VAL_3:*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) {
+ # CHECK: ^bb0(%in: f32, %out: f32):
+ # CHECK: linalg.yield %in : f32
+ # CHECK: } -> tensor<16x16xf32>
+ @linalg.generic(
+ [x],
+ [y],
+ [id_map, id_map],
+ [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
+ )
+ def f(x, y):
+ return x
+
+ assert isinstance(f, Value)
+
+ # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32>
+ z = tensor.empty((16, 16, 16), f32)
+
+ minor_id = AffineMap.get_minor_identity(3, 2)
+ id_map = AffineMap.get_identity(3)
+
+ # CHECK: %%[[VAL_4:.*]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) {
+ # CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32):
+ # CHECK: linalg.yield %in, %out : f32, f32
+ # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>)
+ @linalg.generic(
+ [x],
+ [z, z],
+ [minor_id, id_map, id_map],
+ [
+ linalg.IteratorType.parallel,
+ linalg.IteratorType.parallel,
+ linalg.IteratorType.parallel,
+ ],
+ )
+ def g(x, z1, z2):
+ return x, z1
+
+ assert isinstance(g, OpResultList)
+ assert len(g) == 2
+ assert isinstance(g[0].type, RankedTensorType)
+ assert isinstance(g[1].type, RankedTensorType)
+
+ print(module)
``````````
</details>
https://github.com/llvm/llvm-project/pull/124496
More information about the Mlir-commits
mailing list