[Mlir-commits] [mlir] [mlir][tensor] Apply `InsertSliceOfTransferWriteOpFolder` only when `transfer_write` overwrites all elements of `insert_slice` (PR #108803)

Rajveer Singh Bharadwaj llvmlistbot at llvm.org
Mon Sep 30 04:17:19 PDT 2024


https://github.com/Rajveer100 updated https://github.com/llvm/llvm-project/pull/108803

>From b216f581499b1ec1ea1f7c52de0ee4a991611b24 Mon Sep 17 00:00:00 2001
From: Rajveer <rajveer.developer at icloud.com>
Date: Mon, 16 Sep 2024 13:51:57 +0530
Subject: [PATCH] [mlir][tensor] Apply `InsertSliceOfTransferWriteOpFolder`
 only when `transfer_write` overwrites all elements of `insert_slice`

Resolves #101708

The updated logic now correctly checks if `transfer_write` completely
overwrites `insert_slice` and only then applies the rewrite for this pattern.

This check currently covers static sizes, for dynamic sizes
value bounds analysis is needed (see `TODO:`).
---
 .../Tensor/Transforms/FoldTensorSubsetOps.cpp | 24 +++++++++++
 .../Tensor/fold-tensor-subset-ops.mlir        | 40 +++++++++++--------
 2 files changed, 48 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 5396531922aab3..770d3f43de2c2a 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include <type_traits>
@@ -67,6 +68,9 @@ class InsertSliceOfTransferWriteOpFolder final
 
   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
                                 PatternRewriter &rewriter) const override;
+
+  static bool
+  doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp);
 };
 } // namespace
 
@@ -84,6 +88,15 @@ static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
                 "strides, this may result in needing to insert "
                 "vector.insert_strided_slice/extract_strided_slice ops");
   }
+  if constexpr (std::is_same_v<XferOp, vector::TransferWriteOp>) {
+    if constexpr (std::is_same_v<ExtractOrInsertOp, tensor::InsertSliceOp>) {
+      if (!InsertSliceOfTransferWriteOpFolder::
+              doesTransferWriteCoverInsertSlice(xferOp))
+        return rewriter.notifyMatchFailure(
+            xferOp, "transfer_write does not cover insert_slice");
+    }
+  }
+
   return success();
 }
 
@@ -154,6 +167,17 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
   return success();
 }
 
+bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
+    vector::TransferWriteOp writeOp) {
+  if (writeOp.getShapedType().hasStaticShape())
+    return llvm::equal(writeOp.getVectorType().getShape(),
+                       writeOp.getShapedType().getShape());
+
+  // TODO: Use ValueBoundsConstraintSet for dynamic shapes.
+
+  return false;
+}
+
 template <typename OpTy>
 struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
index 1a84e141049325..988b5d835c16ed 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
@@ -144,8 +144,6 @@ func.func @transfer_read_of_extract_slice_swappy_rank_reducing(%t : tensor<?x?x?
 
 // -----
 
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
-
 //       CHECK: func @fold_vector_transfer_write_with_rank_reduced_insert_slice
 //  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
 //  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
@@ -155,6 +153,7 @@ func.func @transfer_read_of_extract_slice_swappy_rank_reducing(%t : tensor<?x?x?
 //  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
 //  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
 //  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG8:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 func.func @fold_vector_transfer_write_with_rank_reduced_insert_slice(
     %arg0 : tensor<?x?x?xf32>,
     %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
@@ -162,11 +161,8 @@ func.func @fold_vector_transfer_write_with_rank_reduced_insert_slice(
     %st : tensor<?x?xf32>) -> tensor<?x?x?xf32> {
   %cst = arith.constant 0.0 : f32
 
-//   CHECK-NOT:    insert_slice
-//   CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-//   CHECK-DAG:    %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
-//   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
-//   CHECK-DAG:    vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]] {in_bounds = [true]} : vector<4xf32>, tensor<?x?x?xf32
+  //   CHECK-DAG:  %[[r1:.*]] = vector.transfer_write %[[ARG1]], %[[ARG8]][%[[ARG6]], %[[ARG7]]] {in_bounds = [true]} : vector<4xf32>, tensor<?x?xf32>
+  //   CHECK-DAG:  %[[r2:.*]] = tensor.insert_slice %[[r1]] into %[[ARG0]][0, %[[ARG2]], %[[ARG3]]] [1, %[[ARG4]], %[[ARG5]]] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
   %0 = vector.transfer_write %arg1, %st[%arg6, %arg7] {in_bounds = [true]}
       : vector<4xf32>, tensor<?x?xf32>
   %1 = tensor.insert_slice %0 into %arg0[0, %arg2, %arg3] [1, %arg4, %arg5] [1, 1, 1]
@@ -176,9 +172,6 @@ func.func @fold_vector_transfer_write_with_rank_reduced_insert_slice(
 
 // -----
 
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
-//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
-
 //       CHECK: func @fold_vector_transfer_write_with_inner_rank_reduced_insert_slice
 //  CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
 //  CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
@@ -188,6 +181,7 @@ func.func @fold_vector_transfer_write_with_rank_reduced_insert_slice(
 //  CHECK-SAME:    %[[ARG5:[a-zA-Z0-9]+]]: index
 //  CHECK-SAME:    %[[ARG6:[a-zA-Z0-9]+]]: index
 //  CHECK-SAME:    %[[ARG7:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:    %[[ARG8:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 func.func @fold_vector_transfer_write_with_inner_rank_reduced_insert_slice(
     %arg0 : tensor<?x?x?xf32>,
     %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
@@ -195,12 +189,8 @@ func.func @fold_vector_transfer_write_with_inner_rank_reduced_insert_slice(
     %st : tensor<?x?xf32>) -> tensor<?x?x?xf32> {
   %cst = arith.constant 0.0 : f32
 
-  //   CHECK-NOT: insert_slice
-  //   CHECK-DAG:  %[[C0:.+]] = arith.constant 0 : index
-  //   CHECK-DAG:  %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
-  //   CHECK-DAG:  %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
-  //   CHECK-DAG:  vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]], %[[C0]]]
-  //  CHECK-SAME:    {in_bounds = [true], permutation_map = #[[MAP2]]} : vector<4xf32>, tensor<?x?x?xf32
+  //   CHECK-DAG:  %[[r1:.*]] = vector.transfer_write %[[ARG1]], %[[ARG8]][%[[ARG6]], %[[ARG7]]] {in_bounds = [true]} : vector<4xf32>, tensor<?x?xf32>
+  //   CHECK-DAG:  %[[r2:.*]] = tensor.insert_slice %[[r1]] into %[[ARG0]][%[[ARG2]], %[[ARG3]], 0] [%[[ARG4]], %[[ARG5]], 1] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
   %0 = vector.transfer_write %arg1, %st[%arg6, %arg7] {in_bounds = [true]}
       : vector<4xf32>, tensor<?x?xf32>
   %1 = tensor.insert_slice %0 into %arg0[%arg2, %arg3, 0] [%arg4, %arg5, 1] [1, 1, 1]
@@ -226,6 +216,24 @@ func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x
 
 // -----
 
+// This test is negative since `transfer_write` only
+// writes to `5x6` of the `100x100` elements of `%arg3`
+// CHECK-LABEL: func @insert_slice_of_transfer_write_overwrite_all(
+//  CHECK-SAME:     %[[arg0:.*]]: tensor<1000x1000xf32>, %[[arg1:.*]]: vector<5x6xf32>, %[[arg2:.*]]: index, %[[arg3:.*]]: tensor<100x100xf32>
+func.func @insert_slice_of_transfer_write_overwrite_all(%arg0: tensor<1000x1000xf32>, %arg1: vector<5x6xf32>, %arg2: index, %arg3: tensor<100x100xf32>) -> tensor<1000x1000xf32> {
+  %c0 = arith.constant 0 : index
+
+//       CHECK:   %[[c0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[r1:.*]] = vector.transfer_write %[[arg1]], %[[arg3]][%[[c0]], %[[c0]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<100x100xf32>
+//       CHECK:   %[[r2:.*]] = tensor.insert_slice %[[r1]] into %[[arg0]][3, %[[arg2]]] [100, 100] [1, 1] : tensor<100x100xf32> into tensor<1000x1000xf32>
+//       CHECK:   return %[[r2]] : tensor<1000x1000xf32>
+  %0 = vector.transfer_write %arg1, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<100x100xf32>
+  %inserted_slice = tensor.insert_slice %0 into %arg0[3, %arg2] [100, 100] [1, 1] : tensor<100x100xf32> into tensor<1000x1000xf32>
+  return %inserted_slice : tensor<1000x1000xf32>
+}
+
+// -----
+
 //   CHECK-DAG: #[[$d0d2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 
 // CHECK-LABEL: func @insert_slice_of_transfer_write_swappy_rank_extending(



More information about the Mlir-commits mailing list