[Mlir-commits] [mlir] 6d7c9c3 - [mlir][Linalg] Bufferize the region of LinalgOps as well.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 8 22:36:15 PST 2021


Author: MaheshRavishankar
Date: 2021-12-08T22:36:01-08:00
New Revision: 6d7c9c3d0e78815fc54246dd2d233d3900d22e1c

URL: https://github.com/llvm/llvm-project/commit/6d7c9c3d0e78815fc54246dd2d233d3900d22e1c
DIFF: https://github.com/llvm/llvm-project/commit/6d7c9c3d0e78815fc54246dd2d233d3900d22e1c.diff

LOG: [mlir][Linalg] Bufferize the region of LinalgOps as well.

The region of `linalg.generic` might contain `tensor` operations. For
example, current lowering of `gather` uses a `tensor.extract` in the
body of the `LinalgOp`. Bufferize the ops within a `LinalgOp` region
as well to catch such cases.

Differential Revision: https://reviews.llvm.org/D115322

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 3ac95dbe18c4f..81d865042d7f5 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -86,15 +86,15 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
 
   // Set insertion point now that potential alloc/dealloc are introduced.
   b.setInsertionPoint(op);
-  op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands);
+  auto bufferizedOp = cast<LinalgOp>(
+      op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands));
 
   // Replace the results of the old op with the new output buffers.
   if (op->getNumResults())
     state.mapBuffer(op->getResults(), newOutputBuffers);
 
   // The original op will be DCE'd away later.
-
-  return success();
+  return comprehensive_bufferize::bufferize(bufferizedOp.getBlock(), state);
 }
 
 template <typename OpTy>

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 7f908b9b8d924..bd5f73f29df7e 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1008,3 +1008,31 @@ func private @private_func(tensor<?xf32>) -> ()
 func @empty_func() -> () {
   return
 }
+
+// -----
+
+func @gather_like(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xi32>,
+    %arg2 : tensor<?x?xf32> {linalg.inplaceable = true}) -> tensor<?x?xf32> {
+  %0 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+                       affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%arg1 : tensor<?xi32>) outs(%arg2 : tensor<?x?xf32>) {
+      ^bb0(%arg3: i32, %arg4 : f32):
+        %iv1 = linalg.index 1 : index
+	%1 = arith.index_cast %arg3: i32 to index
+	%2 = tensor.extract %arg0[%1, %iv1] : tensor<?x?xf32>
+	linalg.yield %2 : f32
+      } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @gather_like(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32,
+//  CHECK-SAME:     %[[ARG1:.+]]: memref<?xi32
+//  CHECK-SAME:     %[[ARG2:.+]]: memref<?x?xf32
+//  CHECK-SAME:   ) {
+//       CHECK:   linalg.generic
+//  CHECK-SAME:       ins(%[[ARG1]] :
+//  CHECK-SAME:       outs(%[[ARG2]] :
+//       CHECK:     %[[YIELD:.+]] = memref.load %[[ARG0]]
+//       CHECK:     linalg.yield %[[YIELD]]


        


More information about the Mlir-commits mailing list