[Mlir-commits] [mlir] 9bd19bb - [mlir][tensor] Fix bug in utility `tensor::isCastLikeExtractSliceOp`
Christopher Bate
llvmlistbot at llvm.org
Mon Aug 28 10:17:19 PDT 2023
Author: Christopher Bate
Date: 2023-08-28T11:17:11-06:00
New Revision: 9bd19bb703a437dfdac51823f26e25e0537d8c48
URL: https://github.com/llvm/llvm-project/commit/9bd19bb703a437dfdac51823f26e25e0537d8c48
DIFF: https://github.com/llvm/llvm-project/commit/9bd19bb703a437dfdac51823f26e25e0537d8c48.diff
LOG: [mlir][tensor] Fix bug in utility `tensor::isCastLikeExtractSliceOp`
Fixes an issue where `isCastLikeExtractSliceOp` did not account for the fact
that `tensor.extract_slice` may drop non-unit dimensions. This change makes the
utility function behave inline with its name/description. The only user of this
function is in the `FindPayloadReplacementOpInterface` for the
`tensor::ExtractSliceOp`. This can potentially cause downstream projects to have
more "listener could not find replacement op" errors when interpreting Transform
IR, but the behavior is inline with the documented conservative behavior of the
Transform dialect's TrackingListener.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D158635
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
mlir/lib/Dialect/Tensor/Utils/Utils.cpp
mlir/test/Dialect/Tensor/tracking-listener.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index bdb988bd463152..04b4de4a33a52f 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -38,7 +38,7 @@ computeTransposedType(RankedTensorType rankedTensorType,
bool isCastLikeInsertSliceOp(InsertSliceOp op);
/// A tensor.extract_slice is a cast-like operation if it merely rank-reduces
-/// the source tensor or extracts the entire source tensor.
+/// unit dimensions of the source tensor or extracts the entire source tensor.
bool isCastLikeExtractSliceOp(ExtractSliceOp op);
} // namespace tensor
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index c814c08dceb78b..24cbceb3d11791 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -98,8 +98,13 @@ bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
int64_t resultDim = 0;
// Source dims and result dims (apart from dropped dims) must have the same
// size.
- for (int64_t dim = 0; dim < op.getSourceType().getRank(); ++dim) {
+ RankedTensorType sourceType = op.getSourceType();
+ for (int64_t dim = 0, e = sourceType.getRank(); dim < e; ++dim) {
if (droppedDims.test(dim)) {
+ // ExtractSlice may drop unit dimensions that result from taking a size-1
+ // slice from a non-size-1 source dimension.
+ if (sourceType.getDimSize(dim) != 1)
+ return false;
continue;
}
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
diff --git a/mlir/test/Dialect/Tensor/tracking-listener.mlir b/mlir/test/Dialect/Tensor/tracking-listener.mlir
index 6341b7aaad713a..5d06ca8135dbd0 100644
--- a/mlir/test/Dialect/Tensor/tracking-listener.mlir
+++ b/mlir/test/Dialect/Tensor/tracking-listener.mlir
@@ -140,3 +140,14 @@ func.func @non_cast_like_extract_slice() {
{replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<3xf32>
return
}
+
+// -----
+
+func.func @non_cast_like_extract_slice_drop_non_unit_dim() {
+ // expected-error @below {{listener could not find replacement op}}
+ %0 = "test.foo"() {replaced} : () -> (tensor<f32>)
+ %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>)
+ %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 1, 1, 1][1, 1, 1, 1]
+ {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<f32>
+ return
+}
More information about the Mlir-commits
mailing list