[Mlir-commits] [mlir] 9fa61cb - [mlir] Insert tensor.cast only when needed when folding tensor.cast into extract_slice.
Alexander Belyaev
llvmlistbot at llvm.org
Mon Feb 27 06:18:22 PST 2023
Author: Alexander Belyaev
Date: 2023-02-27T15:18:01+01:00
New Revision: 9fa61cbb2e8dbf00e9320145d38331b7da8d552f
URL: https://github.com/llvm/llvm-project/commit/9fa61cbb2e8dbf00e9320145d38331b7da8d552f
DIFF: https://github.com/llvm/llvm-project/commit/9fa61cbb2e8dbf00e9320145d38331b7da8d552f.diff
LOG: [mlir] Insert tensor.cast only when needed when folding tensor.cast into extract_slice.
Differential Revision: https://reviews.llvm.org/D144868
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b6359a935782..b31ef3b98ed0 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1829,7 +1829,7 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
}))
return failure();
- auto castOp = sliceOp.getSource().getDefiningOp<tensor::CastOp>();
+ auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
if (!castOp)
return failure();
@@ -1837,17 +1837,20 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
return failure();
/// Deduce the type of the result to use for the canonicalized operation.
+ Location loc = sliceOp.getLoc();
+ auto sliceOpType = sliceOp.getType();
RankedTensorType resultType =
ExtractSliceOp::inferCanonicalRankReducedResultType(
- sliceOp.getType().getRank(), sliceOp.getSourceType(),
+ sliceOpType.getRank(), sliceOp.getSourceType(),
sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
sliceOp.getMixedStrides());
- Value newSlice = rewriter.create<ExtractSliceOp>(
- sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
+ Value newResult = rewriter.create<ExtractSliceOp>(
+ loc, resultType, castOp.getSource(), sliceOp.getOffsets(),
sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
- rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
- newSlice);
+ if (newResult.getType() != sliceOpType)
+ newResult = rewriter.create<CastOp>(loc, sliceOpType, newResult);
+ rewriter.replaceOp(sliceOp, newResult);
return success();
}
};
More information about the Mlir-commits
mailing list