[Mlir-commits] [mlir] 967578f - Revert "[mlir] Change the pattern for TiledLoopOp bufferization."

Alexander Belyaev llvmlistbot at llvm.org
Wed Aug 11 01:03:39 PDT 2021


Author: Alexander Belyaev
Date: 2021-08-11T10:01:36+02:00
New Revision: 967578f0b8b1bece55de5cdacbe960c5ad87ab4e

URL: https://github.com/llvm/llvm-project/commit/967578f0b8b1bece55de5cdacbe960c5ad87ab4e
DIFF: https://github.com/llvm/llvm-project/commit/967578f0b8b1bece55de5cdacbe960c5ad87ab4e.diff

LOG: Revert "[mlir] Change the pattern for TiledLoopOp bufferization."

This reverts commit 2f946eaa9d2648b883b2a1e567b23fff307f13d9.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/test/Dialect/Linalg/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index 04865594c1ad4..b46e58be8349a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -213,8 +213,10 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
     Location loc = op.getLoc();
     SmallVector<Value, 2> newOutputBuffers;
 
-    if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
-                                         newOutputBuffers, rewriter))) {
+    if (op->getParentOfType<TiledLoopOp>()) {
+      newOutputBuffers = adaptor.outputs();
+    } else if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
+                                                newOutputBuffers, rewriter))) {
       return op.emitOpError()
              << "Failed to allocate buffers for tensor results.";
     }
@@ -231,6 +233,14 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
   }
 };
 
+bool IsBlockArgOfTiledLoop(Value tensor) {
+  if (auto tensorLoad = tensor.getDefiningOp<memref::TensorLoadOp>())
+    if (auto blockArgument = tensorLoad.memref().dyn_cast<BlockArgument>())
+      if (isa<TiledLoopOp>(blockArgument.getOwner()->getParentOp()))
+        return true;
+  return false;
+}
+
 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an
 /// alloc + copy pattern.
 /// ```
@@ -253,6 +263,15 @@ class ExtractSliceOpConverter
     Value sourceMemref = adaptor.source();
     assert(sourceMemref.getType().isa<MemRefType>());
 
+    // Block arguments of the tiled_loop can be bufferized inplace.
+    if (IsBlockArgOfTiledLoop(op.source())) {
+      Value subView = rewriter.create<memref::SubViewOp>(
+          op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
+          op.getMixedStrides());
+      rewriter.replaceOp(op, subView);
+      return success();
+    }
+
     MemRefType subviewMemRefType =
         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
     // op.sizes() capture exactly the dynamic alloc operands matching the
@@ -296,7 +315,12 @@ class InsertSliceOpConverter
     // For now, be conservative and copy the converted input memref.
     // In general, the converted input memref here could be aliased or could
     // point into constant memory, so mutating it would lead to miscompilations.
-    Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
+    // Block arguments of the tiled_loop can be bufferized inplace.
+    Value destMemRef;
+    if (IsBlockArgOfTiledLoop(op.dest()))
+      destMemRef = adaptor.dest();
+    else
+      destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
     assert(destMemRef.getType().isa<MemRefType>());
 
     // Take a subview to copy the small memref.
@@ -310,115 +334,60 @@ class InsertSliceOpConverter
   }
 };
 
-bool isBlockArgOfTiledLoop(Value tensor) {
-  if (auto blockArgument = tensor.dyn_cast<BlockArgument>())
-    return isa<TiledLoopOp>(blockArgument.getOwner()->getParentOp());
-  return false;
-}
-
-SmallVector<Value, 3> convertOperands(ValueRange operands,
-                                      BlockAndValueMapping &bvm) {
-  SmallVector<Value, 3> newOperands;
-  newOperands.reserve(operands.size());
-  for (auto operand : operands)
-    newOperands.push_back(bvm.lookupOrDefault(operand));
-  return newOperands;
-}
-
 class TiledLoopOpConverter : public OpConversionPattern<TiledLoopOp> {
 public:
   using OpConversionPattern<TiledLoopOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(TiledLoopOp loop, ArrayRef<Value> operands,
+  matchAndRewrite(TiledLoopOp tiledLoop, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    TiledLoopOp::Adaptor adaptor(operands, loop->getAttrDictionary());
-    if (loop.getNumResults() == 0)
+    TiledLoopOp::Adaptor adaptor(operands, tiledLoop->getAttrDictionary());
+    Location loc = tiledLoop.getLoc();
+    if (tiledLoop.getNumResults() == 0)
       return failure();
-
-    Location loc = loop.getLoc();
-    auto newLoop = rewriter.create<TiledLoopOp>(
+    auto newTiledLoop = rewriter.create<TiledLoopOp>(
         loc, adaptor.lowerBound(), adaptor.upperBound(), adaptor.step(),
         adaptor.inputs(), adaptor.outputs(), adaptor.iterator_types(),
         adaptor.distribution_types());
-
     // Clone the region.
     BlockAndValueMapping bvm;
-    bvm.map(loop.getInductionVars(), newLoop.getInductionVars());
-    bvm.map(loop.getRegionInputArgs(), newLoop.getRegionInputArgs());
-    bvm.map(loop.getRegionOutputArgs(), newLoop.getRegionOutputArgs());
+    bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
 
     OpBuilder innerBuilder =
-        OpBuilder::atBlockEnd(newLoop.getBody(), rewriter.getListener());
-
-    for (auto &op : loop.getBody()->getOperations()) {
-      Location loc = op.getLoc();
-      if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
-        if (isBlockArgOfTiledLoop(extractSlice.source())) {
-          auto newOperands = convertOperands(extractSlice.getOperands(), bvm);
-          auto srcMemRefType =
-              bvm.lookup(extractSlice.source()).getType().cast<MemRefType>();
-          auto dstMemRefType =
-              memref::SubViewOp::inferResultType(
-                  srcMemRefType,
-                  extractFromI64ArrayAttr(extractSlice.static_offsets()),
-                  extractFromI64ArrayAttr(extractSlice.static_sizes()),
-                  extractFromI64ArrayAttr(extractSlice.static_strides()))
-                  .cast<MemRefType>();
-
-          Value subView = innerBuilder.create<memref::SubViewOp>(
-              loc, TypeRange{dstMemRefType}, newOperands,
-              extractSlice->getAttrs());
-          bvm.map(extractSlice.getResult(), subView);
-          continue;
-        }
-      }
-      if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(op)) {
-        if (isBlockArgOfTiledLoop(insertSlice.dest())) {
-          continue;
-        }
+        OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
+
+    // Remap input block arguments.
+    SmallVector<Value, 2> inputs;
+    for (auto en : llvm::zip(newTiledLoop.getRegionInputArgs(),
+                             tiledLoop.getRegionInputArgs())) {
+      auto &newInputArg = std::get<0>(en);
+      if (!newInputArg.getType().isa<ShapedType>()) {
+        inputs.push_back(std::get<0>(en));
+        continue;
       }
-      if (auto yield = dyn_cast<linalg::YieldOp>(op)) {
-        for (OpOperand &operand : yield->getOpOperands()) {
-          if (auto insert =
-                  operand.get().getDefiningOp<tensor::InsertSliceOp>()) {
-
-            auto dstMemRefType = memref::SubViewOp::inferResultType(
-                getTypeConverter()
-                    ->convertType(insert.source().getType())
-                    .cast<MemRefType>(),
-                extractFromI64ArrayAttr(insert.static_offsets()),
-                extractFromI64ArrayAttr(insert.static_sizes()),
-                extractFromI64ArrayAttr(insert.static_strides()));
-
-            Value subView = innerBuilder.create<memref::SubViewOp>(
-                loc, dstMemRefType, bvm.lookup(insert.dest()),
-                convertOperands(insert.offsets(), bvm),
-                convertOperands(insert.sizes(), bvm),
-                convertOperands(insert.strides(), bvm), insert.static_offsets(),
-                insert.static_sizes(), insert.static_strides());
-
-            Value cast = innerBuilder.create<memref::BufferCastOp>(
-                loc,
-                getTypeConverter()
-                    ->convertType(insert.source().getType())
-                    .cast<MemRefType>(),
-                bvm.lookup(insert.source()));
-
-            innerBuilder.create<linalg::CopyOp>(loc, cast, subView);
-            continue;
-          }
-          auto dst = newLoop.getRegionOutputArgs()[operand.getOperandNumber()];
-          Value cast = innerBuilder.create<memref::BufferCastOp>(
-              loc, dst.getType(), bvm.lookup(operand.get()));
-          innerBuilder.create<linalg::CopyOp>(loc, cast, dst);
-        }
+      inputs.push_back(
+          innerBuilder.create<memref::TensorLoadOp>(loc, newInputArg));
+    }
+    bvm.map(tiledLoop.getRegionInputArgs(), inputs);
+
+    // Remap output block arguments.
+    SmallVector<Value, 2> outputs;
+    for (auto en : llvm::zip(newTiledLoop.getRegionOutputArgs(),
+                             tiledLoop.getRegionOutputArgs())) {
+      auto &newOutputArg = std::get<0>(en);
+      if (!newOutputArg.getType().isa<ShapedType>()) {
+        outputs.push_back(std::get<0>(en));
         continue;
       }
-      innerBuilder.clone(op, bvm);
+      outputs.push_back(
+          innerBuilder.create<memref::TensorLoadOp>(loc, newOutputArg));
     }
+    bvm.map(tiledLoop.getRegionOutputArgs(), outputs);
+
+    for (auto &op : tiledLoop.getBody()->without_terminator())
+      innerBuilder.clone(op, bvm);
     innerBuilder.create<linalg::YieldOp>(loc);
-    rewriter.replaceOp(loop, newLoop.outputs());
+    rewriter.replaceOp(tiledLoop, newTiledLoop.outputs());
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 29f23c10e095c..dbb2bb713e773 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -339,66 +339,13 @@ func @tiled_dot(%A: tensor<10xf32>, %B: tensor<10xf32>,
     linalg.yield %dot_sub : tensor<f32>
   }
   // CHECK: linalg.tiled_loop
-  // CHECK-SAME: ins (%[[A:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>,
-  // CHECK-SAME:      %[[B:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>
-  // CHECK-SAME: outs (%[[C:arg[0-9]]] = %{{[0-9]}}: memref<f32>)
-
-  // CHECK-NEXT: %[[SV_A:.*]] = memref.subview %[[A]]
-  // CHECK-NEXT: %[[SV_B:.*]] = memref.subview %[[B]]
-  // CHECK-NEXT: %[[TMP:.*]] = memref.alloc
-  // CHECK-NEXT: linalg.copy(%[[C]], %[[TMP]])
-  // CHECK-NEXT: linalg.dot ins(%[[SV_A]], %[[SV_B]]
-  // CHECK-SAME:            outs(%[[TMP]] : memref<f32>)
-  // CHECK-NEXT: linalg.copy(%[[TMP]], %[[C]])
-  // CHECK-NEXT: linalg.yield
+  // CHECK-SAME: ins (%[[A:.*]] = %{{.*}}: memref<10xf32>, %[[B:.*]] = %{{.*}}: memref<10xf32>)
+  // CHECK-SAME: outs (%[[C:.*]] = %{{.*}}: memref<f32>)
+  //   CHECK-NOT:   alloc
+  //   CHECK:       %[[SV_A:.*]] = memref.subview %[[A]]
+  //   CHECK:       %[[SV_B:.*]] = memref.subview %[[B]]
+  //   CHECK:       linalg.dot ins(%[[SV_A]], %[[SV_B]]
+  //   CHECK-SAME:             outs(%[[C]] : memref<f32>)
+  //   CHECK:   linalg.yield
   return %dot : tensor<f32>
 }
-
-// -----
-
-#map0 = affine_map<(d0) -> (d0)>
-
-func @tiled_add(%A: tensor<10xf32>, %B: tensor<10xf32>,
-                  %C: tensor<10xf32>) -> tensor<10xf32> {
-  %c0 = constant 0 : index
-  %c2 = constant 2 : index
-  %c10 = constant 10 : index
-
-  %sum = linalg.tiled_loop (%i) = (%c0) to (%c10) step (%c2)
-       ins (%A_ = %A: tensor<10xf32>, %B_ = %B: tensor<10xf32>)
-       outs (%C_ = %C: tensor<10xf32>) {
-    %A_sub = tensor.extract_slice %A_[%i] [%c2] [1]
-      : tensor<10xf32> to tensor<?xf32>
-    %B_sub = tensor.extract_slice %B_[%i] [%c2] [1]
-      : tensor<10xf32> to tensor<?xf32>
-    %C_sub = tensor.extract_slice %C_[%i] [%c2] [1]
-      : tensor<10xf32> to tensor<?xf32>
-    %sum_sub = linalg.generic {
-      indexing_maps = [#map0, #map0, #map0],
-      iterator_types = ["parallel"]
-    } ins(%A_sub, %B_sub : tensor<?xf32>, tensor<?xf32>)
-      outs(%C_sub : tensor<?xf32>) {
-      ^bb0(%a: f32, %b: f32, %c: f32):
-        %0 = std.addf %a, %b : f32
-        linalg.yield %0 : f32
-    } -> tensor<?xf32>
-    %update = tensor.insert_slice %sum_sub into %C_[%i] [%c2] [1]
-      : tensor<?xf32> into tensor<10xf32>
-    linalg.yield %update : tensor<10xf32>
-  }
-  // CHECK: linalg.tiled_loop
-  // CHECK-SAME: ins (%[[A:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>,
-  // CHECK-SAME:      %[[B:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>
-  // CHECK-SAME: outs (%[[C:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>)
-
-  // CHECK-NEXT:  %[[SV_A:.*]] = memref.subview %[[A]]
-  // CHECK-NEXT:  %[[SV_B:.*]] = memref.subview %[[B]]
-  // CHECK-NEXT:  %[[TMP:.*]] = memref.alloc
-  // CHECK-NEXT:  linalg.generic
-  // CHECK-SAME:    ins(%[[SV_A]], %[[SV_B]]
-  // CHECK-SAME:    outs(%[[TMP]] : memref<2xf32>)
-  // CHECK:  %[[SV_C:.*]] = memref.subview %[[C]]
-  // CHECK-NEXT:  linalg.copy(%[[TMP]], %[[SV_C]])
-  // CHECK-NEXT:  linalg.yield
-  return %sum : tensor<10xf32>
-}


        


More information about the Mlir-commits mailing list