[Mlir-commits] [mlir] a969404 - [mlir][linalg] regionBuilder for transpose, broadcast (#69742)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 20 14:14:51 PDT 2023
Author: Maksim Levental
Date: 2023-10-20T16:14:46-05:00
New Revision: a9694043c9b8625fbe0d1a34bc5afadf380cda97
URL: https://github.com/llvm/llvm-project/commit/a9694043c9b8625fbe0d1a34bc5afadf380cda97
DIFF: https://github.com/llvm/llvm-project/commit/a9694043c9b8625fbe0d1a34bc5afadf380cda97.diff
LOG: [mlir][linalg] regionBuilder for transpose, broadcast (#69742)
Currently, `linalg.transpose` and `linalg.broadcast` can't be emitted
through either the C API or the python bindings (which of course go
through the C API). See
https://discourse.llvm.org/t/how-to-build-linalg-transposeop-in-mlir-pybind/73989/10.
The reason is even though they're named ops, there is no opdsl
`@linalg_structured_op` for them and thus while they can be instantiated
they cannot be passed to
[`mlirLinalgFillBuiltinNamedOpRegion`](https://github.com/llvm/llvm-project/blob/a7cccb9cbb2b9954684cbea37615303a59719973/mlir/lib/CAPI/Dialect/Linalg.cpp#L18).
I believe the issue is they both take a `IndexAttrDef` but
`IndexAttrDef` cannot represent dynamic rank. Note, if I'm mistaken and
there is a way to write the `@linalg_structured_op` let me know.
The solution here simply implements the `regionBuilder` interface which
is then picked up by
[`LinalgDialect::addNamedOpBuilders`](https://github.com/llvm/llvm-project/blob/7557530f428a2f226d8d925c33d527dfcfdcb0c5/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp#L116).
Extension classes are added "by hand" that mirror the API of the
`@linalg_structured_op`s. Note, the extension classes are added to to
`dialects/linalg/__init__.py` instead of
`dialects/linalg/opdsl/ops/core_named_ops.py` in order that they're not
confused for opdsl generators/emitters.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/python/mlir/dialects/linalg/__init__.py
mlir/test/python/dialects/linalg/ops.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 21a5e5cc47aeb5c..751edd022883011 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -442,10 +442,16 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
+ static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+ mlir::ArrayRef<mlir::NamedAttribute>) {
+ OpBuilder::InsertionGuard guard(b);
+ b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
+ }
+
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
- return nullptr;
+ return regionBuilder;
}
static void createRegion(::mlir::OpBuilder &opBuilder,
@@ -510,10 +516,16 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
+ static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+ mlir::ArrayRef<mlir::NamedAttribute>) {
+ OpBuilder::InsertionGuard guard(b);
+ b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
+ }
+
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
- return nullptr;
+ return regionBuilder;
}
}];
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 1353870ec7257a9..6e4cb1bd6267120 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -55,3 +55,51 @@
# TODO: guard against surprises and fail create Runtime Custom Ops with
# the same name as existing Core Named Ops.
from .opdsl.ops.core_named_ops import *
+from .opdsl.lang.emitter import isa
+
+from ...ir import *
+from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+
+
+def transpose(
+ input: Union[Operation, OpView, Sequence[Value]],
+ *,
+ outs: List[Union[Operation, OpView, Sequence[Value]]],
+ permutation: Union[DenseI64ArrayAttr, List[int]],
+):
+ input = _get_op_result_or_value(input)
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isa(RankedTensorType, init.type) else []
+
+ op = TransposeOp(
+ result=result_types,
+ input=input,
+ init=init,
+ permutation=permutation,
+ )
+ fill_builtin_region(op.operation)
+ return op
+
+
+def broadcast(
+ input: Union[Operation, OpView, Sequence[Value]],
+ *,
+ outs: List[Union[Operation, OpView, Sequence[Value]]],
+ dimensions: Union[DenseI64ArrayAttr, List[int]],
+):
+ input = _get_op_result_or_value(input)
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isa(RankedTensorType, init.type) else []
+
+ op = BroadcastOp(
+ result=result_types,
+ input=input,
+ init=init,
+ dimensions=dimensions,
+ )
+ fill_builtin_region(op.operation)
+ return op
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index b728e0083781492..b147551c2e73dbd 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -157,3 +157,82 @@ def pass_an_op_directly(arg0, arg1):
return linalg.matmul(lhs, rhs, outs=init)
print(module)
+
+
+# CHECK-LABEL: TEST: testIdentityRegionOps
+ at run
+def testIdentityRegionOps():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x13xf32>
+ # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<13x1xf32>
+ op1 = tensor.EmptyOp([1, 13], f32)
+ op2 = tensor.EmptyOp([13, 1], f32)
+ # CHECK: %[[VAL_2:.*]] = linalg.transpose ins(%[[VAL_0]] : tensor<1x13xf32>) outs(%[[VAL_1]] : tensor<13x1xf32>) permutation = [1, 0]
+ op3 = linalg.TransposeOp(
+ result=[RankedTensorType.get((13, 1), f32)],
+ input=op1,
+ init=op2,
+ permutation=[1, 0],
+ )
+ linalg.fill_builtin_region(op3.operation)
+
+ # CHECK: %[[VAL_3:.*]] = linalg.transpose ins(%[[VAL_1]] : tensor<13x1xf32>) outs(%[[VAL_0]] : tensor<1x13xf32>) permutation = [1, 0]
+ op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
+
+ # CHECK: func.func @transpose_op(%[[VAL_4:.*]]: memref<1x13xf32>, %[[VAL_5:.*]]: memref<13x1xf32>)
+ @func.FuncOp.from_py_func(
+ MemRefType.get((1, 13), f32),
+ MemRefType.get((13, 1), f32),
+ )
+ def transpose_op(op1, op2):
+ # CHECK: linalg.transpose ins(%[[VAL_4]] : memref<1x13xf32>) outs(%[[VAL_5]] : memref<13x1xf32>) permutation = [1, 0]
+ op3 = linalg.TransposeOp(
+ result=[],
+ input=op1,
+ init=op2,
+ permutation=[1, 0],
+ )
+ linalg.fill_builtin_region(op3.operation)
+ # CHECK: linalg.transpose ins(%[[VAL_5]] : memref<13x1xf32>) outs(%[[VAL_4]] : memref<1x13xf32>) permutation = [1, 0]
+ op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
+
+ # CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<16xf32>
+ # CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<16x64xf32>
+ op1 = tensor.EmptyOp([16], f32)
+ op2 = tensor.EmptyOp([16, 64], f32)
+ # CHECK: %[[VAL_8:.*]] = linalg.broadcast ins(%[[VAL_6]] : tensor<16xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [1]
+ op3 = linalg.BroadcastOp(
+ result=[RankedTensorType.get((16, 64), f32)],
+ input=op1,
+ init=op2,
+ dimensions=[1],
+ )
+ linalg.fill_builtin_region(op3.operation)
+
+ # CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<64xf32>
+ op4 = tensor.EmptyOp([64], f32)
+ # CHECK: %[[VAL_10:.*]] = linalg.broadcast ins(%[[VAL_9]] : tensor<64xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [0]
+ op5 = linalg.broadcast(op4, outs=[op2], dimensions=[0])
+
+ # CHECK: func.func @broadcast_op(%[[VAL_11:.*]]: memref<16xf32>, %[[VAL_12:.*]]: memref<16x64xf32>, %[[VAL_13:.*]]: memref<64xf32>)
+ @func.FuncOp.from_py_func(
+ MemRefType.get((16,), f32),
+ MemRefType.get((16, 64), f32),
+ MemRefType.get((64,), f32),
+ )
+ def broadcast_op(op1, op2, op3):
+ # CHECK: linalg.broadcast ins(%[[VAL_11]] : memref<16xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [1]
+ op4 = linalg.BroadcastOp(
+ result=[],
+ input=op1,
+ init=op2,
+ dimensions=[1],
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.broadcast ins(%[[VAL_13]] : memref<64xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [0]
+ op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
+
+ print(module)
More information about the Mlir-commits
mailing list