[Mlir-commits] [mlir] 29dbac0 - [mlir] Add folding for tensor inputs and memref.cast in linalg.tiled_loop.
Alexander Belyaev
llvmlistbot at llvm.org
Wed Apr 28 05:36:48 PDT 2021
Author: Alexander Belyaev
Date: 2021-04-28T14:36:07+02:00
New Revision: 29dbac0ae29576176318525c9af65a15429c9466
URL: https://github.com/llvm/llvm-project/commit/29dbac0ae29576176318525c9af65a15429c9466
DIFF: https://github.com/llvm/llvm-project/commit/29dbac0ae29576176318525c9af65a15429c9466.diff
LOG: [mlir] Add folding for tensor inputs and memref.cast in linalg.tiled_loop.
Tensor inputs, if not used in the body of TiledLoopOp, can be removed.
memref::CastOp can be folded into TiledLoopOp as well.
Differential Revision: https://reviews.llvm.org/D101445
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5a6d498a65b49..8a80036688cd9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -136,7 +136,7 @@ static void dispatchIndexOpFoldResult(OpFoldResult ofr,
/// This is a common class used for patterns of the form
/// ```
-/// someop(memrefcast) -> someop
+/// someop(memrefcast(%src)) -> someop(%src)
/// ```
/// It folds the source of the memref.cast into the root operation directly.
static LogicalResult foldMemRefCast(Operation *op) {
@@ -151,6 +151,44 @@ static LogicalResult foldMemRefCast(Operation *op) {
return success(folded);
}
+/// This is a specialization of `foldMemRefCast` used for patterns of the form
+/// ```
+/// tiled_loop(memrefcast(%src)) -> tiled_loop(%src)
+/// ```
+/// It folds the source of the memref.cast into the root operation directly.
+static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) {
+ bool folded = false;
+ Location loc = op->getLoc();
+
+ Block *body = op.getBody();
+ OpBuilder b = OpBuilder::atBlockBegin(body);
+
+ // Update `input` and `output` operands and block arguments if necessary.
+ // Operands list: [lbs, ubs, steps, inputs, outputs].
+ // Block args list: [ivs, inputs, outputs].
+ for (size_t operandIndex = op.getNumControlOperands(),
+ bbArgIndex = op.getNumLoops(), e = op.getNumOperands();
+ operandIndex < e; ++operandIndex, ++bbArgIndex) {
+ OpOperand &operand = op->getOpOperand(operandIndex);
+
+ auto castOp = operand.get().getDefiningOp<memref::CastOp>();
+ if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
+ operand.set(castOp.getOperand());
+ auto newBbArg =
+ body->insertArgument(bbArgIndex, castOp.getOperand().getType());
+ auto oldBbArg = body->getArgument(newBbArg.getArgNumber() + 1);
+
+ // Insert memref.cast back to the original type.
+ oldBbArg.replaceAllUsesWith(
+ b.create<memref::CastOp>(loc, oldBbArg.getType(), newBbArg));
+ body->eraseArgument(oldBbArg.getArgNumber());
+
+ folded = true;
+ }
+ }
+ return success(folded);
+}
+
//===----------------------------------------------------------------------===//
// Region builder helper.
// TODO: Move this to a utility library.
@@ -2054,6 +2092,63 @@ static LogicalResult verify(TiledLoopOp op) {
namespace {
+static constexpr int64_t kNoMatch = -1;
+
+// Folds away TiledLoopOp input tensors if they have no uses within the body.
+//
+// Example:
+//
+// %0 = linalg.tiled_loop ... ins (%in_ = %in: tensor<...>,
+// %in_buf_ = %in_buf: memref<...>) {...}
+// Becomes
+//
+// linalg.tiled_loop ... ins (%in_buf_ = %in_buf: memref<...>) {...}
+struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
+ using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop,
+ PatternRewriter &rewriter) const final {
+ SmallVector<Value, 2> newInputs, regionInputTensorArgs;
+ // Store ids of the corresponding old and new input operands.
+ SmallVector<int64_t, 2> oldInputIdToNew(tiledLoop.inputs().size(),
+ kNoMatch);
+ for (auto en : llvm::enumerate(
+ llvm::zip(tiledLoop.inputs(), tiledLoop.getRegionInputArgs()))) {
+ Value in, bbArg;
+ size_t index = en.index();
+ std::tie(in, bbArg) = en.value();
+ if (!in.getType().isa<RankedTensorType>() || !bbArg.use_empty()) {
+ oldInputIdToNew[index] = newInputs.size();
+ newInputs.push_back(in);
+ continue;
+ }
+ }
+ if (newInputs.size() == tiledLoop.inputs().size())
+ return failure();
+ Location loc = tiledLoop.getLoc();
+ auto newTiledLoop = rewriter.create<TiledLoopOp>(
+ loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
+ newInputs, tiledLoop.outputs(), tiledLoop.iterator_types());
+
+ // Clone the region.
+ BlockAndValueMapping bvm;
+ bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
+ bvm.map(tiledLoop.getRegionOutputArgs(),
+ newTiledLoop.getRegionOutputArgs());
+ for (const auto &en : llvm::enumerate(oldInputIdToNew))
+ if (en.value() != kNoMatch)
+ bvm.map(tiledLoop.getRegionInputArgs()[en.index()],
+ newTiledLoop.getRegionInputArgs()[en.value()]);
+ OpBuilder innerBuilder =
+ OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
+ for (auto &op : *tiledLoop.getBody())
+ innerBuilder.clone(op, bvm);
+ rewriter.eraseOp(tiledLoop);
+
+ return success();
+ }
+};
+
// Folds away TiledLoopOp output tensors when the following conditions are met:
// * result of `linalg.tiled_loop` has no uses
// * output tensor is the argument of `linalg.yield`
@@ -2085,27 +2180,26 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
// Match the pattern and collect output buffers that will replace the output
// tensors and also the ops that will be ignored when cloning the body.
- SmallVector<Value, 2> newOutputOperands, newYieldArgs,
- regionOutputTensorArgs;
+ SmallVector<Value, 2> newOutputOperands, newYieldArgs;
int resultId = 0;
// Store ids of the corresponding old and new output operands.
- SmallVector<std::pair<size_t, size_t>, 2> old_out_id_to_new;
- for (auto item : llvm::enumerate(
+ SmallVector<int64_t, 2> oldOutputIdToNew(tiledLoop.outputs().size(),
+ kNoMatch);
+ for (auto en : llvm::enumerate(
llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) {
- size_t index = item.index();
- Value out = std::get<0>(item.value());
- Value outRegionArg = std::get<1>(item.value());
+ size_t index = en.index();
+ Value out = std::get<0>(en.value());
+ Value outRegionArg = std::get<1>(en.value());
if (!out.getType().isa<RankedTensorType>()) {
- old_out_id_to_new.push_back({index, newOutputOperands.size()});
+ oldOutputIdToNew[index] = newOutputOperands.size();
newOutputOperands.push_back(out);
- regionOutputTensorArgs.push_back(outRegionArg);
continue;
}
Value result = tiledLoop.getResult(resultId);
Value yieldArg = yieldOp.getOperand(resultId);
if (yieldArg != outRegionArg || !result.use_empty()) {
- old_out_id_to_new.push_back({index, newOutputOperands.size()});
+ oldOutputIdToNew[index] = newOutputOperands.size();
newOutputOperands.push_back(out);
newYieldArgs.push_back(yieldArg);
}
@@ -2119,14 +2213,18 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types());
- // Clone the region ignoring the def-chain for linalg.yield args:
- // unnecessary `subtensor_insert`, `tensor_load` and `cast` ops.
+ // Clone the region.
BlockAndValueMapping bvm;
bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
bvm.map(tiledLoop.getRegionInputArgs(), newTiledLoop.getRegionInputArgs());
- for (const auto &item : old_out_id_to_new)
- bvm.map(tiledLoop.getRegionOutputArgs()[item.first],
- newTiledLoop.getRegionOutputArgs()[item.second]);
+ for (const auto &en : llvm::enumerate(oldOutputIdToNew)) {
+ if (en.value() != kNoMatch)
+ bvm.map(tiledLoop.getRegionOutputArgs()[en.index()],
+ newTiledLoop.getRegionOutputArgs()[en.value()]);
+ else
+ bvm.map(tiledLoop.getRegionOutputArgs()[en.index()],
+ tiledLoop.outputs()[en.index()]);
+ }
OpBuilder innerBuilder =
OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
for (auto &op : tiledLoop.getBody()->without_terminator())
@@ -2141,12 +2239,12 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<TiledLoopResultsFolder>(context);
+ results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder>(context);
}
LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
- return foldMemRefCast(*this);
+ return foldMemRefCastInTiledLoopOp(*this);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index afdfe6fb98a81..e66ee388c65eb 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -18,6 +18,31 @@ func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
// -----
+#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+
+// CHECK-LABEL: func @memref_cast_into_tiled_loop(
+func @memref_cast_into_tiled_loop(%arg0: memref<192xf32>) {
+ %0 = memref.cast %arg0
+ : memref<192xf32> to memref<192xf32, #map>
+ %cst = constant 0.000000e+00 : f32
+ %c24 = constant 24 : index
+ %c0 = constant 0 : index
+ %c192 = constant 192 : index
+ // CHECK: linalg.tiled_loop
+ // CHECK-SAME: outs (%{{.*}} = %{{.*}}: memref<192xf32>)
+ linalg.tiled_loop (%arg3) = (%c0) to (%c192) step (%c24)
+ outs (%out = %0: memref<192xf32, #map>) {
+ %14 = affine.min affine_map<(d0) -> (-d0 + 192, 24)>(%arg3)
+ %16 = memref.subview %out[%arg3] [%14] [1]
+ : memref<192xf32, #map> to memref<?xf32, #map>
+ linalg.fill(%16, %cst) : memref<?xf32, #map>, f32
+ linalg.yield
+ }
+ return
+}
+
+// -----
+
func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32>
{
%0 = linalg.tensor_reshape %arg0
@@ -889,6 +914,30 @@ func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
// -----
+#map0 = affine_map<(d0) -> (24, -d0 + 192)>
+#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
+#map2 = affine_map<(d0) -> (16, -d0 + 192)>
+
+func private @foo(%A: memref<192xf32>) -> ()
+
+func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>) {
+ %c0 = constant 0 : index
+ %c24 = constant 24 : index
+ %c192 = constant 192 : index
+ linalg.tiled_loop (%i) = (%c0) to (%c192) step (%c24)
+ ins (%A_ = %A: memref<192xf32>, %AT_ = %A_tensor: tensor<192xf32>) {
+ call @foo(%A_) : (memref<192xf32>)-> ()
+ linalg.yield
+ }
+ return
+}
+
+// CHECK-LABEL: func @fold_tiled_loop_inputs
+// CHECK: linalg.tiled_loop
+// CHECK-SAME: ins (%{{.*}} = %{{.*}}: memref<192xf32>)
+
+// -----
+
func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,
%arg3: f32) -> (index, index, index)
{
More information about the Mlir-commits
mailing list