[Mlir-commits] [mlir] d518fc2 - [mlir][scf] Support simple symbolic expression when simplify loops

Peiming Liu llvmlistbot at llvm.org
Mon Sep 19 14:50:10 PDT 2022


Author: Peiming Liu
Date: 2022-09-19T21:50:01Z
New Revision: d518fc28b6dce9c2ab5db8575d401e7eaa203612

URL: https://github.com/llvm/llvm-project/commit/d518fc28b6dce9c2ab5db8575d401e7eaa203612
DIFF: https://github.com/llvm/llvm-project/commit/d518fc28b6dce9c2ab5db8575d401e7eaa203612.diff

LOG: [mlir][scf] Support simple symbolic expression when simplify loops

Reviewed By: aartbik, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D134204

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/IR/CMakeLists.txt
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index 9dad732618de3..0b0714c05a2fd 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRSCFDialect
   MLIRSCFOpsIncGen
 
   LINK_LIBS PUBLIC
+  MLIRAffineDialect
   MLIRArithmeticDialect
   MLIRBufferizationDialect
   MLIRControlFlowDialect

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 83a2381b94711..e9f89fa7d7483 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -715,6 +717,24 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
   }
 };
 
+/// Util function that tries to compute a constant 
diff  between u and l.
+/// Returns llvm::None when the 
diff erence between two AffineValueMap is
+/// dynamic.
+static Optional<int64_t> computeConstDiff(Value l, Value u) {
+  auto alb = l.getDefiningOp<AffineApplyOp>();
+  auto aub = u.getDefiningOp<AffineApplyOp>();
+  // ID map: (d0)->d0
+  auto id = AffineMap::getMultiDimIdentityMap(1, l.getContext());
+  auto lb = alb ? alb.getAffineValueMap() : AffineValueMap(id, l);
+  auto ub = aub ? aub.getAffineValueMap() : AffineValueMap(id, u);
+
+  AffineValueMap 
diff Map;
+  AffineValueMap::
diff erence(ub, lb, &
diff Map);
+  if (auto constDiff = 
diff Map.getResult(0).dyn_cast<AffineConstantExpr>())
+    return constDiff.getValue();
+  return llvm::None;
+}
+
 /// Rewriting pattern that erases loops that are known not to iterate, replaces
 /// single-iteration loops with their bodies, and removes empty loops that
 /// iterate at least once and only return values defined outside of the loop.
@@ -730,15 +750,13 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
       return success();
     }
 
-    auto lb = op.getLowerBound().getDefiningOp<arith::ConstantOp>();
-    auto ub = op.getUpperBound().getDefiningOp<arith::ConstantOp>();
-    if (!lb || !ub)
+    Optional<int64_t> 
diff  =
+        computeConstDiff(op.getLowerBound(), op.getUpperBound());
+    if (!
diff )
       return failure();
 
     // If the loop is known to have 0 iterations, remove it.
-    llvm::APInt lbValue = lb.getValue().cast<IntegerAttr>().getValue();
-    llvm::APInt ubValue = ub.getValue().cast<IntegerAttr>().getValue();
-    if (lbValue.sge(ubValue)) {
+    if (*
diff  <= 0) {
       rewriter.replaceOp(op, op.getIterOperands());
       return success();
     }
@@ -750,7 +768,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
     // If the loop is known to have 1 iteration, inline its body and remove the
     // loop.
     llvm::APInt stepValue = step.getValue().cast<IntegerAttr>().getValue();
-    if ((lbValue + stepValue).sge(ubValue)) {
+    if (stepValue.sge(*
diff )) {
       SmallVector<Value, 4> blockArgs;
       blockArgs.reserve(op.getNumIterOperands() + 1);
       blockArgs.push_back(op.getLowerBound());

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8e087fc0f38a4..b5616407d55c1 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -324,7 +324,7 @@ func.func @to_select_with_body(%cond: i1) -> index {
 // CHECK:           [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
 // CHECK:           scf.if {{.*}} {
 // CHECK:             "test.op"() : () -> ()
-// CHECK:           } 
+// CHECK:           }
 // CHECK:           return [[V0]] : index
 // -----
 
@@ -547,7 +547,7 @@ func.func @merge_yielding_nested_if_nv1(%arg0: i1, %arg1: i1) {
       scf.yield %0, %1 : i32, f32
     }
   }
-  return 
+  return
 }
 
 // CHECK-LABEL: @merge_yielding_nested_if_nv2
@@ -557,7 +557,7 @@ func.func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 {
 // CHECK: %[[PRE1:.*]] = "test.op1"() : () -> i32
 // CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
 // CHECK: %[[RES:.*]] = arith.select %[[ARG0]], %[[PRE0]], %[[PRE1]]
-// CHECK: scf.if %[[COND]] 
+// CHECK: scf.if %[[COND]]
 // CHECK:   "test.run"() : () -> ()
 // CHECK: }
 // CHECK: return %[[RES]]
@@ -719,6 +719,27 @@ func.func @replace_single_iteration_loop_non_unit_step() {
   return
 }
 
+
+// -----
+
+#map = affine_map<(d0) -> (d0 + 1)>
+// CHECK-LABEL: func @replace_single_iteration_const_
diff (
+//  CHECK-SAME: %[[A0:.*]]: index)
+func.func @replace_single_iteration_const_
diff (%arg0 : index) {
+  // CHECK-NEXT: %[[CST:.*]] = arith.constant 2
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %5 = affine.apply #map(%arg0)
+  // CHECK-NOT: scf.for
+  scf.for %arg2 = %arg0 to %5 step %c1 {
+    // CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[A0]], %[[CST]]
+    %7 = arith.muli %c2, %arg2 : index
+    // CHECK-NEXT: "test.consume"(%[[MUL]])
+    "test.consume"(%7) : (index) -> ()
+  }
+  return
+}
+
 // -----
 
 // CHECK-LABEL: @remove_empty_parallel_loop
@@ -986,7 +1007,7 @@ func.func @while_cond_true() -> i1 {
 // CHECK-NEXT:           %[[cmp:.+]] = "test.condition"() : () -> i1
 // CHECK-NEXT:           scf.condition(%[[cmp]]) %[[cmp]] : i1
 // CHECK-NEXT:         } do {
-// CHECK-NEXT:         ^bb0(%arg0: i1):  
+// CHECK-NEXT:         ^bb0(%arg0: i1):
 // CHECK-NEXT:           "test.use"(%[[true]]) : (i1) -> ()
 // CHECK-NEXT:           scf.yield
 // CHECK-NEXT:         }
@@ -1009,7 +1030,7 @@ func.func @while_unused_arg(%x : i32, %y : f64) -> i32 {
 // CHECK-NEXT:           %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1
 // CHECK-NEXT:           scf.condition(%[[cmp]]) %[[arg2]] : i32
 // CHECK-NEXT:         } do {
-// CHECK-NEXT:         ^bb0(%[[post:.+]]: i32):  
+// CHECK-NEXT:         ^bb0(%[[post:.+]]: i32):
 // CHECK-NEXT:           %[[next:.+]] = "test.use"(%[[post]]) : (i32) -> i32
 // CHECK-NEXT:           scf.yield %[[next]] : i32
 // CHECK-NEXT:         }
@@ -1105,7 +1126,7 @@ func.func @while_unused_result() -> i32 {
 // CHECK-NEXT:           %{{.*}} = "test.get_some_value"() : () -> i64
 // CHECK-NEXT:           scf.condition(%[[cmp]]) %[[val]] : i32
 // CHECK-NEXT:         } do {
-// CHECK-NEXT:         ^bb0(%[[arg:.*]]: i32):  
+// CHECK-NEXT:         ^bb0(%[[arg:.*]]: i32):
 // CHECK-NEXT:           "test.use"(%[[arg]]) : (i32) -> ()
 // CHECK-NEXT:           scf.yield
 // CHECK-NEXT:         }
@@ -1133,7 +1154,7 @@ func.func @while_cmp_lhs(%arg0 : i32) {
 // CHECK-NEXT:         %[[cmp:.+]] = arith.cmpi ne, %[[val]], %arg0 : i32
 // CHECK-NEXT:           scf.condition(%[[cmp]]) %[[val]] : i32
 // CHECK-NEXT:         } do {
-// CHECK-NEXT:         ^bb0(%arg1: i32):  
+// CHECK-NEXT:         ^bb0(%arg1: i32):
 // CHECK-NEXT:           "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> ()
 // CHECK-NEXT:           scf.yield
 // CHECK-NEXT:         }
@@ -1160,7 +1181,7 @@ func.func @while_cmp_rhs(%arg0 : i32) {
 // CHECK-NEXT:         %[[cmp:.+]] = arith.cmpi ne, %arg0, %[[val]] : i32
 // CHECK-NEXT:           scf.condition(%[[cmp]]) %[[val]] : i32
 // CHECK-NEXT:         } do {
-// CHECK-NEXT:         ^bb0(%arg1: i32):  
+// CHECK-NEXT:         ^bb0(%arg1: i32):
 // CHECK-NEXT:           "test.use"(%[[true]], %[[false]], %arg1) : (i1, i1, i32) -> ()
 // CHECK-NEXT:           scf.yield
 // CHECK-NEXT:         }

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index f7af617c4a4d3..e502a2842f2d4 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2678,6 +2678,7 @@ cc_library(
     ),
     includes = ["include"],
     deps = [
+        ":AffineDialect",
         ":ArithmeticDialect",
         ":ArithmeticUtils",
         ":BufferizationDialect",


        


More information about the Mlir-commits mailing list