[Mlir-commits] [mlir] e2c8fcb - [mlir][linalg] Fold dim(linalg.tiled_loop) to dim(output_arg)
Matthias Springer
llvmlistbot at llvm.org
Wed Sep 8 21:44:44 PDT 2021
Author: Matthias Springer
Date: 2021-09-09T13:37:28+09:00
New Revision: e2c8fcb9d0bd33fda481f7e27cf0d6ebdde2b5b0
URL: https://github.com/llvm/llvm-project/commit/e2c8fcb9d0bd33fda481f7e27cf0d6ebdde2b5b0
DIFF: https://github.com/llvm/llvm-project/commit/e2c8fcb9d0bd33fda481f7e27cf0d6ebdde2b5b0.diff
LOG: [mlir][linalg] Fold dim(linalg.tiled_loop) to dim(output_arg)
Fold dim ops of linalg.tiled_loop results to dim ops of the respective iter args if the loop is shape preserving.
Differential Revision: https://reviews.llvm.org/D109431
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a1c0c996e332c..688f241c36f5a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2286,6 +2286,44 @@ struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
}
};
+} // namespace
+
+/// A simple, conservative analysis to determine if the loop is shape
+/// conserving. I.e., the type of the arg-th yielded value is the same as the
+/// type of the corresponding basic block argument of the loop.
+/// Note: This function handles only simple cases. Expand as needed.
+static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) {
+ auto yieldOp = cast<YieldOp>(loopOp.getLoopBody().front().getTerminator());
+ if (yieldOp.values().empty())
+ // Tiled loop either has no outputs or is a "memref-based version". In
+ // either case, the loop is shape conserving.
+ return true;
+ assert(arg < static_cast<int64_t>(yieldOp.values().size()) &&
+ "arg is out of bounds");
+ Value value = yieldOp.values()[arg];
+ while (value) {
+ if (value == loopOp.getRegionOutputArgs()[arg])
+ return true;
+ OpResult opResult = value.dyn_cast<OpResult>();
+ if (!opResult)
+ return false;
+
+ using tensor::InsertSliceOp;
+ value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
+ .template Case<InsertSliceOp>(
+ [&](InsertSliceOp op) { return op.dest(); })
+ .template Case<TiledLoopOp>([&](TiledLoopOp loopOp) {
+ return isShapePreserving(loopOp, opResult.getResultNumber())
+ ? loopOp.outputs()[opResult.getResultNumber()]
+ : Value();
+ })
+ .Default([&](auto op) { return Value(); });
+ }
+ return false;
+}
+
+namespace {
+
/// Fold dim(x) where `x` is an input/output argument of a TiledLoopOp block
/// to dim(y) where `y` is the initial input/output value of the argument.
///
@@ -2307,40 +2345,6 @@ template <typename OpTy>
struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
- /// A simple, conservative analysis to determine if the loop is shape
- /// conserving. I.e., the type of the arg-th yielded value is the same as the
- /// type of the corresponding basic block argument of the loop.
- /// Note: This function handles only simple cases. Expand as needed.
- static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) {
- auto yieldOp = cast<YieldOp>(loopOp.getLoopBody().front().getTerminator());
- if (yieldOp.values().empty())
- // Tiled loop either has no outputs or is a "memref-based version". In
- // either case, the loop is shape conserving.
- return true;
- assert(arg < static_cast<int64_t>(yieldOp.values().size()) &&
- "arg is out of bounds");
- Value value = yieldOp.values()[arg];
- while (value) {
- if (value == loopOp.getRegionOutputArgs()[arg])
- return true;
- OpResult opResult = value.dyn_cast<OpResult>();
- if (!opResult)
- return false;
-
- using tensor::InsertSliceOp;
- value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
- .template Case<InsertSliceOp>(
- [&](InsertSliceOp op) { return op.dest(); })
- .template Case<TiledLoopOp>([&](TiledLoopOp loopOp) {
- return isShapePreserving(loopOp, opResult.getResultNumber())
- ? loopOp.outputs()[opResult.getResultNumber()]
- : Value();
- })
- .Default([&](auto op) { return Value(); });
- }
- return false;
- }
-
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const final {
auto src = dimOp.source().template dyn_cast<BlockArgument>();
@@ -2380,6 +2384,45 @@ struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
}
};
+/// Fold dim(r) where `r` is the result of a TiledLoopOp to dim(y) where `y`
+/// is the initial output value of the loop.
+///
+/// E.g.:
+/// %y = ... : tensor<...>
+/// %r = linalg.tiled_loop ... outs(%i = %y : tensor<...>) {
+/// ...
+/// }
+/// %0 = tensor.dim %r, %c0 : tensor<...>
+///
+/// is folded to:
+/// %y = ... : tensor<...>
+/// linalg.tiled_loop ... outs(%i = %y : tensor<...>) {
+/// ...
+/// }
+/// %0 = tensor.dim %y, %c0 : tensor<...>
+///
+/// Note: Dim ops are folded only if it can be proven that the runtime type of
+/// the yielded value (in case of outputs) does not change with loop iterations.
+template <typename OpTy>
+struct DimOfTiledLoopResultFolder : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy dimOp,
+ PatternRewriter &rewriter) const final {
+ auto loopOp = dimOp.source().template getDefiningOp<TiledLoopOp>();
+ if (!loopOp)
+ return failure();
+ auto opResult = dimOp.source().template cast<OpResult>();
+ unsigned resultNumber = opResult.getResultNumber();
+ if (!isShapePreserving(loopOp, resultNumber))
+ return failure();
+ rewriter.updateRootInPlace(dimOp, [&]() {
+ dimOp.sourceMutable().assign(loopOp.outputs()[resultNumber]);
+ });
+ return success();
+ }
+};
+
// Folds away TiledLoopOp output tensors when the following conditions are met:
// * result of `linalg.tiled_loop` has no uses
// * output tensor is the argument of `linalg.yield`
@@ -2485,7 +2528,9 @@ void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder,
DimOfTiledLoopInsOutsFolder<tensor::DimOp>,
- DimOfTiledLoopInsOutsFolder<memref::DimOp>>(context);
+ DimOfTiledLoopInsOutsFolder<memref::DimOp>,
+ DimOfTiledLoopResultFolder<tensor::DimOp>,
+ DimOfTiledLoopResultFolder<memref::DimOp>>(context);
}
LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 14c7a1ac639df..db915f10e7dde 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -956,3 +956,51 @@ func @dim_of_tiled_loop_input(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %a
}
return %r : tensor<?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @dim_of_tiled_loop_result(
+// CHECK-SAME: %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
+// CHECK: %[[c0:.*]] = constant 0 : index
+// CHECK: tensor.dim %[[arg2]], %[[c0]]
+func @dim_of_tiled_loop_result(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %s: index)
+ -> index {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0)
+ to (%d0, %d1) step (%c1, %c1)
+ ins (%in0 = %arg0 : tensor<?x?xf32>, %in1 = %arg1 : tensor<?x?xf32>)
+ outs (%out1 = %arg2 : tensor<?x?xf32>) {
+ %1 = tensor.insert_slice %arg0 into %out1 [0, 0] [%s, %s] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+ linalg.yield %1 : tensor<?x?xf32>
+ }
+ %r2 = tensor.dim %r, %c0 : tensor<?x?xf32>
+ return %r2 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @dim_of_tiled_loop_result_no_canonicalize(
+// CHECK-SAME: %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
+// CHECK: %[[c0:.*]] = constant 0 : index
+// CHECK: %[[r:.*]] = linalg.tiled_loop
+// CHECK: tensor.dim %[[r]], %[[c0]]
+func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %s: index)
+ -> index {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0)
+ to (%d0, %d1) step (%c1, %c1)
+ ins (%in0 = %arg0 : tensor<?x?xf32>, %in1 = %arg1 : tensor<?x?xf32>)
+ outs (%out1 = %arg2 : tensor<?x?xf32>) {
+ %1 = tensor.insert_slice %arg0 into %arg1 [0, 0] [%s, %s] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+ linalg.yield %1 : tensor<?x?xf32>
+ }
+ %r2 = tensor.dim %r, %c0 : tensor<?x?xf32>
+ return %r2 : index
+}
+
More information about the Mlir-commits
mailing list