[Mlir-commits] [mlir] [mlir][tensor] Apply `InsertSliceOfTransferWriteOpFolder` only when `transfer_write` overwrites all elements of `insert_slice` (PR #108803)

Rajveer Singh Bharadwaj llvmlistbot at llvm.org
Wed Sep 18 04:57:04 PDT 2024


https://github.com/Rajveer100 updated https://github.com/llvm/llvm-project/pull/108803

>From 6b4aec50410e86f72f0e1a7dc1e3113b3bb4a3e7 Mon Sep 17 00:00:00 2001
From: Rajveer <rajveer.developer at icloud.com>
Date: Mon, 16 Sep 2024 13:51:57 +0530
Subject: [PATCH] [mlir][tensor] Apply `InsertSliceOfTransferWriteOpFolder`
 only when `transfer_write` overwrites all elements of `insert_slice`

Resolves #101708
---
 .../Tensor/Transforms/FoldTensorSubsetOps.cpp | 31 +++++++++++++++++++
 .../Tensor/fold-tensor-subset-ops.mlir        | 16 ++++++++++
 2 files changed, 47 insertions(+)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 5396531922aab3..af3e5e45a81b20 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include <type_traits>
@@ -67,6 +68,12 @@ class InsertSliceOfTransferWriteOpFolder final
 
   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
                                 PatternRewriter &rewriter) const override;
+
+private:
+  static bool
+  doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp,
+                                    tensor::InsertSliceOp insertSliceOp,
+                                    MLIRContext *context);
 };
 } // namespace
 
@@ -136,6 +143,11 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
   if (failed(preconditionResult))
     return preconditionResult;
 
+  if (!doesTransferWriteCoverInsertSlice(writeOp, insertSliceOp,
+                                         rewriter.getContext()))
+    return rewriter.notifyMatchFailure(
+        insertSliceOp, "transfer_write does not cover insert_slice");
+
   SmallVector<Value> indices(writeOp.getIndices().begin(),
                              writeOp.getIndices().end());
   SmallVector<Value> sourceIndices;
@@ -154,6 +166,25 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
   return success();
 }
 
+bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
+    vector::TransferWriteOp writeOp, tensor::InsertSliceOp insertSliceOp,
+    MLIRContext *context) {
+  auto destType = cast<ShapedType>(writeOp.getOperand(0).getType());
+  auto insertSliceType = insertSliceOp.getSourceType();
+
+  if (destType.hasStaticShape() && insertSliceType.hasStaticShape()) {
+    for (int64_t d = 0, e = insertSliceType.getRank(); d < e; ++d) {
+      if (destType.getDimSize(d) != insertSliceType.getDimSize(d))
+        return false;
+    }
+    return true;
+  }
+
+  // Todo: ValueBoundsConstraintSet for dynamic shapes.
+
+  return true;
+}
+
 template <typename OpTy>
 struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
index 1a84e141049325..0c9a8b839284ee 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
@@ -226,6 +226,22 @@ func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x
 
 // -----
 
+// CHECK-LABEL: func @insert_slice_of_transfer_write_overwrite_all(
+//  CHECK-SAME:     %[[arg0:.*]]: tensor<1000x1000xf32>, %[[arg1:.*]]: vector<5x6xf32>, %[[arg2:.*]]: index, %[[arg3:.*]]: tensor<100x100xf32>
+func.func @insert_slice_of_transfer_write_overwrite_all(%arg0: tensor<1000x1000xf32>, %arg1: vector<5x6xf32>, %arg2: index, %arg3: tensor<100x100xf32>) -> tensor<1000x1000xf32> {
+  %c0 = arith.constant 0 : index
+
+//       CHECK:   %[[c0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[r1:.*]] = vector.transfer_write %[[arg1]], %[[arg3]][%[[c0]], %[[c0]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<100x100xf32>
+//       CHECK:   %[[r2:.*]] = tensor.insert_slice %[[r1]] into %[[arg0]][3, %[[arg2]]] [100, 100] [1, 1] : tensor<100x100xf32> into tensor<1000x1000xf32>
+//       CHECK:   return %[[r2]] : tensor<1000x1000xf32>
+  %0 = vector.transfer_write %arg1, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<100x100xf32>
+  %inserted_slice = tensor.insert_slice %0 into %arg0[3, %arg2] [100, 100] [1, 1] : tensor<100x100xf32> into tensor<1000x1000xf32>
+  return %inserted_slice : tensor<1000x1000xf32>
+}
+
+// -----
+
 //   CHECK-DAG: #[[$d0d2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 
 // CHECK-LABEL: func @insert_slice_of_transfer_write_swappy_rank_extending(



More information about the Mlir-commits mailing list