[Mlir-commits] [mlir] 1652871 - [mlir][linalg][bufferize] Reimplementation of TiledLoopOp bufferization
Matthias Springer
llvmlistbot at llvm.org
Wed Dec 15 01:48:45 PST 2021
Author: Matthias Springer
Date: 2021-12-15T18:45:29+09:00
New Revision: 1652871473a70e9908d794e858388edc89cd6e88
URL: https://github.com/llvm/llvm-project/commit/1652871473a70e9908d794e858388edc89cd6e88
DIFF: https://github.com/llvm/llvm-project/commit/1652871473a70e9908d794e858388edc89cd6e88.diff
LOG: [mlir][linalg][bufferize] Reimplementation of TiledLoopOp bufferization
Instead of modifying the existing linalg.tiled_loop op, create a new op with memref input/outputs and delete the old op.
Differential Revision: https://reviews.llvm.org/D115493
Added:
Modified:
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 5abfbe559d0a..f9c6281e67c7 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -241,115 +242,88 @@ struct TiledLoopOpInterface
BufferizationState &state) const {
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
-
- // Allocate output buffers if needed, forward output tensor args to the
- // terminator.
- Operation *yieldOp = tiledLoopOp.getBody()->getTerminator();
- Block *body = tiledLoopOp.getBody();
-
- // Take copies of the old input and output operands, so we can insert
- // inplace easily.
- auto oldInputs = llvm::to_vector<4>(tiledLoopOp.inputs());
- auto oldOutputs = llvm::to_vector<4>(tiledLoopOp.outputs());
-
- int numLoops = tiledLoopOp.getNumLoops();
- int numControlOperands = tiledLoopOp.getNumControlOperands();
-
- // Add buffers for outputs and the corresponding block arguments.
- // Keep separate iterators to increment without further leaking impl.
- // details. Start with outputs to avoid interference from new input buffers.
- int numNewOutputBuffers = 0;
- int resultIndex = 0;
- int oldOutputBBArgIndex = numLoops + oldInputs.size();
- int nextOutputBBArgIndex = numLoops + oldInputs.size() + oldOutputs.size();
- int nextOutputOperandIndex =
- numControlOperands + oldInputs.size() + oldOutputs.size();
- for (Value oldOutputTensor : oldOutputs) {
- if (!oldOutputTensor.getType().isa<TensorType>()) {
- // Skip and increment the old bbarg index only.
- ++oldOutputBBArgIndex;
- // Do not increment resultIndex as only tensors are returned.
- // TODO: better interface to avoid leaking such impl details.
- continue;
+ // Use IRRewriter instead of OpBuilder because it has additional helper
+ // functions.
+ IRRewriter rewriter(op->getContext());
+ rewriter.setInsertionPoint(tiledLoopOp);
+
+ // Compute new inputs, outputs and results.
+ SmallVector<Value> newInputs, newOutputs, newResults;
+ for (Value value : tiledLoopOp.inputs()) {
+ if (value.getType().isa<TensorType>()) {
+ newInputs.push_back(state.lookupBuffer(value));
+ } else {
+ newInputs.push_back(value);
}
-
- assert(oldOutputTensor.getType().isa<RankedTensorType>() &&
- "bufferizable output must be a ranked tensor");
-
- const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
- OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
- Value resultBuffer = state.getResultBuffer(opResult);
- if (!resultBuffer)
- return failure();
-
- // Insert mapping and aliasing info.
- state.mapBuffer(opResult, resultBuffer);
-
- // Insert new operand and bbArg.
- tiledLoopOp->insertOperands(nextOutputOperandIndex, resultBuffer);
- BlockArgument newBufferBBArg =
- body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType());
- BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex);
- // Insert mapping and aliasing info.
- state.mapBuffer(oldTensorBBArg, newBufferBBArg);
-
- // Set operand of `linalg.yield` to the bbArg so it just canonicalizes
- // away later.
- yieldOperand.set(oldTensorBBArg);
-
- // Increment indices.
- ++numNewOutputBuffers;
- ++resultIndex;
- ++oldOutputBBArgIndex;
- ++nextOutputBBArgIndex;
- ++nextOutputOperandIndex;
}
-
- // Add buffers for inputs and the corresponding block arguments.
- // Keep separate iterators to increment without further leaking impl.
- // details.
- int numNewInputBuffers = 0;
- int oldInputBBArgIndex = numLoops;
- int nextInputBBArgIndex = numLoops + oldInputs.size();
- int nextInputOperandIndex = numControlOperands + oldInputs.size();
- for (Value oldInputTensor : oldInputs) {
- if (!oldInputTensor.getType().isa<TensorType>()) {
- // Skip and increment the old bbarg index only.
- ++oldInputBBArgIndex;
- continue;
+ int nextResultNum = 0;
+ for (Value value : tiledLoopOp.outputs()) {
+ if (value.getType().isa<TensorType>()) {
+ Value buffer =
+ state.getResultBuffer(tiledLoopOp->getResult(nextResultNum++));
+ newOutputs.push_back(buffer);
+ newResults.push_back(buffer);
+ } else {
+ newOutputs.push_back(value);
}
+ }
- Value inputBuffer = state.lookupBuffer(oldInputTensor);
-
- // Insert new operand and bbArg.
- tiledLoopOp->insertOperands(nextInputOperandIndex, inputBuffer);
- BlockArgument newBufferBBArg =
- body->insertArgument(nextInputBBArgIndex, inputBuffer.getType());
- BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex);
+ // Create new TiledLoopOp.
+ auto newTiledLoopOp = rewriter.create<TiledLoopOp>(
+ tiledLoopOp.getLoc(), tiledLoopOp.lowerBound(),
+ tiledLoopOp.upperBound(), tiledLoopOp.step(), newInputs, newOutputs,
+ tiledLoopOp.iterator_types(), tiledLoopOp.distribution_types());
+
+ // Remove terminator.
+ if (!newTiledLoopOp.getBody()->empty())
+ rewriter.eraseOp(tiledLoopOp.getBody()->getTerminator());
+
+ // Compute new loop body arguments.
+ SmallVector<Value> newBlockArgs, newRegionInOutArgs, oldRegionInOutArgs;
+ ValueRange newInductionVars = newTiledLoopOp.getInductionVars();
+ newBlockArgs.append(newInductionVars.begin(), newInductionVars.end());
+
+ ValueRange newRegionInArgs = newTiledLoopOp.getRegionInputArgs();
+ ValueRange newRegionOutArgs = newTiledLoopOp.getRegionOutputArgs();
+ newRegionInOutArgs.append(newRegionInArgs.begin(), newRegionInArgs.end());
+ newRegionInOutArgs.append(newRegionOutArgs.begin(), newRegionOutArgs.end());
+
+ ValueRange oldRegionInArgs = tiledLoopOp.getRegionInputArgs();
+ ValueRange oldRegionOutArgs = tiledLoopOp.getRegionOutputArgs();
+ oldRegionInOutArgs.append(oldRegionInArgs.begin(), oldRegionInArgs.end());
+ oldRegionInOutArgs.append(oldRegionOutArgs.begin(), oldRegionOutArgs.end());
+ assert(newRegionInArgs.size() == oldRegionInArgs.size() &&
+ "expected same number of input args");
+ assert(newRegionOutArgs.size() == oldRegionOutArgs.size() &&
+ "expected same number of output args");
+
+ for (auto it : llvm::zip(oldRegionInOutArgs, newRegionInOutArgs)) {
+ Value oldArg = std::get<0>(it);
+ Value newArg = std::get<1>(it);
+ rewriter.setInsertionPointToStart(newTiledLoopOp->getBlock());
+ if (oldArg.getType().isa<TensorType>()) {
+ newBlockArgs.push_back(rewriter.create<bufferization::ToTensorOp>(
+ oldArg.getLoc(), newArg));
+ } else {
+ newBlockArgs.push_back(newArg);
+ }
+ }
- // Insert mapping and aliasing info.
- state.mapBuffer(oldTensorBBArg, newBufferBBArg);
+ // Move old body into new loop.
+ rewriter.mergeBlocks(tiledLoopOp.getBody(), newTiledLoopOp.getBody(),
+ newBlockArgs);
- // Increment indices.
- ++numNewInputBuffers;
- ++oldInputBBArgIndex;
- ++nextInputBBArgIndex;
- ++nextInputOperandIndex;
- }
+ // Replace previous terminator with a new one that does not yield anything.
+ Operation *oldTerminator = newTiledLoopOp.getBody()->getTerminator();
+ rewriter.setInsertionPointToEnd(newTiledLoopOp.getBody());
+ rewriter.create<linalg::YieldOp>(oldTerminator->getLoc());
+ rewriter.eraseOp(oldTerminator);
- // Update segment sizes.
- // TODO: Helper method to avoid leaking impl details.
- tiledLoopOp->setAttr(
- TiledLoopOp::getOperandSegmentSizeAttr(),
- b.getI32VectorAttr(
- {numLoops, numLoops, numLoops,
- static_cast<int>(oldInputs.size()) + numNewInputBuffers,
- static_cast<int>(oldOutputs.size()) + numNewOutputBuffers}));
+ // Replace results and delete old op.
+ state.replaceOp(op, newResults);
// Bufferize loop body.
- return comprehensive_bufferize::bufferize(&tiledLoopOp.region(), state);
+ return comprehensive_bufferize::bufferize(newTiledLoopOp.getBody(), state);
}
};
@@ -372,14 +346,14 @@ struct YieldOpInterface
BufferizationState &state) const {
auto yieldOp = cast<linalg::YieldOp>(op);
- // No tensors -> success.
- if (!llvm::any_of(yieldOp.getOperandTypes(),
- [](Type t) { return t.isa<TensorType>(); }))
- return success();
- // linalg::YieldOp nested under TiledLoop must just canonicalize.
- if (yieldOp->getParentOfType<TiledLoopOp>())
- return success();
- llvm_unreachable("unexpected yieldOp");
+ if (!yieldOp->getParentOfType<TiledLoopOp>())
+ return yieldOp->emitError(
+ "expected that linalg.yield terminates a tiled_loop");
+
+ assert(yieldOp->getOpOperands().empty() &&
+ "expected that linalg.yield was bufferized together with"
+ " tiled_loop");
+ return success();
}
};
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 2a760cc9bfcb..691d14b791cc 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6393,6 +6393,7 @@ cc_library(
includes = ["include"],
deps = [
":BufferizableOpInterface",
+ ":BufferizationDialect",
":IR",
":LinalgOps",
":LinalgStructuredOpsIncGen",
More information about the Mlir-commits
mailing list