[Mlir-commits] [mlir] [mlir] Replace dynamic sizes in insert_slice of tensor.cast canonicalization (PR #91352)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 8 07:03:25 PDT 2024


https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/91352

>From 953c3e4cb195ba6624501d3bd284ebc2e61d74ff Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Tue, 7 May 2024 10:44:46 -0400
Subject: [PATCH 1/3] [mlir] Replace dynamic sizes in insert_slice of
 tensor.cast canonicalization

---
 mlir/include/mlir/IR/BuiltinTypes.h      |  8 ++++++-
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 29 +++++++++++++++++++++---
 mlir/lib/IR/BuiltinTypes.cpp             | 24 ++++++++++----------
 3 files changed, 45 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 2361cf1371237..5579b138668d2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -360,9 +360,15 @@ class VectorType::Builder {
 /// which dimensions must be kept when e.g. compute MemRef strides under
 /// rank-reducing operations. Return std::nullopt if reducedShape cannot be
 /// obtained by dropping only `1` entries in `originalShape`.
+/// If `matchDynamic` is true, then dynamic dims in `originalShape` and
+/// `reducedShape` will be considered matching with non-dynamic dims, unless
+/// the non-dynamic dim is from `originalShape` and equal to 1. For example,
+/// in ([1, 3, ?], [?, 5]), the mask would be {1, 0, 0}, since 3 and 5 will
+/// match with the corresponding dynamic dims.
 std::optional<llvm::SmallDenseSet<unsigned>>
 computeRankReductionMask(ArrayRef<int64_t> originalShape,
-                         ArrayRef<int64_t> reducedShape);
+                         ArrayRef<int64_t> reducedShape,
+                         bool matchDynamic = false);
 
 /// Enum that captures information related to verifier error conditions on
 /// slice insert/extract type of ops.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4c65045084dc5..d560c11464f1c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2711,15 +2711,38 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
     auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
     if (!srcType || !dstType)
       return failure();
+
+    // The tensor.cast source could have additional static information not seen
+    // in the insert slice op static sizes, so we ignore dynamic dims when
+    // computing the rank reduction mask.
+    SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
+    auto rankReductionMask = computeRankReductionMask(
+        staticSizes, srcType.getShape(), /*matchDynamic=*/true);
+    if (!rankReductionMask.has_value())
+      return failure();
+    // Replace dimensions in the insert slice op with corresponding static dims
+    // from the cast source type. If the insert slice sizes have static dims
+    // that are not static in the tensor.cast source (i.e., when the cast op
+    // casts a dynamic dim to static), the dim should not be replaced, and the
+    // pattern will fail later in `verifyInsertSliceOp`.
+    SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
+    int64_t rankReducedIdx = 0;
+    for (auto [idx, size] : enumerate(staticSizes)) {
+      if (!rankReductionMask.value().contains(idx) &&
+          !srcType.isDynamicDim(rankReducedIdx)) {
+        mixedSizes[idx] = getAsIndexOpFoldResult(
+            rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
+        size = srcType.getDimSize(rankReducedIdx++);
+      }
+    }
     if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
-                            insertSliceOp.getStaticSizes(),
-                            insertSliceOp.getStaticStrides()) !=
+                            staticSizes, insertSliceOp.getStaticStrides()) !=
         SliceVerificationResult::Success)
       return failure();
 
     Operation *replacement = rewriter.create<InsertOpTy>(
         insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
-        insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+        mixedSizes, insertSliceOp.getMixedStrides());
 
     // In the parallel case there is no result and so nothing to cast.
     bool isParallelInsert =
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index a2738946de410..179797cb943a1 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -408,24 +408,24 @@ unsigned BaseMemRefType::getMemorySpaceAsInt() const {
 // MemRefType
 //===----------------------------------------------------------------------===//
 
-/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
-/// `originalShape` with some `1` entries erased, return the set of indices
-/// that specifies which of the entries of `originalShape` are dropped to obtain
-/// `reducedShape`. The returned mask can be applied as a projection to
-/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
-/// which dimensions must be kept when e.g. compute MemRef strides under
-/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
-/// obtained by dropping only `1` entries in `originalShape`.
 std::optional<llvm::SmallDenseSet<unsigned>>
 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
-                               ArrayRef<int64_t> reducedShape) {
+                               ArrayRef<int64_t> reducedShape,
+                               bool matchDynamic) {
   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
   llvm::SmallDenseSet<unsigned> unusedDims;
   unsigned reducedIdx = 0;
   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
     // Greedily insert `originalIdx` if match.
-    if (reducedIdx < reducedRank &&
-        originalShape[originalIdx] == reducedShape[reducedIdx]) {
+    int64_t origSize = originalShape[originalIdx];
+    // if `matchDynamic`, count dynamic dims as a match, unless `origSize` is 1.
+    if (matchDynamic && reducedIdx < reducedRank && origSize != 1 &&
+        (ShapedType::isDynamic(reducedShape[reducedIdx]) ||
+         ShapedType::isDynamic(origSize))) {
+      reducedIdx++;
+      continue;
+    }
+    if (reducedIdx < reducedRank && origSize == reducedShape[reducedIdx]) {
       reducedIdx++;
       continue;
     }
@@ -433,7 +433,7 @@ mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
     unusedDims.insert(originalIdx);
     // If no match on `originalIdx`, the `originalShape` at this dimension
     // must be 1, otherwise we bail.
-    if (originalShape[originalIdx] != 1)
+    if (origSize != 1)
       return std::nullopt;
   }
   // The whole reducedShape must be scanned, otherwise we bail.

>From e8344cefa6293796774561ba175f514184756165 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Tue, 7 May 2024 11:51:05 -0400
Subject: [PATCH 2/3] update lit tests

---
 mlir/test/Dialect/Tensor/canonicalize.mlir | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 6177fe3c752c9..53c8a65d39e63 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1890,14 +1890,13 @@ func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {
 
 // -----
 
-// There was an issue in cast + insert_slice folding generating invalid ir.
-// https://github.com/llvm/llvm-project/issues/53099
 // CHECK-LABEL: func @insert_slice_cast
 func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
-  // CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<1x?xf32> to tensor<?x?xf32>
+  // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>
   %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
-  // CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
-  // CHECK-SAME: : tensor<?x?xf32> into tensor<?x?xf32>
+  // CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG0]]
+  // CHECK-SAME: [{{.*}}, {{.*}}] [1, {{.*}}] [{{.*}}, {{.*}}]
+  // CHECK-SAME: : tensor<1x?xf32> into tensor<?x?xf32>
   %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
   // CHECK: return %[[RES]] : tensor<?x?xf32>
   return %1 : tensor<?x?xf32>

>From 8b3406f720de2af9edf0334eec3200b3521546a2 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Wed, 8 May 2024 10:03:11 -0400
Subject: [PATCH 3/3] add additional test

---
 mlir/test/Dialect/Tensor/canonicalize.mlir | 42 ++++++++++++++--------
 1 file changed, 28 insertions(+), 14 deletions(-)

diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 53c8a65d39e63..8036d996d2324 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -755,6 +755,34 @@ func.func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
 
 // -----
 
+// CHECK-LABEL: func @insert_slice_cast
+func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
+  // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>
+  %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
+  // CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG0]]
+  // CHECK-SAME: [{{.*}}, {{.*}}] [1, {{.*}}] [{{.*}}, {{.*}}]
+  // CHECK-SAME: : tensor<1x?xf32> into tensor<?x?xf32>
+  %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
+  // CHECK: return %[[RES]] : tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_cast_no_fold
+func.func @insert_slice_cast_no_fold(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
+  %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x5xf32>
+  // CHECK: %[[CAST:.*]] = tensor.cast
+  // CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
+  // CHECK-SAME: [{{.*}}, {{.*}}] [{{.*}}, 5] [{{.*}}, {{.*}}]
+  // CHECK-SAME: : tensor<?x5xf32> into tensor<?x?xf32>
+  %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, 5] [%arg6, %arg7] : tensor<?x5xf32> into tensor<?x?xf32>
+  // CHECK: return %[[RES]] : tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src(
 // CHECK-SAME:      %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
 //      CHECK:    %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x64xf32>
@@ -1890,20 +1918,6 @@ func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {
 
 // -----
 
-// CHECK-LABEL: func @insert_slice_cast
-func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
-  // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>
-  %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
-  // CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG0]]
-  // CHECK-SAME: [{{.*}}, {{.*}}] [1, {{.*}}] [{{.*}}, {{.*}}]
-  // CHECK-SAME: : tensor<1x?xf32> into tensor<?x?xf32>
-  %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
-  // CHECK: return %[[RES]] : tensor<?x?xf32>
-  return %1 : tensor<?x?xf32>
-}
-
-// -----
-
 // CHECK-LABEL: func @cast_extract_slice
 func.func @cast_extract_slice(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
     -> tensor<16x512xf32> {



More information about the Mlir-commits mailing list