[Mlir-commits] [mlir] [mlir][python] implement GenericOp bindings (PR #124496)
Maksim Levental
llvmlistbot at llvm.org
Mon Jan 27 10:54:13 PST 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/124496
>From 30534395864030def62d8dbec72756876a2a2946 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Sun, 26 Jan 2025 20:01:51 -0500
Subject: [PATCH 1/2] [mlir][python] implement GenericOp bindings
---
mlir/python/mlir/dialects/linalg/__init__.py | 44 +++++++++
mlir/test/python/dialects/linalg/ops.py | 97 +++++++++++++++++++-
2 files changed, 140 insertions(+), 1 deletion(-)
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 8fb1227ee80ff5..29b1386e383a22 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -10,6 +10,7 @@
# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
from .._linalg_ops_gen import *
from .._linalg_enum_gen import *
+from .._linalg_enum_gen import _iteratortypeenum
# These are the ground truth functions defined as:
# ```
@@ -58,6 +59,7 @@
from ...ir import *
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from ...extras.meta import region_op
def transpose(
@@ -102,3 +104,45 @@ def broadcast(
)
fill_builtin_region(op.operation)
return op
+
+
+ at register_attribute_builder("IteratorTypeArrayAttr")
+def _IteratorTypeArrayAttr(x, context):
+ return ArrayAttr.get([_iteratortypeenum(v, context) for v in x])
+
+
+class GenericOp_(GenericOp):
+ def __init__(
+ self,
+ inputs,
+ outputs,
+ indexing_maps,
+ iterator_types,
+ *,
+ doc=None,
+ library_call=None,
+ loc=None,
+ ip=None,
+ ):
+ result_types = []
+ if isinstance(outputs[0].type, RankedTensorType):
+ result_types = [o.type for o in outputs]
+
+ super().__init__(
+ result_types,
+ inputs,
+ outputs,
+ indexing_maps,
+ iterator_types,
+ doc=doc,
+ library_call=library_call,
+ loc=loc,
+ ip=ip,
+ )
+ element_types = [i.type.element_type for i in inputs] + [
+ o.type.element_type for o in outputs
+ ]
+ self.regions[0].blocks.append(*element_types)
+
+
+generic = region_op(GenericOp_, terminator=YieldOp)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 72045a07b2da80..ac7186c24bed84 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -1,6 +1,6 @@
# RUN: %PYTHON %s | FileCheck %s
-from mlir.dialects import arith, builtin, func, linalg, tensor
+from mlir.dialects import arith, func, linalg, tensor, memref
from mlir.dialects.linalg.opdsl.lang import *
from mlir.ir import *
@@ -84,6 +84,7 @@ def named_form(lhs, rhs):
print(module)
+
# CHECK-LABEL: TEST: testIdentityRegionOps
@run
def testIdentityRegionOps():
@@ -161,3 +162,97 @@ def broadcast_op(op1, op2, op3):
op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
print(module)
+
+
+# CHECK-LABEL: TEST: testGenericOp
+ at run
+def testGenericOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ memref_t = MemRefType.get([10, 10], f32)
+ with InsertionPoint(module.body):
+ id_map_1 = AffineMap.get_identity(2)
+ # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32>
+ # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32>
+ x = tensor.empty((16, 16), f32)
+ y = tensor.empty((16, 16), f32)
+
+ # CHECK: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) {
+ # CHECK: ^bb0(%in: f32, %out: f32):
+ # CHECK: linalg.yield %in : f32
+ # CHECK: } -> tensor<16x16xf32>
+ @linalg.generic(
+ [x],
+ [y],
+ [id_map_1, id_map_1],
+ [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
+ )
+ def f(a, b):
+ assert isinstance(a, Value)
+ assert isinstance(a.type, F32Type)
+ assert isinstance(b, Value)
+ assert isinstance(b.type, F32Type)
+ return a
+
+ assert isinstance(f, Value)
+ assert isinstance(f.type, RankedTensorType)
+
+ # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32>
+ z = tensor.empty((16, 16, 16), f32)
+
+ minor_id = AffineMap.get_minor_identity(3, 2)
+ id_map_2 = AffineMap.get_identity(3)
+
+ # CHECK: %[[VAL_4:.+]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) {
+ # CHECK: ^bb0(%in: f32, %out: f32, %out_1: f32):
+ # CHECK: linalg.yield %in, %out : f32, f32
+ # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>)
+ @linalg.generic(
+ [x],
+ [z, z],
+ [minor_id, id_map_2, id_map_2],
+ [
+ linalg.IteratorType.parallel,
+ linalg.IteratorType.parallel,
+ linalg.IteratorType.parallel,
+ ],
+ )
+ def g(a, b, c):
+ assert isinstance(a, Value)
+ assert isinstance(a.type, F32Type)
+ assert isinstance(b, Value)
+ assert isinstance(b.type, F32Type)
+ assert isinstance(c, Value)
+ assert isinstance(c.type, F32Type)
+ return a, b
+
+ assert isinstance(g, OpResultList)
+ assert len(g) == 2
+ assert isinstance(g[0].type, RankedTensorType)
+ assert isinstance(g[1].type, RankedTensorType)
+
+ # CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<10x10xf32>
+ # CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<10x10xf32>
+ xx = memref.alloc(memref_t, [], [])
+ yy = memref.alloc(memref_t, [], [])
+
+ # CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_5]] : memref<10x10xf32>) outs(%[[VAL_6]] : memref<10x10xf32>) {
+ # CHECK: ^bb0(%in: f32, %out: f32):
+ # CHECK: linalg.yield %in : f32
+ # CHECK: }
+ @linalg.generic(
+ [xx],
+ [yy],
+ [id_map_1, id_map_1],
+ [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
+ )
+ def f(a, b):
+ assert isinstance(a, Value)
+ assert isinstance(a.type, F32Type)
+ assert isinstance(b, Value)
+ assert isinstance(b.type, F32Type)
+ return a
+
+ module.operation.verify()
+ print(module)
>From 94b957520052e9dc25ff06f64271053502717146 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Mon, 27 Jan 2025 13:54:05 -0500
Subject: [PATCH 2/2] Update mlir/python/mlir/dialects/linalg/__init__.py
---
mlir/python/mlir/dialects/linalg/__init__.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 29b1386e383a22..742262a9c49695 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -111,6 +111,7 @@ def _IteratorTypeArrayAttr(x, context):
return ArrayAttr.get([_iteratortypeenum(v, context) for v in x])
+# The underscore is needed here so that there's no collision with opdsl generation.
class GenericOp_(GenericOp):
def __init__(
self,
More information about the Mlir-commits
mailing list