[Mlir-commits] [mlir] 34cf67a - [mlir][tensor] TrackingListener: Find replacement ops through cast-like ExtractSliceOps

Matthias Springer llvmlistbot at llvm.org
Thu Jun 1 00:07:01 PDT 2023


Author: Matthias Springer
Date: 2023-06-01T09:00:56+02:00
New Revision: 34cf67aef5a3655b57e52842a1bb4913295076e4

URL: https://github.com/llvm/llvm-project/commit/34cf67aef5a3655b57e52842a1bb4913295076e4
DIFF: https://github.com/llvm/llvm-project/commit/34cf67aef5a3655b57e52842a1bb4913295076e4.diff

LOG: [mlir][tensor] TrackingListener: Find replacement ops through cast-like ExtractSliceOps

Certain ExtractSliceOps, that do extract all elements from the destination, are treated like casts when looking for replacement ops. Such ExtractSliceOps are typically rank expansions.

Differential Revision: https://reviews.llvm.org/D151804

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
    mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
    mlir/lib/Dialect/Tensor/Utils/Utils.cpp
    mlir/test/Dialect/Tensor/tracking-listener.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index a037d40f901b..c610b5d0f737 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -47,6 +47,10 @@ computeTransposedType(RankedTensorType rankedTensorType,
 /// the same shape.
 bool isCastLikeInsertSliceOp(InsertSliceOp op);
 
+/// A tensor.extract_slice is a cast-like operation if it merely rank-reduces
+/// the source tensor or extracts the entire source tensor.
+bool isCastLikeExtractSliceOp(ExtractSliceOp op);
+
 } // namespace tensor
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 9b609a2f55f4..09a6b5049955 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -38,8 +38,6 @@ tensor::TrackingListener::findReplacementOp(Operation *op,
       return nullptr;
 
     // Skip cast-like operations.
-    // TODO: CastOpInterface could be used if CollapseShapeOp and ExpandShapeOp
-    // implement that interface
     values.clear();
     llvm::TypeSwitch<Operation *>(defOp)
         .Case<CastOp>([&](CastOp op) { values.push_back(op.getSource()); })
@@ -53,6 +51,10 @@ tensor::TrackingListener::findReplacementOp(Operation *op,
           if (isCastLikeInsertSliceOp(op))
             values.push_back(op.getSource());
         })
+        .Case<ExtractSliceOp>([&](ExtractSliceOp op) {
+          if (isCastLikeExtractSliceOp(op))
+            values.push_back(op.getSource());
+        })
         .Default([](Operation *op) {});
   } while (!values.empty());
 

diff  --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 165cf9b0b2f7..4d5404a3be2d 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -123,3 +123,22 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
 
   return true;
 }
+
+bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
+  llvm::SmallBitVector droppedDims = op.getDroppedDims();
+  int64_t resultDim = 0;
+  // Source dims and result dims (apart from dropped dims) must have the same
+  // size.
+  for (int64_t dim = 0; dim < op.getSourceType().getRank(); ++dim) {
+    if (droppedDims.test(dim)) {
+      continue;
+    }
+    FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
+        op.getSource(), op.getResult(), dim, resultDim);
+    if (failed(equalDimSize) || !*equalDimSize)
+      return false;
+    ++resultDim;
+  }
+
+  return true;
+}

diff  --git a/mlir/test/Dialect/Tensor/tracking-listener.mlir b/mlir/test/Dialect/Tensor/tracking-listener.mlir
index 369dcec45e3a..6341b7aaad71 100644
--- a/mlir/test/Dialect/Tensor/tracking-listener.mlir
+++ b/mlir/test/Dialect/Tensor/tracking-listener.mlir
@@ -105,3 +105,38 @@ func.func @cast_like_insert_slice_dynamic(
       {replacement_0 = 0} : tensor<?xf32> into tensor<1x?x1xf32>
   return
 }
+
+// -----
+
+func.func @cast_like_extract_slice() {
+  %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
+  // expected-remark @below {{replacement found}}
+  %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>)
+  %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 5, 1, 1][1, 1, 1, 1]
+      {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<5xf32>
+  return
+}
+
+// -----
+
+func.func @cast_like_extract_slice_dynamic() {
+  %0 = "test.foo"() {replaced} : () -> (tensor<?xf32>)
+  // expected-remark @below {{replacement found}}
+  %1 = "test.foo"() : () -> (tensor<1x?x1x1xf32>)
+  %c1 = arith.constant 1 : index
+  %dim = tensor.dim %1, %c1 : tensor<1x?x1x1xf32>
+  %2 = tensor.extract_slice %1[0, 0, 0, 0][1, %dim, 1, 1][1, 1, 1, 1]
+      {replacement_0 = 0} : tensor<1x?x1x1xf32> to tensor<?xf32>
+  return
+}
+
+// -----
+
+func.func @non_cast_like_extract_slice() {
+  // expected-error @below {{listener could not find replacement op}}
+  %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
+  %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>)
+  %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 3, 1, 1][1, 1, 1, 1]
+      {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<3xf32>
+  return
+}


        


More information about the Mlir-commits mailing list