[Mlir-commits] [mlir] [mlir][linalg] regionBuilder for transpose, broadcast (PR #69742)
Maksim Levental
llvmlistbot at llvm.org
Fri Oct 20 11:49:27 PDT 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/69742
>From 55f260e883b33ab31bae22e56ec2ecb595e51886 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 20 Oct 2023 12:23:20 -0500
Subject: [PATCH] [mlir][linalg] regionBuilder for transpose, broadcast
---
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 16 +++-
mlir/python/mlir/dialects/linalg/__init__.py | 48 +++++++++++
mlir/test/python/dialects/linalg/ops.py | 79 +++++++++++++++++++
3 files changed, 141 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 21a5e5cc47aeb5c..751edd022883011 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -442,10 +442,16 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
+ static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+ mlir::ArrayRef<mlir::NamedAttribute>) {
+ OpBuilder::InsertionGuard guard(b);
+ b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
+ }
+
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
- return nullptr;
+ return regionBuilder;
}
static void createRegion(::mlir::OpBuilder &opBuilder,
@@ -510,10 +516,16 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
+ static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+ mlir::ArrayRef<mlir::NamedAttribute>) {
+ OpBuilder::InsertionGuard guard(b);
+ b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
+ }
+
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
- return nullptr;
+ return regionBuilder;
}
}];
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 1353870ec7257a9..6e4cb1bd6267120 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -55,3 +55,51 @@
# TODO: guard against surprises and fail create Runtime Custom Ops with
# the same name as existing Core Named Ops.
from .opdsl.ops.core_named_ops import *
+from .opdsl.lang.emitter import isa
+
+from ...ir import *
+from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+
+
+def transpose(
+ input: Union[Operation, OpView, Sequence[Value]],
+ *,
+ outs: List[Union[Operation, OpView, Sequence[Value]]],
+ permutation: Union[DenseI64ArrayAttr, List[int]],
+):
+ input = _get_op_result_or_value(input)
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isa(RankedTensorType, init.type) else []
+
+ op = TransposeOp(
+ result=result_types,
+ input=input,
+ init=init,
+ permutation=permutation,
+ )
+ fill_builtin_region(op.operation)
+ return op
+
+
+def broadcast(
+ input: Union[Operation, OpView, Sequence[Value]],
+ *,
+ outs: List[Union[Operation, OpView, Sequence[Value]]],
+ dimensions: Union[DenseI64ArrayAttr, List[int]],
+):
+ input = _get_op_result_or_value(input)
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isa(RankedTensorType, init.type) else []
+
+ op = BroadcastOp(
+ result=result_types,
+ input=input,
+ init=init,
+ dimensions=dimensions,
+ )
+ fill_builtin_region(op.operation)
+ return op
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index b728e0083781492..b147551c2e73dbd 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -157,3 +157,82 @@ def pass_an_op_directly(arg0, arg1):
return linalg.matmul(lhs, rhs, outs=init)
print(module)
+
+
+# CHECK-LABEL: TEST: testIdentityRegionOps
+ at run
+def testIdentityRegionOps():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x13xf32>
+ # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<13x1xf32>
+ op1 = tensor.EmptyOp([1, 13], f32)
+ op2 = tensor.EmptyOp([13, 1], f32)
+ # CHECK: %[[VAL_2:.*]] = linalg.transpose ins(%[[VAL_0]] : tensor<1x13xf32>) outs(%[[VAL_1]] : tensor<13x1xf32>) permutation = [1, 0]
+ op3 = linalg.TransposeOp(
+ result=[RankedTensorType.get((13, 1), f32)],
+ input=op1,
+ init=op2,
+ permutation=[1, 0],
+ )
+ linalg.fill_builtin_region(op3.operation)
+
+ # CHECK: %[[VAL_3:.*]] = linalg.transpose ins(%[[VAL_1]] : tensor<13x1xf32>) outs(%[[VAL_0]] : tensor<1x13xf32>) permutation = [1, 0]
+ op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
+
+ # CHECK: func.func @transpose_op(%[[VAL_4:.*]]: memref<1x13xf32>, %[[VAL_5:.*]]: memref<13x1xf32>)
+ @func.FuncOp.from_py_func(
+ MemRefType.get((1, 13), f32),
+ MemRefType.get((13, 1), f32),
+ )
+ def transpose_op(op1, op2):
+ # CHECK: linalg.transpose ins(%[[VAL_4]] : memref<1x13xf32>) outs(%[[VAL_5]] : memref<13x1xf32>) permutation = [1, 0]
+ op3 = linalg.TransposeOp(
+ result=[],
+ input=op1,
+ init=op2,
+ permutation=[1, 0],
+ )
+ linalg.fill_builtin_region(op3.operation)
+ # CHECK: linalg.transpose ins(%[[VAL_5]] : memref<13x1xf32>) outs(%[[VAL_4]] : memref<1x13xf32>) permutation = [1, 0]
+ op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
+
+ # CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<16xf32>
+ # CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<16x64xf32>
+ op1 = tensor.EmptyOp([16], f32)
+ op2 = tensor.EmptyOp([16, 64], f32)
+ # CHECK: %[[VAL_8:.*]] = linalg.broadcast ins(%[[VAL_6]] : tensor<16xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [1]
+ op3 = linalg.BroadcastOp(
+ result=[RankedTensorType.get((16, 64), f32)],
+ input=op1,
+ init=op2,
+ dimensions=[1],
+ )
+ linalg.fill_builtin_region(op3.operation)
+
+ # CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<64xf32>
+ op4 = tensor.EmptyOp([64], f32)
+ # CHECK: %[[VAL_10:.*]] = linalg.broadcast ins(%[[VAL_9]] : tensor<64xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [0]
+ op5 = linalg.broadcast(op4, outs=[op2], dimensions=[0])
+
+ # CHECK: func.func @broadcast_op(%[[VAL_11:.*]]: memref<16xf32>, %[[VAL_12:.*]]: memref<16x64xf32>, %[[VAL_13:.*]]: memref<64xf32>)
+ @func.FuncOp.from_py_func(
+ MemRefType.get((16,), f32),
+ MemRefType.get((16, 64), f32),
+ MemRefType.get((64,), f32),
+ )
+ def broadcast_op(op1, op2, op3):
+ # CHECK: linalg.broadcast ins(%[[VAL_11]] : memref<16xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [1]
+ op4 = linalg.BroadcastOp(
+ result=[],
+ input=op1,
+ init=op2,
+ dimensions=[1],
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.broadcast ins(%[[VAL_13]] : memref<64xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [0]
+ op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
+
+ print(module)
More information about the Mlir-commits
mailing list