[Mlir-commits] [mlir] [mlir][linalg]Implement canonicalizer for `BroadCastOp` on tensors (PR #80466)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 2 09:52:07 PST 2024


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/80466

>From f51bb7b15be55e682be76f2289b991ed42ab4d41 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 2 Feb 2024 11:37:03 -0600
Subject: [PATCH 1/2] [mlir][linalg]Implement canonicalizer for
 `linalg::BroadCastOp` on tensors

---
 .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td  |  1 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp           | 14 ++++++++++++++
 2 files changed, 15 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 751edd0228830..11b6f50032c09 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -531,6 +531,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
 
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e86b9762d8581..cddb0671dd58f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1907,6 +1907,20 @@ void BroadcastOp::getEffects(
                         getDpsInits());
 }
 
+LogicalResult BroadcastOp::canonicalize(BroadcastOp op,
+                                        PatternRewriter &rewriter) {
+  // For tensor semantics, if op's input and init are same shape, it is a no op.
+  // Otherwise, with buffer semantics, the op does a copy and we don't
+  // canonicalize.
+  if (op.hasPureTensorSemantics() &&
+      (op.getInput().getType() == op.getInit().getType())) {
+    rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
+    rewriter.eraseOp(op);
+    return success();
+  }
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // YieldOp
 //===----------------------------------------------------------------------===//

>From c40476b2b7a186af7237a3fdc1599a129d65f749 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 2 Feb 2024 11:51:52 -0600
Subject: [PATCH 2/2] Add regression test

---
 mlir/test/Dialect/Linalg/canonicalize.mlir | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 052dc367ca677..a2777a035320f 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,3 +1017,15 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso
   %copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %copy : tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @broadcast_same_shape(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x3xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x3xf32>)
+//       CHECK-NOT:   linalg.broadcast
+//       CHECK:       return %[[ARG0]] : tensor<2x3xf32>
+func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) -> tensor<2x3xf32> {
+  %0 = linalg.broadcast ins(%input: tensor<2x3xf32>) outs(%init: tensor<2x3xf32>) dimensions = []
+  return %0 : tensor<2x3xf32>
+}
\ No newline at end of file



More information about the Mlir-commits mailing list