[Mlir-commits] [mlir] 34cf67a - [mlir][tensor] TrackingListener: Find replacement ops through cast-like ExtractSliceOps
Matthias Springer
llvmlistbot at llvm.org
Thu Jun 1 00:07:01 PDT 2023
Author: Matthias Springer
Date: 2023-06-01T09:00:56+02:00
New Revision: 34cf67aef5a3655b57e52842a1bb4913295076e4
URL: https://github.com/llvm/llvm-project/commit/34cf67aef5a3655b57e52842a1bb4913295076e4
DIFF: https://github.com/llvm/llvm-project/commit/34cf67aef5a3655b57e52842a1bb4913295076e4.diff
LOG: [mlir][tensor] TrackingListener: Find replacement ops through cast-like ExtractSliceOps
Certain ExtractSliceOps, that do extract all elements from the destination, are treated like casts when looking for replacement ops. Such ExtractSliceOps are typically rank expansions.
Differential Revision: https://reviews.llvm.org/D151804
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
mlir/lib/Dialect/Tensor/Utils/Utils.cpp
mlir/test/Dialect/Tensor/tracking-listener.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index a037d40f901b..c610b5d0f737 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -47,6 +47,10 @@ computeTransposedType(RankedTensorType rankedTensorType,
/// the same shape.
bool isCastLikeInsertSliceOp(InsertSliceOp op);
+/// A tensor.extract_slice is a cast-like operation if it merely rank-reduces
+/// the source tensor or extracts the entire source tensor.
+bool isCastLikeExtractSliceOp(ExtractSliceOp op);
+
} // namespace tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 9b609a2f55f4..09a6b5049955 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -38,8 +38,6 @@ tensor::TrackingListener::findReplacementOp(Operation *op,
return nullptr;
// Skip cast-like operations.
- // TODO: CastOpInterface could be used if CollapseShapeOp and ExpandShapeOp
- // implement that interface
values.clear();
llvm::TypeSwitch<Operation *>(defOp)
.Case<CastOp>([&](CastOp op) { values.push_back(op.getSource()); })
@@ -53,6 +51,10 @@ tensor::TrackingListener::findReplacementOp(Operation *op,
if (isCastLikeInsertSliceOp(op))
values.push_back(op.getSource());
})
+ .Case<ExtractSliceOp>([&](ExtractSliceOp op) {
+ if (isCastLikeExtractSliceOp(op))
+ values.push_back(op.getSource());
+ })
.Default([](Operation *op) {});
} while (!values.empty());
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 165cf9b0b2f7..4d5404a3be2d 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -123,3 +123,22 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
return true;
}
+
+bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
+ llvm::SmallBitVector droppedDims = op.getDroppedDims();
+ int64_t resultDim = 0;
+ // Source dims and result dims (apart from dropped dims) must have the same
+ // size.
+ for (int64_t dim = 0; dim < op.getSourceType().getRank(); ++dim) {
+ if (droppedDims.test(dim)) {
+ continue;
+ }
+ FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
+ op.getSource(), op.getResult(), dim, resultDim);
+ if (failed(equalDimSize) || !*equalDimSize)
+ return false;
+ ++resultDim;
+ }
+
+ return true;
+}
diff --git a/mlir/test/Dialect/Tensor/tracking-listener.mlir b/mlir/test/Dialect/Tensor/tracking-listener.mlir
index 369dcec45e3a..6341b7aaad71 100644
--- a/mlir/test/Dialect/Tensor/tracking-listener.mlir
+++ b/mlir/test/Dialect/Tensor/tracking-listener.mlir
@@ -105,3 +105,38 @@ func.func @cast_like_insert_slice_dynamic(
{replacement_0 = 0} : tensor<?xf32> into tensor<1x?x1xf32>
return
}
+
+// -----
+
+func.func @cast_like_extract_slice() {
+ %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
+ // expected-remark @below {{replacement found}}
+ %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>)
+ %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 5, 1, 1][1, 1, 1, 1]
+ {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<5xf32>
+ return
+}
+
+// -----
+
+func.func @cast_like_extract_slice_dynamic() {
+ %0 = "test.foo"() {replaced} : () -> (tensor<?xf32>)
+ // expected-remark @below {{replacement found}}
+ %1 = "test.foo"() : () -> (tensor<1x?x1x1xf32>)
+ %c1 = arith.constant 1 : index
+ %dim = tensor.dim %1, %c1 : tensor<1x?x1x1xf32>
+ %2 = tensor.extract_slice %1[0, 0, 0, 0][1, %dim, 1, 1][1, 1, 1, 1]
+ {replacement_0 = 0} : tensor<1x?x1x1xf32> to tensor<?xf32>
+ return
+}
+
+// -----
+
+func.func @non_cast_like_extract_slice() {
+ // expected-error @below {{listener could not find replacement op}}
+ %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
+ %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>)
+ %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 3, 1, 1][1, 1, 1, 1]
+ {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<3xf32>
+ return
+}
More information about the Mlir-commits
mailing list