[Mlir-commits] [mlir] 7f2236c - [mlir][linalg] Add output tensor args folding for linalg.tiled_loop.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Mar 25 10:12:07 PDT 2021
Author: Alexander Belyaev
Date: 2021-03-25T18:11:05+01:00
New Revision: 7f2236cf581e6d666e4c3eb512a76f1608fe0bf7
URL: https://github.com/llvm/llvm-project/commit/7f2236cf581e6d666e4c3eb512a76f1608fe0bf7
DIFF: https://github.com/llvm/llvm-project/commit/7f2236cf581e6d666e4c3eb512a76f1608fe0bf7.diff
LOG: [mlir][linalg] Add output tensor args folding for linalg.tiled_loop.
Folds away TiledLoopOp output tensors when the following conditions are met:
* result of `linalg.tiled_loop` has no uses
* output tensor is the argument of `linalg.yield`
Example:
```
%0 = linalg.tiled_loop ... outs (%out, %out_buf:tensor<...>, memref<...>) {
...
linalg.yield %out : tensor ...
}
```
Becomes
```
linalg.tiled_loop ... outs (%out_buf:memref<...>) {
...
linalg.yield
}
```
Differential Revision: https://reviews.llvm.org/D99333
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index d54efbe37a57..fe6720761abe 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -584,6 +584,9 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
}
unsigned getNumLoops() { return step().size(); }
}];
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index fdb2e4f4603e..744f0276daa8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1943,6 +1943,87 @@ bool TiledLoopOp::isDefinedOutsideOfLoop(Value value) {
static LogicalResult verify(TiledLoopOp op) { return success(); }
+namespace {
+
+// Folds away TiledLoopOp output tensors when the following conditions are met:
+// * result of `linalg.tiled_loop` has no uses
+// * output tensor is the argument of `linalg.yield`
+//
+// Example:
+//
+// %0 = linalg.tiled_loop ... outs (%out, %out_buf:tensor<...>, memref<...>) {
+// ...
+// linalg.yield %out : tensor ...
+// }
+//
+// Becomes
+//
+// linalg.tiled_loop ... outs (%out_buf:memref<...>) {
+// ...
+// linalg.yield
+// }
+struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
+ using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop,
+ PatternRewriter &rewriter) const final {
+ if (tiledLoop.getNumResults() == 0)
+ return failure();
+
+ Block *block = tiledLoop.getBody();
+ auto yieldOp = cast<linalg::YieldOp>(block->getTerminator());
+
+ // Match the pattern and collect output buffers that will replace the output
+ // tensors and also the ops that will be ignored when cloning the body.
+ SmallVector<Value, 2> newOutputOperands, newYieldArgs;
+ int resultId = 0;
+ for (Value out : tiledLoop.outputs()) {
+ if (!out.getType().isa<RankedTensorType>()) {
+ newOutputOperands.push_back(out);
+ continue;
+ }
+ Value result = tiledLoop.getResult(resultId);
+ Value yieldArg = yieldOp.getOperand(resultId);
+ if (yieldArg != out || !result.use_empty()) {
+ newOutputOperands.push_back(out);
+ newYieldArgs.push_back(yieldArg);
+ }
+ ++resultId;
+ }
+ if (newOutputOperands.size() == tiledLoop.outputs().size())
+ return failure();
+
+ Location loc = tiledLoop.getLoc();
+ auto newTiledLoop = rewriter.create<TiledLoopOp>(
+ loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
+ tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types());
+
+ // Clone the region ignoring the def-chain for linalg.yield args:
+ // unnecessary `subtensor_insert`, `tensor_load` and `cast` ops.
+ BlockAndValueMapping bvm;
+ bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
+ OpBuilder innerBuilder =
+ OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
+ for (auto &op : tiledLoop.getBody()->without_terminator())
+ innerBuilder.clone(op, bvm);
+ innerBuilder.create<linalg::YieldOp>(loc, newYieldArgs);
+ rewriter.eraseOp(tiledLoop);
+
+ return success();
+ }
+};
+} // namespace
+
+void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<TiledLoopResultsFolder>(context);
+}
+
+LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ return foldMemRefCast(*this);
+}
+
/////// Operations corresponding to library calls defined with Tablegen ////////
template <typename LinalgPoolingOp>
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 5ec93dda59d0..44f9dbd49cd5 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -818,3 +818,46 @@ func @fold_fill_reshape() -> tensor<6x4xf32> {
// CHECK: return %[[FILL]] : tensor<6x4xf32>
return %reshape : tensor<6x4xf32>
}
+
+// -----
+
+#map0 = affine_map<(d0) -> (24, -d0 + 192)>
+#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
+#map2 = affine_map<(d0) -> (16, -d0 + 192)>
+
+func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
+ %C: memref<192x192xf32>) -> ()
+
+func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
+ %C: memref<192x192xf32>,
+ %C_tensor: tensor<192x192xf32>) {
+ %cst = constant 0.000000e+00 : f32
+ %c24 = constant 24 : index
+ %c16 = constant 16 : index
+ %c0 = constant 0 : index
+ %c192 = constant 192 : index
+ %useless = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192)
+ step (%c24, %c16)
+ ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>)
+ outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) {
+ call @foo(%A, %B, %C) : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
+ linalg.yield %C_tensor : tensor<192x192xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: func @fold_tiled_loop_results(
+// CHECK-SAME: %[[A:.*]]: [[TY:.*]], %[[B:.*]]: [[TY]], %[[C:.*]]: [[TY]],
+// CHECK-SAME: %[[C_TENSOR:.*]]: tensor<{{.*}}>) {
+// CHECK: %[[C24:.*]] = constant 24 : index
+// CHECK: %[[C16:.*]] = constant 16 : index
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C192:.*]] = constant 192 : index
+
+// CHECK-NOT: %{{.*}} = linalg.tiled_loop
+// CHECK: linalg.tiled_loop (%{{.*}}, %{{.*}}) = (%[[C0]], %[[C0]])
+// CHECK-SAME: to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]])
+// CHECK-SAME: ins (%[[A]], %[[B]]: memref<192x192xf32>, memref<192x192xf32>)
+// CHECK-SAME: outs (%[[C]]:memref<192x192xf32>) {
+// CHECK-NEXT: call @foo(%[[A]], %[[B]], %[[C]])
+// CHECK-NEXT: linalg.yield
More information about the Mlir-commits
mailing list