[Mlir-commits] [mlir] aeac932 - [SCF] Clean up ForOpTensorCastFolder and harden it against nop tensor casts
Benjamin Kramer
llvmlistbot at llvm.org
Wed Apr 19 07:06:09 PDT 2023
Author: Benjamin Kramer
Date: 2023-04-19T16:05:59+02:00
New Revision: aeac932943bc7e0c54903d2c4b754e3a87e90fb0
URL: https://github.com/llvm/llvm-project/commit/aeac932943bc7e0c54903d2c4b754e3a87e90fb0
DIFF: https://github.com/llvm/llvm-project/commit/aeac932943bc7e0c54903d2c4b754e3a87e90fb0.diff
LOG: [SCF] Clean up ForOpTensorCastFolder and harden it against nop tensor casts
The code was inserting a new cast, discarding it, then inserting it
again.
The self-cast issue is the root of #62135 because it would end up
dropping the loop and inserting an invalid cast to itself. As far as I
can tell tensor.cast with the same src and dst types is not invalid but
it can't really be tested in isolation as it's immediately folded.
Fixes #62135
Differential Revision: https://reviews.llvm.org/D148714
Added:
Modified:
mlir/lib/Dialect/SCF/IR/SCF.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 06d4addae84f0..76b3589b15f49 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -886,9 +886,9 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
/// Perform a replacement of one iter OpOperand of an scf.for to the
/// `replacement` value which is expected to be the source of a tensor.cast.
/// tensor.cast ops are inserted inside the block to account for the type cast.
-static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
- OpOperand &operand,
- Value replacement) {
+static SmallVector<Value>
+replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
+ Value replacement) {
Type oldType = operand.get().getType(), newType = replacement.getType();
assert(oldType.isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
"expected ranked tensor types");
@@ -897,8 +897,8 @@ static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
ForOp forOp = cast<ForOp>(operand.getOwner());
assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
"expected an iter OpOperand");
- if (operand.get().getType() == replacement.getType())
- return forOp;
+ assert(operand.get().getType() != replacement.getType() &&
+ "Expected a
diff erent type");
SmallVector<Value> newIterOperands;
for (OpOperand &opOperand : forOp.getIterOpOperands()) {
if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
@@ -949,7 +949,7 @@ static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
newForOp.getLoc(), oldType, newResults[yieldIdx]);
- return newForOp;
+ return newResults;
}
/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
@@ -986,7 +986,8 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
OpOperand &iterOpOperand = std::get<0>(it);
auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
- if (!incomingCast)
+ if (!incomingCast ||
+ incomingCast.getSource().getType() == incomingCast.getType())
continue;
// If the dest type of the cast does not preserve static information in
// the source type.
@@ -998,18 +999,9 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
continue;
// Create a new ForOp with that iter operand replaced.
- auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
- incomingCast.getSource());
-
- // Insert outgoing cast and use it to replace the corresponding result.
- rewriter.setInsertionPointAfter(newForOp);
- SmallVector<Value> replacements = newForOp.getResults();
- unsigned returnIdx =
- iterOpOperand.getOperandNumber() - op.getNumControlOperands();
- replacements[returnIdx] = rewriter.create<tensor::CastOp>(
- op.getLoc(), incomingCast.getDest().getType(),
- replacements[returnIdx]);
- rewriter.replaceOp(op, replacements);
+ rewriter.replaceOp(
+ op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
+ incomingCast.getSource()));
return success();
}
return failure();
More information about the Mlir-commits
mailing list