[Mlir-commits] [mlir] 08dbed8 - [mlir][linalg] Canonicalize dim ops of tiled_loop block args
Matthias Springer
llvmlistbot at llvm.org
Wed Aug 18 19:27:26 PDT 2021
Author: Matthias Springer
Date: 2021-08-19T11:24:33+09:00
New Revision: 08dbed8a5725087a7be71c293e4d86ca0c4c0510
URL: https://github.com/llvm/llvm-project/commit/08dbed8a5725087a7be71c293e4d86ca0c4c0510
DIFF: https://github.com/llvm/llvm-project/commit/08dbed8a5725087a7be71c293e4d86ca0c4c0510.diff
LOG: [mlir][linalg] Canonicalize dim ops of tiled_loop block args
E.g.:
```
%y = ... : tensor<...>
linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
tensor.dim %x, %c0 : tensor<...>
}
```
is rewritten to:
```
%y = ... : tensor<...>
linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
tensor.dim %y, %c0 : tensor<...>
}
```
Differential Revision: https://reviews.llvm.org/D108272
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a58d91e98b71..6b65a9ecd9e5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2064,6 +2064,57 @@ struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
}
};
+/// Fold dim(x) where `x` is an input/output argument of a TiledLoopOp block
+/// to dim(y) where `y` is the initial input/output value of the argument.
+///
+/// E.g.:
+/// %y = ... : tensor<...>
+/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
+/// tensor.dim %x, %c0 : tensor<...>
+/// }
+///
+/// is folded to:
+/// %y = ... : tensor<...>
+/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
+/// tensor.dim %y, %c0 : tensor<...>
+/// }
+template <typename OpTy>
+struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy dimOp,
+ PatternRewriter &rewriter) const final {
+ auto src = dimOp.source().template dyn_cast<BlockArgument>();
+ if (!src)
+ return failure();
+ auto loopOp = dyn_cast<TiledLoopOp>(
+ src.getOwner()->getParent()->getParentOp());
+ if (!loopOp)
+ return failure();
+
+ auto inputArgs = loopOp.getRegionInputArgs();
+ auto it1 = llvm::find(inputArgs, src);
+ if (it1 != inputArgs.end()) {
+ rewriter.updateRootInPlace(dimOp, [&] {
+ dimOp.sourceMutable().assign(loopOp.inputs()[it1 - inputArgs.begin()]);
+ });
+ return success();
+ }
+
+ auto outputArgs = loopOp.getRegionOutputArgs();
+ auto it2 = llvm::find(outputArgs, src);
+ if (it2 != outputArgs.end()) {
+ rewriter.updateRootInPlace(dimOp, [&] {
+ dimOp.sourceMutable().assign(
+ loopOp.outputs()[it2 - outputArgs.begin()]);
+ });
+ return success();
+ }
+
+ return failure();
+ }
+};
+
// Folds away TiledLoopOp output tensors when the following conditions are met:
// * result of `linalg.tiled_loop` has no uses
// * output tensor is the argument of `linalg.yield`
@@ -2167,7 +2218,9 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder>(context);
+ results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder,
+ DimOfTiledLoopInsOutsFolder<tensor::DimOp>,
+ DimOfTiledLoopInsOutsFolder<memref::DimOp>>(context);
}
LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 34dacd58a51e..41a4bfe9c980 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -919,3 +919,30 @@ func @dim_of_pad_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
%r = tensor.dim %0, %c0 : tensor<?x?xf32>
return %r : index
}
+
+// -----
+
+// CHECK-LABEL: func @dim_of_tiled_loop_input(
+// CHECK-SAME: %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
+// CHECK: %[[c0:.*]] = constant 0 : index
+// CHECK: linalg.tiled_loop
+// CHECK: %[[dim:.*]] = tensor.dim %[[arg1]], %[[c0]]
+// CHECK: index_cast %[[dim]]
+func @dim_of_tiled_loop_input(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+ -> tensor<?x?xf32> {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0)
+ to (%d0, %d1) step (%c1, %c1)
+ ins (%in0 = %arg0 : tensor<?x?xf32>, %in1 = %arg1 : tensor<?x?xf32>)
+ outs (%out1 = %arg2 : tensor<?x?xf32>) {
+ %inner_dim = tensor.dim %in1, %c0 : tensor<?x?xf32>
+ %cast1 = std.index_cast %inner_dim : index to i32
+ %cast2 = std.sitofp %cast1 : i32 to f32
+ %fill = linalg.fill(%cast2, %out1) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+ linalg.yield %fill : tensor<?x?xf32>
+ }
+ return %r : tensor<?x?xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
index de7c0b7c4820..01dab480d89c 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -121,8 +121,8 @@ module {
// TLOOP: %[[AB_SUB:.*]] = linalg.matmul
// TLOOP-SAME: ins(%[[A_SUB]], %[[B_]] : {{.*}}) outs(%[[AB_INIT_SUB]]
-// TLOOP: %[[DIM_B_1:.*]] = tensor.dim %[[B_]], %[[C1]] : [[TY]]
-// TLOOP: %[[DIM_C_1:.*]] = tensor.dim %[[C_]], %[[C1]] : [[TY]]
+// TLOOP: %[[DIM_B_1:.*]] = tensor.dim %[[B]], %[[C1]] : [[TY]]
+// TLOOP: %[[DIM_C_1:.*]] = tensor.dim %[[C]], %[[C1]] : [[TY]]
// TLOOP: %[[ABC_SUB_:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) =
// TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_C_1]], %[[DIM_B_1]])
@@ -300,7 +300,7 @@ module {
// TLOOP-SAME: %[[C0_F32_:.*]] = %[[C0_F32]]
// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
-// TLOOP: %[[DIM_A__1:.*]] = tensor.dim %[[A_]], %[[C1]] : [[TY]]
+// TLOOP: %[[DIM_A__1:.*]] = tensor.dim %[[A]], %[[C1]] : [[TY]]
// TLOOP: %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0]
// TLOOP: %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]]
// TLOOP: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]]
@@ -371,7 +371,7 @@ module {
// TLOOP-SAME: %[[C0_F32_:.*]] = %[[C0_F32]]
// TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
-// TLOOP: %[[DIM_A__1:.*]] = tensor.dim %[[A_]], %[[C1]] : [[TY]]
+// TLOOP: %[[DIM_A__1:.*]] = tensor.dim %[[A]], %[[C1]] : [[TY]]
// TLOOP: %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0]
// TLOOP: %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]]
// TLOOP: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]]
More information about the Mlir-commits
mailing list