[Mlir-commits] [mlir] c7d569b - [mlir][scf] Fold dim(scf.for) to dim(iter_arg)
Matthias Springer
llvmlistbot at llvm.org
Wed Sep 8 21:52:48 PDT 2021
Author: Matthias Springer
Date: 2021-09-09T13:47:13+09:00
New Revision: c7d569b8f73d5f1ff03a65fb2b25d966d98c5a5f
URL: https://github.com/llvm/llvm-project/commit/c7d569b8f73d5f1ff03a65fb2b25d966d98c5a5f
DIFF: https://github.com/llvm/llvm-project/commit/c7d569b8f73d5f1ff03a65fb2b25d966d98c5a5f.diff
LOG: [mlir][scf] Fold dim(scf.for) to dim(iter_arg)
Fold dim ops of scf.for results to dim ops of the respective iter args if the loop is shape preserving.
Differential Revision: https://reviews.llvm.org/D109430
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 3f2cc70bf7061..a65ffc2d1b2c9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -25,6 +25,37 @@
using namespace mlir;
using namespace mlir::scf;
+/// A simple, conservative analysis to determine if the loop is shape
+/// conserving. I.e., the type of the arg-th yielded value is the same as the
+/// type of the corresponding basic block argument of the loop.
+/// Note: This function handles only simple cases. Expand as needed.
+static bool isShapePreserving(ForOp forOp, int64_t arg) {
+ auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
+ assert(arg < static_cast<int64_t>(yieldOp.results().size()) &&
+ "arg is out of bounds");
+ Value value = yieldOp.results()[arg];
+ while (value) {
+ if (value == forOp.getRegionIterArgs()[arg])
+ return true;
+ OpResult opResult = value.dyn_cast<OpResult>();
+ if (!opResult)
+ return false;
+
+ using tensor::InsertSliceOp;
+ value =
+ llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
+ .template Case<InsertSliceOp>(
+ [&](InsertSliceOp op) { return op.dest(); })
+ .template Case<ForOp>([&](ForOp forOp) {
+ return isShapePreserving(forOp, opResult.getResultNumber())
+ ? forOp.getIterOperands()[opResult.getResultNumber()]
+ : Value();
+ })
+ .Default([&](auto op) { return Value(); });
+ }
+ return false;
+}
+
namespace {
/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
///
@@ -52,37 +83,6 @@ template <typename OpTy>
struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
- /// A simple, conservative analysis to determine if the loop is shape
- /// conserving. I.e., the type of the arg-th yielded value is the same as the
- /// type of the corresponding basic block argument of the loop.
- /// Note: This function handles only simple cases. Expand as needed.
- static bool isShapePreserving(ForOp forOp, int64_t arg) {
- auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
- assert(arg < static_cast<int64_t>(yieldOp.results().size()) &&
- "arg is out of bounds");
- Value value = yieldOp.results()[arg];
- while (value) {
- if (value == forOp.getRegionIterArgs()[arg])
- return true;
- OpResult opResult = value.dyn_cast<OpResult>();
- if (!opResult)
- return false;
-
- using tensor::InsertSliceOp;
- value =
- llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
- .template Case<InsertSliceOp>(
- [&](InsertSliceOp op) { return op.dest(); })
- .template Case<ForOp>([&](ForOp forOp) {
- return isShapePreserving(forOp, opResult.getResultNumber())
- ? forOp.getIterOperands()[opResult.getResultNumber()]
- : Value();
- })
- .Default([&](auto op) { return Value(); });
- }
- return false;
- }
-
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
auto blockArg = dimOp.source().template dyn_cast<BlockArgument>();
@@ -102,6 +102,48 @@ struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
};
};
+/// Fold dim ops of loop results to dim ops of their respective init args. E.g.:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
+/// ...
+/// }
+/// %1 = tensor.dim %r, %c0 : tensor<?x?xf32>
+/// ```
+///
+/// is folded to:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
+/// ...
+/// }
+/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
+/// ```
+///
+/// Note: Dim ops are folded only if it can be proven that the runtime type of
+/// the iter arg does not change with loop iterations.
+template <typename OpTy>
+struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy dimOp,
+ PatternRewriter &rewriter) const override {
+ auto forOp = dimOp.source().template getDefiningOp<scf::ForOp>();
+ if (!forOp)
+ return failure();
+ auto opResult = dimOp.source().template cast<OpResult>();
+ unsigned resultNumber = opResult.getResultNumber();
+ if (!isShapePreserving(forOp, resultNumber))
+ return failure();
+ rewriter.updateRootInPlace(dimOp, [&](){
+ dimOp.sourceMutable().assign(forOp.getIterOperands()[resultNumber]);
+ });
+ return success();
+ }
+};
+
/// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
/// and scf.parallel loops with a known range.
template <typename OpTy, bool IsMin>
@@ -156,7 +198,9 @@ void mlir::scf::populateSCFForLoopCanonicalizationPatterns(
.insert<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>,
AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>,
DimOfIterArgFolder<tensor::DimOp>,
- DimOfIterArgFolder<memref::DimOp>>(ctx);
+ DimOfIterArgFolder<memref::DimOp>,
+ DimOfLoopResultFolder<tensor::DimOp>,
+ DimOfLoopResultFolder<memref::DimOp>>(ctx);
}
std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() {
diff --git a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
index 60004d53e240e..813cb6a848320 100644
--- a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
+++ b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
@@ -313,3 +313,38 @@ func @tensor_dim_of_iter_arg_no_canonicalize(%t : tensor<?x?xf32>,
}
return %1 : index
}
+
+// -----
+
+// CHECK-LABEL: func @tensor_dim_of_loop_result(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
+// CHECK: tensor.dim %[[t]]
+func @tensor_dim_of_loop_result(%t : tensor<?x?xf32>) -> index {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c10 = constant 10 : index
+ %0 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t)
+ -> (tensor<?x?xf32>) {
+ scf.yield %arg0 : tensor<?x?xf32>
+ }
+ %dim = tensor.dim %0, %c0 : tensor<?x?xf32>
+ return %dim : index
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_dim_of_loop_result_no_canonicalize(
+// CHECK: %[[loop:.*]]:2 = scf.for
+// CHECK: tensor.dim %[[loop]]#1
+func @tensor_dim_of_loop_result_no_canonicalize(%t : tensor<?x?xf32>,
+ %u : tensor<?x?xf32>) -> index {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c10 = constant 10 : index
+ %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %u)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ scf.yield %arg0, %u : tensor<?x?xf32>, tensor<?x?xf32>
+ }
+ %dim = tensor.dim %1, %c0 : tensor<?x?xf32>
+ return %dim : index
+}
More information about the Mlir-commits
mailing list