[Mlir-commits] [mlir] [MLIR][Affine] Allow integer constant operands in affine symbol pure-op check (PR #188778)
Mehdi Amini
llvmlistbot at llvm.org
Wed Apr 1 01:58:58 PDT 2026
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/188778
>From 2f7f948ac584cbc3ac71007683b12568414941fb Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Mar 2026 07:28:11 -0700
Subject: [PATCH] [MLIR][Affine] Allow integer constant operands in affine
symbol pure-op check
The `isValidSymbol` function requires values to be index-typed. When
checking operands of a `Pure` op (e.g., `arith.index_cast`) it
recursively called `isValidSymbol` on each operand. This rejected
non-index typed values like `i64` constants even though they are
compile-time constants that make the resulting index value a valid
affine symbol.
For example, the following was previously rejected:
affine.for %i = 0 to 2 iter_args(%acc = %c0) -> (i64) {
%bound = arith.constant 2 : i64
%idx = arith.index_cast %bound : i64 to index
affine.for %j = 0 to %idx { ... } // error: operand cannot be used as a symbol
...
}
Introduce `isValidSymbolOrPureIntegerValue` as a helper that extends the
symbol validity check to accept non-index integer values that are
constants, top-level values, or results of pure ops with valid operands.
This helper is used when checking operands of pure ops in `isValidSymbol`.
Fixes #129901
Assisted-by: Claude Code
---
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 48 ++++++++++++++++++++++--
mlir/test/Dialect/Affine/ops.mlir | 40 ++++++++++++++++++++
2 files changed, 84 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 839d34b41cbd4..43ade54147c26 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -442,6 +442,45 @@ static bool isTopLevelValueOrAbove(Value value, Region *region) {
return false;
}
+/// Returns true if `value` can appear as an operand in a pure-op chain that
+/// ultimately produces a valid affine symbol. Index-typed values delegate to
+/// `isValidSymbol`. Any compile-time constant (integer, float, etc.) is
+/// accepted. Non-constant, non-index values are further restricted to integers
+/// that are either top-level or results of pure ops with valid operands.
+static bool isValidAffineSymbolOperand(Value value, Region *region) {
+ // Index-typed values use the standard symbol check.
+ if (value.getType().isIndex())
+ return affine::isValidSymbol(value, region);
+
+ // Only integer-typed non-index values are considered for the top-level and
+ // pure-op recursion cases. Float (and other) types must bottom out in
+ // constants below.
+ bool isIntegerTyped = value.getType().isSignlessInteger() ||
+ value.getType().isSignedInteger() ||
+ value.getType().isUnsignedInteger();
+
+ // Top-level integer values are valid.
+ if (isIntegerTyped && region && isTopLevelValueOrAbove(value, region))
+ return true;
+
+ auto *defOp = value.getDefiningOp();
+ if (!defOp)
+ return false;
+
+ // Any compile-time constant (integer, float, etc.) is valid.
+ if (defOp->hasTrait<OpTrait::ConstantLike>())
+ return true;
+
+ if (!isIntegerTyped)
+ return false;
+
+ // Pure op whose operands are all valid.
+ return isPure(defOp) &&
+ llvm::all_of(defOp->getOperands(), [&](Value operand) {
+ return isValidAffineSymbolOperand(operand, region);
+ });
+}
+
/// A value can be used as a symbol for `region` iff it meets one of the
/// following conditions:
/// *) It is a constant.
@@ -468,13 +507,14 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
return false;
// Constant operation is ok.
- Attribute operandCst;
- if (matchPattern(defOp, m_Constant(&operandCst)))
+ if (defOp->hasTrait<OpTrait::ConstantLike>())
return true;
- // `Pure` operation that whose operands are valid symbolic identifiers.
+ // `Pure` operation whose operands are valid symbolic identifiers. Non-index
+ // typed operands (e.g., integer operands to `arith.index_cast`) are also
+ // accepted if they are constants or top-level values.
if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
- return affine::isValidSymbol(operand, region);
+ return isValidAffineSymbolOperand(operand, region);
})) {
return true;
}
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index 1562f5b1693c0..d59393a5143e5 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -466,3 +466,43 @@ func.func @parallel_minnumf_reduce() {
return
}
+
+// -----
+
+// CHECK-LABEL: func @affine_for_index_cast_symbol
+// Check that integer constants cast to index via arith.index_cast are valid
+// affine symbols even when not at the top level of the affine scope.
+func.func @affine_for_index_cast_symbol(%x: i64) -> i64 {
+ %c0 = arith.constant 0 : i64
+ %result:1 = affine.for %i = 0 to 2 iter_args(%acc = %c0) -> (i64) {
+ // CHECK: affine.for
+ %sum = arith.addi %acc, %x : i64
+ %bound = arith.constant 2 : i64
+ %idx = arith.index_cast %bound : i64 to index
+ affine.for %j = 0 to %idx {
+ }
+ affine.yield %sum : i64
+ }
+ return %result : i64
+}
+
+// -----
+
+// CHECK-LABEL: func @affine_for_index_cast_fptosi_symbol
+// Check that a float constant converted to integer via fptosi and then cast to
+// index via arith.index_cast is a valid affine symbol.
+func.func @affine_for_index_cast_fptosi_symbol(%x: i64) -> i64 {
+ %c0 = arith.constant 0 : i64
+ %result:1 = affine.for %i = 0 to 2 iter_args(%acc = %c0) -> (i64) {
+ // CHECK: affine.for
+ %sum = arith.addi %acc, %x : i64
+ %fbound = arith.constant 2.0 : f64
+ %ibound = arith.fptosi %fbound : f64 to i64
+ %idx = arith.index_cast %ibound : i64 to index
+ %ubound = arith.addi %idx, %i : index
+ affine.for %j = 0 to %idx {
+ }
+ affine.yield %sum : i64
+ }
+ return %result : i64
+}
More information about the Mlir-commits
mailing list