[Mlir-commits] [mlir] c7d569b - [mlir][scf] Fold dim(scf.for) to dim(iter_arg)

Matthias Springer llvmlistbot at llvm.org
Wed Sep 8 21:52:48 PDT 2021


Author: Matthias Springer
Date: 2021-09-09T13:47:13+09:00
New Revision: c7d569b8f73d5f1ff03a65fb2b25d966d98c5a5f

URL: https://github.com/llvm/llvm-project/commit/c7d569b8f73d5f1ff03a65fb2b25d966d98c5a5f
DIFF: https://github.com/llvm/llvm-project/commit/c7d569b8f73d5f1ff03a65fb2b25d966d98c5a5f.diff

LOG: [mlir][scf] Fold dim(scf.for) to dim(iter_arg)

Fold dim ops of scf.for results to dim ops of the respective iter args if the loop is shape preserving.

Differential Revision: https://reviews.llvm.org/D109430

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
    mlir/test/Dialect/SCF/for-loop-canonicalization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 3f2cc70bf7061..a65ffc2d1b2c9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -25,6 +25,37 @@
 using namespace mlir;
 using namespace mlir::scf;
 
+/// A simple, conservative analysis to determine if the loop is shape
+/// conserving. I.e., the type of the arg-th yielded value is the same as the
+/// type of the corresponding basic block argument of the loop.
+/// Note: This function handles only simple cases. Expand as needed.
+static bool isShapePreserving(ForOp forOp, int64_t arg) {
+  auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
+  assert(arg < static_cast<int64_t>(yieldOp.results().size()) &&
+         "arg is out of bounds");
+  Value value = yieldOp.results()[arg];
+  while (value) {
+    if (value == forOp.getRegionIterArgs()[arg])
+      return true;
+    OpResult opResult = value.dyn_cast<OpResult>();
+    if (!opResult)
+      return false;
+
+    using tensor::InsertSliceOp;
+    value =
+        llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
+            .template Case<InsertSliceOp>(
+                [&](InsertSliceOp op) { return op.dest(); })
+            .template Case<ForOp>([&](ForOp forOp) {
+              return isShapePreserving(forOp, opResult.getResultNumber())
+                         ? forOp.getIterOperands()[opResult.getResultNumber()]
+                         : Value();
+            })
+            .Default([&](auto op) { return Value(); });
+  }
+  return false;
+}
+
 namespace {
 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
 ///
@@ -52,37 +83,6 @@ template <typename OpTy>
 struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
 
-  /// A simple, conservative analysis to determine if the loop is shape
-  /// conserving. I.e., the type of the arg-th yielded value is the same as the
-  /// type of the corresponding basic block argument of the loop.
-  /// Note: This function handles only simple cases. Expand as needed.
-  static bool isShapePreserving(ForOp forOp, int64_t arg) {
-    auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
-    assert(arg < static_cast<int64_t>(yieldOp.results().size()) &&
-           "arg is out of bounds");
-    Value value = yieldOp.results()[arg];
-    while (value) {
-      if (value == forOp.getRegionIterArgs()[arg])
-        return true;
-      OpResult opResult = value.dyn_cast<OpResult>();
-      if (!opResult)
-        return false;
-
-      using tensor::InsertSliceOp;
-      value =
-          llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
-              .template Case<InsertSliceOp>(
-                  [&](InsertSliceOp op) { return op.dest(); })
-              .template Case<ForOp>([&](ForOp forOp) {
-                return isShapePreserving(forOp, opResult.getResultNumber())
-                           ? forOp.getIterOperands()[opResult.getResultNumber()]
-                           : Value();
-              })
-              .Default([&](auto op) { return Value(); });
-    }
-    return false;
-  }
-
   LogicalResult matchAndRewrite(OpTy dimOp,
                                 PatternRewriter &rewriter) const override {
     auto blockArg = dimOp.source().template dyn_cast<BlockArgument>();
@@ -102,6 +102,48 @@ struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
   };
 };
 
+/// Fold dim ops of loop results to dim ops of their respective init args. E.g.:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
+///   ...
+/// }
+/// %1 = tensor.dim %r, %c0 : tensor<?x?xf32>
+/// ```
+///
+/// is folded to:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
+///   ...
+/// }
+/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
+/// ```
+///
+/// Note: Dim ops are folded only if it can be proven that the runtime type of
+/// the iter arg does not change with loop iterations.
+template <typename OpTy>
+struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy dimOp,
+                                PatternRewriter &rewriter) const override {
+    auto forOp = dimOp.source().template getDefiningOp<scf::ForOp>();
+    if (!forOp)
+      return failure();
+    auto opResult = dimOp.source().template cast<OpResult>();
+    unsigned resultNumber = opResult.getResultNumber();
+    if (!isShapePreserving(forOp, resultNumber))
+      return failure();
+    rewriter.updateRootInPlace(dimOp, [&](){
+      dimOp.sourceMutable().assign(forOp.getIterOperands()[resultNumber]);
+    });
+    return success();
+  }
+};
+
 /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
 /// and scf.parallel loops with a known range.
 template <typename OpTy, bool IsMin>
@@ -156,7 +198,9 @@ void mlir::scf::populateSCFForLoopCanonicalizationPatterns(
       .insert<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>,
               AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>,
               DimOfIterArgFolder<tensor::DimOp>,
-              DimOfIterArgFolder<memref::DimOp>>(ctx);
+              DimOfIterArgFolder<memref::DimOp>,
+              DimOfLoopResultFolder<tensor::DimOp>,
+              DimOfLoopResultFolder<memref::DimOp>>(ctx);
 }
 
 std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() {

diff  --git a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
index 60004d53e240e..813cb6a848320 100644
--- a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
+++ b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
@@ -313,3 +313,38 @@ func @tensor_dim_of_iter_arg_no_canonicalize(%t : tensor<?x?xf32>,
   }
   return %1 : index
 }
+
+// -----
+
+// CHECK-LABEL: func @tensor_dim_of_loop_result(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
+//       CHECK:   tensor.dim %[[t]]
+func @tensor_dim_of_loop_result(%t : tensor<?x?xf32>) -> index {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c10 = constant 10 : index
+  %0 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t)
+      -> (tensor<?x?xf32>) {
+    scf.yield %arg0 : tensor<?x?xf32>
+  }
+  %dim = tensor.dim %0, %c0 : tensor<?x?xf32>
+  return %dim : index
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_dim_of_loop_result_no_canonicalize(
+//       CHECK:   %[[loop:.*]]:2 = scf.for
+//       CHECK:   tensor.dim %[[loop]]#1
+func @tensor_dim_of_loop_result_no_canonicalize(%t : tensor<?x?xf32>,
+                                                %u : tensor<?x?xf32>) -> index {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c10 = constant 10 : index
+  %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %u)
+      -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+    scf.yield %arg0, %u : tensor<?x?xf32>, tensor<?x?xf32>
+  }
+  %dim = tensor.dim %1, %c0 : tensor<?x?xf32>
+  return %dim : index
+}


        


More information about the Mlir-commits mailing list