[Mlir-commits] [mlir] 18e08fb - [mlir][linalg][bufferize] Fix tiled_loop bufferization

Matthias Springer llvmlistbot at llvm.org
Thu Jan 6 00:56:24 PST 2022


Author: Matthias Springer
Date: 2022-01-06T17:51:33+09:00
New Revision: 18e08fbd01bfc1efeccbdb0278660487c20eccba

URL: https://github.com/llvm/llvm-project/commit/18e08fbd01bfc1efeccbdb0278660487c20eccba
DIFF: https://github.com/llvm/llvm-project/commit/18e08fbd01bfc1efeccbdb0278660487c20eccba.diff

LOG: [mlir][linalg][bufferize] Fix tiled_loop bufferization

Until now, bufferization assumed that the yieleded tensor of a linalg.tiled_loop is an output tensor. This is not necessarily the case.

Differential Revision: https://reviews.llvm.org/D116685

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index dd9f12311754..536664a6dfb7 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -323,9 +323,23 @@ struct TiledLoopOpInterface
                          newBlockArgs);
 
     // Replace previous terminator with a new one that does not yield anything.
-    Operation *oldTerminator = newTiledLoopOp.getBody()->getTerminator();
+    auto oldTerminator =
+        cast<linalg::YieldOp>(newTiledLoopOp.getBody()->getTerminator());
     rewriter.setInsertionPointToEnd(newTiledLoopOp.getBody());
-    rewriter.create<linalg::YieldOp>(oldTerminator->getLoc());
+    auto newTerminator =
+        rewriter.create<linalg::YieldOp>(oldTerminator->getLoc());
+
+    // Copy buffer of yielded tensor to output buffer. If everything bufferized
+    // inplace, this copy will fold away.
+    rewriter.setInsertionPoint(newTerminator);
+    for (auto it : llvm::zip(oldTerminator.values(), newOutputs)) {
+      Value output = std::get<1>(it);
+      Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
+          newTerminator.getLoc(), output.getType(), std::get<0>(it));
+      state.createMemCpy(rewriter, newTerminator.getLoc(), toMemrefOp, output);
+    }
+
+    // Erase old terminator.
     rewriter.eraseOp(oldTerminator);
 
     // Replace results and delete old op.

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 8f08e37c6774..30fad7a2b928 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -576,6 +576,7 @@ func @tiled_dot(%A: tensor<?xf32>, %B: tensor<?xf32>, %c: tensor<f32> {linalg.in
   %0 = tensor.dim %A, %c0 : tensor<?xf32>
 
   //     CHECK: linalg.tiled_loop {{.*}} to (%[[M]]) {{.*}} %[[A]]{{.*}}%[[B]]{{.*}}outs{{.*}}%[[c]]
+  // CHECK-NOT: copy
   %1 = linalg.tiled_loop (%arg3) = (%c0) to (%0) step (%c3)
        ins (%arg4 = %A: tensor<?xf32>, %use = %effecting : memref<?xf32>, %arg5 = %B: tensor<?xf32>)
       outs (%arg6 = %c: tensor<f32>)
@@ -655,6 +656,40 @@ func @tiled_fill(%A: tensor<?xf32> {linalg.inplaceable = true}) -> tensor<?xf32>
 
 // -----
 
+//      CHECK:  func @tiled_loop_yield_out_of_place(
+// CHECK-SAME:    %[[A:[a-zA-Z0-9]*]]: memref<?xf32, #{{.*}}>,
+// CHECK-SAME:    %[[B:[a-zA-Z0-9]*]]: memref<?xf32, #{{.*}}>
+func @tiled_loop_yield_out_of_place(
+    %A: tensor<?xf32> {linalg.inplaceable = true},
+    %B: tensor<?xf32> {linalg.inplaceable = true})
+  -> tensor<?xf32>
+{
+  %c3 = arith.constant 3 : index
+  %c0 = arith.constant 0 : index
+  %f0 = arith.constant 0.0 : f32
+
+  //     CHECK: %[[M:.*]] = memref.dim %[[A]], {{.*}} : memref<?xf32, #[[$DYN_MAP:.*]]>
+  %0 = tensor.dim %A, %c0 : tensor<?xf32>
+
+  //     CHECK: linalg.tiled_loop {{.*}} to (%[[M]]) {{.*}} outs{{.*}}%[[A]]
+  %1 = linalg.tiled_loop (%arg3) = (%c0) to (%0) step (%c3)
+      outs (%arg1 = %A: tensor<?xf32>)
+      iterators["parallel"]
+  {
+    // CHECK-NOT:   alloc
+    //     CHECK:   linalg.copy(%[[B]], %[[A]])
+    linalg.yield %B : tensor<?xf32>
+    //     CHECK:   linalg.yield
+    // CHECK-NOT:   tensor
+  }
+
+  //     CHECK: return
+  // CHECK-NOT: tensor
+  return %1 : tensor<?xf32>
+}
+
+// -----
+
 // CHECK: #[[$DYNAMIC:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
 
 // CHECK: func private @external_func(memref<?xf32, #[[$DYNAMIC]]>)


        


More information about the Mlir-commits mailing list