[Mlir-commits] [mlir] [MLIR][Linalg][Python] Improve bindings for linalg.elementwise (PR #139462)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun May 11 12:26:06 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Rolf Morel (rolfmorel)
<details>
<summary>Changes</summary>
Adds wrappers for ElementWiseOp, in particular to ensure appropriate default indexing maps are derived.
---
Full diff: https://github.com/llvm/llvm-project/pull/139462.diff
2 Files Affected:
- (modified) mlir/python/mlir/dialects/linalg/__init__.py (+62)
- (modified) mlir/test/python/dialects/linalg/ops.py (+186)
``````````diff
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 63586a5bb8bbb..6049e2ba923ba 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -216,6 +216,68 @@ def contract(
)
+# Extend and shadow the TableGen-derived version to make sure correct default
+# indexing_maps are derived (as there is no mechanism for doing so given the
+# Python API bypasses the C++-builders).
+class ElementWiseOp_(ElementwiseOp):
+ def __init__(
+ self,
+ result_tensors,
+ inputs,
+ outputs,
+ kind,
+ indexing_maps=None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if indexing_maps is None:
+ inputs = [_get_op_result_or_value(in_) for in_ in inputs]
+ num_args = len(inputs)
+ for in0, in1 in zip(inputs[:-1], inputs[1:]):
+ assert in0.type == in1.type
+ if outputs:
+ outputs = [_get_op_result_or_value(out) for out in outputs]
+ num_args += 1
+ assert inputs[0].type == outputs[0].type
+ indexing_maps = [ir.AffineMap.get_identity(inputs[0].type.rank)] * num_args
+ super().__init__(
+ result_tensors=result_tensors,
+ inputs=inputs,
+ outputs=outputs,
+ kind=kind,
+ indexing_maps=indexing_maps,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ElementwiseOp = ElementWiseOp_
+
+
+def elementwise(
+ *ins: Union[Operation, OpView, Value],
+ outs: Sequence[Union[Operation, OpView, Value]],
+ kind: Union[ElementwiseKind, Attribute],
+ indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
+):
+ ins = [_get_op_result_or_value(input) for input in ins]
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length at most 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
+
+ op = ElementwiseOp(
+ result_tensors=result_types,
+ inputs=ins,
+ outputs=[init],
+ kind=kind,
+ indexing_maps=indexing_maps,
+ )
+ fill_builtin_region(op.operation)
+ return _get_op_result_or_op_results(op)
+
+
def pack(
source,
dest,
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index e32a911b24b11..5a163474210a6 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -606,3 +606,189 @@ def tensor_pack(src, dst):
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
# CHECK: }
print(module)
+
+
+# CHECK-LABEL: TEST: testElementwiseOp
+ at run
+def testElementwiseOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ rect_shape = (8, 16)
+ vert_line_shape = (8,)
+ hor_line_shape = (16,)
+ transposed_rect_shape = (16, 8)
+
+ # CHECK-DAG: #[[$IdentMap2D:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+ # CHECK-DAG: #[[$TransMap2D:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+ # CHECK-DAG: #[[$VertLineBCastMap:.*]] = affine_map<(d0, d1) -> (d0)>
+ # CHECK-DAG: #[[$HorLineBCastMap:.*]] = affine_map<(d0, d1) -> (d1)>
+
+ ident_map_2d = AffineMap.get_identity(2)
+ transposed_map_2d = AffineMap.get_permutation((1, 0))
+ vert_line_bcast_map = AffineMap.get(2, 0, [AffineDimExpr.get(0)])
+ hor_line_bcast_map = AffineMap.get(2, 0, [AffineDimExpr.get(1)])
+
+ # CHECK: func.func @elementwise_op(
+ @func.FuncOp.from_py_func(
+ # CHECK-SAME: %[[Rect:.*]]: tensor<8x16xf32>,
+ RankedTensorType.get(rect_shape, f32),
+ # CHECK-SAME: %[[RectMem:.*]]: memref<8x16xf32>,
+ MemRefType.get(rect_shape, f32),
+ # CHECK-SAME: %[[VertLine:.*]]: tensor<8xf32>,
+ RankedTensorType.get(vert_line_shape, f32),
+ # CHECK-SAME: %[[VertLineMem:.*]]: memref<8xf32>,
+ MemRefType.get(vert_line_shape, f32),
+ # CHECK-SAME: %[[HorLine:.*]]: tensor<16xf32>,
+ RankedTensorType.get(hor_line_shape, f32),
+ # CHECK-SAME: %[[HorLineMem:.*]]: memref<16xf32>,
+ MemRefType.get(hor_line_shape, f32),
+ # CHECK-SAME: %[[TransRect:.*]]: tensor<16x8xf32>,
+ RankedTensorType.get(transposed_rect_shape, f32),
+ # CHECK-SAME: %[[TransRectMem:.*]]: memref<16x8xf32>)
+ MemRefType.get(transposed_rect_shape, f32),
+ )
+ def elementwise_op(
+ rect,
+ rect_mem,
+ vert_line,
+ vert_line_mem,
+ hor_line,
+ hor_line_mem,
+ trans_rect,
+ trans_rect_mem,
+ ):
+ # CHECK: %[[OutRect:.*]] = tensor.empty() : tensor<8x16xf32>
+ out_rect = tensor.EmptyOp(rect_shape, f32)
+ # CHECK: %[[OutRectMem:.*]] = memref.alloca() : memref<8x16xf32>
+ out_rect_mem = memref.alloca(MemRefType.get(rect_shape, f32), [], [])
+
+ if _inferred_affine_maps := True:
+ # CHECK: linalg.elementwise
+ # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+ # CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
+ op1 = linalg.ElementwiseOp(
+ result_tensors=(out_rect.result.type,),
+ inputs=(rect,),
+ outputs=(out_rect,),
+ kind=linalg.ElementwiseKind.exp,
+ )
+ linalg.fill_builtin_region(op1.operation)
+
+ # CHECK: linalg.elementwise
+ # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+ # CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
+ linalg.elementwise(
+ rect,
+ outs=(out_rect,),
+ kind=linalg.ElementwiseKind.exp,
+ )
+
+ # CHECK: linalg.elementwise
+ # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+ # CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>)
+ # CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
+ linalg.elementwise(
+ rect_mem,
+ outs=(out_rect_mem,),
+ kind=linalg.ElementwiseKind.exp,
+ )
+
+ if _explicit_ident_affine_maps := True:
+ # Same as above but with default identity indexing_maps explicitly provided.
+ # CHECK: linalg.elementwise
+ # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+ # CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
+ op3 = linalg.ElementwiseOp(
+ result_tensors=(out_rect.result.type,),
+ inputs=(rect,),
+ outputs=(out_rect,),
+ kind=linalg.ElementwiseKind.exp,
+ indexing_maps=[ident_map_2d, ident_map_2d],
+ )
+ linalg.fill_builtin_region(op3.operation)
+
+ # CHECK: linalg.elementwise
+ # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
+ # CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>)
+ # CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
+ linalg.elementwise(
+ rect_mem,
+ outs=(out_rect_mem,),
+ kind=linalg.ElementwiseKind.exp,
+ indexing_maps=[ident_map_2d, ident_map_2d],
+ )
+
+ if _ops_with_non_ident_input_maps := True:
+ # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<exp>
+ # CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$IdentMap2D]]]
+ # CHECK-SAME: ins(%[[VertLine]] : tensor<8xf32>)
+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
+ op4 = linalg.ElementwiseOp(
+ result_tensors=(out_rect.result.type,),
+ inputs=(vert_line,),
+ outputs=(out_rect,),
+ kind=linalg.ElementwiseKind.exp,
+ indexing_maps=[vert_line_bcast_map, ident_map_2d],
+ )
+ linalg.fill_builtin_region(op4.operation)
+
+ # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
+ # CHECK-SAME: indexing_maps = [#[[$IdentMap2D]], #[[$VertLineBCastMap]], #[[$IdentMap2D]]]
+ # CHECK-SAME: ins(%[[Rect]], %[[VertLine]] : tensor<8x16xf32>, tensor<8xf32>)
+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
+ op4 = linalg.ElementwiseOp(
+ result_tensors=(out_rect.result.type,),
+ inputs=(rect, vert_line),
+ outputs=(out_rect,),
+ kind=linalg.ElementwiseKind.add,
+ indexing_maps=[ident_map_2d, vert_line_bcast_map, ident_map_2d],
+ )
+ linalg.fill_builtin_region(op4.operation)
+
+ # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<div>
+ # CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$IdentMap2D]]]
+ # CHECK-SAME: ins(%[[VertLine]], %[[HorLine]] : tensor<8xf32>, tensor<16xf32>)
+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
+ linalg.elementwise(
+ vert_line,
+ hor_line,
+ outs=(out_rect,),
+ kind=linalg.ElementwiseKind.div,
+ indexing_maps=[
+ vert_line_bcast_map,
+ hor_line_bcast_map,
+ ident_map_2d,
+ ],
+ )
+
+ if _ops_with_non_ident_and_transposed_input_maps := True:
+ # CHECK: %[[VertLineBoolsMem:.*]] = memref.alloca() : memref<8xi1>
+ vert_line_bools_mem = memref.alloca(
+ MemRefType.get(vert_line_shape, IntegerType.get_signless(1)),
+ [],
+ [],
+ )
+ # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<select>
+ # CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$TransMap2D]], #[[$IdentMap2D]]]
+ # CHECK-SAME: ins(%[[VertLineBoolsMem]], %[[HorLineMem]], %[[TransRectMem]] : memref<8xi1>, memref<16xf32>, memref<16x8xf32>)
+ # CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
+ linalg.elementwise(
+ vert_line_bools_mem,
+ hor_line_mem,
+ trans_rect_mem,
+ outs=(out_rect_mem,),
+ kind=linalg.ElementwiseKind.select,
+ indexing_maps=[
+ vert_line_bcast_map,
+ hor_line_bcast_map,
+ transposed_map_2d,
+ ident_map_2d,
+ ],
+ )
+
+ print(module)
``````````
</details>
https://github.com/llvm/llvm-project/pull/139462
More information about the Mlir-commits
mailing list