[Mlir-commits] [mlir] [mlir] [linalg] Fix bufferize error in tensor.parallel_insert_slice op (PR #98312)
donald chen
llvmlistbot at llvm.org
Wed Jul 10 06:07:18 PDT 2024
https://github.com/cxy-1993 created https://github.com/llvm/llvm-project/pull/98312
tensor.parallel op has implicit inplace behavior. In the "copy-before-write" bufferize mode, the resolveConflict function will generate bufferize.copy, making the result incorrect. This patch fixes this issue.
>From e8658600932a4fc55461b41177cdf378b53f0024 Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Wed, 10 Jul 2024 12:55:56 +0000
Subject: [PATCH] [mlir] [linalg] Fix bufferize error in
tensor.parallel_insert_slice op
tensor.parallel op has implicit inplace behavior. In the "copy-before-write"
bufferize mode, the resolveConflict function will generate bufferize.copy,
making the result incorrect. This patch fixes this issue.
---
.../BufferizableOpInterfaceImpl.cpp | 5 +++++
mlir/test/Dialect/Tensor/bufferize.mlir | 20 +++++++++++++++++++
2 files changed, 25 insertions(+)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index d078a575f40dd..eabcff33df98a 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -997,6 +997,11 @@ struct ParallelInsertSliceOpInterface
rewriter.eraseOp(op);
return success();
}
+
+ LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
+ const AnalysisState &state) const {
+ return success();
+ }
};
/// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index e85d9e740adf4..3a3c8af15e6e4 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -626,3 +626,23 @@ func.func @tensor.splat_dynamic(%f: f32, %m: index, %n: index) -> tensor<?x3x?xf
return %0 : tensor<?x3x?xf32>
}
+// -----
+
+// CHECK-LABEL: func.func @parallel_insert_slice_copy_before_write
+func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: tensor<4xf32>) {
+ %c1 = arith.constant 1 : index
+ %num_threads = arith.constant 4 : index
+
+ // CHECK: scf.forall {{.*}} {
+ %result = scf.forall (%thread_idx) in (%num_threads) shared_outs (%o = %out) -> tensor<4xf32> {
+ %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<4xf32> to tensor<1xf32>
+ scf.forall.in_parallel {
+ // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, strided<[1], offset: ?>>
+ // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<4xf32> to memref<1xf32, strided<[1], offset: ?>>
+ tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
+ tensor<1xf32> into tensor<4xf32>
+ }
+ }
+ // CHECK: }
+ return
+}
More information about the Mlir-commits
mailing list