[Mlir-commits] [mlir] 1652871 - [mlir][linalg][bufferize] Reimplementation of TiledLoopOp bufferization

Matthias Springer llvmlistbot at llvm.org
Wed Dec 15 01:48:45 PST 2021


Author: Matthias Springer
Date: 2021-12-15T18:45:29+09:00
New Revision: 1652871473a70e9908d794e858388edc89cd6e88

URL: https://github.com/llvm/llvm-project/commit/1652871473a70e9908d794e858388edc89cd6e88
DIFF: https://github.com/llvm/llvm-project/commit/1652871473a70e9908d794e858388edc89cd6e88.diff

LOG: [mlir][linalg][bufferize] Reimplementation of TiledLoopOp bufferization

Instead of modifying the existing linalg.tiled_loop op, create a new op with memref input/outputs and delete the old op.

Differential Revision: https://reviews.llvm.org/D115493

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 5abfbe559d0a..f9c6281e67c7 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -241,115 +242,88 @@ struct TiledLoopOpInterface
                           BufferizationState &state) const {
     auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
 
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-
-    // Allocate output buffers if needed, forward output tensor args to the
-    // terminator.
-    Operation *yieldOp = tiledLoopOp.getBody()->getTerminator();
-    Block *body = tiledLoopOp.getBody();
-
-    // Take copies of the old input and output operands, so we can insert
-    // inplace easily.
-    auto oldInputs = llvm::to_vector<4>(tiledLoopOp.inputs());
-    auto oldOutputs = llvm::to_vector<4>(tiledLoopOp.outputs());
-
-    int numLoops = tiledLoopOp.getNumLoops();
-    int numControlOperands = tiledLoopOp.getNumControlOperands();
-
-    // Add buffers for outputs and the corresponding block arguments.
-    // Keep separate iterators to increment without further leaking impl.
-    // details. Start with outputs to avoid interference from new input buffers.
-    int numNewOutputBuffers = 0;
-    int resultIndex = 0;
-    int oldOutputBBArgIndex = numLoops + oldInputs.size();
-    int nextOutputBBArgIndex = numLoops + oldInputs.size() + oldOutputs.size();
-    int nextOutputOperandIndex =
-        numControlOperands + oldInputs.size() + oldOutputs.size();
-    for (Value oldOutputTensor : oldOutputs) {
-      if (!oldOutputTensor.getType().isa<TensorType>()) {
-        // Skip and increment the old bbarg index only.
-        ++oldOutputBBArgIndex;
-        // Do not increment resultIndex as only tensors are returned.
-        // TODO: better interface to avoid leaking such impl details.
-        continue;
+    // Use IRRewriter instead of OpBuilder because it has additional helper
+    // functions.
+    IRRewriter rewriter(op->getContext());
+    rewriter.setInsertionPoint(tiledLoopOp);
+
+    // Compute new inputs, outputs and results.
+    SmallVector<Value> newInputs, newOutputs, newResults;
+    for (Value value : tiledLoopOp.inputs()) {
+      if (value.getType().isa<TensorType>()) {
+        newInputs.push_back(state.lookupBuffer(value));
+      } else {
+        newInputs.push_back(value);
       }
-
-      assert(oldOutputTensor.getType().isa<RankedTensorType>() &&
-             "bufferizable output must be a ranked tensor");
-
-      const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
-      OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
-      Value resultBuffer = state.getResultBuffer(opResult);
-      if (!resultBuffer)
-        return failure();
-
-      // Insert mapping and aliasing info.
-      state.mapBuffer(opResult, resultBuffer);
-
-      // Insert new operand and bbArg.
-      tiledLoopOp->insertOperands(nextOutputOperandIndex, resultBuffer);
-      BlockArgument newBufferBBArg =
-          body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType());
-      BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex);
-      // Insert mapping and aliasing info.
-      state.mapBuffer(oldTensorBBArg, newBufferBBArg);
-
-      // Set operand of `linalg.yield` to the bbArg so it just canonicalizes
-      // away later.
-      yieldOperand.set(oldTensorBBArg);
-
-      // Increment indices.
-      ++numNewOutputBuffers;
-      ++resultIndex;
-      ++oldOutputBBArgIndex;
-      ++nextOutputBBArgIndex;
-      ++nextOutputOperandIndex;
     }
-
-    // Add buffers for inputs and the corresponding block arguments.
-    // Keep separate iterators to increment without further leaking impl.
-    // details.
-    int numNewInputBuffers = 0;
-    int oldInputBBArgIndex = numLoops;
-    int nextInputBBArgIndex = numLoops + oldInputs.size();
-    int nextInputOperandIndex = numControlOperands + oldInputs.size();
-    for (Value oldInputTensor : oldInputs) {
-      if (!oldInputTensor.getType().isa<TensorType>()) {
-        // Skip and increment the old bbarg index only.
-        ++oldInputBBArgIndex;
-        continue;
+    int nextResultNum = 0;
+    for (Value value : tiledLoopOp.outputs()) {
+      if (value.getType().isa<TensorType>()) {
+        Value buffer =
+            state.getResultBuffer(tiledLoopOp->getResult(nextResultNum++));
+        newOutputs.push_back(buffer);
+        newResults.push_back(buffer);
+      } else {
+        newOutputs.push_back(value);
       }
+    }
 
-      Value inputBuffer = state.lookupBuffer(oldInputTensor);
-
-      // Insert new operand and bbArg.
-      tiledLoopOp->insertOperands(nextInputOperandIndex, inputBuffer);
-      BlockArgument newBufferBBArg =
-          body->insertArgument(nextInputBBArgIndex, inputBuffer.getType());
-      BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex);
+    // Create new TiledLoopOp.
+    auto newTiledLoopOp = rewriter.create<TiledLoopOp>(
+        tiledLoopOp.getLoc(), tiledLoopOp.lowerBound(),
+        tiledLoopOp.upperBound(), tiledLoopOp.step(), newInputs, newOutputs,
+        tiledLoopOp.iterator_types(), tiledLoopOp.distribution_types());
+
+    // Remove terminator.
+    if (!newTiledLoopOp.getBody()->empty())
+      rewriter.eraseOp(tiledLoopOp.getBody()->getTerminator());
+
+    // Compute new loop body arguments.
+    SmallVector<Value> newBlockArgs, newRegionInOutArgs, oldRegionInOutArgs;
+    ValueRange newInductionVars = newTiledLoopOp.getInductionVars();
+    newBlockArgs.append(newInductionVars.begin(), newInductionVars.end());
+
+    ValueRange newRegionInArgs = newTiledLoopOp.getRegionInputArgs();
+    ValueRange newRegionOutArgs = newTiledLoopOp.getRegionOutputArgs();
+    newRegionInOutArgs.append(newRegionInArgs.begin(), newRegionInArgs.end());
+    newRegionInOutArgs.append(newRegionOutArgs.begin(), newRegionOutArgs.end());
+
+    ValueRange oldRegionInArgs = tiledLoopOp.getRegionInputArgs();
+    ValueRange oldRegionOutArgs = tiledLoopOp.getRegionOutputArgs();
+    oldRegionInOutArgs.append(oldRegionInArgs.begin(), oldRegionInArgs.end());
+    oldRegionInOutArgs.append(oldRegionOutArgs.begin(), oldRegionOutArgs.end());
+    assert(newRegionInArgs.size() == oldRegionInArgs.size() &&
+           "expected same number of input args");
+    assert(newRegionOutArgs.size() == oldRegionOutArgs.size() &&
+           "expected same number of output args");
+
+    for (auto it : llvm::zip(oldRegionInOutArgs, newRegionInOutArgs)) {
+      Value oldArg = std::get<0>(it);
+      Value newArg = std::get<1>(it);
+      rewriter.setInsertionPointToStart(newTiledLoopOp->getBlock());
+      if (oldArg.getType().isa<TensorType>()) {
+        newBlockArgs.push_back(rewriter.create<bufferization::ToTensorOp>(
+            oldArg.getLoc(), newArg));
+      } else {
+        newBlockArgs.push_back(newArg);
+      }
+    }
 
-      // Insert mapping and aliasing info.
-      state.mapBuffer(oldTensorBBArg, newBufferBBArg);
+    // Move old body into new loop.
+    rewriter.mergeBlocks(tiledLoopOp.getBody(), newTiledLoopOp.getBody(),
+                         newBlockArgs);
 
-      // Increment indices.
-      ++numNewInputBuffers;
-      ++oldInputBBArgIndex;
-      ++nextInputBBArgIndex;
-      ++nextInputOperandIndex;
-    }
+    // Replace previous terminator with a new one that does not yield anything.
+    Operation *oldTerminator = newTiledLoopOp.getBody()->getTerminator();
+    rewriter.setInsertionPointToEnd(newTiledLoopOp.getBody());
+    rewriter.create<linalg::YieldOp>(oldTerminator->getLoc());
+    rewriter.eraseOp(oldTerminator);
 
-    // Update segment sizes.
-    // TODO: Helper method to avoid leaking impl details.
-    tiledLoopOp->setAttr(
-        TiledLoopOp::getOperandSegmentSizeAttr(),
-        b.getI32VectorAttr(
-            {numLoops, numLoops, numLoops,
-             static_cast<int>(oldInputs.size()) + numNewInputBuffers,
-             static_cast<int>(oldOutputs.size()) + numNewOutputBuffers}));
+    // Replace results and delete old op.
+    state.replaceOp(op, newResults);
 
     // Bufferize loop body.
-    return comprehensive_bufferize::bufferize(&tiledLoopOp.region(), state);
+    return comprehensive_bufferize::bufferize(newTiledLoopOp.getBody(), state);
   }
 };
 
@@ -372,14 +346,14 @@ struct YieldOpInterface
                           BufferizationState &state) const {
     auto yieldOp = cast<linalg::YieldOp>(op);
 
-    // No tensors -> success.
-    if (!llvm::any_of(yieldOp.getOperandTypes(),
-                      [](Type t) { return t.isa<TensorType>(); }))
-      return success();
-    // linalg::YieldOp nested under TiledLoop must just canonicalize.
-    if (yieldOp->getParentOfType<TiledLoopOp>())
-      return success();
-    llvm_unreachable("unexpected yieldOp");
+    if (!yieldOp->getParentOfType<TiledLoopOp>())
+      return yieldOp->emitError(
+          "expected that linalg.yield terminates a tiled_loop");
+
+    assert(yieldOp->getOpOperands().empty() &&
+           "expected that linalg.yield was bufferized together with"
+           " tiled_loop");
+    return success();
   }
 };
 

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 2a760cc9bfcb..691d14b791cc 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6393,6 +6393,7 @@ cc_library(
     includes = ["include"],
     deps = [
         ":BufferizableOpInterface",
+        ":BufferizationDialect",
         ":IR",
         ":LinalgOps",
         ":LinalgStructuredOpsIncGen",


        


More information about the Mlir-commits mailing list