[Mlir-commits] [mlir] edaffeb - Cloned from CL 389610703 by 'g4 patch'.
Alexander Belyaev
llvmlistbot at llvm.org
Mon Aug 9 12:57:19 PDT 2021
Author: Alexander Belyaev
Date: 2021-08-09T21:57:06+02:00
New Revision: edaffebcb2a62b0195e23fe7d4ead005822865c3
URL: https://github.com/llvm/llvm-project/commit/edaffebcb2a62b0195e23fe7d4ead005822865c3
DIFF: https://github.com/llvm/llvm-project/commit/edaffebcb2a62b0195e23fe7d4ead005822865c3.diff
LOG: Cloned from CL 389610703 by 'g4 patch'.
Original change by pifon at pifon:tfrt_clean:6896:citc on 2021/08/09 05:30:17.
Ad b
Differential Revision: https://reviews.llvm.org/D107762
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/test/Dialect/Linalg/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index cd0445f1ff87d..b46e58be8349a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -213,8 +213,10 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
Location loc = op.getLoc();
SmallVector<Value, 2> newOutputBuffers;
- if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
- newOutputBuffers, rewriter))) {
+ if (op->getParentOfType<TiledLoopOp>()) {
+ newOutputBuffers = adaptor.outputs();
+ } else if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
+ newOutputBuffers, rewriter))) {
return op.emitOpError()
<< "Failed to allocate buffers for tensor results.";
}
@@ -231,6 +233,14 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
}
};
+bool IsBlockArgOfTiledLoop(Value tensor) {
+ if (auto tensorLoad = tensor.getDefiningOp<memref::TensorLoadOp>())
+ if (auto blockArgument = tensorLoad.memref().dyn_cast<BlockArgument>())
+ if (isa<TiledLoopOp>(blockArgument.getOwner()->getParentOp()))
+ return true;
+ return false;
+}
+
/// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an
/// alloc + copy pattern.
/// ```
@@ -253,6 +263,15 @@ class ExtractSliceOpConverter
Value sourceMemref = adaptor.source();
assert(sourceMemref.getType().isa<MemRefType>());
+ // Block arguments of the tiled_loop can be bufferized inplace.
+ if (IsBlockArgOfTiledLoop(op.source())) {
+ Value subView = rewriter.create<memref::SubViewOp>(
+ op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
+ op.getMixedStrides());
+ rewriter.replaceOp(op, subView);
+ return success();
+ }
+
MemRefType subviewMemRefType =
getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
// op.sizes() capture exactly the dynamic alloc operands matching the
@@ -296,7 +315,12 @@ class InsertSliceOpConverter
// For now, be conservative and copy the converted input memref.
// In general, the converted input memref here could be aliased or could
// point into constant memory, so mutating it would lead to miscompilations.
- Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
+ // Block arguments of the tiled_loop can be bufferized inplace.
+ Value destMemRef;
+ if (IsBlockArgOfTiledLoop(op.dest()))
+ destMemRef = adaptor.dest();
+ else
+ destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
assert(destMemRef.getType().isa<MemRefType>());
// Take a subview to copy the small memref.
@@ -310,6 +334,64 @@ class InsertSliceOpConverter
}
};
+class TiledLoopOpConverter : public OpConversionPattern<TiledLoopOp> {
+public:
+ using OpConversionPattern<TiledLoopOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(TiledLoopOp tiledLoop, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ TiledLoopOp::Adaptor adaptor(operands, tiledLoop->getAttrDictionary());
+ Location loc = tiledLoop.getLoc();
+ if (tiledLoop.getNumResults() == 0)
+ return failure();
+ auto newTiledLoop = rewriter.create<TiledLoopOp>(
+ loc, adaptor.lowerBound(), adaptor.upperBound(), adaptor.step(),
+ adaptor.inputs(), adaptor.outputs(), adaptor.iterator_types(),
+ adaptor.distribution_types());
+ // Clone the region.
+ BlockAndValueMapping bvm;
+ bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
+
+ OpBuilder innerBuilder =
+ OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
+
+ // Remap input block arguments.
+ SmallVector<Value, 2> inputs;
+ for (auto en : llvm::zip(newTiledLoop.getRegionInputArgs(),
+ tiledLoop.getRegionInputArgs())) {
+ auto &newInputArg = std::get<0>(en);
+ if (!newInputArg.getType().isa<ShapedType>()) {
+ inputs.push_back(std::get<0>(en));
+ continue;
+ }
+ inputs.push_back(
+ innerBuilder.create<memref::TensorLoadOp>(loc, newInputArg));
+ }
+ bvm.map(tiledLoop.getRegionInputArgs(), inputs);
+
+ // Remap output block arguments.
+ SmallVector<Value, 2> outputs;
+ for (auto en : llvm::zip(newTiledLoop.getRegionOutputArgs(),
+ tiledLoop.getRegionOutputArgs())) {
+ auto &newOutputArg = std::get<0>(en);
+ if (!newOutputArg.getType().isa<ShapedType>()) {
+ outputs.push_back(std::get<0>(en));
+ continue;
+ }
+ outputs.push_back(
+ innerBuilder.create<memref::TensorLoadOp>(loc, newOutputArg));
+ }
+ bvm.map(tiledLoop.getRegionOutputArgs(), outputs);
+
+ for (auto &op : tiledLoop.getBody()->without_terminator())
+ innerBuilder.clone(op, bvm);
+ innerBuilder.create<linalg::YieldOp>(loc);
+ rewriter.replaceOp(tiledLoop, newTiledLoop.outputs());
+ return success();
+ }
+};
+
class VectorTransferReadOpConverter
: public OpConversionPattern<vector::TransferReadOp> {
public:
@@ -352,14 +434,66 @@ class VectorTransferWriteOpConverter
};
} // namespace
+static Value materializeTensorLoad(OpBuilder &builder, TensorType type,
+ ValueRange inputs, Location loc) {
+ assert(inputs.size() == 1);
+ assert(inputs[0].getType().isa<BaseMemRefType>());
+ return builder.create<memref::TensorLoadOp>(loc, type, inputs[0]);
+}
+
namespace {
+
+/// A helper type converter class that automatically populates the relevant
+/// materializations and type conversions for bufferization.
+//
+// The default BufferizeTypeConverter defined in "Transforms/Bufferize.h" does
+// not properly support memrefs with non-default layout. Whenever a layout of
+// memref changes during bufferization, target materialization call back would
+// assert that the non-matching type is a tensor.
+// There was an attempt to fix this behavior of dialect conversion in a more
+// principal way in https://reviews.llvm.org/D93126 but it had to be reverted
+// due to test failures outside of MLIR Core. It might make sense to revive this
+// PR.
+class CustomBufferizeTypeConverter : public BufferizeTypeConverter {
+public:
+ CustomBufferizeTypeConverter() {
+ // Keep all types unchanged.
+ addConversion([](Type type) { return type; });
+ // Convert RankedTensorType to MemRefType.
+ addConversion([](RankedTensorType type) -> Type {
+ return MemRefType::get(type.getShape(), type.getElementType());
+ });
+ // Convert UnrankedTensorType to UnrankedMemRefType.
+ addConversion([](UnrankedTensorType type) -> Type {
+ return UnrankedMemRefType::get(type.getElementType(), 0);
+ });
+ addArgumentMaterialization(materializeTensorLoad);
+ addSourceMaterialization(materializeTensorLoad);
+ addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
+ ValueRange inputs, Location loc) -> Value {
+ assert(inputs.size() == 1);
+ // Target materialization is invoked if the new operand type does not
+ // match the expected type. A special case is when the new operand type is
+ // a memref with a specified layout, i.e. non-empty affine map.
+ // TODO(pifon) : Change how target materialization is invoked in dialect
+ // conversion.
+ if (auto memrefType = inputs[0].getType().dyn_cast<MemRefType>()) {
+ assert(!memrefType.getAffineMaps().empty());
+ return inputs[0];
+ }
+ assert(inputs[0].getType().isa<TensorType>());
+ return builder.create<memref::BufferCastOp>(loc, type, inputs[0]);
+ });
+ }
+};
+
/// Converts Linalg operations that work on tensor-type operands or results to
/// work on buffers.
struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
void runOnOperation() override {
MLIRContext &context = getContext();
ConversionTarget target(context);
- BufferizeTypeConverter typeConverter;
+ CustomBufferizeTypeConverter typeConverter;
// Mark all Standard operations legal.
target.addLegalDialect<AffineDialect, math::MathDialect,
@@ -401,6 +535,7 @@ void mlir::linalg::populateLinalgBufferizePatterns(
BufferizeTensorReshapeOp<TensorCollapseShapeOp>,
ExtractSliceOpConverter,
InsertSliceOpConverter,
+ TiledLoopOpConverter,
VectorTransferReadOpConverter,
VectorTransferWriteOpConverter
>(typeConverter, patterns.getContext());
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 93dbf4a563675..dbb2bb713e773 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -316,3 +316,36 @@ func @vector_transfer(%in: tensor<4xf32>, %out: tensor<4xf32>) {
// CHECK: vector.transfer_read {{.*}} : memref<4xf32>, vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<4xf32>
}
+
+// -----
+
+// CHECK: func @tiled_dot
+func @tiled_dot(%A: tensor<10xf32>, %B: tensor<10xf32>,
+ %C: tensor<f32>) -> tensor<f32> {
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %c10 = constant 10 : index
+
+ %dot = linalg.tiled_loop (%i) = (%c0) to (%c10) step (%c2)
+ ins (%A_ = %A: tensor<10xf32>, %B_ = %B: tensor<10xf32>)
+ outs (%C_ = %C: tensor<f32>)
+ iterators["reduction"] {
+ %A_sub = tensor.extract_slice %A_[%i] [%c2] [1]
+ : tensor<10xf32> to tensor<?xf32>
+ %B_sub = tensor.extract_slice %B_[%i] [%c2] [1]
+ : tensor<10xf32> to tensor<?xf32>
+ %dot_sub = linalg.dot ins(%A_sub, %B_sub : tensor<?xf32>, tensor<?xf32>)
+ outs(%C_ : tensor<f32>) -> tensor<f32>
+ linalg.yield %dot_sub : tensor<f32>
+ }
+ // CHECK: linalg.tiled_loop
+ // CHECK-SAME: ins (%[[A:.*]] = %{{.*}}: memref<10xf32>, %[[B:.*]] = %{{.*}}: memref<10xf32>)
+ // CHECK-SAME: outs (%[[C:.*]] = %{{.*}}: memref<f32>)
+ // CHECK-NOT: alloc
+ // CHECK: %[[SV_A:.*]] = memref.subview %[[A]]
+ // CHECK: %[[SV_B:.*]] = memref.subview %[[B]]
+ // CHECK: linalg.dot ins(%[[SV_A]], %[[SV_B]]
+ // CHECK-SAME: outs(%[[C]] : memref<f32>)
+ // CHECK: linalg.yield
+ return %dot : tensor<f32>
+}
More information about the Mlir-commits
mailing list