[Mlir-commits] [mlir] 2f946ea - [mlir] Change the pattern for TiledLoopOp bufferization.

Alexander Belyaev llvmlistbot at llvm.org
Tue Aug 10 12:27:24 PDT 2021


Author: Alexander Belyaev
Date: 2021-08-10T21:27:02+02:00
New Revision: 2f946eaa9d2648b883b2a1e567b23fff307f13d9

URL: https://github.com/llvm/llvm-project/commit/2f946eaa9d2648b883b2a1e567b23fff307f13d9
DIFF: https://github.com/llvm/llvm-project/commit/2f946eaa9d2648b883b2a1e567b23fff307f13d9.diff

LOG: [mlir] Change the pattern for TiledLoopOp bufferization.

This version is does not affect the patterns for Extract/InsertSliceOp and
LinalgOps.

Differential Revision: https://reviews.llvm.org/D107858

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/test/Dialect/Linalg/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index b46e58be8349a..04865594c1ad4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -213,10 +213,8 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
     Location loc = op.getLoc();
     SmallVector<Value, 2> newOutputBuffers;
 
-    if (op->getParentOfType<TiledLoopOp>()) {
-      newOutputBuffers = adaptor.outputs();
-    } else if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
-                                                newOutputBuffers, rewriter))) {
+    if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
+                                         newOutputBuffers, rewriter))) {
       return op.emitOpError()
              << "Failed to allocate buffers for tensor results.";
     }
@@ -233,14 +231,6 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
   }
 };
 
-bool IsBlockArgOfTiledLoop(Value tensor) {
-  if (auto tensorLoad = tensor.getDefiningOp<memref::TensorLoadOp>())
-    if (auto blockArgument = tensorLoad.memref().dyn_cast<BlockArgument>())
-      if (isa<TiledLoopOp>(blockArgument.getOwner()->getParentOp()))
-        return true;
-  return false;
-}
-
 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an
 /// alloc + copy pattern.
 /// ```
@@ -263,15 +253,6 @@ class ExtractSliceOpConverter
     Value sourceMemref = adaptor.source();
     assert(sourceMemref.getType().isa<MemRefType>());
 
-    // Block arguments of the tiled_loop can be bufferized inplace.
-    if (IsBlockArgOfTiledLoop(op.source())) {
-      Value subView = rewriter.create<memref::SubViewOp>(
-          op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
-          op.getMixedStrides());
-      rewriter.replaceOp(op, subView);
-      return success();
-    }
-
     MemRefType subviewMemRefType =
         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
     // op.sizes() capture exactly the dynamic alloc operands matching the
@@ -315,12 +296,7 @@ class InsertSliceOpConverter
     // For now, be conservative and copy the converted input memref.
     // In general, the converted input memref here could be aliased or could
     // point into constant memory, so mutating it would lead to miscompilations.
-    // Block arguments of the tiled_loop can be bufferized inplace.
-    Value destMemRef;
-    if (IsBlockArgOfTiledLoop(op.dest()))
-      destMemRef = adaptor.dest();
-    else
-      destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
+    Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
     assert(destMemRef.getType().isa<MemRefType>());
 
     // Take a subview to copy the small memref.
@@ -334,60 +310,115 @@ class InsertSliceOpConverter
   }
 };
 
+bool isBlockArgOfTiledLoop(Value tensor) {
+  if (auto blockArgument = tensor.dyn_cast<BlockArgument>())
+    return isa<TiledLoopOp>(blockArgument.getOwner()->getParentOp());
+  return false;
+}
+
+SmallVector<Value, 3> convertOperands(ValueRange operands,
+                                      BlockAndValueMapping &bvm) {
+  SmallVector<Value, 3> newOperands;
+  newOperands.reserve(operands.size());
+  for (auto operand : operands)
+    newOperands.push_back(bvm.lookupOrDefault(operand));
+  return newOperands;
+}
+
 class TiledLoopOpConverter : public OpConversionPattern<TiledLoopOp> {
 public:
   using OpConversionPattern<TiledLoopOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(TiledLoopOp tiledLoop, ArrayRef<Value> operands,
+  matchAndRewrite(TiledLoopOp loop, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    TiledLoopOp::Adaptor adaptor(operands, tiledLoop->getAttrDictionary());
-    Location loc = tiledLoop.getLoc();
-    if (tiledLoop.getNumResults() == 0)
+    TiledLoopOp::Adaptor adaptor(operands, loop->getAttrDictionary());
+    if (loop.getNumResults() == 0)
       return failure();
-    auto newTiledLoop = rewriter.create<TiledLoopOp>(
+
+    Location loc = loop.getLoc();
+    auto newLoop = rewriter.create<TiledLoopOp>(
         loc, adaptor.lowerBound(), adaptor.upperBound(), adaptor.step(),
         adaptor.inputs(), adaptor.outputs(), adaptor.iterator_types(),
         adaptor.distribution_types());
+
     // Clone the region.
     BlockAndValueMapping bvm;
-    bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
+    bvm.map(loop.getInductionVars(), newLoop.getInductionVars());
+    bvm.map(loop.getRegionInputArgs(), newLoop.getRegionInputArgs());
+    bvm.map(loop.getRegionOutputArgs(), newLoop.getRegionOutputArgs());
 
     OpBuilder innerBuilder =
-        OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
-
-    // Remap input block arguments.
-    SmallVector<Value, 2> inputs;
-    for (auto en : llvm::zip(newTiledLoop.getRegionInputArgs(),
-                             tiledLoop.getRegionInputArgs())) {
-      auto &newInputArg = std::get<0>(en);
-      if (!newInputArg.getType().isa<ShapedType>()) {
-        inputs.push_back(std::get<0>(en));
-        continue;
+        OpBuilder::atBlockEnd(newLoop.getBody(), rewriter.getListener());
+
+    for (auto &op : loop.getBody()->getOperations()) {
+      Location loc = op.getLoc();
+      if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
+        if (isBlockArgOfTiledLoop(extractSlice.source())) {
+          auto newOperands = convertOperands(extractSlice.getOperands(), bvm);
+          auto srcMemRefType =
+              bvm.lookup(extractSlice.source()).getType().cast<MemRefType>();
+          auto dstMemRefType =
+              memref::SubViewOp::inferResultType(
+                  srcMemRefType,
+                  extractFromI64ArrayAttr(extractSlice.static_offsets()),
+                  extractFromI64ArrayAttr(extractSlice.static_sizes()),
+                  extractFromI64ArrayAttr(extractSlice.static_strides()))
+                  .cast<MemRefType>();
+
+          Value subView = innerBuilder.create<memref::SubViewOp>(
+              loc, TypeRange{dstMemRefType}, newOperands,
+              extractSlice->getAttrs());
+          bvm.map(extractSlice.getResult(), subView);
+          continue;
+        }
       }
-      inputs.push_back(
-          innerBuilder.create<memref::TensorLoadOp>(loc, newInputArg));
-    }
-    bvm.map(tiledLoop.getRegionInputArgs(), inputs);
-
-    // Remap output block arguments.
-    SmallVector<Value, 2> outputs;
-    for (auto en : llvm::zip(newTiledLoop.getRegionOutputArgs(),
-                             tiledLoop.getRegionOutputArgs())) {
-      auto &newOutputArg = std::get<0>(en);
-      if (!newOutputArg.getType().isa<ShapedType>()) {
-        outputs.push_back(std::get<0>(en));
+      if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(op)) {
+        if (isBlockArgOfTiledLoop(insertSlice.dest())) {
+          continue;
+        }
+      }
+      if (auto yield = dyn_cast<linalg::YieldOp>(op)) {
+        for (OpOperand &operand : yield->getOpOperands()) {
+          if (auto insert =
+                  operand.get().getDefiningOp<tensor::InsertSliceOp>()) {
+
+            auto dstMemRefType = memref::SubViewOp::inferResultType(
+                getTypeConverter()
+                    ->convertType(insert.source().getType())
+                    .cast<MemRefType>(),
+                extractFromI64ArrayAttr(insert.static_offsets()),
+                extractFromI64ArrayAttr(insert.static_sizes()),
+                extractFromI64ArrayAttr(insert.static_strides()));
+
+            Value subView = innerBuilder.create<memref::SubViewOp>(
+                loc, dstMemRefType, bvm.lookup(insert.dest()),
+                convertOperands(insert.offsets(), bvm),
+                convertOperands(insert.sizes(), bvm),
+                convertOperands(insert.strides(), bvm), insert.static_offsets(),
+                insert.static_sizes(), insert.static_strides());
+
+            Value cast = innerBuilder.create<memref::BufferCastOp>(
+                loc,
+                getTypeConverter()
+                    ->convertType(insert.source().getType())
+                    .cast<MemRefType>(),
+                bvm.lookup(insert.source()));
+
+            innerBuilder.create<linalg::CopyOp>(loc, cast, subView);
+            continue;
+          }
+          auto dst = newLoop.getRegionOutputArgs()[operand.getOperandNumber()];
+          Value cast = innerBuilder.create<memref::BufferCastOp>(
+              loc, dst.getType(), bvm.lookup(operand.get()));
+          innerBuilder.create<linalg::CopyOp>(loc, cast, dst);
+        }
         continue;
       }
-      outputs.push_back(
-          innerBuilder.create<memref::TensorLoadOp>(loc, newOutputArg));
-    }
-    bvm.map(tiledLoop.getRegionOutputArgs(), outputs);
-
-    for (auto &op : tiledLoop.getBody()->without_terminator())
       innerBuilder.clone(op, bvm);
+    }
     innerBuilder.create<linalg::YieldOp>(loc);
-    rewriter.replaceOp(tiledLoop, newTiledLoop.outputs());
+    rewriter.replaceOp(loop, newLoop.outputs());
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index dbb2bb713e773..29f23c10e095c 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -339,13 +339,66 @@ func @tiled_dot(%A: tensor<10xf32>, %B: tensor<10xf32>,
     linalg.yield %dot_sub : tensor<f32>
   }
   // CHECK: linalg.tiled_loop
-  // CHECK-SAME: ins (%[[A:.*]] = %{{.*}}: memref<10xf32>, %[[B:.*]] = %{{.*}}: memref<10xf32>)
-  // CHECK-SAME: outs (%[[C:.*]] = %{{.*}}: memref<f32>)
-  //   CHECK-NOT:   alloc
-  //   CHECK:       %[[SV_A:.*]] = memref.subview %[[A]]
-  //   CHECK:       %[[SV_B:.*]] = memref.subview %[[B]]
-  //   CHECK:       linalg.dot ins(%[[SV_A]], %[[SV_B]]
-  //   CHECK-SAME:             outs(%[[C]] : memref<f32>)
-  //   CHECK:   linalg.yield
+  // CHECK-SAME: ins (%[[A:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>,
+  // CHECK-SAME:      %[[B:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>
+  // CHECK-SAME: outs (%[[C:arg[0-9]]] = %{{[0-9]}}: memref<f32>)
+
+  // CHECK-NEXT: %[[SV_A:.*]] = memref.subview %[[A]]
+  // CHECK-NEXT: %[[SV_B:.*]] = memref.subview %[[B]]
+  // CHECK-NEXT: %[[TMP:.*]] = memref.alloc
+  // CHECK-NEXT: linalg.copy(%[[C]], %[[TMP]])
+  // CHECK-NEXT: linalg.dot ins(%[[SV_A]], %[[SV_B]]
+  // CHECK-SAME:            outs(%[[TMP]] : memref<f32>)
+  // CHECK-NEXT: linalg.copy(%[[TMP]], %[[C]])
+  // CHECK-NEXT: linalg.yield
   return %dot : tensor<f32>
 }
+
+// -----
+
+#map0 = affine_map<(d0) -> (d0)>
+
+func @tiled_add(%A: tensor<10xf32>, %B: tensor<10xf32>,
+                  %C: tensor<10xf32>) -> tensor<10xf32> {
+  %c0 = constant 0 : index
+  %c2 = constant 2 : index
+  %c10 = constant 10 : index
+
+  %sum = linalg.tiled_loop (%i) = (%c0) to (%c10) step (%c2)
+       ins (%A_ = %A: tensor<10xf32>, %B_ = %B: tensor<10xf32>)
+       outs (%C_ = %C: tensor<10xf32>) {
+    %A_sub = tensor.extract_slice %A_[%i] [%c2] [1]
+      : tensor<10xf32> to tensor<?xf32>
+    %B_sub = tensor.extract_slice %B_[%i] [%c2] [1]
+      : tensor<10xf32> to tensor<?xf32>
+    %C_sub = tensor.extract_slice %C_[%i] [%c2] [1]
+      : tensor<10xf32> to tensor<?xf32>
+    %sum_sub = linalg.generic {
+      indexing_maps = [#map0, #map0, #map0],
+      iterator_types = ["parallel"]
+    } ins(%A_sub, %B_sub : tensor<?xf32>, tensor<?xf32>)
+      outs(%C_sub : tensor<?xf32>) {
+      ^bb0(%a: f32, %b: f32, %c: f32):
+        %0 = std.addf %a, %b : f32
+        linalg.yield %0 : f32
+    } -> tensor<?xf32>
+    %update = tensor.insert_slice %sum_sub into %C_[%i] [%c2] [1]
+      : tensor<?xf32> into tensor<10xf32>
+    linalg.yield %update : tensor<10xf32>
+  }
+  // CHECK: linalg.tiled_loop
+  // CHECK-SAME: ins (%[[A:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>,
+  // CHECK-SAME:      %[[B:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>
+  // CHECK-SAME: outs (%[[C:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>)
+
+  // CHECK-NEXT:  %[[SV_A:.*]] = memref.subview %[[A]]
+  // CHECK-NEXT:  %[[SV_B:.*]] = memref.subview %[[B]]
+  // CHECK-NEXT:  %[[TMP:.*]] = memref.alloc
+  // CHECK-NEXT:  linalg.generic
+  // CHECK-SAME:    ins(%[[SV_A]], %[[SV_B]]
+  // CHECK-SAME:    outs(%[[TMP]] : memref<2xf32>)
+  // CHECK:  %[[SV_C:.*]] = memref.subview %[[C]]
+  // CHECK-NEXT:  linalg.copy(%[[TMP]], %[[SV_C]])
+  // CHECK-NEXT:  linalg.yield
+  return %sum : tensor<10xf32>
+}


        


More information about the Mlir-commits mailing list