[Mlir-commits] [mlir] [mlir][linalg] Refactor `EraseIdentityGenericOp` to be reused by other `LinalgOp`s (PR #80466)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 12 08:20:38 PST 2024


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/80466

>From f51bb7b15be55e682be76f2289b991ed42ab4d41 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 2 Feb 2024 11:37:03 -0600
Subject: [PATCH 1/3] [mlir][linalg]Implement canonicalizer for
 `linalg::BroadCastOp` on tensors

---
 .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td  |  1 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp           | 14 ++++++++++++++
 2 files changed, 15 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 751edd02288301..11b6f50032c093 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -531,6 +531,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
 
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e86b9762d8581f..cddb0671dd58f9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1907,6 +1907,20 @@ void BroadcastOp::getEffects(
                         getDpsInits());
 }
 
+LogicalResult BroadcastOp::canonicalize(BroadcastOp op,
+                                        PatternRewriter &rewriter) {
+  // For tensor semantics, if op's input and init are same shape, it is a no op.
+  // Otherwise, with buffer semantics, the op does a copy and we don't
+  // canonicalize.
+  if (op.hasPureTensorSemantics() &&
+      (op.getInput().getType() == op.getInit().getType())) {
+    rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
+    rewriter.eraseOp(op);
+    return success();
+  }
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // YieldOp
 //===----------------------------------------------------------------------===//

>From c40476b2b7a186af7237a3fdc1599a129d65f749 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 2 Feb 2024 11:51:52 -0600
Subject: [PATCH 2/3] Add regression test

---
 mlir/test/Dialect/Linalg/canonicalize.mlir | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 052dc367ca6779..a2777a035320fa 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,3 +1017,15 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso
   %copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %copy : tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @broadcast_same_shape(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x3xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x3xf32>)
+//       CHECK-NOT:   linalg.broadcast
+//       CHECK:       return %[[ARG0]] : tensor<2x3xf32>
+func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) -> tensor<2x3xf32> {
+  %0 = linalg.broadcast ins(%input: tensor<2x3xf32>) outs(%init: tensor<2x3xf32>) dimensions = []
+  return %0 : tensor<2x3xf32>
+}
\ No newline at end of file

>From 458e93a3a6cf1f4b28984ff9d1d7da2e9ce60a30 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 2 Feb 2024 19:24:07 -0600
Subject: [PATCH 3/3] Refactor EraseIdentityGenericOp to be reused by any
 LinalgOp

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  2 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 54 ++++++++-----------
 2 files changed, 24 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 11b6f50032c093..272bc3116c5fdc 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -531,7 +531,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
 
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
-  let hasCanonicalizeMethod = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cddb0671dd58f9..a0f02f6a7f259d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1087,24 +1087,25 @@ LogicalResult GenericOp::verify() { return success(); }
 
 namespace {
 
-/// Remove generic operations (on tensors) that are just copying
+/// Remove any linalg operation (on tensors) that are just copying
 /// the values from inputs to the results. Requirements are
 /// 1) All iterator types are parallel
 /// 2) The body contains just a yield operation with the yielded values being
 ///    the arguments corresponding to the operands.
-struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
+template <typename OpTy>
+struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(GenericOp genericOp,
+  LogicalResult matchAndRewrite(OpTy linalgOp,
                                 PatternRewriter &rewriter) const override {
     // Check all indexing maps are identity.
-    if (llvm::any_of(genericOp.getIndexingMapsArray(),
+    if (llvm::any_of(linalgOp.getIndexingMapsArray(),
                      [](AffineMap map) { return !map.isIdentity(); }))
       return failure();
 
     // Check that the body of the linalg operation is just a linalg.yield
     // operation.
-    Block &body = genericOp.getRegion().front();
+    Block &body = linalgOp->getRegion(0).front();
     if (!llvm::hasSingleElement(body))
       return failure();
     auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
@@ -1112,18 +1113,18 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
       return failure();
 
     // In the buffer case, we need to check exact buffer equality.
-    if (genericOp.hasPureBufferSemantics()) {
-      if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 &&
-          genericOp.getDpsInputOperand(0)->get() ==
-              genericOp.getDpsInitOperand(0)->get()) {
-        rewriter.eraseOp(genericOp);
+    if (linalgOp.hasPureBufferSemantics()) {
+      if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
+          linalgOp.getDpsInputOperand(0)->get() ==
+              linalgOp.getDpsInitOperand(0)->get()) {
+        rewriter.eraseOp(linalgOp);
         return success();
       }
       return failure();
     }
 
     // Mixed semantics is not supported yet.
-    if (!genericOp.hasPureTensorSemantics())
+    if (!linalgOp.hasPureTensorSemantics())
       return failure();
 
     // Get the argument number of the returned values. That is the operand
@@ -1134,8 +1135,8 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
       if (!yieldArg || yieldArg.getOwner() != &body)
         return failure();
       unsigned argumentNumber = yieldArg.getArgNumber();
-      Value returnedArg = genericOp->getOperand(argumentNumber);
-      Type resultType = genericOp->getResult(yieldVal.index()).getType();
+      Value returnedArg = linalgOp->getOperand(argumentNumber);
+      Type resultType = linalgOp->getResult(yieldVal.index()).getType();
       // The input can have a different type than the result, e.g. a dynamic
       // input dimension can be turned into a static output dimension.
       Type returnType = returnedArg.getType();
@@ -1145,21 +1146,21 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
         if (sparse_tensor::getSparseTensorEncoding(returnType) ||
             sparse_tensor::getSparseTensorEncoding(resultType))
           returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
-              genericOp.getLoc(), resultType, returnedArg);
+              linalgOp.getLoc(), resultType, returnedArg);
         else {
           if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
                                                  resultType))
             return failure();
           returnedArg = rewriter.create<tensor::CastOp>(
-              genericOp.getLoc(), resultType, returnedArg);
+              linalgOp.getLoc(), resultType, returnedArg);
         }
       }
       returnedArgs.push_back(returnedArg);
     }
 
-    if (returnedArgs.size() != genericOp->getNumResults())
+    if (returnedArgs.size() != linalgOp->getNumResults())
       return failure();
-    rewriter.replaceOp(genericOp, returnedArgs);
+    rewriter.replaceOp(linalgOp, returnedArgs);
     return success();
   }
 };
@@ -1168,7 +1169,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
 
 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<EraseIdentityGenericOp>(context);
+  results.add<EraseIdentityLinalgOp<GenericOp>>(context);
 }
 
 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
@@ -1907,18 +1908,9 @@ void BroadcastOp::getEffects(
                         getDpsInits());
 }
 
-LogicalResult BroadcastOp::canonicalize(BroadcastOp op,
-                                        PatternRewriter &rewriter) {
-  // For tensor semantics, if op's input and init are same shape, it is a no op.
-  // Otherwise, with buffer semantics, the op does a copy and we don't
-  // canonicalize.
-  if (op.hasPureTensorSemantics() &&
-      (op.getInput().getType() == op.getInit().getType())) {
-    rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
-    rewriter.eraseOp(op);
-    return success();
-  }
-  return failure();
+void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                              MLIRContext *context) {
+  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list