[Mlir-commits] [mlir] [mlir][tensor] Apply `InsertSliceOfTransferWriteOpFolder` only when `transfer_write` overwrites all elements of `insert_slice` (PR #108803)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 16 01:31:17 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Rajveer Singh Bharadwaj (Rajveer100)
<details>
<summary>Changes</summary>
Resolves #<!-- -->101708
---
Full diff: https://github.com/llvm/llvm-project/pull/108803.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp (+20)
- (modified) mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir (+9)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 5396531922aab3..f7a490844e95af 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -23,6 +23,8 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <cstddef>
+#include <sys/_types/_int64_t.h>
#include <type_traits>
namespace mlir {
@@ -67,6 +69,12 @@ class InsertSliceOfTransferWriteOpFolder final
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
PatternRewriter &rewriter) const override;
+
+private:
+ static bool
+ doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp,
+ tensor::InsertSliceOp insertSliceOp,
+ MLIRContext *context);
};
} // namespace
@@ -136,6 +144,11 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
if (failed(preconditionResult))
return preconditionResult;
+ if (!doesTransferWriteCoverInsertSlice(writeOp, insertSliceOp,
+ rewriter.getContext()))
+ return rewriter.notifyMatchFailure(
+ insertSliceOp, "transfer_write does not cover insert_slice");
+
SmallVector<Value> indices(writeOp.getIndices().begin(),
writeOp.getIndices().end());
SmallVector<Value> sourceIndices;
@@ -154,6 +167,13 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
return success();
}
+bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
+ vector::TransferWriteOp writeOp, tensor::InsertSliceOp insertSliceOp,
+ MLIRContext *context) {
+ // Todo
+ return true;
+}
+
template <typename OpTy>
struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
index 1a84e141049325..7ba24511e96ba5 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
@@ -226,6 +226,15 @@ func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x
// -----
+func.func @insert_slice_of_transfer_write(%arg0: tensor<1000x1000xf32>, %arg1: vector<5x6xf32>, %arg2: index, %arg3: tensor<100x100xf32>) -> tensor<1000x1000xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.transfer_write %arg1, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<100x100xf32>
+ %inserted_slice = tensor.insert_slice %0 into %arg0[3, %arg2] [100, 100] [1, 1] : tensor<100x100xf32> into tensor<1000x1000xf32>
+ return %inserted_slice : tensor<1000x1000xf32>
+}
+
+// -----
+
// CHECK-DAG: #[[$d0d2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-LABEL: func @insert_slice_of_transfer_write_swappy_rank_extending(
``````````
</details>
https://github.com/llvm/llvm-project/pull/108803
More information about the Mlir-commits
mailing list