[Mlir-commits] [mlir] [mlir][python] implement GenericOp bindings (PR #124496)
Maksim Levental
llvmlistbot at llvm.org
Sun Jan 26 17:02:41 PST 2025
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/124496
None
>From 206cb67df44b6dade7408a4adb8337ca01e00c82 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Sun, 26 Jan 2025 20:01:51 -0500
Subject: [PATCH] [mlir][python] implement GenericOp bindings
---
mlir/python/mlir/dialects/linalg/__init__.py | 44 ++++++++++++++
mlir/test/python/dialects/linalg/ops.py | 60 ++++++++++++++++++++
2 files changed, 104 insertions(+)
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 8fb1227ee80ff5..946094e2e9f691 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -10,6 +10,7 @@
# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
from .._linalg_ops_gen import *
from .._linalg_enum_gen import *
+from .._linalg_enum_gen import _iteratortypeenum
# These are the ground truth functions defined as:
# ```
@@ -58,6 +59,7 @@
from ...ir import *
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from ...extras.meta import region_op
def transpose(
@@ -102,3 +104,45 @@ def broadcast(
)
fill_builtin_region(op.operation)
return op
+
+
+ at register_attribute_builder("IteratorTypeArrayAttr")
+def _IteratorTypeArrayAttr(x, context):
+ return ArrayAttr.get([_iteratortypeenum(v, context) for v in x])
+
+
+class GenericOp(GenericOp):
+ def __init__(
+ self,
+ inputs,
+ outputs,
+ indexing_maps,
+ iterator_types,
+ *,
+ doc=None,
+ library_call=None,
+ loc=None,
+ ip=None,
+ ):
+ result_types = []
+ if isinstance(outputs[0].type, RankedTensorType):
+ result_types = [o.type for o in outputs]
+
+ super().__init__(
+ result_types,
+ inputs,
+ outputs,
+ indexing_maps,
+ iterator_types,
+ doc=doc,
+ library_call=library_call,
+ loc=loc,
+ ip=ip,
+ )
+ element_types = [i.type.element_type for i in inputs] + [
+ o.type.element_type for o in outputs
+ ]
+ self.regions[0].blocks.append(*element_types)
+
+
+generic = region_op(GenericOp, terminator=YieldOp)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 72045a07b2da80..b7e0f2884bb249 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -84,6 +84,7 @@ def named_form(lhs, rhs):
print(module)
+
# CHECK-LABEL: TEST: testIdentityRegionOps
@run
def testIdentityRegionOps():
@@ -161,3 +162,62 @@ def broadcast_op(op1, op2, op3):
op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
print(module)
+
+
+# CHECK-LABEL: TEST: testGenericOp
+ at run
+def testGenericOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ id_map = AffineMap.get_identity(2)
+ # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32>
+ # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32>
+ x = tensor.empty((16, 16), f32)
+ y = tensor.empty((16, 16), f32)
+
+ # CHECK: %[[VAL_3:*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) {
+ # CHECK: ^bb0(%in: f32, %out: f32):
+ # CHECK: linalg.yield %in : f32
+ # CHECK: } -> tensor<16x16xf32>
+ @linalg.generic(
+ [x],
+ [y],
+ [id_map, id_map],
+ [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
+ )
+ def f(x, y):
+ return x
+
+ assert isinstance(f, Value)
+
+ # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32>
+ z = tensor.empty((16, 16, 16), f32)
+
+ minor_id = AffineMap.get_minor_identity(3, 2)
+ id_map = AffineMap.get_identity(3)
+
+ # CHECK: %%[[VAL_4:.*]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) {
+ # CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32):
+ # CHECK: linalg.yield %in, %out : f32, f32
+ # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>)
+ @linalg.generic(
+ [x],
+ [z, z],
+ [minor_id, id_map, id_map],
+ [
+ linalg.IteratorType.parallel,
+ linalg.IteratorType.parallel,
+ linalg.IteratorType.parallel,
+ ],
+ )
+ def g(x, z1, z2):
+ return x, z1
+
+ assert isinstance(g, OpResultList)
+ assert len(g) == 2
+ assert isinstance(g[0].type, RankedTensorType)
+ assert isinstance(g[1].type, RankedTensorType)
+
+ print(module)
More information about the Mlir-commits
mailing list