[Mlir-commits] [mlir] [MLIR][Affine] Allow integer constant operands in affine symbol pure-op check (PR #188778)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 26 09:00:16 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

The `isValidSymbol` function requires values to be index-typed. When checking operands of a `Pure` op (e.g., `arith.index_cast`) it recursively called `isValidSymbol` on each operand. This rejected non-index typed values like `i64` constants even though they are compile-time constants that make the resulting index value a valid affine symbol.

For example, the following was previously rejected:

  affine.for %i = 0 to 2 iter_args(%acc = %c0) -> (i64) {
    %bound = arith.constant 2 : i64
    %idx = arith.index_cast %bound : i64 to index
    affine.for %j = 0 to %idx { ... }  // error: operand cannot be used as a symbol
    ...
  }

Introduce `isValidSymbolOrPureIntegerValue` as a helper that extends the symbol validity check to accept non-index integer values that are constants, top-level values, or results of pure ops with valid operands. This helper is used when checking operands of pure ops in `isValidSymbol`.

Fixes #<!-- -->129901

Assisted-by: Claude Code

---
Full diff: https://github.com/llvm/llvm-project/pull/188778.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+41-2) 
- (modified) mlir/test/Dialect/Affine/ops.mlir (+19) 


``````````diff
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 839d34b41cbd4..c3c5b5ce586ae 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -442,6 +442,43 @@ static bool isTopLevelValueOrAbove(Value value, Region *region) {
   return false;
 }
 
+/// Returns true if `value` is a valid operand to a `Pure` operation that
+/// produces a valid affine symbol. Unlike `isValidSymbol`, this check does not
+/// require the value to be index-typed; it also accepts integer-typed values
+/// that are constants, top-level values, or results of pure ops with valid
+/// operands. This allows operations like `arith.index_cast` applied to an
+/// integer constant to be recognized as valid affine symbols.
+static bool isValidSymbolOrPureIntegerValue(Value value, Region *region) {
+  // Index-typed values use the standard symbol check.
+  if (value.getType().isIndex())
+    return affine::isValidSymbol(value, region);
+
+  // Only integer-typed values are considered further.
+  if (!value.getType().isSignlessInteger() &&
+      !value.getType().isSignedInteger() &&
+      !value.getType().isUnsignedInteger())
+    return false;
+
+  // Top-level integer values are valid.
+  if (region && isTopLevelValueOrAbove(value, region))
+    return true;
+
+  auto *defOp = value.getDefiningOp();
+  if (!defOp)
+    return false;
+
+  // Constant integer is valid.
+  Attribute cst;
+  if (matchPattern(defOp, m_Constant(&cst)))
+    return true;
+
+  // Pure op whose operands are all valid.
+  return isPure(defOp) &&
+         llvm::all_of(defOp->getOperands(), [&](Value operand) {
+           return isValidSymbolOrPureIntegerValue(operand, region);
+         });
+}
+
 /// A value can be used as a symbol for `region` iff it meets one of the
 /// following conditions:
 /// *) It is a constant.
@@ -472,9 +509,11 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
   if (matchPattern(defOp, m_Constant(&operandCst)))
     return true;
 
-  // `Pure` operation that whose operands are valid symbolic identifiers.
+  // `Pure` operation whose operands are valid symbolic identifiers. Non-index
+  // typed operands (e.g., integer operands to `arith.index_cast`) are also
+  // accepted if they are constants or top-level values.
   if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
-        return affine::isValidSymbol(operand, region);
+        return isValidSymbolOrPureIntegerValue(operand, region);
       })) {
     return true;
   }
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index 1562f5b1693c0..0dc41b588dff9 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -466,3 +466,22 @@ func.func @parallel_minnumf_reduce() {
   return
 }
 
+
+// -----
+
+// CHECK-LABEL: func @affine_for_index_cast_symbol
+// Check that integer constants cast to index via arith.index_cast are valid
+// affine symbols even when not at the top level of the affine scope.
+func.func @affine_for_index_cast_symbol(%x: i64) -> i64 {
+  %c0 = arith.constant 0 : i64
+  %result:1 = affine.for %i = 0 to 2 iter_args(%acc = %c0) -> (i64) {
+    // CHECK: affine.for
+    %sum = arith.addi %acc, %x : i64
+    %bound = arith.constant 2 : i64
+    %idx = arith.index_cast %bound : i64 to index
+    affine.for %j = 0 to %idx {
+    }
+    affine.yield %sum : i64
+  }
+  return %result : i64
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/188778


More information about the Mlir-commits mailing list