[Mlir-commits] [mlir] [MLIR][Linalg][Python] Improve bindings for linalg.elementwise (PR #139462)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun May 11 12:26:06 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

<details>
<summary>Changes</summary>

Adds wrappers for ElementWiseOp, in particular to ensure appropriate default indexing maps are derived.

---
Full diff: https://github.com/llvm/llvm-project/pull/139462.diff


2 Files Affected:

- (modified) mlir/python/mlir/dialects/linalg/__init__.py (+62) 
- (modified) mlir/test/python/dialects/linalg/ops.py (+186) 


``````````diff
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 63586a5bb8bbb..6049e2ba923ba 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -216,6 +216,68 @@ def contract(
     )
 
 
+# Extend and shadow the TableGen-derived version to make sure correct default
+# indexing_maps are derived (as there is no mechanism for doing so given the
+# Python API bypasses the C++-builders).
+class ElementWiseOp_(ElementwiseOp):
+    def __init__(
+        self,
+        result_tensors,
+        inputs,
+        outputs,
+        kind,
+        indexing_maps=None,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        if indexing_maps is None:
+            inputs = [_get_op_result_or_value(in_) for in_ in inputs]
+            num_args = len(inputs)
+            for in0, in1 in zip(inputs[:-1], inputs[1:]):
+                assert in0.type == in1.type
+            if outputs:
+                outputs = [_get_op_result_or_value(out) for out in outputs]
+                num_args += 1
+                assert inputs[0].type == outputs[0].type
+            indexing_maps = [ir.AffineMap.get_identity(inputs[0].type.rank)] * num_args
+        super().__init__(
+            result_tensors=result_tensors,
+            inputs=inputs,
+            outputs=outputs,
+            kind=kind,
+            indexing_maps=indexing_maps,
+            loc=loc,
+            ip=ip,
+        )
+
+
+ElementwiseOp = ElementWiseOp_
+
+
+def elementwise(
+    *ins: Union[Operation, OpView, Value],
+    outs: Sequence[Union[Operation, OpView, Value]],
+    kind: Union[ElementwiseKind, Attribute],
+    indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
+):
+    ins = [_get_op_result_or_value(input) for input in ins]
+    if len(outs) > 1:
+        raise ValueError(f"{outs=} must have length at most 1.")
+    init = _get_op_result_or_value(outs[0])
+    result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+    op = ElementwiseOp(
+        result_tensors=result_types,
+        inputs=ins,
+        outputs=[init],
+        kind=kind,
+        indexing_maps=indexing_maps,
+    )
+    fill_builtin_region(op.operation)
+    return _get_op_result_or_op_results(op)
+
+
 def pack(
     source,
     dest,
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index e32a911b24b11..5a163474210a6 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -606,3 +606,189 @@ def tensor_pack(src, dst):
         # CHECK:           return %[[VAL_4]] : tensor<128x128xf32>
         # CHECK:         }
         print(module)
+
+
+# CHECK-LABEL: TEST: testElementwiseOp
+ at run
+def testElementwiseOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            rect_shape = (8, 16)
+            vert_line_shape = (8,)
+            hor_line_shape = (16,)
+            transposed_rect_shape = (16, 8)
+
+            # CHECK-DAG: #[[$IdentMap2D:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+            # CHECK-DAG: #[[$TransMap2D:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+            # CHECK-DAG: #[[$VertLineBCastMap:.*]] = affine_map<(d0, d1) -> (d0)>
+            # CHECK-DAG: #[[$HorLineBCastMap:.*]] = affine_map<(d0, d1) -> (d1)>
+
+            ident_map_2d = AffineMap.get_identity(2)
+            transposed_map_2d = AffineMap.get_permutation((1, 0))
+            vert_line_bcast_map = AffineMap.get(2, 0, [AffineDimExpr.get(0)])
+            hor_line_bcast_map = AffineMap.get(2, 0, [AffineDimExpr.get(1)])
+
+            # CHECK: func.func @elementwise_op(
+            @func.FuncOp.from_py_func(
+                # CHECK-SAME:                         %[[Rect:.*]]: tensor<8x16xf32>,
+                RankedTensorType.get(rect_shape, f32),
+                # CHECK-SAME:                         %[[RectMem:.*]]: memref<8x16xf32>,
+                MemRefType.get(rect_shape, f32),
+                # CHECK-SAME:                         %[[VertLine:.*]]: tensor<8xf32>,
+                RankedTensorType.get(vert_line_shape, f32),
+                # CHECK-SAME:                         %[[VertLineMem:.*]]: memref<8xf32>,
+                MemRefType.get(vert_line_shape, f32),
+                # CHECK-SAME:                         %[[HorLine:.*]]: tensor<16xf32>,
+                RankedTensorType.get(hor_line_shape, f32),
+                # CHECK-SAME:                         %[[HorLineMem:.*]]: memref<16xf32>,
+                MemRefType.get(hor_line_shape, f32),
+                # CHECK-SAME:                         %[[TransRect:.*]]: tensor<16x8xf32>,
+                RankedTensorType.get(transposed_rect_shape, f32),
+                # CHECK-SAME:                         %[[TransRectMem:.*]]: memref<16x8xf32>)
+                MemRefType.get(transposed_rect_shape, f32),
+            )
+            def elementwise_op(
+                rect,
+                rect_mem,
+                vert_line,
+                vert_line_mem,
+                hor_line,
+                hor_line_mem,
+                trans_rect,
+                trans_rect_mem,
+            ):
+                # CHECK: %[[OutRect:.*]] = tensor.empty() : tensor<8x16xf32>
+                out_rect = tensor.EmptyOp(rect_shape, f32)
+                # CHECK: %[[OutRectMem:.*]] = memref.alloca() : memref<8x16xf32>
+                out_rect_mem = memref.alloca(MemRefType.get(rect_shape, f32), [], [])
+
+                if _inferred_affine_maps := True:
+                    # CHECK: linalg.elementwise
+                    # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+                    # CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
+                    # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
+                    op1 = linalg.ElementwiseOp(
+                        result_tensors=(out_rect.result.type,),
+                        inputs=(rect,),
+                        outputs=(out_rect,),
+                        kind=linalg.ElementwiseKind.exp,
+                    )
+                    linalg.fill_builtin_region(op1.operation)
+
+                    # CHECK: linalg.elementwise
+                    # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+                    # CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
+                    # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
+                    linalg.elementwise(
+                        rect,
+                        outs=(out_rect,),
+                        kind=linalg.ElementwiseKind.exp,
+                    )
+
+                    # CHECK: linalg.elementwise
+                    # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+                    # CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>)
+                    # CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
+                    linalg.elementwise(
+                        rect_mem,
+                        outs=(out_rect_mem,),
+                        kind=linalg.ElementwiseKind.exp,
+                    )
+
+                if _explicit_ident_affine_maps := True:
+                    # Same as above but with default identity indexing_maps explicitly provided.
+                    # CHECK: linalg.elementwise
+                    # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+                    # CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
+                    # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
+                    op3 = linalg.ElementwiseOp(
+                        result_tensors=(out_rect.result.type,),
+                        inputs=(rect,),
+                        outputs=(out_rect,),
+                        kind=linalg.ElementwiseKind.exp,
+                        indexing_maps=[ident_map_2d, ident_map_2d],
+                    )
+                    linalg.fill_builtin_region(op3.operation)
+
+                    # CHECK: linalg.elementwise
+                    # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+                    # CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>)
+                    # CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
+                    linalg.elementwise(
+                        rect_mem,
+                        outs=(out_rect_mem,),
+                        kind=linalg.ElementwiseKind.exp,
+                        indexing_maps=[ident_map_2d, ident_map_2d],
+                    )
+
+                if _ops_with_non_ident_input_maps := True:
+                    # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<exp>
+                    # CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$IdentMap2D]]]
+                    # CHECK-SAME: ins(%[[VertLine]] : tensor<8xf32>)
+                    # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
+                    op4 = linalg.ElementwiseOp(
+                        result_tensors=(out_rect.result.type,),
+                        inputs=(vert_line,),
+                        outputs=(out_rect,),
+                        kind=linalg.ElementwiseKind.exp,
+                        indexing_maps=[vert_line_bcast_map, ident_map_2d],
+                    )
+                    linalg.fill_builtin_region(op4.operation)
+
+                    # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
+                    # CHECK-SAME: indexing_maps = [#[[$IdentMap2D]], #[[$VertLineBCastMap]], #[[$IdentMap2D]]]
+                    # CHECK-SAME: ins(%[[Rect]], %[[VertLine]] : tensor<8x16xf32>, tensor<8xf32>)
+                    # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
+                    op4 = linalg.ElementwiseOp(
+                        result_tensors=(out_rect.result.type,),
+                        inputs=(rect, vert_line),
+                        outputs=(out_rect,),
+                        kind=linalg.ElementwiseKind.add,
+                        indexing_maps=[ident_map_2d, vert_line_bcast_map, ident_map_2d],
+                    )
+                    linalg.fill_builtin_region(op4.operation)
+
+                    # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<div>
+                    # CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$IdentMap2D]]]
+                    # CHECK-SAME: ins(%[[VertLine]], %[[HorLine]] : tensor<8xf32>, tensor<16xf32>)
+                    # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
+                    linalg.elementwise(
+                        vert_line,
+                        hor_line,
+                        outs=(out_rect,),
+                        kind=linalg.ElementwiseKind.div,
+                        indexing_maps=[
+                            vert_line_bcast_map,
+                            hor_line_bcast_map,
+                            ident_map_2d,
+                        ],
+                    )
+
+                if _ops_with_non_ident_and_transposed_input_maps := True:
+                    # CHECK: %[[VertLineBoolsMem:.*]] = memref.alloca() : memref<8xi1>
+                    vert_line_bools_mem = memref.alloca(
+                        MemRefType.get(vert_line_shape, IntegerType.get_signless(1)),
+                        [],
+                        [],
+                    )
+                    # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<select>
+                    # CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$TransMap2D]], #[[$IdentMap2D]]]
+                    # CHECK-SAME: ins(%[[VertLineBoolsMem]], %[[HorLineMem]], %[[TransRectMem]] : memref<8xi1>, memref<16xf32>, memref<16x8xf32>)
+                    # CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
+                    linalg.elementwise(
+                        vert_line_bools_mem,
+                        hor_line_mem,
+                        trans_rect_mem,
+                        outs=(out_rect_mem,),
+                        kind=linalg.ElementwiseKind.select,
+                        indexing_maps=[
+                            vert_line_bcast_map,
+                            hor_line_bcast_map,
+                            transposed_map_2d,
+                            ident_map_2d,
+                        ],
+                    )
+
+        print(module)

``````````

</details>


https://github.com/llvm/llvm-project/pull/139462


More information about the Mlir-commits mailing list