[Mlir-commits] [mlir] f2f8975 - [mlir][linalg] Use explicit replace in canonicalization pattern (NFC).

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 14 00:17:36 PDT 2022


Author: gysit
Date: 2022-03-14T07:09:51Z
New Revision: f2f89751e44a626e9e360826283d68b1b71b868c

URL: https://github.com/llvm/llvm-project/commit/f2f89751e44a626e9e360826283d68b1b71b868c
DIFF: https://github.com/llvm/llvm-project/commit/f2f89751e44a626e9e360826283d68b1b71b868c.diff

LOG: [mlir][linalg] Use explicit replace in canonicalization pattern (NFC).

Introduce an explicit `replaceOp` call to enable the tracking of the producer LinalgOp.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D121369

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1e8c29f58417d..becf321456248 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1642,20 +1642,13 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
     resultTypes[resultNumber] = resultType;
     Operation *newOp = linalgOp.clone(rewriter, loc, resultTypes, newOperands);
 
-    if (!resultValue.hasOneUse()) {
-      SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
-      // Create a tensor.cast operation back to the original type.
-      Value castBack = rewriter.create<tensor::CastOp>(
-          loc, resultValue.getType(), newOp->getResult(resultNumber));
-      results[resultNumber] = castBack;
-      // Replace all uses except the use in the cast op that is matched by the
-      // pattern. Note that this cast is from a more static shape to a more
-      // dynamic shape. These are expected to be pulled into their consumers.
-      rewriter.replaceOpWithIf(linalgOp, results,
-                               [&castOp](OpOperand &use) -> bool {
-                                 return use.getOwner() != castOp.getOperation();
-                               });
-    }
+    // Create a tensor.cast operation back to the original type.
+    Value castBack = rewriter.create<tensor::CastOp>(
+        loc, resultValue.getType(), newOp->getResult(resultNumber));
+
+    SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
+    results[resultNumber] = castBack;
+    rewriter.replaceOp(linalgOp, results);
     rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
     return success();
   }


        


More information about the Mlir-commits mailing list