[Mlir-commits] [mlir] 2f946ea - [mlir] Change the pattern for TiledLoopOp bufferization.
Alexander Belyaev
llvmlistbot at llvm.org
Tue Aug 10 12:27:24 PDT 2021
Author: Alexander Belyaev
Date: 2021-08-10T21:27:02+02:00
New Revision: 2f946eaa9d2648b883b2a1e567b23fff307f13d9
URL: https://github.com/llvm/llvm-project/commit/2f946eaa9d2648b883b2a1e567b23fff307f13d9
DIFF: https://github.com/llvm/llvm-project/commit/2f946eaa9d2648b883b2a1e567b23fff307f13d9.diff
LOG: [mlir] Change the pattern for TiledLoopOp bufferization.
This version is does not affect the patterns for Extract/InsertSliceOp and
LinalgOps.
Differential Revision: https://reviews.llvm.org/D107858
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/test/Dialect/Linalg/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index b46e58be8349a..04865594c1ad4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -213,10 +213,8 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
Location loc = op.getLoc();
SmallVector<Value, 2> newOutputBuffers;
- if (op->getParentOfType<TiledLoopOp>()) {
- newOutputBuffers = adaptor.outputs();
- } else if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
- newOutputBuffers, rewriter))) {
+ if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
+ newOutputBuffers, rewriter))) {
return op.emitOpError()
<< "Failed to allocate buffers for tensor results.";
}
@@ -233,14 +231,6 @@ class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
}
};
-bool IsBlockArgOfTiledLoop(Value tensor) {
- if (auto tensorLoad = tensor.getDefiningOp<memref::TensorLoadOp>())
- if (auto blockArgument = tensorLoad.memref().dyn_cast<BlockArgument>())
- if (isa<TiledLoopOp>(blockArgument.getOwner()->getParentOp()))
- return true;
- return false;
-}
-
/// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an
/// alloc + copy pattern.
/// ```
@@ -263,15 +253,6 @@ class ExtractSliceOpConverter
Value sourceMemref = adaptor.source();
assert(sourceMemref.getType().isa<MemRefType>());
- // Block arguments of the tiled_loop can be bufferized inplace.
- if (IsBlockArgOfTiledLoop(op.source())) {
- Value subView = rewriter.create<memref::SubViewOp>(
- op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
- op.getMixedStrides());
- rewriter.replaceOp(op, subView);
- return success();
- }
-
MemRefType subviewMemRefType =
getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
// op.sizes() capture exactly the dynamic alloc operands matching the
@@ -315,12 +296,7 @@ class InsertSliceOpConverter
// For now, be conservative and copy the converted input memref.
// In general, the converted input memref here could be aliased or could
// point into constant memory, so mutating it would lead to miscompilations.
- // Block arguments of the tiled_loop can be bufferized inplace.
- Value destMemRef;
- if (IsBlockArgOfTiledLoop(op.dest()))
- destMemRef = adaptor.dest();
- else
- destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
+ Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
assert(destMemRef.getType().isa<MemRefType>());
// Take a subview to copy the small memref.
@@ -334,60 +310,115 @@ class InsertSliceOpConverter
}
};
+bool isBlockArgOfTiledLoop(Value tensor) {
+ if (auto blockArgument = tensor.dyn_cast<BlockArgument>())
+ return isa<TiledLoopOp>(blockArgument.getOwner()->getParentOp());
+ return false;
+}
+
+SmallVector<Value, 3> convertOperands(ValueRange operands,
+ BlockAndValueMapping &bvm) {
+ SmallVector<Value, 3> newOperands;
+ newOperands.reserve(operands.size());
+ for (auto operand : operands)
+ newOperands.push_back(bvm.lookupOrDefault(operand));
+ return newOperands;
+}
+
class TiledLoopOpConverter : public OpConversionPattern<TiledLoopOp> {
public:
using OpConversionPattern<TiledLoopOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(TiledLoopOp tiledLoop, ArrayRef<Value> operands,
+ matchAndRewrite(TiledLoopOp loop, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- TiledLoopOp::Adaptor adaptor(operands, tiledLoop->getAttrDictionary());
- Location loc = tiledLoop.getLoc();
- if (tiledLoop.getNumResults() == 0)
+ TiledLoopOp::Adaptor adaptor(operands, loop->getAttrDictionary());
+ if (loop.getNumResults() == 0)
return failure();
- auto newTiledLoop = rewriter.create<TiledLoopOp>(
+
+ Location loc = loop.getLoc();
+ auto newLoop = rewriter.create<TiledLoopOp>(
loc, adaptor.lowerBound(), adaptor.upperBound(), adaptor.step(),
adaptor.inputs(), adaptor.outputs(), adaptor.iterator_types(),
adaptor.distribution_types());
+
// Clone the region.
BlockAndValueMapping bvm;
- bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
+ bvm.map(loop.getInductionVars(), newLoop.getInductionVars());
+ bvm.map(loop.getRegionInputArgs(), newLoop.getRegionInputArgs());
+ bvm.map(loop.getRegionOutputArgs(), newLoop.getRegionOutputArgs());
OpBuilder innerBuilder =
- OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
-
- // Remap input block arguments.
- SmallVector<Value, 2> inputs;
- for (auto en : llvm::zip(newTiledLoop.getRegionInputArgs(),
- tiledLoop.getRegionInputArgs())) {
- auto &newInputArg = std::get<0>(en);
- if (!newInputArg.getType().isa<ShapedType>()) {
- inputs.push_back(std::get<0>(en));
- continue;
+ OpBuilder::atBlockEnd(newLoop.getBody(), rewriter.getListener());
+
+ for (auto &op : loop.getBody()->getOperations()) {
+ Location loc = op.getLoc();
+ if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
+ if (isBlockArgOfTiledLoop(extractSlice.source())) {
+ auto newOperands = convertOperands(extractSlice.getOperands(), bvm);
+ auto srcMemRefType =
+ bvm.lookup(extractSlice.source()).getType().cast<MemRefType>();
+ auto dstMemRefType =
+ memref::SubViewOp::inferResultType(
+ srcMemRefType,
+ extractFromI64ArrayAttr(extractSlice.static_offsets()),
+ extractFromI64ArrayAttr(extractSlice.static_sizes()),
+ extractFromI64ArrayAttr(extractSlice.static_strides()))
+ .cast<MemRefType>();
+
+ Value subView = innerBuilder.create<memref::SubViewOp>(
+ loc, TypeRange{dstMemRefType}, newOperands,
+ extractSlice->getAttrs());
+ bvm.map(extractSlice.getResult(), subView);
+ continue;
+ }
}
- inputs.push_back(
- innerBuilder.create<memref::TensorLoadOp>(loc, newInputArg));
- }
- bvm.map(tiledLoop.getRegionInputArgs(), inputs);
-
- // Remap output block arguments.
- SmallVector<Value, 2> outputs;
- for (auto en : llvm::zip(newTiledLoop.getRegionOutputArgs(),
- tiledLoop.getRegionOutputArgs())) {
- auto &newOutputArg = std::get<0>(en);
- if (!newOutputArg.getType().isa<ShapedType>()) {
- outputs.push_back(std::get<0>(en));
+ if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(op)) {
+ if (isBlockArgOfTiledLoop(insertSlice.dest())) {
+ continue;
+ }
+ }
+ if (auto yield = dyn_cast<linalg::YieldOp>(op)) {
+ for (OpOperand &operand : yield->getOpOperands()) {
+ if (auto insert =
+ operand.get().getDefiningOp<tensor::InsertSliceOp>()) {
+
+ auto dstMemRefType = memref::SubViewOp::inferResultType(
+ getTypeConverter()
+ ->convertType(insert.source().getType())
+ .cast<MemRefType>(),
+ extractFromI64ArrayAttr(insert.static_offsets()),
+ extractFromI64ArrayAttr(insert.static_sizes()),
+ extractFromI64ArrayAttr(insert.static_strides()));
+
+ Value subView = innerBuilder.create<memref::SubViewOp>(
+ loc, dstMemRefType, bvm.lookup(insert.dest()),
+ convertOperands(insert.offsets(), bvm),
+ convertOperands(insert.sizes(), bvm),
+ convertOperands(insert.strides(), bvm), insert.static_offsets(),
+ insert.static_sizes(), insert.static_strides());
+
+ Value cast = innerBuilder.create<memref::BufferCastOp>(
+ loc,
+ getTypeConverter()
+ ->convertType(insert.source().getType())
+ .cast<MemRefType>(),
+ bvm.lookup(insert.source()));
+
+ innerBuilder.create<linalg::CopyOp>(loc, cast, subView);
+ continue;
+ }
+ auto dst = newLoop.getRegionOutputArgs()[operand.getOperandNumber()];
+ Value cast = innerBuilder.create<memref::BufferCastOp>(
+ loc, dst.getType(), bvm.lookup(operand.get()));
+ innerBuilder.create<linalg::CopyOp>(loc, cast, dst);
+ }
continue;
}
- outputs.push_back(
- innerBuilder.create<memref::TensorLoadOp>(loc, newOutputArg));
- }
- bvm.map(tiledLoop.getRegionOutputArgs(), outputs);
-
- for (auto &op : tiledLoop.getBody()->without_terminator())
innerBuilder.clone(op, bvm);
+ }
innerBuilder.create<linalg::YieldOp>(loc);
- rewriter.replaceOp(tiledLoop, newTiledLoop.outputs());
+ rewriter.replaceOp(loop, newLoop.outputs());
return success();
}
};
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index dbb2bb713e773..29f23c10e095c 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -339,13 +339,66 @@ func @tiled_dot(%A: tensor<10xf32>, %B: tensor<10xf32>,
linalg.yield %dot_sub : tensor<f32>
}
// CHECK: linalg.tiled_loop
- // CHECK-SAME: ins (%[[A:.*]] = %{{.*}}: memref<10xf32>, %[[B:.*]] = %{{.*}}: memref<10xf32>)
- // CHECK-SAME: outs (%[[C:.*]] = %{{.*}}: memref<f32>)
- // CHECK-NOT: alloc
- // CHECK: %[[SV_A:.*]] = memref.subview %[[A]]
- // CHECK: %[[SV_B:.*]] = memref.subview %[[B]]
- // CHECK: linalg.dot ins(%[[SV_A]], %[[SV_B]]
- // CHECK-SAME: outs(%[[C]] : memref<f32>)
- // CHECK: linalg.yield
+ // CHECK-SAME: ins (%[[A:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>,
+ // CHECK-SAME: %[[B:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>
+ // CHECK-SAME: outs (%[[C:arg[0-9]]] = %{{[0-9]}}: memref<f32>)
+
+ // CHECK-NEXT: %[[SV_A:.*]] = memref.subview %[[A]]
+ // CHECK-NEXT: %[[SV_B:.*]] = memref.subview %[[B]]
+ // CHECK-NEXT: %[[TMP:.*]] = memref.alloc
+ // CHECK-NEXT: linalg.copy(%[[C]], %[[TMP]])
+ // CHECK-NEXT: linalg.dot ins(%[[SV_A]], %[[SV_B]]
+ // CHECK-SAME: outs(%[[TMP]] : memref<f32>)
+ // CHECK-NEXT: linalg.copy(%[[TMP]], %[[C]])
+ // CHECK-NEXT: linalg.yield
return %dot : tensor<f32>
}
+
+// -----
+
+#map0 = affine_map<(d0) -> (d0)>
+
+func @tiled_add(%A: tensor<10xf32>, %B: tensor<10xf32>,
+ %C: tensor<10xf32>) -> tensor<10xf32> {
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %c10 = constant 10 : index
+
+ %sum = linalg.tiled_loop (%i) = (%c0) to (%c10) step (%c2)
+ ins (%A_ = %A: tensor<10xf32>, %B_ = %B: tensor<10xf32>)
+ outs (%C_ = %C: tensor<10xf32>) {
+ %A_sub = tensor.extract_slice %A_[%i] [%c2] [1]
+ : tensor<10xf32> to tensor<?xf32>
+ %B_sub = tensor.extract_slice %B_[%i] [%c2] [1]
+ : tensor<10xf32> to tensor<?xf32>
+ %C_sub = tensor.extract_slice %C_[%i] [%c2] [1]
+ : tensor<10xf32> to tensor<?xf32>
+ %sum_sub = linalg.generic {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel"]
+ } ins(%A_sub, %B_sub : tensor<?xf32>, tensor<?xf32>)
+ outs(%C_sub : tensor<?xf32>) {
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ %0 = std.addf %a, %b : f32
+ linalg.yield %0 : f32
+ } -> tensor<?xf32>
+ %update = tensor.insert_slice %sum_sub into %C_[%i] [%c2] [1]
+ : tensor<?xf32> into tensor<10xf32>
+ linalg.yield %update : tensor<10xf32>
+ }
+ // CHECK: linalg.tiled_loop
+ // CHECK-SAME: ins (%[[A:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>,
+ // CHECK-SAME: %[[B:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>
+ // CHECK-SAME: outs (%[[C:arg[0-9]]] = %{{[0-9]}}: memref<10xf32>)
+
+ // CHECK-NEXT: %[[SV_A:.*]] = memref.subview %[[A]]
+ // CHECK-NEXT: %[[SV_B:.*]] = memref.subview %[[B]]
+ // CHECK-NEXT: %[[TMP:.*]] = memref.alloc
+ // CHECK-NEXT: linalg.generic
+ // CHECK-SAME: ins(%[[SV_A]], %[[SV_B]]
+ // CHECK-SAME: outs(%[[TMP]] : memref<2xf32>)
+ // CHECK: %[[SV_C:.*]] = memref.subview %[[C]]
+ // CHECK-NEXT: linalg.copy(%[[TMP]], %[[SV_C]])
+ // CHECK-NEXT: linalg.yield
+ return %sum : tensor<10xf32>
+}
More information about the Mlir-commits
mailing list