[Mlir-commits] [mlir] 4fa6c27 - [mlir][scf] Allow runtime type of iter_args to change

Matthias Springer llvmlistbot at llvm.org
Thu Sep 2 18:07:57 PDT 2021


Author: Matthias Springer
Date: 2021-09-03T10:03:05+09:00
New Revision: 4fa6c2734c484ad7299257b317b75f9bc5482b7c

URL: https://github.com/llvm/llvm-project/commit/4fa6c2734c484ad7299257b317b75f9bc5482b7c
DIFF: https://github.com/llvm/llvm-project/commit/4fa6c2734c484ad7299257b317b75f9bc5482b7c.diff

LOG: [mlir][scf] Allow runtime type of iter_args to change

The limitation on iter_args introduced with D108806 is too restricting. Changes of the runtime type should be allowed.

Extends the dim op canonicalization with a simple analysis to determine when it is safe to canonicalize.

Differential Revision: https://reviews.llvm.org/D109125

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
    mlir/test/Dialect/SCF/for-loop-canonicalization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index d04266a70cddc..f63186b2c6ec3 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -147,17 +147,16 @@ def ForOp : SCF_Op<"for",
     passed as additional SSA operands to the "scf.for" following the 3 loop
     control SSA values mentioned above (lower bound, upper bound and step). The
     operation region has an argument for the induction variable, followed by
-    one argument for each loop-carried variable, representing he value of the
+    one argument for each loop-carried variable, representing the value of the
     variable at the current iteration.
 
     The region must terminate with a "scf.yield" that passes the current
-    values of loop-carried variables to the next iteration, or to the "scf.for"
-    result, if at the last iteration. The type (static or dynamic) of a
-    loop-carried variable may not change with iterations. E.g., it is illegal
-    to pass a tensor of larger size to the next iteration; even if the tensor's
-    dimensions are dynamic (i.e., same static type). Note, that when the
-    loop-carried variables are present, calling ForOp::build will not insert the
-    terminator implicitly. The caller must insert "scf.yield" in that case.
+    values of all loop-carried variables to the next iteration, or to the
+    "scf.for" result, if at the last iteration. The static type of a
+    loop-carried variable may not change with iterations; its runtime type is
+    allowed to change. Note, that when the loop-carried variables are present,
+    calling ForOp::build will not insert the terminator implicitly. The caller
+    must insert "scf.yield" in that case.
 
     "scf.for" results hold the final values after the last iteration.
     For example, to sum-reduce a memref:

diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 3e6fe03486881..3f2cc70bf7061 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 using namespace mlir::scf;
@@ -44,10 +45,44 @@ namespace {
 ///   ...
 /// }
 /// ```
+///
+/// Note: Dim ops are folded only if it can be proven that the runtime type of
+/// the iter arg does not change with loop iterations.
 template <typename OpTy>
 struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
 
+  /// A simple, conservative analysis to determine if the loop is shape
+  /// conserving. I.e., the type of the arg-th yielded value is the same as the
+  /// type of the corresponding basic block argument of the loop.
+  /// Note: This function handles only simple cases. Expand as needed.
+  static bool isShapePreserving(ForOp forOp, int64_t arg) {
+    auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
+    assert(arg < static_cast<int64_t>(yieldOp.results().size()) &&
+           "arg is out of bounds");
+    Value value = yieldOp.results()[arg];
+    while (value) {
+      if (value == forOp.getRegionIterArgs()[arg])
+        return true;
+      OpResult opResult = value.dyn_cast<OpResult>();
+      if (!opResult)
+        return false;
+
+      using tensor::InsertSliceOp;
+      value =
+          llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
+              .template Case<InsertSliceOp>(
+                  [&](InsertSliceOp op) { return op.dest(); })
+              .template Case<ForOp>([&](ForOp forOp) {
+                return isShapePreserving(forOp, opResult.getResultNumber())
+                           ? forOp.getIterOperands()[opResult.getResultNumber()]
+                           : Value();
+              })
+              .Default([&](auto op) { return Value(); });
+    }
+    return false;
+  }
+
   LogicalResult matchAndRewrite(OpTy dimOp,
                                 PatternRewriter &rewriter) const override {
     auto blockArg = dimOp.source().template dyn_cast<BlockArgument>();
@@ -56,6 +91,8 @@ struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
     auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
     if (!forOp)
       return failure();
+    if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1))
+      return failure();
 
     Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get();
     rewriter.updateRootInPlace(

diff  --git a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
index 9f80b31a2542e..60004d53e240e 100644
--- a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
+++ b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
@@ -242,3 +242,74 @@ func @tensor_dim_of_iter_arg(%t : tensor<?x?xf32>) -> index {
   }
   return %1 : index
 }
+
+// -----
+
+// CHECK-LABEL: func @tensor_dim_of_iter_arg_insertslice(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>,
+//       CHECK:   scf.for
+//       CHECK:     tensor.dim %[[t]]
+func @tensor_dim_of_iter_arg_insertslice(%t : tensor<?x?xf32>,
+                                         %t2 : tensor<?x?xf32>) -> index {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c10 = constant 10 : index
+  %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0)
+      -> (tensor<?x?xf32>, index) {
+    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+    %2 = tensor.insert_slice %t2 into %arg0[0, 0] [10, 10] [1, 1]
+        : tensor<?x?xf32> into tensor<?x?xf32>
+    %3 = tensor.insert_slice %t2 into %2[1, 1] [10, 10] [1, 1]
+        : tensor<?x?xf32> into tensor<?x?xf32>
+    scf.yield %3, %dim : tensor<?x?xf32>, index
+  }
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_dim_of_iter_arg_nested_for(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>,
+//       CHECK:   scf.for
+//       CHECK:     scf.for
+//       CHECK:       tensor.dim %[[t]]
+func @tensor_dim_of_iter_arg_nested_for(%t : tensor<?x?xf32>,
+                                        %t2 : tensor<?x?xf32>) -> index {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c10 = constant 10 : index
+  %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0)
+      -> (tensor<?x?xf32>, index) {
+    %2, %3 = scf.for %j = %c0 to %c10 step %c1 iter_args(%arg2 = %arg0, %arg3 = %arg1)
+        -> (tensor<?x?xf32>, index) {
+      %dim = tensor.dim %arg2, %c0 : tensor<?x?xf32>
+      %4 = tensor.insert_slice %t2 into %arg2[0, 0] [10, 10] [1, 1]
+          : tensor<?x?xf32> into tensor<?x?xf32>
+      scf.yield %4, %dim : tensor<?x?xf32>, index
+    }
+    scf.yield %2, %3 : tensor<?x?xf32>, index
+  }
+  return %1 : index
+}
+
+// -----
+
+// A test case that should not canonicalize because the loop is not shape
+// conserving.
+
+// CHECK-LABEL: func @tensor_dim_of_iter_arg_no_canonicalize(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>,
+//       CHECK:   scf.for {{.*}} iter_args(%[[arg0:.*]] = %[[t]]
+//       CHECK:     tensor.dim %[[arg0]]
+func @tensor_dim_of_iter_arg_no_canonicalize(%t : tensor<?x?xf32>,
+                                             %t2 : tensor<?x?xf32>) -> index {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c10 = constant 10 : index
+  %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0)
+      -> (tensor<?x?xf32>, index) {
+    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+    scf.yield %t2, %dim : tensor<?x?xf32>, index
+  }
+  return %1 : index
+}


        


More information about the Mlir-commits mailing list