[Mlir-commits] [mlir] [MLIR][Affine] Allow integer constant operands in affine symbol pure-op check (PR #188778)
Mehdi Amini
llvmlistbot at llvm.org
Thu Mar 26 08:59:36 PDT 2026
https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/188778
The `isValidSymbol` function requires values to be index-typed. When checking operands of a `Pure` op (e.g., `arith.index_cast`) it recursively called `isValidSymbol` on each operand. This rejected non-index typed values like `i64` constants even though they are compile-time constants that make the resulting index value a valid affine symbol.
For example, the following was previously rejected:
affine.for %i = 0 to 2 iter_args(%acc = %c0) -> (i64) {
%bound = arith.constant 2 : i64
%idx = arith.index_cast %bound : i64 to index
affine.for %j = 0 to %idx { ... } // error: operand cannot be used as a symbol
...
}
Introduce `isValidSymbolOrPureIntegerValue` as a helper that extends the symbol validity check to accept non-index integer values that are constants, top-level values, or results of pure ops with valid operands. This helper is used when checking operands of pure ops in `isValidSymbol`.
Fixes #129901
Assisted-by: Claude Code
>From 08fc8783184e6a9749209d2cf88931f6ba49a30b Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Mar 2026 07:28:11 -0700
Subject: [PATCH] [MLIR][Affine] Allow integer constant operands in affine
symbol pure-op check
The `isValidSymbol` function requires values to be index-typed. When
checking operands of a `Pure` op (e.g., `arith.index_cast`) it
recursively called `isValidSymbol` on each operand. This rejected
non-index typed values like `i64` constants even though they are
compile-time constants that make the resulting index value a valid
affine symbol.
For example, the following was previously rejected:
affine.for %i = 0 to 2 iter_args(%acc = %c0) -> (i64) {
%bound = arith.constant 2 : i64
%idx = arith.index_cast %bound : i64 to index
affine.for %j = 0 to %idx { ... } // error: operand cannot be used as a symbol
...
}
Introduce `isValidSymbolOrPureIntegerValue` as a helper that extends the
symbol validity check to accept non-index integer values that are
constants, top-level values, or results of pure ops with valid operands.
This helper is used when checking operands of pure ops in `isValidSymbol`.
Fixes #129901
Assisted-by: Claude Code
---
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 43 ++++++++++++++++++++++--
mlir/test/Dialect/Affine/ops.mlir | 19 +++++++++++
2 files changed, 60 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 839d34b41cbd4..c3c5b5ce586ae 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -442,6 +442,43 @@ static bool isTopLevelValueOrAbove(Value value, Region *region) {
return false;
}
+/// Returns true if `value` is a valid operand to a `Pure` operation that
+/// produces a valid affine symbol. Unlike `isValidSymbol`, this check does not
+/// require the value to be index-typed; it also accepts integer-typed values
+/// that are constants, top-level values, or results of pure ops with valid
+/// operands. This allows operations like `arith.index_cast` applied to an
+/// integer constant to be recognized as valid affine symbols.
+static bool isValidSymbolOrPureIntegerValue(Value value, Region *region) {
+ // Index-typed values use the standard symbol check.
+ if (value.getType().isIndex())
+ return affine::isValidSymbol(value, region);
+
+ // Only integer-typed values are considered further.
+ if (!value.getType().isSignlessInteger() &&
+ !value.getType().isSignedInteger() &&
+ !value.getType().isUnsignedInteger())
+ return false;
+
+ // Top-level integer values are valid.
+ if (region && isTopLevelValueOrAbove(value, region))
+ return true;
+
+ auto *defOp = value.getDefiningOp();
+ if (!defOp)
+ return false;
+
+ // Constant integer is valid.
+ Attribute cst;
+ if (matchPattern(defOp, m_Constant(&cst)))
+ return true;
+
+ // Pure op whose operands are all valid.
+ return isPure(defOp) &&
+ llvm::all_of(defOp->getOperands(), [&](Value operand) {
+ return isValidSymbolOrPureIntegerValue(operand, region);
+ });
+}
+
/// A value can be used as a symbol for `region` iff it meets one of the
/// following conditions:
/// *) It is a constant.
@@ -472,9 +509,11 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
if (matchPattern(defOp, m_Constant(&operandCst)))
return true;
- // `Pure` operation that whose operands are valid symbolic identifiers.
+ // `Pure` operation whose operands are valid symbolic identifiers. Non-index
+ // typed operands (e.g., integer operands to `arith.index_cast`) are also
+ // accepted if they are constants or top-level values.
if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
- return affine::isValidSymbol(operand, region);
+ return isValidSymbolOrPureIntegerValue(operand, region);
})) {
return true;
}
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index 1562f5b1693c0..0dc41b588dff9 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -466,3 +466,22 @@ func.func @parallel_minnumf_reduce() {
return
}
+
+// -----
+
+// CHECK-LABEL: func @affine_for_index_cast_symbol
+// Check that integer constants cast to index via arith.index_cast are valid
+// affine symbols even when not at the top level of the affine scope.
+func.func @affine_for_index_cast_symbol(%x: i64) -> i64 {
+ %c0 = arith.constant 0 : i64
+ %result:1 = affine.for %i = 0 to 2 iter_args(%acc = %c0) -> (i64) {
+ // CHECK: affine.for
+ %sum = arith.addi %acc, %x : i64
+ %bound = arith.constant 2 : i64
+ %idx = arith.index_cast %bound : i64 to index
+ affine.for %j = 0 to %idx {
+ }
+ affine.yield %sum : i64
+ }
+ return %result : i64
+}
More information about the Mlir-commits
mailing list