[Mlir-commits] [mlir] d518fc2 - [mlir][scf] Support simple symbolic expression when simplify loops
Peiming Liu
llvmlistbot at llvm.org
Mon Sep 19 14:50:10 PDT 2022
Author: Peiming Liu
Date: 2022-09-19T21:50:01Z
New Revision: d518fc28b6dce9c2ab5db8575d401e7eaa203612
URL: https://github.com/llvm/llvm-project/commit/d518fc28b6dce9c2ab5db8575d401e7eaa203612
DIFF: https://github.com/llvm/llvm-project/commit/d518fc28b6dce9c2ab5db8575d401e7eaa203612.diff
LOG: [mlir][scf] Support simple symbolic expression when simplify loops
Reviewed By: aartbik, ThomasRaoux
Differential Revision: https://reviews.llvm.org/D134204
Added:
Modified:
mlir/lib/Dialect/SCF/IR/CMakeLists.txt
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index 9dad732618de3..0b0714c05a2fd 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRSCFOpsIncGen
LINK_LIBS PUBLIC
+ MLIRAffineDialect
MLIRArithmeticDialect
MLIRBufferizationDialect
MLIRControlFlowDialect
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 83a2381b94711..e9f89fa7d7483 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -715,6 +717,24 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
}
};
+/// Util function that tries to compute a constant
diff between u and l.
+/// Returns llvm::None when the
diff erence between two AffineValueMap is
+/// dynamic.
+static Optional<int64_t> computeConstDiff(Value l, Value u) {
+ auto alb = l.getDefiningOp<AffineApplyOp>();
+ auto aub = u.getDefiningOp<AffineApplyOp>();
+ // ID map: (d0)->d0
+ auto id = AffineMap::getMultiDimIdentityMap(1, l.getContext());
+ auto lb = alb ? alb.getAffineValueMap() : AffineValueMap(id, l);
+ auto ub = aub ? aub.getAffineValueMap() : AffineValueMap(id, u);
+
+ AffineValueMap
diff Map;
+ AffineValueMap::
diff erence(ub, lb, &
diff Map);
+ if (auto constDiff =
diff Map.getResult(0).dyn_cast<AffineConstantExpr>())
+ return constDiff.getValue();
+ return llvm::None;
+}
+
/// Rewriting pattern that erases loops that are known not to iterate, replaces
/// single-iteration loops with their bodies, and removes empty loops that
/// iterate at least once and only return values defined outside of the loop.
@@ -730,15 +750,13 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
return success();
}
- auto lb = op.getLowerBound().getDefiningOp<arith::ConstantOp>();
- auto ub = op.getUpperBound().getDefiningOp<arith::ConstantOp>();
- if (!lb || !ub)
+ Optional<int64_t>
diff =
+ computeConstDiff(op.getLowerBound(), op.getUpperBound());
+ if (!
diff )
return failure();
// If the loop is known to have 0 iterations, remove it.
- llvm::APInt lbValue = lb.getValue().cast<IntegerAttr>().getValue();
- llvm::APInt ubValue = ub.getValue().cast<IntegerAttr>().getValue();
- if (lbValue.sge(ubValue)) {
+ if (*
diff <= 0) {
rewriter.replaceOp(op, op.getIterOperands());
return success();
}
@@ -750,7 +768,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
// If the loop is known to have 1 iteration, inline its body and remove the
// loop.
llvm::APInt stepValue = step.getValue().cast<IntegerAttr>().getValue();
- if ((lbValue + stepValue).sge(ubValue)) {
+ if (stepValue.sge(*
diff )) {
SmallVector<Value, 4> blockArgs;
blockArgs.reserve(op.getNumIterOperands() + 1);
blockArgs.push_back(op.getLowerBound());
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8e087fc0f38a4..b5616407d55c1 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -324,7 +324,7 @@ func.func @to_select_with_body(%cond: i1) -> index {
// CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
// CHECK: scf.if {{.*}} {
// CHECK: "test.op"() : () -> ()
-// CHECK: }
+// CHECK: }
// CHECK: return [[V0]] : index
// -----
@@ -547,7 +547,7 @@ func.func @merge_yielding_nested_if_nv1(%arg0: i1, %arg1: i1) {
scf.yield %0, %1 : i32, f32
}
}
- return
+ return
}
// CHECK-LABEL: @merge_yielding_nested_if_nv2
@@ -557,7 +557,7 @@ func.func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 {
// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> i32
// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
// CHECK: %[[RES:.*]] = arith.select %[[ARG0]], %[[PRE0]], %[[PRE1]]
-// CHECK: scf.if %[[COND]]
+// CHECK: scf.if %[[COND]]
// CHECK: "test.run"() : () -> ()
// CHECK: }
// CHECK: return %[[RES]]
@@ -719,6 +719,27 @@ func.func @replace_single_iteration_loop_non_unit_step() {
return
}
+
+// -----
+
+#map = affine_map<(d0) -> (d0 + 1)>
+// CHECK-LABEL: func @replace_single_iteration_const_
diff (
+// CHECK-SAME: %[[A0:.*]]: index)
+func.func @replace_single_iteration_const_
diff (%arg0 : index) {
+ // CHECK-NEXT: %[[CST:.*]] = arith.constant 2
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %5 = affine.apply #map(%arg0)
+ // CHECK-NOT: scf.for
+ scf.for %arg2 = %arg0 to %5 step %c1 {
+ // CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[A0]], %[[CST]]
+ %7 = arith.muli %c2, %arg2 : index
+ // CHECK-NEXT: "test.consume"(%[[MUL]])
+ "test.consume"(%7) : (index) -> ()
+ }
+ return
+}
+
// -----
// CHECK-LABEL: @remove_empty_parallel_loop
@@ -986,7 +1007,7 @@ func.func @while_cond_true() -> i1 {
// CHECK-NEXT: %[[cmp:.+]] = "test.condition"() : () -> i1
// CHECK-NEXT: scf.condition(%[[cmp]]) %[[cmp]] : i1
// CHECK-NEXT: } do {
-// CHECK-NEXT: ^bb0(%arg0: i1):
+// CHECK-NEXT: ^bb0(%arg0: i1):
// CHECK-NEXT: "test.use"(%[[true]]) : (i1) -> ()
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
@@ -1009,7 +1030,7 @@ func.func @while_unused_arg(%x : i32, %y : f64) -> i32 {
// CHECK-NEXT: %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1
// CHECK-NEXT: scf.condition(%[[cmp]]) %[[arg2]] : i32
// CHECK-NEXT: } do {
-// CHECK-NEXT: ^bb0(%[[post:.+]]: i32):
+// CHECK-NEXT: ^bb0(%[[post:.+]]: i32):
// CHECK-NEXT: %[[next:.+]] = "test.use"(%[[post]]) : (i32) -> i32
// CHECK-NEXT: scf.yield %[[next]] : i32
// CHECK-NEXT: }
@@ -1105,7 +1126,7 @@ func.func @while_unused_result() -> i32 {
// CHECK-NEXT: %{{.*}} = "test.get_some_value"() : () -> i64
// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32
// CHECK-NEXT: } do {
-// CHECK-NEXT: ^bb0(%[[arg:.*]]: i32):
+// CHECK-NEXT: ^bb0(%[[arg:.*]]: i32):
// CHECK-NEXT: "test.use"(%[[arg]]) : (i32) -> ()
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
@@ -1133,7 +1154,7 @@ func.func @while_cmp_lhs(%arg0 : i32) {
// CHECK-NEXT: %[[cmp:.+]] = arith.cmpi ne, %[[val]], %arg0 : i32
// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32
// CHECK-NEXT: } do {
-// CHECK-NEXT: ^bb0(%arg1: i32):
+// CHECK-NEXT: ^bb0(%arg1: i32):
// CHECK-NEXT: "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> ()
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
@@ -1160,7 +1181,7 @@ func.func @while_cmp_rhs(%arg0 : i32) {
// CHECK-NEXT: %[[cmp:.+]] = arith.cmpi ne, %arg0, %[[val]] : i32
// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32
// CHECK-NEXT: } do {
-// CHECK-NEXT: ^bb0(%arg1: i32):
+// CHECK-NEXT: ^bb0(%arg1: i32):
// CHECK-NEXT: "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> ()
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index f7af617c4a4d3..e502a2842f2d4 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2678,6 +2678,7 @@ cc_library(
),
includes = ["include"],
deps = [
+ ":AffineDialect",
":ArithmeticDialect",
":ArithmeticUtils",
":BufferizationDialect",
More information about the Mlir-commits
mailing list