[Mlir-commits] [mlir] f2f8975 - [mlir][linalg] Use explicit replace in canonicalization pattern (NFC).
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 14 00:17:36 PDT 2022
Author: gysit
Date: 2022-03-14T07:09:51Z
New Revision: f2f89751e44a626e9e360826283d68b1b71b868c
URL: https://github.com/llvm/llvm-project/commit/f2f89751e44a626e9e360826283d68b1b71b868c
DIFF: https://github.com/llvm/llvm-project/commit/f2f89751e44a626e9e360826283d68b1b71b868c.diff
LOG: [mlir][linalg] Use explicit replace in canonicalization pattern (NFC).
Introduce an explicit `replaceOp` call to enable the tracking of the producer LinalgOp.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D121369
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1e8c29f58417d..becf321456248 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1642,20 +1642,13 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
resultTypes[resultNumber] = resultType;
Operation *newOp = linalgOp.clone(rewriter, loc, resultTypes, newOperands);
- if (!resultValue.hasOneUse()) {
- SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
- // Create a tensor.cast operation back to the original type.
- Value castBack = rewriter.create<tensor::CastOp>(
- loc, resultValue.getType(), newOp->getResult(resultNumber));
- results[resultNumber] = castBack;
- // Replace all uses except the use in the cast op that is matched by the
- // pattern. Note that this cast is from a more static shape to a more
- // dynamic shape. These are expected to be pulled into their consumers.
- rewriter.replaceOpWithIf(linalgOp, results,
- [&castOp](OpOperand &use) -> bool {
- return use.getOwner() != castOp.getOperation();
- });
- }
+ // Create a tensor.cast operation back to the original type.
+ Value castBack = rewriter.create<tensor::CastOp>(
+ loc, resultValue.getType(), newOp->getResult(resultNumber));
+
+ SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
+ results[resultNumber] = castBack;
+ rewriter.replaceOp(linalgOp, results);
rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
return success();
}
More information about the Mlir-commits
mailing list