[Mlir-commits] [mlir] 1bc5fe6 - [mlir][python] implement GenericOp bindings (#124496)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 28 09:02:30 PST 2025


Author: Maksim Levental
Date: 2025-01-28T12:02:26-05:00
New Revision: 1bc5fe669f5477eadd84270e971591a718693bba

URL: https://github.com/llvm/llvm-project/commit/1bc5fe669f5477eadd84270e971591a718693bba
DIFF: https://github.com/llvm/llvm-project/commit/1bc5fe669f5477eadd84270e971591a718693bba.diff

LOG: [mlir][python] implement GenericOp bindings (#124496)

Added: 
    

Modified: 
    mlir/python/mlir/dialects/linalg/__init__.py
    mlir/test/python/dialects/linalg/ops.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 8fb1227ee80ff5..742262a9c49695 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -10,6 +10,7 @@
 #   DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
 from .._linalg_ops_gen import *
 from .._linalg_enum_gen import *
+from .._linalg_enum_gen import _iteratortypeenum
 
 # These are the ground truth functions defined as:
 # ```
@@ -58,6 +59,7 @@
 
 from ...ir import *
 from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from ...extras.meta import region_op
 
 
 def transpose(
@@ -102,3 +104,46 @@ def broadcast(
     )
     fill_builtin_region(op.operation)
     return op
+
+
+ at register_attribute_builder("IteratorTypeArrayAttr")
+def _IteratorTypeArrayAttr(x, context):
+    return ArrayAttr.get([_iteratortypeenum(v, context) for v in x])
+
+
+# The underscore is needed here so that there's no collision with opdsl generation.
+class GenericOp_(GenericOp):
+    def __init__(
+        self,
+        inputs,
+        outputs,
+        indexing_maps,
+        iterator_types,
+        *,
+        doc=None,
+        library_call=None,
+        loc=None,
+        ip=None,
+    ):
+        result_types = []
+        if isinstance(outputs[0].type, RankedTensorType):
+            result_types = [o.type for o in outputs]
+
+        super().__init__(
+            result_types,
+            inputs,
+            outputs,
+            indexing_maps,
+            iterator_types,
+            doc=doc,
+            library_call=library_call,
+            loc=loc,
+            ip=ip,
+        )
+        element_types = [i.type.element_type for i in inputs] + [
+            o.type.element_type for o in outputs
+        ]
+        self.regions[0].blocks.append(*element_types)
+
+
+generic = region_op(GenericOp_, terminator=YieldOp)

diff  --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 72045a07b2da80..ac7186c24bed84 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -1,6 +1,6 @@
 # RUN: %PYTHON %s | FileCheck %s
 
-from mlir.dialects import arith, builtin, func, linalg, tensor
+from mlir.dialects import arith, func, linalg, tensor, memref
 from mlir.dialects.linalg.opdsl.lang import *
 from mlir.ir import *
 
@@ -84,6 +84,7 @@ def named_form(lhs, rhs):
 
     print(module)
 
+
 # CHECK-LABEL: TEST: testIdentityRegionOps
 @run
 def testIdentityRegionOps():
@@ -161,3 +162,97 @@ def broadcast_op(op1, op2, op3):
                 op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
 
     print(module)
+
+
+# CHECK-LABEL: TEST: testGenericOp
+ at run
+def testGenericOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        memref_t = MemRefType.get([10, 10], f32)
+        with InsertionPoint(module.body):
+            id_map_1 = AffineMap.get_identity(2)
+            # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32>
+            # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32>
+            x = tensor.empty((16, 16), f32)
+            y = tensor.empty((16, 16), f32)
+
+            # CHECK: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) {
+            # CHECK: ^bb0(%in: f32, %out: f32):
+            # CHECK:   linalg.yield %in : f32
+            # CHECK: } -> tensor<16x16xf32>
+            @linalg.generic(
+                [x],
+                [y],
+                [id_map_1, id_map_1],
+                [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
+            )
+            def f(a, b):
+                assert isinstance(a, Value)
+                assert isinstance(a.type, F32Type)
+                assert isinstance(b, Value)
+                assert isinstance(b.type, F32Type)
+                return a
+
+            assert isinstance(f, Value)
+            assert isinstance(f.type, RankedTensorType)
+
+            # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32>
+            z = tensor.empty((16, 16, 16), f32)
+
+            minor_id = AffineMap.get_minor_identity(3, 2)
+            id_map_2 = AffineMap.get_identity(3)
+
+            # CHECK: %[[VAL_4:.+]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) {
+            # CHECK: ^bb0(%in: f32, %out: f32, %out_1: f32):
+            # CHECK:   linalg.yield %in, %out : f32, f32
+            # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>)
+            @linalg.generic(
+                [x],
+                [z, z],
+                [minor_id, id_map_2, id_map_2],
+                [
+                    linalg.IteratorType.parallel,
+                    linalg.IteratorType.parallel,
+                    linalg.IteratorType.parallel,
+                ],
+            )
+            def g(a, b, c):
+                assert isinstance(a, Value)
+                assert isinstance(a.type, F32Type)
+                assert isinstance(b, Value)
+                assert isinstance(b.type, F32Type)
+                assert isinstance(c, Value)
+                assert isinstance(c.type, F32Type)
+                return a, b
+
+            assert isinstance(g, OpResultList)
+            assert len(g) == 2
+            assert isinstance(g[0].type, RankedTensorType)
+            assert isinstance(g[1].type, RankedTensorType)
+
+            # CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<10x10xf32>
+            # CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<10x10xf32>
+            xx = memref.alloc(memref_t, [], [])
+            yy = memref.alloc(memref_t, [], [])
+
+            # CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_5]] : memref<10x10xf32>) outs(%[[VAL_6]] : memref<10x10xf32>) {
+            # CHECK: ^bb0(%in: f32, %out: f32):
+            # CHECK:   linalg.yield %in : f32
+            # CHECK: }
+            @linalg.generic(
+                [xx],
+                [yy],
+                [id_map_1, id_map_1],
+                [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
+            )
+            def f(a, b):
+                assert isinstance(a, Value)
+                assert isinstance(a.type, F32Type)
+                assert isinstance(b, Value)
+                assert isinstance(b.type, F32Type)
+                return a
+
+    module.operation.verify()
+    print(module)


        


More information about the Mlir-commits mailing list