[Mlir-commits] [mlir] 4fa6c27 - [mlir][scf] Allow runtime type of iter_args to change
Matthias Springer
llvmlistbot at llvm.org
Thu Sep 2 18:07:57 PDT 2021
Author: Matthias Springer
Date: 2021-09-03T10:03:05+09:00
New Revision: 4fa6c2734c484ad7299257b317b75f9bc5482b7c
URL: https://github.com/llvm/llvm-project/commit/4fa6c2734c484ad7299257b317b75f9bc5482b7c
DIFF: https://github.com/llvm/llvm-project/commit/4fa6c2734c484ad7299257b317b75f9bc5482b7c.diff
LOG: [mlir][scf] Allow runtime type of iter_args to change
The limitation on iter_args introduced with D108806 is too restricting. Changes of the runtime type should be allowed.
Extends the dim op canonicalization with a simple analysis to determine when it is safe to canonicalize.
Differential Revision: https://reviews.llvm.org/D109125
Added:
Modified:
mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index d04266a70cddc..f63186b2c6ec3 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -147,17 +147,16 @@ def ForOp : SCF_Op<"for",
passed as additional SSA operands to the "scf.for" following the 3 loop
control SSA values mentioned above (lower bound, upper bound and step). The
operation region has an argument for the induction variable, followed by
- one argument for each loop-carried variable, representing he value of the
+ one argument for each loop-carried variable, representing the value of the
variable at the current iteration.
The region must terminate with a "scf.yield" that passes the current
- values of loop-carried variables to the next iteration, or to the "scf.for"
- result, if at the last iteration. The type (static or dynamic) of a
- loop-carried variable may not change with iterations. E.g., it is illegal
- to pass a tensor of larger size to the next iteration; even if the tensor's
- dimensions are dynamic (i.e., same static type). Note, that when the
- loop-carried variables are present, calling ForOp::build will not insert the
- terminator implicitly. The caller must insert "scf.yield" in that case.
+ values of all loop-carried variables to the next iteration, or to the
+ "scf.for" result, if at the last iteration. The static type of a
+ loop-carried variable may not change with iterations; its runtime type is
+ allowed to change. Note, that when the loop-carried variables are present,
+ calling ForOp::build will not insert the terminator implicitly. The caller
+ must insert "scf.yield" in that case.
"scf.for" results hold the final values after the last iteration.
For example, to sum-reduce a memref:
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 3e6fe03486881..3f2cc70bf7061 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::scf;
@@ -44,10 +45,44 @@ namespace {
/// ...
/// }
/// ```
+///
+/// Note: Dim ops are folded only if it can be proven that the runtime type of
+/// the iter arg does not change with loop iterations.
template <typename OpTy>
struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
+ /// A simple, conservative analysis to determine if the loop is shape
+ /// conserving. I.e., the type of the arg-th yielded value is the same as the
+ /// type of the corresponding basic block argument of the loop.
+ /// Note: This function handles only simple cases. Expand as needed.
+ static bool isShapePreserving(ForOp forOp, int64_t arg) {
+ auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
+ assert(arg < static_cast<int64_t>(yieldOp.results().size()) &&
+ "arg is out of bounds");
+ Value value = yieldOp.results()[arg];
+ while (value) {
+ if (value == forOp.getRegionIterArgs()[arg])
+ return true;
+ OpResult opResult = value.dyn_cast<OpResult>();
+ if (!opResult)
+ return false;
+
+ using tensor::InsertSliceOp;
+ value =
+ llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
+ .template Case<InsertSliceOp>(
+ [&](InsertSliceOp op) { return op.dest(); })
+ .template Case<ForOp>([&](ForOp forOp) {
+ return isShapePreserving(forOp, opResult.getResultNumber())
+ ? forOp.getIterOperands()[opResult.getResultNumber()]
+ : Value();
+ })
+ .Default([&](auto op) { return Value(); });
+ }
+ return false;
+ }
+
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
auto blockArg = dimOp.source().template dyn_cast<BlockArgument>();
@@ -56,6 +91,8 @@ struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
if (!forOp)
return failure();
+ if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1))
+ return failure();
Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get();
rewriter.updateRootInPlace(
diff --git a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
index 9f80b31a2542e..60004d53e240e 100644
--- a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
+++ b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
@@ -242,3 +242,74 @@ func @tensor_dim_of_iter_arg(%t : tensor<?x?xf32>) -> index {
}
return %1 : index
}
+
+// -----
+
+// CHECK-LABEL: func @tensor_dim_of_iter_arg_insertslice(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>,
+// CHECK: scf.for
+// CHECK: tensor.dim %[[t]]
+func @tensor_dim_of_iter_arg_insertslice(%t : tensor<?x?xf32>,
+ %t2 : tensor<?x?xf32>) -> index {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c10 = constant 10 : index
+ %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0)
+ -> (tensor<?x?xf32>, index) {
+ %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %2 = tensor.insert_slice %t2 into %arg0[0, 0] [10, 10] [1, 1]
+ : tensor<?x?xf32> into tensor<?x?xf32>
+ %3 = tensor.insert_slice %t2 into %2[1, 1] [10, 10] [1, 1]
+ : tensor<?x?xf32> into tensor<?x?xf32>
+ scf.yield %3, %dim : tensor<?x?xf32>, index
+ }
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_dim_of_iter_arg_nested_for(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>,
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: tensor.dim %[[t]]
+func @tensor_dim_of_iter_arg_nested_for(%t : tensor<?x?xf32>,
+ %t2 : tensor<?x?xf32>) -> index {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c10 = constant 10 : index
+ %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0)
+ -> (tensor<?x?xf32>, index) {
+ %2, %3 = scf.for %j = %c0 to %c10 step %c1 iter_args(%arg2 = %arg0, %arg3 = %arg1)
+ -> (tensor<?x?xf32>, index) {
+ %dim = tensor.dim %arg2, %c0 : tensor<?x?xf32>
+ %4 = tensor.insert_slice %t2 into %arg2[0, 0] [10, 10] [1, 1]
+ : tensor<?x?xf32> into tensor<?x?xf32>
+ scf.yield %4, %dim : tensor<?x?xf32>, index
+ }
+ scf.yield %2, %3 : tensor<?x?xf32>, index
+ }
+ return %1 : index
+}
+
+// -----
+
+// A test case that should not canonicalize because the loop is not shape
+// conserving.
+
+// CHECK-LABEL: func @tensor_dim_of_iter_arg_no_canonicalize(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>,
+// CHECK: scf.for {{.*}} iter_args(%[[arg0:.*]] = %[[t]]
+// CHECK: tensor.dim %[[arg0]]
+func @tensor_dim_of_iter_arg_no_canonicalize(%t : tensor<?x?xf32>,
+ %t2 : tensor<?x?xf32>) -> index {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c10 = constant 10 : index
+ %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0)
+ -> (tensor<?x?xf32>, index) {
+ %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ scf.yield %t2, %dim : tensor<?x?xf32>, index
+ }
+ return %1 : index
+}
More information about the Mlir-commits
mailing list