[Mlir-commits] [mlir] 422aaf3 - [mlir][Linalg] Add named Linalg ops on tensor to buffer support.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Oct 12 04:21:12 PDT 2020
Author: Nicolas Vasilache
Date: 2020-10-12T11:20:23Z
New Revision: 422aaf31daa520899303adaf82ba76743624ee0a
URL: https://github.com/llvm/llvm-project/commit/422aaf31daa520899303adaf82ba76743624ee0a
DIFF: https://github.com/llvm/llvm-project/commit/422aaf31daa520899303adaf82ba76743624ee0a.diff
LOG: [mlir][Linalg] Add named Linalg ops on tensor to buffer support.
This revision introduces support for buffer allocation for any named linalg op.
To avoid template instantiating many ops, a new ConversionPattern is created to capture the LinalgOp interface.
Some APIs are updated to remain consistent with MLIR style:
`OwningRewritePatternList * -> OwningRewritePatternList &`
`BufferAssignmentTypeConverter * -> BufferAssignmentTypeConverter &`
Differential revision: https://reviews.llvm.org/D89226
Added:
mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
Modified:
mlir/docs/Tutorials/QuickstartRewrites.md
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
mlir/include/mlir/Transforms/Bufferize.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp
mlir/lib/Transforms/BufferPlacement.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/lib/Transforms/TestBufferPlacement.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Tutorials/QuickstartRewrites.md b/mlir/docs/Tutorials/QuickstartRewrites.md
index fbc2406a0f31..447f8a62f91e 100644
--- a/mlir/docs/Tutorials/QuickstartRewrites.md
+++ b/mlir/docs/Tutorials/QuickstartRewrites.md
@@ -155,7 +155,7 @@ add_public_tablegen_target(<name-of-the-cmake-target>)
Then you can `#include` the generated file in any C++ implementation file you
like. (You will also need to make sure the library depends on the CMake target
defined in the above.) The generated file will have a `populateWithGenerated(
-MLIRContext *context, OwningRewritePatternList *patterns)` function that you can
+MLIRContext *context, OwningRewritePatternList &patterns)` function that you can
use to collect all the generated patterns inside `patterns` and then use
`patterns` in any pass you would like.
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 2332f516c44a..55df0bccbb64 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -39,7 +39,7 @@ def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">;
def NamedStructuredOpTrait : NativeOpTrait<"linalg::NamedStructuredOpTrait">;
// Base Tablegen class for Linalg ops.
-// Linalg ops that correspond to library calls operate on linalg::View as their
+// Linalg ops that correspond to library calls operate on ShapedType as their
// first operands. These may be optionally followed by non-view operands
// depending on the specific Linalg op.
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index dbb89c73954b..845873ff83df 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -628,7 +628,9 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
InterfaceMethod<
/*desc=*/[{
Clone the current operation with the given location and operands. This
- is used to abstract away the optional underlying region creation.
+ is used to abstract away the optional underlying region creation. This
+ does not change the balance between input, output_buffer and
+ init_tensors operands.
}],
/*retTy=*/"Operation *",
/*methodName=*/"clone",
@@ -666,6 +668,23 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
}
return res;
}
+ //========================================================================//
+ // Helper functions to mutate the `operand_segment_sizes` attribute.
+ // These are useful when cloning and changing operand types.
+ //========================================================================//
+ void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); }
+ void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); }
+ void setNumInitTensors(unsigned num) { setOperandSegmentAt(2, num); }
+
+ private:
+ void setOperandSegmentAt(unsigned idx, unsigned val) {
+ auto attr = getOperation()->getAttr("operand_segment_sizes")
+ .cast<DenseIntElementsAttr>();
+ unsigned i = 0;
+ auto newAttr = attr.mapValues(IntegerType::get(32, getContext()),
+ [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
+ getOperation()->setAttr("operand_segment_sizes", newAttr);
+ }
}];
}
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7512f69608a4..a2dee8c3ae65 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/Bufferize.h"
#include "llvm/ADT/SmallBitVector.h"
namespace mlir {
@@ -51,8 +52,8 @@ void populateConvVectorizationPatterns(
/// Populates the given list with patterns to convert Linalg operations on
/// tensors to buffers.
void populateConvertLinalgOnTensorsToBuffersPatterns(
- MLIRContext *context, BufferAssignmentTypeConverter *converter,
- OwningRewritePatternList *patterns);
+ MLIRContext *context, BufferAssignmentTypeConverter &converter,
+ OwningRewritePatternList &patterns);
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector`
@@ -797,6 +798,46 @@ class IndexedGenericOpToLibraryCallRewrite
void populateLinalgToStandardConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx);
+//===----------------------------------------------------------------------===//
+// Buffer allocation patterns.
+//===----------------------------------------------------------------------===//
+
+/// Generic BufferAssignmentConversionPattern that matches any Operation* and
+/// dispatches internally. This avoids template instantiating one pattern for
+/// each LinalgOp op.
+class LinalgOpConverter : public BufferAssignmentConversionPattern {
+public:
+ LinalgOpConverter(MLIRContext *context,
+ BufferAssignmentTypeConverter &converter)
+ : BufferAssignmentConversionPattern(context, converter) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final;
+};
+
+class TensorConstantOpConverter
+ : public BufferAssignmentOpConversionPattern<ConstantOp> {
+public:
+ using BufferAssignmentOpConversionPattern<
+ ConstantOp>::BufferAssignmentOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final;
+};
+
+class TensorCastOpConverter
+ : public BufferAssignmentOpConversionPattern<TensorCastOp> {
+public:
+ using BufferAssignmentOpConversionPattern<
+ TensorCastOp>::BufferAssignmentOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(TensorCastOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final;
+};
+
//===----------------------------------------------------------------------===//
// Support for staged pattern application.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index 72816b72f41e..81f64dc3ccb3 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -41,8 +41,8 @@ void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns,
std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
void populateShapeTypeConversionPatterns(
- MLIRContext *ctx, BufferAssignmentTypeConverter *converter,
- OwningRewritePatternList *patterns);
+ MLIRContext *ctx, BufferAssignmentTypeConverter &converter,
+ OwningRewritePatternList &patterns);
// Collects a set of patterns to replace tensors as inputs and outputs to shape
// operations with buffers. This only modifies the shape operations.
std::unique_ptr<FunctionPass> createShapeTensorToMemrefPass();
diff --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h
index e8fffcfe4925..26452e9db513 100644
--- a/mlir/include/mlir/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Transforms/Bufferize.h
@@ -140,14 +140,28 @@ class BufferAssignmentOpConversionPattern
: public OpConversionPattern<SourceOp> {
public:
explicit BufferAssignmentOpConversionPattern(
- MLIRContext *context, BufferAssignmentTypeConverter *converter,
+ MLIRContext *context, BufferAssignmentTypeConverter &converter,
PatternBenefit benefit = 1)
- : OpConversionPattern<SourceOp>(context, benefit), converter(converter) {
- assert(converter && "The type converter has not been defined");
- }
+ : OpConversionPattern<SourceOp>(context, benefit), converter(converter) {}
+
+protected:
+ BufferAssignmentTypeConverter &converter;
+};
+
+/// Helper conversion pattern that encapsulates a BufferAssignmentTypeConverter
+/// instance and that operates on Operation* to be compatible with OpInterfaces.
+/// This allows avoiding to instantiate N patterns for ops that can be subsumed
+/// by a single op interface (e.g. Linalg named ops).
+class BufferAssignmentConversionPattern : public ConversionPattern {
+public:
+ explicit BufferAssignmentConversionPattern(
+ MLIRContext *context, BufferAssignmentTypeConverter &converter,
+ PatternBenefit benefit = 1)
+ : ConversionPattern(benefit, converter, MatchAnyOpTypeTag()),
+ converter(converter) {}
protected:
- BufferAssignmentTypeConverter *converter;
+ BufferAssignmentTypeConverter &converter;
};
/// Converts the signature of the function using BufferAssignmentTypeConverter.
@@ -191,15 +205,15 @@ class BufferAssignmentReturnOpConverter
OpBuilder builder(returnOp);
for (auto operand : llvm::enumerate(operands)) {
SmallVector<Value, 2> values;
- this->converter->tryDecomposeValue(
- builder, loc, operand.value().getType(), operand.value(), values);
+ this->converter.tryDecomposeValue(builder, loc, operand.value().getType(),
+ operand.value(), values);
Type type = returnOp.getOperand(operand.index()).getType();
SmallVector<Type, 2> originTypes;
- this->converter->tryDecomposeType(type, originTypes);
+ this->converter.tryDecomposeType(type, originTypes);
for (auto value : llvm::enumerate(values)) {
Type origin = originTypes[value.index()];
Type converted = value.value().getType();
- auto kind = this->converter->getResultConversionKind(origin, converted);
+ auto kind = this->converter.getResultConversionKind(origin, converted);
if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult)
newOperands.push_back(value.value());
else
@@ -247,10 +261,10 @@ class BufferAssignmentCallOpConverter
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
typename CopyOpTy>
static void populateWithBufferAssignmentOpConversionPatterns(
- MLIRContext *context, BufferAssignmentTypeConverter *converter,
- OwningRewritePatternList *patterns) {
+ MLIRContext *context, BufferAssignmentTypeConverter &converter,
+ OwningRewritePatternList &patterns) {
// clang-format off
- patterns->insert<
+ patterns.insert<
BufferAssignmentCallOpConverter,
BufferAssignmentFuncOpConverter,
BufferAssignmentReturnOpConverter
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
new file mode 100644
index 000000000000..2c01a688cfa0
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s -convert-linalg-on-tensors-to-buffers -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func @main() {
+ %A = constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
+ %B = constant dense<[[1.0, 2.0, 3.0, 4.0],
+ [5.0, 6.0, 7.0, 8.0],
+ [9.0, 10.0, 11.0, 12.0]]> : tensor<3x4xf32>
+ %C = constant dense<1000.0> : tensor<2x4xf32>
+
+ %D = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>)
+ init(%C: tensor<2x4xf32>) -> tensor<2x4xf32>
+
+ %unranked = tensor_cast %D : tensor<2x4xf32> to tensor<*xf32>
+ call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
+
+ // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
+ // CHECK-SAME: rank = 2 offset = 0 sizes = [2, 4] strides = [4, 1] data =
+ // CHECK-NEXT: [1038, 1044, 1050, 1056]
+ // CHECK-NEXT: [1065, 1074, 1083, 1092]
+
+ return
+}
+
+func @print_memref_f32(%ptr : tensor<*xf32>)
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index d11cc51d1d59..69786823dd32 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -152,7 +152,7 @@ struct LowerGpuOpsToNVVMOpsPass
void mlir::populateGpuToNVVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- populateWithGenerated(converter.getDialect()->getContext(), &patterns);
+ populateWithGenerated(converter.getDialect()->getContext(), patterns);
patterns
.insert<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 40cf097c9c5a..e9b44a9fef52 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -86,7 +86,7 @@ struct LowerGpuOpsToROCDLOpsPass
void mlir::populateGpuToROCDLConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
- populateWithGenerated(converter.getDialect()->getContext(), &patterns);
+ populateWithGenerated(converter.getDialect()->getContext(), patterns);
patterns.insert<
GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>,
diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
index f4273a44bb9c..ee1e4131854f 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
@@ -330,7 +330,7 @@ namespace {
void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- populateWithGenerated(context, &patterns);
+ populateWithGenerated(context, patterns);
patterns.insert<
GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index 442117c619de..650c44f9e922 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -22,40 +22,37 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Bufferize.h"
-namespace {
-
using namespace ::mlir;
using namespace ::mlir::linalg;
-SmallVector<Range, 4>
-computeLoopRanges(Location loc, linalg::GenericOp linalgOp, OpBuilder *b) {
+static SmallVector<Range, 4> computeLoopRanges(Location loc, LinalgOp linalgOp,
+ OpBuilder &b) {
auto indexingMaps = llvm::to_vector<4>(
linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
auto inputIndexingMaps =
llvm::makeArrayRef(indexingMaps).take_front(linalgOp.getNumInputs());
- mlir::edsc::ScopedContext scope(*b, loc);
+ mlir::edsc::ScopedContext scope(b, loc);
return emitLoopRanges(scope.getBuilderRef(), loc,
concatAffineMaps(inputIndexingMaps),
- getShape(*b, linalgOp));
+ getShape(b, linalgOp));
}
-Value maybeConvertToIndex(Location loc, Value val, OpBuilder *b) {
+static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) {
if (val.getType().isIndex())
return val;
- return b->create<IndexCastOp>(loc, val, b->getIndexType());
+ return b.create<IndexCastOp>(loc, val, b.getIndexType());
}
-LogicalResult allocateBuffersForResults(Location loc,
- linalg::GenericOp linalgOp,
- linalg::GenericOpAdaptor &adaptor,
- SmallVectorImpl<Value> *resultBuffers,
- OpBuilder *b) {
+static LogicalResult
+allocateBuffersForResults(Location loc, LinalgOp linalgOp,
+ linalg::GenericOpAdaptor &adaptor,
+ SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
// Lazily compute loopRanges.
SmallVector<Range, 4> loopRanges;
// Allocate a buffer for every tensor result.
- for (auto en : llvm::enumerate(linalgOp.getResultTypes())) {
+ for (auto en : llvm::enumerate(linalgOp.getOperation()->getResultTypes())) {
size_t resultIndex = en.index();
Type resultType = en.value();
@@ -79,24 +76,24 @@ LogicalResult allocateBuffersForResults(Location loc,
Value initTensor = linalgOp.getInitTensor(resultIndex);
Value initBuffer = adaptor.init_tensors()[resultIndex];
if (initTensor.hasOneUse()) {
- resultBuffers->push_back(initBuffer);
+ resultBuffers.push_back(initBuffer);
continue;
}
SmallVector<Value, 4> dynOperands;
for (auto dim : llvm::enumerate(tensorShape)) {
if (dim.value() == TensorType::kDynamicSize) {
- dynOperands.push_back(b->create<DimOp>(loc, initTensor, dim.index()));
+ dynOperands.push_back(b.create<DimOp>(loc, initTensor, dim.index()));
}
}
- auto alloc = b->create<AllocOp>(loc, memrefType, dynOperands);
- b->create<linalg::CopyOp>(loc, initBuffer, alloc);
- resultBuffers->push_back(alloc);
+ auto alloc = b.create<AllocOp>(loc, memrefType, dynOperands);
+ b.create<linalg::CopyOp>(loc, initBuffer, alloc);
+ resultBuffers.push_back(alloc);
continue;
}
// Allocate buffers for statically-shaped results.
if (memrefType.hasStaticShape()) {
- resultBuffers->push_back(b->create<AllocOp>(loc, memrefType));
+ resultBuffers.push_back(b.create<AllocOp>(loc, memrefType));
continue;
}
@@ -123,148 +120,157 @@ LogicalResult allocateBuffersForResults(Location loc,
return failure();
}
}
- resultBuffers->push_back(b->create<AllocOp>(loc, memrefType, dynOperands));
+ resultBuffers.push_back(b.create<AllocOp>(loc, memrefType, dynOperands));
}
return success();
}
+// Specialization for `linalg::GenericOp`.
/// A pattern to convert Generic Linalg operations which work on tensors to
/// use buffers. A buffer is allocated using BufferAssignmentPlacer for
/// each operation result. BufferPlacement pass should be later used to move
/// Alloc operations to the correct positions and insert the missing Dealloc
/// operations in the correct places.
-class GenericOpConverter
- : public BufferAssignmentOpConversionPattern<linalg::GenericOp> {
-public:
- using BufferAssignmentOpConversionPattern<
- linalg::GenericOp>::BufferAssignmentOpConversionPattern;
-
- LogicalResult
- matchAndRewrite(linalg::GenericOp linalgOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- linalg::GenericOpAdaptor adaptor(
- operands, linalgOp.getOperation()->getAttrDictionary());
-
- // All inputs need to be turned into buffers first. Until then, bail out.
- if (llvm::any_of(adaptor.inputs(),
- [](Value in) { return !in.getType().isa<MemRefType>(); }))
- return failure();
-
- // All init_tensors need to be turned into buffers first. Until then, bail
- // out.
- if (llvm::any_of(adaptor.init_tensors(),
- [](Value in) { return !in.getType().isa<MemRefType>(); }))
- return failure();
+static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
+ linalg::GenericOp genericOp,
+ ValueRange inputs, ValueRange outputs) {
+ // Generate a new linalg operation that works on buffers.
+ auto newGenericOp = rewriter.create<linalg::GenericOp>(
+ genericOp.getLoc(),
+ /*resultTensorTypes=*/llvm::None,
+ /*inputs=*/inputs,
+ /*outputBuffers=*/outputs,
+ /*initTensors=*/llvm::None, genericOp.indexing_maps(),
+ genericOp.iterator_types(), genericOp.docAttr(),
+ genericOp.library_callAttr(), genericOp.symbol_sourceAttr());
+
+ // Create a new block in the region of the new Generic Op.
+ Block *oldBlock = genericOp.getBody();
+ Region &newRegion = newGenericOp.region();
+ Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
+ oldBlock->getArgumentTypes());
+
+ // Add the result arguments to the new block.
+ for (Value v : ValueRange(outputs).drop_front(genericOp.getNumInitTensors()))
+ newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());
+
+ // Clone the body of the old block to the new block.
+ BlockAndValueMapping mapping;
+ mapping.map(oldBlock->getArguments(), newBlock->getArguments());
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToEnd(newBlock);
+ for (auto &op : oldBlock->getOperations()) {
+ Operation *clonedOp = rewriter.clone(op, mapping);
+ mapping.map(op.getResults(), clonedOp->getResults());
+ }
- Location loc = linalgOp.getLoc();
- SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(),
- adaptor.output_buffers().end());
+ // Replace the results of the old op with the new output buffers.
+ rewriter.replaceOp(genericOp, outputs);
+}
- if (failed(allocateBuffersForResults(loc, linalgOp, adaptor,
- &newOutputBuffers, &rewriter))) {
- linalgOp.emitOpError()
- << "Failed to allocate buffers for tensor results.";
- return failure();
- }
+// TODO: Specialization for `linalg::IndexedGenericOp`.
+
+// Specialization for all other `linalg::LinalgOp`.
+static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
+ linalg::LinalgOp linalgOp,
+ ValueRange inputs, ValueRange outputs) {
+ assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
+ assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation()));
+ SmallVector<Value, 8> newOperands = inputs;
+ newOperands.append(outputs.begin(), outputs.end());
+ auto otherOperands = linalgOp.getAssumedNonShapedOperands();
+ newOperands.append(otherOperands.begin(), otherOperands.end());
+ LinalgOp res = cast<LinalgOp>(linalgOp.clone(rewriter, linalgOp.getLoc(),
+ /*resultTypes=*/ArrayRef<Type>{},
+ newOperands));
+ // Need to mutate the operands_segment_sizes in the resulting op.
+ res.setNumOutputBuffers(outputs.size());
+ res.setNumInitTensors(0);
+ // Replace the results of the old op with the new output buffers.
+ rewriter.replaceOp(linalgOp, outputs);
+}
- // Generate a new linalg operation that works on buffers.
- auto newLinalgOp = rewriter.create<linalg::GenericOp>(
- loc,
- /*resultTensorTypes=*/llvm::None,
- /*inputs=*/adaptor.inputs(),
- /*outputBuffers=*/newOutputBuffers,
- /*initTensors=*/llvm::None, linalgOp.indexing_maps(),
- linalgOp.iterator_types(), linalgOp.docAttr(),
- linalgOp.library_callAttr(), linalgOp.symbol_sourceAttr());
-
- // Create a new block in the region of the new Generic Op.
- Block *oldBlock = linalgOp.getBody();
- Region &newRegion = newLinalgOp.region();
- Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
- oldBlock->getArgumentTypes());
-
- // Add the result arguments that do not come from init_tensors to the new
- // block.
- // TODO: update this assumption because the reality is more complex under
- // linalg on tensor based transformations.
- for (Value v :
- ValueRange(newOutputBuffers).drop_front(adaptor.init_tensors().size()))
- newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());
-
- // Clone the body of the old block to the new block.
- BlockAndValueMapping mapping;
- mapping.map(oldBlock->getArguments(), newBlock->getArguments());
-
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToEnd(newBlock);
- for (auto &op : oldBlock->getOperations()) {
- Operation *clonedOp = rewriter.clone(op, mapping);
- mapping.map(op.getResults(), clonedOp->getResults());
- }
+LogicalResult mlir::linalg::LinalgOpConverter::matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
+ if (!linalgOp)
+ return failure();
+
+ // We abuse the GenericOpAdaptor here.
+ // TODO: Manually create an Adaptor that captures inputs, output_buffers and
+ // init_tensors for all linalg::LinalgOp interface ops.
+ linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
+
+ // All inputs need to be turned into buffers first. Until then, bail out.
+ if (llvm::any_of(adaptor.inputs(),
+ [](Value in) { return !in.getType().isa<MemRefType>(); }))
+ return failure();
+
+ // All init_tensors need to be turned into buffers first. Until then, bail
+ // out.
+ if (llvm::any_of(adaptor.init_tensors(),
+ [](Value in) { return !in.getType().isa<MemRefType>(); }))
+ return failure();
+
+ Location loc = linalgOp.getLoc();
+ SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(),
+ adaptor.output_buffers().end());
+
+ if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, newOutputBuffers,
+ rewriter))) {
+ linalgOp.emitOpError() << "Failed to allocate buffers for tensor results.";
+ return failure();
+ }
- // Replace the results of the old op with the new output buffers.
- rewriter.replaceOp(linalgOp, newOutputBuffers);
+ // Delegate to the linalg generic pattern.
+ if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
+ finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(),
+ newOutputBuffers);
return success();
}
-};
-// Rewrite a tensor `constant` to a vector constant folloed by a vector store
-// and a vector.type_cast.
-class TensorConstantOpConverter
- : public BufferAssignmentOpConversionPattern<ConstantOp> {
-public:
- using BufferAssignmentOpConversionPattern<
- ConstantOp>::BufferAssignmentOpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- if (!op.getType().isa<RankedTensorType>())
- return failure();
- auto attr = op.getValue().cast<DenseElementsAttr>();
-
- Location loc = op.getLoc();
- MemRefType memrefType =
- converter->convertType(op.getType()).cast<MemRefType>();
- VectorType vectorType =
- VectorType::get(memrefType.getShape(), memrefType.getElementType());
-
- // vector constant takes attributes that are compatible with tensor
- // constant.
- Value cstVec =
- rewriter.create<ConstantOp>(loc, vectorType, attr.reshape(vectorType));
-
- // Alloc a memref<vector<...>>, store the constant and typecast the vector
- // away.
- MemRefType memrefOfVectorType = MemRefType::get({}, vectorType);
- Value alloc =
- rewriter.create<AllocOp>(loc, memrefOfVectorType, ValueRange{});
- rewriter.create<StoreOp>(loc, cstVec, alloc);
- rewriter.replaceOpWithNewOp<vector::TypeCastOp>(op, memrefType, alloc);
+ finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(),
+ newOutputBuffers);
+ return success();
+}
- return success();
- }
-};
+LogicalResult mlir::linalg::TensorConstantOpConverter::matchAndRewrite(
+ ConstantOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (!op.getType().isa<RankedTensorType>())
+ return failure();
+ auto attr = op.getValue().cast<DenseElementsAttr>();
+
+ Location loc = op.getLoc();
+ MemRefType memrefType =
+ converter.convertType(op.getType()).cast<MemRefType>();
+ VectorType vectorType =
+ VectorType::get(memrefType.getShape(), memrefType.getElementType());
+ Value cstVec =
+ rewriter.create<ConstantOp>(loc, vectorType, attr.reshape(vectorType));
+
+ MemRefType memrefOfVectorType = MemRefType::get({}, vectorType);
+ Value alloc = rewriter.create<AllocOp>(loc, memrefOfVectorType, ValueRange{});
+ rewriter.create<StoreOp>(loc, cstVec, alloc);
+ rewriter.replaceOpWithNewOp<vector::TypeCastOp>(op, memrefType, alloc);
-// Rewrite a `tensor_cast` as a `memref_cast` with no layout, in the 0-memory
-// space.
-class TensorCastOpConverter
- : public BufferAssignmentOpConversionPattern<TensorCastOp> {
-public:
- using BufferAssignmentOpConversionPattern<
- TensorCastOp>::BufferAssignmentOpConversionPattern;
-
- LogicalResult
- matchAndRewrite(TensorCastOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final {
- if (op.getType().hasRank())
- return failure();
- Type t = UnrankedMemRefType::get(op.getType().getElementType(),
- /*memorySpace=*/0);
- rewriter.replaceOpWithNewOp<MemRefCastOp>(op, t, operands.front());
- return success();
- }
-};
+ return success();
+}
+
+LogicalResult mlir::linalg::TensorCastOpConverter::matchAndRewrite(
+ TensorCastOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (op.getType().hasRank())
+ return failure();
+ Type t = UnrankedMemRefType::get(op.getType().getElementType(),
+ /*memorySpace=*/0);
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(op, t, operands.front());
+ return success();
+}
+
+namespace {
/// Converts Linalg operations that work on tensor-type operands or results to
/// work on buffers.
@@ -326,11 +332,11 @@ struct ConvertLinalgOnTensorsToBuffers
BufferAssignmentTypeConverter::AppendToArgumentsList);
OwningRewritePatternList patterns;
- populateConvertLinalgOnTensorsToBuffersPatterns(&context, &converter,
- &patterns);
+ populateConvertLinalgOnTensorsToBuffersPatterns(&context, converter,
+ patterns);
populateWithBufferAssignmentOpConversionPatterns<
- mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(&context, &converter,
- &patterns);
+ mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(&context, converter,
+ patterns);
if (failed(applyFullConversion(this->getOperation(), target, patterns)))
this->signalPassFailure();
}
@@ -341,13 +347,13 @@ std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertLinalgOnTensorsToBuffersPass() {
return std::make_unique<ConvertLinalgOnTensorsToBuffers>();
}
-
void mlir::linalg::populateConvertLinalgOnTensorsToBuffersPatterns(
- MLIRContext *context, BufferAssignmentTypeConverter *converter,
- OwningRewritePatternList *patterns) {
- patterns->insert<
+
+ MLIRContext *context, BufferAssignmentTypeConverter &converter,
+ OwningRewritePatternList &patterns) {
+ patterns.insert<
// clang-format off
- GenericOpConverter,
+ LinalgOpConverter,
TensorCastOpConverter,
TensorConstantOpConverter
// clang-format on
diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp
index e281cea24e73..83e432855f6e 100644
--- a/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp
@@ -38,7 +38,7 @@ class TypeConversionAssumingOpConverter
newResultTypes.reserve(assumingOp.getNumResults());
for (auto result : assumingOp.getResults()) {
auto originalType = result.getType();
- Type convertedType = converter->convertType(originalType);
+ Type convertedType = converter.convertType(originalType);
newResultTypes.push_back(convertedType);
}
@@ -60,7 +60,7 @@ struct ShapeTensorToMemrefPass
OwningRewritePatternList patterns;
BufferAssignmentTypeConverter converter;
- populateShapeTypeConversionPatterns(&ctx, &converter, &patterns);
+ populateShapeTypeConversionPatterns(&ctx, converter, patterns);
ConversionTarget target(getContext());
auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
@@ -81,9 +81,9 @@ struct ShapeTensorToMemrefPass
//
// TODO: Change this to work generally with any type conversions.
void mlir::populateShapeTypeConversionPatterns(
- MLIRContext *context, BufferAssignmentTypeConverter *converter,
- OwningRewritePatternList *patterns) {
- patterns->insert<TypeConversionAssumingOpConverter>(context, converter);
+ MLIRContext *context, BufferAssignmentTypeConverter &converter,
+ OwningRewritePatternList &patterns) {
+ patterns.insert<TypeConversionAssumingOpConverter>(context, converter);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp
index 380c72087bbd..3a0d6ebd962b 100644
--- a/mlir/lib/Transforms/BufferPlacement.cpp
+++ b/mlir/lib/Transforms/BufferPlacement.cpp
@@ -875,8 +875,8 @@ LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite(
TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
for (auto argType : llvm::enumerate(funcType.getInputs())) {
SmallVector<Type, 2> decomposedTypes, convertedTypes;
- converter->tryDecomposeType(argType.value(), decomposedTypes);
- converter->convertTypes(decomposedTypes, convertedTypes);
+ converter.tryDecomposeType(argType.value(), decomposedTypes);
+ converter.convertTypes(decomposedTypes, convertedTypes);
conversion.addInputs(argType.index(), convertedTypes);
}
@@ -885,10 +885,10 @@ LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite(
newResultTypes.reserve(funcOp.getNumResults());
for (Type resultType : funcType.getResults()) {
SmallVector<Type, 2> originTypes;
- converter->tryDecomposeType(resultType, originTypes);
+ converter.tryDecomposeType(resultType, originTypes);
for (auto origin : originTypes) {
- Type converted = converter->convertType(origin);
- auto kind = converter->getResultConversionKind(origin, converted);
+ Type converted = converter.convertType(origin);
+ auto kind = converter.getResultConversionKind(origin, converted);
if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList)
conversion.addInputs(converted);
else
@@ -897,7 +897,7 @@ LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite(
}
}
- if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter,
+ if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), converter,
&conversion)))
return failure();
@@ -986,8 +986,8 @@ LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
// values if a decompose callback function has been provided by the user.
for (auto operand : operands) {
SmallVector<Value, 2> values;
- this->converter->tryDecomposeValue(builder, loc, operand.getType(), operand,
- values);
+ this->converter.tryDecomposeValue(builder, loc, operand.getType(), operand,
+ values);
newOperands.append(values.begin(), values.end());
}
@@ -998,11 +998,11 @@ LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
mappings.resize(callOp.getNumResults());
for (auto result : llvm::enumerate(callOp.getResults())) {
SmallVector<Type, 2> originTypes;
- converter->tryDecomposeType(result.value().getType(), originTypes);
+ converter.tryDecomposeType(result.value().getType(), originTypes);
auto &resultMapping = mappings[result.index()];
for (Type origin : originTypes) {
- Type converted = converter->convertType(origin);
- auto kind = converter->getResultConversionKind(origin, converted);
+ Type converted = converter.convertType(origin);
+ auto kind = converter.getResultConversionKind(origin, converted);
if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) {
newResultTypes.push_back(converted);
// The result value is not yet available. Its index is kept and it is
@@ -1039,7 +1039,7 @@ LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
} else {
// Values need to be packed using callback function. The same callback
// that is used for materializeArgumentConversion is used for packing.
- Value packed = converter->materializeArgumentConversion(
+ Value packed = converter.materializeArgumentConversion(
nextBuilder, loc, callOp.getType(i), valuesToPack);
replacedValues.push_back(packed);
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 5f2b9e32dac7..32d618d9008e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -77,7 +77,7 @@ struct FoldingPattern : public RewritePattern {
struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
void runOnFunction() override {
mlir::OwningRewritePatternList patterns;
- populateWithGenerated(&getContext(), &patterns);
+ populateWithGenerated(&getContext(), patterns);
// Verify named pattern is generated with expected name.
patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext());
@@ -547,7 +547,7 @@ struct TestLegalizePatternDriver
void runOnOperation() override {
TestTypeConverter converter;
mlir::OwningRewritePatternList patterns;
- populateWithGenerated(&getContext(), &patterns);
+ populateWithGenerated(&getContext(), patterns);
patterns.insert<
TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase,
diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index aecf99f69729..ead4c5d6fb38 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -147,12 +147,12 @@ struct TestBufferPlacementPreparationPass
};
void populateTensorLinalgToBufferLinalgConversionPattern(
- MLIRContext *context, BufferAssignmentTypeConverter *converter,
- OwningRewritePatternList *patterns) {
+ MLIRContext *context, BufferAssignmentTypeConverter &converter,
+ OwningRewritePatternList &patterns) {
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, converter,
patterns);
- patterns->insert<GenericOpConverter>(context, converter);
+ patterns.insert<GenericOpConverter>(context, converter);
}
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -230,8 +230,8 @@ struct TestBufferPlacementPreparationPass
});
OwningRewritePatternList patterns;
- populateTensorLinalgToBufferLinalgConversionPattern(&context, &converter,
- &patterns);
+ populateTensorLinalgToBufferLinalgConversionPattern(&context, converter,
+ patterns);
if (failed(applyFullConversion(this->getOperation(), target, patterns)))
this->signalPassFailure();
};
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 495fbe1715e0..ff6138f73914 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1147,9 +1147,9 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
// Emit function to add the generated matchers to the pattern list.
os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::MLIRContext "
- "*context, ::mlir::OwningRewritePatternList *patterns) {\n";
+ "*context, ::mlir::OwningRewritePatternList &patterns) {\n";
for (const auto &name : rewriterNames) {
- os << " patterns->insert<" << name << ">(context);\n";
+ os << " patterns.insert<" << name << ">(context);\n";
}
os << "}\n";
}
More information about the Mlir-commits
mailing list