[Mlir-commits] [mlir] f8b2794 - [mlir][scf]: Add value bound between scf for loop yield and result (#123200)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 18 22:52:50 PST 2025


Author: Aviad Cohen
Date: 2025-01-19T08:52:46+02:00
New Revision: f8b27949a8c4fa8d8e15f9858e2ed38d7267f7dd

URL: https://github.com/llvm/llvm-project/commit/f8b27949a8c4fa8d8e15f9858e2ed38d7267f7dd
DIFF: https://github.com/llvm/llvm-project/commit/f8b27949a8c4fa8d8e15f9858e2ed38d7267f7dd.diff

LOG: [mlir][scf]: Add value bound between scf for loop yield and result (#123200)

We can prove that:
%result == %init_arg + trip_count * (%yielded_value - %iter_arg). Where
trip_count is (ub - lb) / step.

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index fbd236b648cb8a..8a27bf186d1c2a 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -70,6 +70,22 @@ struct ForOpInterface
         cstr.bound(value) == cstr.getExpr(initArg);
       }
     }
+
+    if (dim.has_value() || isa<BlockArgument>(value))
+      return;
+
+    // `value` is result of `forOp`, we can prove that:
+    // %result == %init_arg + trip_count * (%yielded_value - %iter_arg).
+    // Where trip_count is (ub - lb) / step.
+    AffineExpr lbExpr = cstr.getExpr(forOp.getLowerBound());
+    AffineExpr ubExpr = cstr.getExpr(forOp.getUpperBound());
+    AffineExpr stepExpr = cstr.getExpr(forOp.getStep());
+    AffineExpr tripCountExpr =
+        AffineExpr(ubExpr - lbExpr).ceilDiv(stepExpr); // (ub - lb) / step
+    AffineExpr oneIterAdvanceExpr =
+        cstr.getExpr(yieldedValue) - cstr.getExpr(iterArg);
+    cstr.bound(value) ==
+        cstr.getExpr(initArg) + AffineExpr(tripCountExpr * oneIterAdvanceExpr);
   }
 
   void populateBoundsForIndexValue(Operation *op, Value value,

diff  --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
index 6e0c16a9a2b33f..b48f38f592dc92 100644
--- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
@@ -267,3 +267,74 @@ func.func @compare_scf_for(%a: index, %b: index, %c: index) {
   }
   return
 }
+
+// -----
+
+func.func @scf_for_result_infer() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %0 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg = %c0) -> index {
+    %2 = "test.some_use"() : () -> (i1)
+    %3 = scf.if %2 -> (index) {
+        %5 = arith.addi %arg, %c1 : index
+        scf.yield %5 : index
+    } else {
+        scf.yield %arg : index
+    }
+    scf.yield %3 : index
+  }
+  // expected-remark @below{{true}}
+  "test.compare"(%0, %c10) {cmp = "LE"} : (index, index) -> ()
+  return
+}
+
+// -----
+
+func.func @scf_for_result_infer_dynamic_init(%i : index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %0 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg = %i) -> index {
+    %2 = "test.some_use"() : () -> (i1)
+    %3 = scf.if %2 -> (index) {
+        %5 = arith.addi %arg, %c1 : index
+        scf.yield %5 : index
+    } else {
+        scf.yield %arg : index
+    }
+    scf.yield %3 : index
+  }
+  %6 = arith.addi %i, %c10 : index
+  // expected-remark @below{{true}}
+  "test.compare"(%0, %6) {cmp = "LE"} : (index, index) -> ()
+  return
+}
+
+// -----
+
+func.func @scf_for_result_infer_dynamic_init_big_step(%i : index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+  %c5 = arith.constant 5 : index
+  %c10 = arith.constant 10 : index
+  %0 = scf.for %iv = %c0 to %c10 step %c2 iter_args(%arg = %i) -> index {
+    %2 = "test.some_use"() : () -> (i1)
+    %3 = scf.if %2 -> (index) {
+        %5 = arith.addi %arg, %c1 : index
+        scf.yield %5 : index
+    } else {
+        scf.yield %arg : index
+    }
+    scf.yield %3 : index
+  }
+  %6 = arith.addi %i, %c5 : index
+  %7 = arith.addi %i, %c4 : index
+  // expected-remark @below{{true}}
+  "test.compare"(%0, %6) {cmp = "LE"} : (index, index) -> ()
+  // expected-error @below{{unknown}}
+  "test.compare"(%0, %7) {cmp = "LE"} : (index, index) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list