[Mlir-commits] [mlir] 18e08fb - [mlir][linalg][bufferize] Fix tiled_loop bufferization
Matthias Springer
llvmlistbot at llvm.org
Thu Jan 6 00:56:24 PST 2022
Author: Matthias Springer
Date: 2022-01-06T17:51:33+09:00
New Revision: 18e08fbd01bfc1efeccbdb0278660487c20eccba
URL: https://github.com/llvm/llvm-project/commit/18e08fbd01bfc1efeccbdb0278660487c20eccba
DIFF: https://github.com/llvm/llvm-project/commit/18e08fbd01bfc1efeccbdb0278660487c20eccba.diff
LOG: [mlir][linalg][bufferize] Fix tiled_loop bufferization
Until now, bufferization assumed that the yieleded tensor of a linalg.tiled_loop is an output tensor. This is not necessarily the case.
Differential Revision: https://reviews.llvm.org/D116685
Added:
Modified:
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index dd9f12311754..536664a6dfb7 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -323,9 +323,23 @@ struct TiledLoopOpInterface
newBlockArgs);
// Replace previous terminator with a new one that does not yield anything.
- Operation *oldTerminator = newTiledLoopOp.getBody()->getTerminator();
+ auto oldTerminator =
+ cast<linalg::YieldOp>(newTiledLoopOp.getBody()->getTerminator());
rewriter.setInsertionPointToEnd(newTiledLoopOp.getBody());
- rewriter.create<linalg::YieldOp>(oldTerminator->getLoc());
+ auto newTerminator =
+ rewriter.create<linalg::YieldOp>(oldTerminator->getLoc());
+
+ // Copy buffer of yielded tensor to output buffer. If everything bufferized
+ // inplace, this copy will fold away.
+ rewriter.setInsertionPoint(newTerminator);
+ for (auto it : llvm::zip(oldTerminator.values(), newOutputs)) {
+ Value output = std::get<1>(it);
+ Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
+ newTerminator.getLoc(), output.getType(), std::get<0>(it));
+ state.createMemCpy(rewriter, newTerminator.getLoc(), toMemrefOp, output);
+ }
+
+ // Erase old terminator.
rewriter.eraseOp(oldTerminator);
// Replace results and delete old op.
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 8f08e37c6774..30fad7a2b928 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -576,6 +576,7 @@ func @tiled_dot(%A: tensor<?xf32>, %B: tensor<?xf32>, %c: tensor<f32> {linalg.in
%0 = tensor.dim %A, %c0 : tensor<?xf32>
// CHECK: linalg.tiled_loop {{.*}} to (%[[M]]) {{.*}} %[[A]]{{.*}}%[[B]]{{.*}}outs{{.*}}%[[c]]
+ // CHECK-NOT: copy
%1 = linalg.tiled_loop (%arg3) = (%c0) to (%0) step (%c3)
ins (%arg4 = %A: tensor<?xf32>, %use = %effecting : memref<?xf32>, %arg5 = %B: tensor<?xf32>)
outs (%arg6 = %c: tensor<f32>)
@@ -655,6 +656,40 @@ func @tiled_fill(%A: tensor<?xf32> {linalg.inplaceable = true}) -> tensor<?xf32>
// -----
+// CHECK: func @tiled_loop_yield_out_of_place(
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, #{{.*}}>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<?xf32, #{{.*}}>
+func @tiled_loop_yield_out_of_place(
+ %A: tensor<?xf32> {linalg.inplaceable = true},
+ %B: tensor<?xf32> {linalg.inplaceable = true})
+ -> tensor<?xf32>
+{
+ %c3 = arith.constant 3 : index
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f32
+
+ // CHECK: %[[M:.*]] = memref.dim %[[A]], {{.*}} : memref<?xf32, #[[$DYN_MAP:.*]]>
+ %0 = tensor.dim %A, %c0 : tensor<?xf32>
+
+ // CHECK: linalg.tiled_loop {{.*}} to (%[[M]]) {{.*}} outs{{.*}}%[[A]]
+ %1 = linalg.tiled_loop (%arg3) = (%c0) to (%0) step (%c3)
+ outs (%arg1 = %A: tensor<?xf32>)
+ iterators["parallel"]
+ {
+ // CHECK-NOT: alloc
+ // CHECK: linalg.copy(%[[B]], %[[A]])
+ linalg.yield %B : tensor<?xf32>
+ // CHECK: linalg.yield
+ // CHECK-NOT: tensor
+ }
+
+ // CHECK: return
+ // CHECK-NOT: tensor
+ return %1 : tensor<?xf32>
+}
+
+// -----
+
// CHECK: #[[$DYNAMIC:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
// CHECK: func private @external_func(memref<?xf32, #[[$DYNAMIC]]>)
More information about the Mlir-commits
mailing list