[Mlir-commits] [mlir] [mlir][linalg] regionBuilder for transpose, broadcast (PR #69742)

Maksim Levental llvmlistbot at llvm.org
Fri Oct 20 11:49:27 PDT 2023


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/69742

>From 55f260e883b33ab31bae22e56ec2ecb595e51886 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 20 Oct 2023 12:23:20 -0500
Subject: [PATCH] [mlir][linalg] regionBuilder for transpose, broadcast

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  | 16 +++-
 mlir/python/mlir/dialects/linalg/__init__.py  | 48 +++++++++++
 mlir/test/python/dialects/linalg/ops.py       | 79 +++++++++++++++++++
 3 files changed, 141 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 21a5e5cc47aeb5c..751edd022883011 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -442,10 +442,16 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
     // Implement functions necessary for DestinationStyleOpInterface.
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
+    static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+        mlir::ArrayRef<mlir::NamedAttribute>) {
+      OpBuilder::InsertionGuard guard(b);
+      b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
+    }
+
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
         mlir::ArrayRef<mlir::NamedAttribute>)>
       getRegionBuilder() {
-      return nullptr;
+      return regionBuilder;
     }
 
     static void createRegion(::mlir::OpBuilder &opBuilder,
@@ -510,10 +516,16 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
     // Implement functions necessary for DestinationStyleOpInterface.
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
+    static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+        mlir::ArrayRef<mlir::NamedAttribute>) {
+      OpBuilder::InsertionGuard guard(b);
+      b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
+    }
+
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
         mlir::ArrayRef<mlir::NamedAttribute>)>
       getRegionBuilder() {
-      return nullptr;
+      return regionBuilder;
     }
   }];
 
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 1353870ec7257a9..6e4cb1bd6267120 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -55,3 +55,51 @@
 #     TODO: guard against surprises and fail create Runtime Custom Ops with
 #     the same name as existing Core Named Ops.
 from .opdsl.ops.core_named_ops import *
+from .opdsl.lang.emitter import isa
+
+from ...ir import *
+from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+
+
+def transpose(
+    input: Union[Operation, OpView, Sequence[Value]],
+    *,
+    outs: List[Union[Operation, OpView, Sequence[Value]]],
+    permutation: Union[DenseI64ArrayAttr, List[int]],
+):
+    input = _get_op_result_or_value(input)
+    if len(outs) > 1:
+        raise ValueError(f"{outs=} must have length 1.")
+    init = _get_op_result_or_value(outs[0])
+    result_types = [init.type] if isa(RankedTensorType, init.type) else []
+
+    op = TransposeOp(
+        result=result_types,
+        input=input,
+        init=init,
+        permutation=permutation,
+    )
+    fill_builtin_region(op.operation)
+    return op
+
+
+def broadcast(
+    input: Union[Operation, OpView, Sequence[Value]],
+    *,
+    outs: List[Union[Operation, OpView, Sequence[Value]]],
+    dimensions: Union[DenseI64ArrayAttr, List[int]],
+):
+    input = _get_op_result_or_value(input)
+    if len(outs) > 1:
+        raise ValueError(f"{outs=} must have length 1.")
+    init = _get_op_result_or_value(outs[0])
+    result_types = [init.type] if isa(RankedTensorType, init.type) else []
+
+    op = BroadcastOp(
+        result=result_types,
+        input=input,
+        init=init,
+        dimensions=dimensions,
+    )
+    fill_builtin_region(op.operation)
+    return op
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index b728e0083781492..b147551c2e73dbd 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -157,3 +157,82 @@ def pass_an_op_directly(arg0, arg1):
                 return linalg.matmul(lhs, rhs, outs=init)
 
     print(module)
+
+
+# CHECK-LABEL: TEST: testIdentityRegionOps
+ at run
+def testIdentityRegionOps():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x13xf32>
+            # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<13x1xf32>
+            op1 = tensor.EmptyOp([1, 13], f32)
+            op2 = tensor.EmptyOp([13, 1], f32)
+            # CHECK: %[[VAL_2:.*]] = linalg.transpose ins(%[[VAL_0]] : tensor<1x13xf32>) outs(%[[VAL_1]] : tensor<13x1xf32>) permutation = [1, 0]
+            op3 = linalg.TransposeOp(
+                result=[RankedTensorType.get((13, 1), f32)],
+                input=op1,
+                init=op2,
+                permutation=[1, 0],
+            )
+            linalg.fill_builtin_region(op3.operation)
+
+            # CHECK: %[[VAL_3:.*]] = linalg.transpose ins(%[[VAL_1]] : tensor<13x1xf32>) outs(%[[VAL_0]] : tensor<1x13xf32>) permutation = [1, 0]
+            op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
+
+            # CHECK:         func.func @transpose_op(%[[VAL_4:.*]]: memref<1x13xf32>, %[[VAL_5:.*]]: memref<13x1xf32>)
+            @func.FuncOp.from_py_func(
+                MemRefType.get((1, 13), f32),
+                MemRefType.get((13, 1), f32),
+            )
+            def transpose_op(op1, op2):
+                # CHECK: linalg.transpose ins(%[[VAL_4]] : memref<1x13xf32>) outs(%[[VAL_5]] : memref<13x1xf32>) permutation = [1, 0]
+                op3 = linalg.TransposeOp(
+                    result=[],
+                    input=op1,
+                    init=op2,
+                    permutation=[1, 0],
+                )
+                linalg.fill_builtin_region(op3.operation)
+                # CHECK: linalg.transpose ins(%[[VAL_5]] : memref<13x1xf32>) outs(%[[VAL_4]] : memref<1x13xf32>) permutation = [1, 0]
+                op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
+
+            # CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<16xf32>
+            # CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<16x64xf32>
+            op1 = tensor.EmptyOp([16], f32)
+            op2 = tensor.EmptyOp([16, 64], f32)
+            # CHECK: %[[VAL_8:.*]] = linalg.broadcast ins(%[[VAL_6]] : tensor<16xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [1]
+            op3 = linalg.BroadcastOp(
+                result=[RankedTensorType.get((16, 64), f32)],
+                input=op1,
+                init=op2,
+                dimensions=[1],
+            )
+            linalg.fill_builtin_region(op3.operation)
+
+            # CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<64xf32>
+            op4 = tensor.EmptyOp([64], f32)
+            # CHECK: %[[VAL_10:.*]] = linalg.broadcast ins(%[[VAL_9]] : tensor<64xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [0]
+            op5 = linalg.broadcast(op4, outs=[op2], dimensions=[0])
+
+            # CHECK: func.func @broadcast_op(%[[VAL_11:.*]]: memref<16xf32>, %[[VAL_12:.*]]: memref<16x64xf32>, %[[VAL_13:.*]]: memref<64xf32>)
+            @func.FuncOp.from_py_func(
+                MemRefType.get((16,), f32),
+                MemRefType.get((16, 64), f32),
+                MemRefType.get((64,), f32),
+            )
+            def broadcast_op(op1, op2, op3):
+                # CHECK: linalg.broadcast ins(%[[VAL_11]] : memref<16xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [1]
+                op4 = linalg.BroadcastOp(
+                    result=[],
+                    input=op1,
+                    init=op2,
+                    dimensions=[1],
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.broadcast ins(%[[VAL_13]] : memref<64xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [0]
+                op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
+
+    print(module)



More information about the Mlir-commits mailing list