[Mlir-commits] [mlir] 1bc5fe6 - [mlir][python] implement GenericOp bindings (#124496)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 28 09:02:30 PST 2025
Author: Maksim Levental
Date: 2025-01-28T12:02:26-05:00
New Revision: 1bc5fe669f5477eadd84270e971591a718693bba
URL: https://github.com/llvm/llvm-project/commit/1bc5fe669f5477eadd84270e971591a718693bba
DIFF: https://github.com/llvm/llvm-project/commit/1bc5fe669f5477eadd84270e971591a718693bba.diff
LOG: [mlir][python] implement GenericOp bindings (#124496)
Added:
Modified:
mlir/python/mlir/dialects/linalg/__init__.py
mlir/test/python/dialects/linalg/ops.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 8fb1227ee80ff5..742262a9c49695 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -10,6 +10,7 @@
# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
from .._linalg_ops_gen import *
from .._linalg_enum_gen import *
+from .._linalg_enum_gen import _iteratortypeenum
# These are the ground truth functions defined as:
# ```
@@ -58,6 +59,7 @@
from ...ir import *
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from ...extras.meta import region_op
def transpose(
@@ -102,3 +104,46 @@ def broadcast(
)
fill_builtin_region(op.operation)
return op
+
+
+ at register_attribute_builder("IteratorTypeArrayAttr")
+def _IteratorTypeArrayAttr(x, context):
+ return ArrayAttr.get([_iteratortypeenum(v, context) for v in x])
+
+
+# The underscore is needed here so that there's no collision with opdsl generation.
+class GenericOp_(GenericOp):
+ def __init__(
+ self,
+ inputs,
+ outputs,
+ indexing_maps,
+ iterator_types,
+ *,
+ doc=None,
+ library_call=None,
+ loc=None,
+ ip=None,
+ ):
+ result_types = []
+ if isinstance(outputs[0].type, RankedTensorType):
+ result_types = [o.type for o in outputs]
+
+ super().__init__(
+ result_types,
+ inputs,
+ outputs,
+ indexing_maps,
+ iterator_types,
+ doc=doc,
+ library_call=library_call,
+ loc=loc,
+ ip=ip,
+ )
+ element_types = [i.type.element_type for i in inputs] + [
+ o.type.element_type for o in outputs
+ ]
+ self.regions[0].blocks.append(*element_types)
+
+
+generic = region_op(GenericOp_, terminator=YieldOp)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 72045a07b2da80..ac7186c24bed84 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -1,6 +1,6 @@
# RUN: %PYTHON %s | FileCheck %s
-from mlir.dialects import arith, builtin, func, linalg, tensor
+from mlir.dialects import arith, func, linalg, tensor, memref
from mlir.dialects.linalg.opdsl.lang import *
from mlir.ir import *
@@ -84,6 +84,7 @@ def named_form(lhs, rhs):
print(module)
+
# CHECK-LABEL: TEST: testIdentityRegionOps
@run
def testIdentityRegionOps():
@@ -161,3 +162,97 @@ def broadcast_op(op1, op2, op3):
op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
print(module)
+
+
+# CHECK-LABEL: TEST: testGenericOp
+ at run
+def testGenericOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ memref_t = MemRefType.get([10, 10], f32)
+ with InsertionPoint(module.body):
+ id_map_1 = AffineMap.get_identity(2)
+ # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32>
+ # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32>
+ x = tensor.empty((16, 16), f32)
+ y = tensor.empty((16, 16), f32)
+
+ # CHECK: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) {
+ # CHECK: ^bb0(%in: f32, %out: f32):
+ # CHECK: linalg.yield %in : f32
+ # CHECK: } -> tensor<16x16xf32>
+ @linalg.generic(
+ [x],
+ [y],
+ [id_map_1, id_map_1],
+ [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
+ )
+ def f(a, b):
+ assert isinstance(a, Value)
+ assert isinstance(a.type, F32Type)
+ assert isinstance(b, Value)
+ assert isinstance(b.type, F32Type)
+ return a
+
+ assert isinstance(f, Value)
+ assert isinstance(f.type, RankedTensorType)
+
+ # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32>
+ z = tensor.empty((16, 16, 16), f32)
+
+ minor_id = AffineMap.get_minor_identity(3, 2)
+ id_map_2 = AffineMap.get_identity(3)
+
+ # CHECK: %[[VAL_4:.+]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) {
+ # CHECK: ^bb0(%in: f32, %out: f32, %out_1: f32):
+ # CHECK: linalg.yield %in, %out : f32, f32
+ # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>)
+ @linalg.generic(
+ [x],
+ [z, z],
+ [minor_id, id_map_2, id_map_2],
+ [
+ linalg.IteratorType.parallel,
+ linalg.IteratorType.parallel,
+ linalg.IteratorType.parallel,
+ ],
+ )
+ def g(a, b, c):
+ assert isinstance(a, Value)
+ assert isinstance(a.type, F32Type)
+ assert isinstance(b, Value)
+ assert isinstance(b.type, F32Type)
+ assert isinstance(c, Value)
+ assert isinstance(c.type, F32Type)
+ return a, b
+
+ assert isinstance(g, OpResultList)
+ assert len(g) == 2
+ assert isinstance(g[0].type, RankedTensorType)
+ assert isinstance(g[1].type, RankedTensorType)
+
+ # CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<10x10xf32>
+ # CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<10x10xf32>
+ xx = memref.alloc(memref_t, [], [])
+ yy = memref.alloc(memref_t, [], [])
+
+ # CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_5]] : memref<10x10xf32>) outs(%[[VAL_6]] : memref<10x10xf32>) {
+ # CHECK: ^bb0(%in: f32, %out: f32):
+ # CHECK: linalg.yield %in : f32
+ # CHECK: }
+ @linalg.generic(
+ [xx],
+ [yy],
+ [id_map_1, id_map_1],
+ [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
+ )
+ def f(a, b):
+ assert isinstance(a, Value)
+ assert isinstance(a.type, F32Type)
+ assert isinstance(b, Value)
+ assert isinstance(b.type, F32Type)
+ return a
+
+ module.operation.verify()
+ print(module)
More information about the Mlir-commits
mailing list