[Mlir-commits] [mlir] [MLIR][Affine] Allow integer constant operands in affine symbol pure-op check (PR #188778)

Mehdi Amini llvmlistbot at llvm.org
Wed Apr 1 01:58:58 PDT 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/188778

>From 2f7f948ac584cbc3ac71007683b12568414941fb Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Mar 2026 07:28:11 -0700
Subject: [PATCH] [MLIR][Affine] Allow integer constant operands in affine
 symbol pure-op check

The `isValidSymbol` function requires values to be index-typed. When
checking operands of a `Pure` op (e.g., `arith.index_cast`) it
recursively called `isValidSymbol` on each operand. This rejected
non-index typed values like `i64` constants even though they are
compile-time constants that make the resulting index value a valid
affine symbol.

For example, the following was previously rejected:

  affine.for %i = 0 to 2 iter_args(%acc = %c0) -> (i64) {
    %bound = arith.constant 2 : i64
    %idx = arith.index_cast %bound : i64 to index
    affine.for %j = 0 to %idx { ... }  // error: operand cannot be used as a symbol
    ...
  }

Introduce `isValidSymbolOrPureIntegerValue` as a helper that extends the
symbol validity check to accept non-index integer values that are
constants, top-level values, or results of pure ops with valid operands.
This helper is used when checking operands of pure ops in `isValidSymbol`.

Fixes #129901

Assisted-by: Claude Code
---
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 48 ++++++++++++++++++++++--
 mlir/test/Dialect/Affine/ops.mlir        | 40 ++++++++++++++++++++
 2 files changed, 84 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 839d34b41cbd4..43ade54147c26 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -442,6 +442,45 @@ static bool isTopLevelValueOrAbove(Value value, Region *region) {
   return false;
 }
 
+/// Returns true if `value` can appear as an operand in a pure-op chain that
+/// ultimately produces a valid affine symbol. Index-typed values delegate to
+/// `isValidSymbol`. Any compile-time constant (integer, float, etc.) is
+/// accepted. Non-constant, non-index values are further restricted to integers
+/// that are either top-level or results of pure ops with valid operands.
+static bool isValidAffineSymbolOperand(Value value, Region *region) {
+  // Index-typed values use the standard symbol check.
+  if (value.getType().isIndex())
+    return affine::isValidSymbol(value, region);
+
+  // Only integer-typed non-index values are considered for the top-level and
+  // pure-op recursion cases. Float (and other) types must bottom out in
+  // constants below.
+  bool isIntegerTyped = value.getType().isSignlessInteger() ||
+                        value.getType().isSignedInteger() ||
+                        value.getType().isUnsignedInteger();
+
+  // Top-level integer values are valid.
+  if (isIntegerTyped && region && isTopLevelValueOrAbove(value, region))
+    return true;
+
+  auto *defOp = value.getDefiningOp();
+  if (!defOp)
+    return false;
+
+  // Any compile-time constant (integer, float, etc.) is valid.
+  if (defOp->hasTrait<OpTrait::ConstantLike>())
+    return true;
+
+  if (!isIntegerTyped)
+    return false;
+
+  // Pure op whose operands are all valid.
+  return isPure(defOp) &&
+         llvm::all_of(defOp->getOperands(), [&](Value operand) {
+           return isValidAffineSymbolOperand(operand, region);
+         });
+}
+
 /// A value can be used as a symbol for `region` iff it meets one of the
 /// following conditions:
 /// *) It is a constant.
@@ -468,13 +507,14 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) {
     return false;
 
   // Constant operation is ok.
-  Attribute operandCst;
-  if (matchPattern(defOp, m_Constant(&operandCst)))
+  if (defOp->hasTrait<OpTrait::ConstantLike>())
     return true;
 
-  // `Pure` operation that whose operands are valid symbolic identifiers.
+  // `Pure` operation whose operands are valid symbolic identifiers. Non-index
+  // typed operands (e.g., integer operands to `arith.index_cast`) are also
+  // accepted if they are constants or top-level values.
   if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
-        return affine::isValidSymbol(operand, region);
+        return isValidAffineSymbolOperand(operand, region);
       })) {
     return true;
   }
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index 1562f5b1693c0..d59393a5143e5 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -466,3 +466,43 @@ func.func @parallel_minnumf_reduce() {
   return
 }
 
+
+// -----
+
+// CHECK-LABEL: func @affine_for_index_cast_symbol
+// Check that integer constants cast to index via arith.index_cast are valid
+// affine symbols even when not at the top level of the affine scope.
+func.func @affine_for_index_cast_symbol(%x: i64) -> i64 {
+  %c0 = arith.constant 0 : i64
+  %result:1 = affine.for %i = 0 to 2 iter_args(%acc = %c0) -> (i64) {
+    // CHECK: affine.for
+    %sum = arith.addi %acc, %x : i64
+    %bound = arith.constant 2 : i64
+    %idx = arith.index_cast %bound : i64 to index
+    affine.for %j = 0 to %idx {
+    }
+    affine.yield %sum : i64
+  }
+  return %result : i64
+}
+
+// -----
+
+// CHECK-LABEL: func @affine_for_index_cast_fptosi_symbol
+// Check that a float constant converted to integer via fptosi and then cast to
+// index via arith.index_cast is a valid affine symbol.
+func.func @affine_for_index_cast_fptosi_symbol(%x: i64) -> i64 {
+  %c0 = arith.constant 0 : i64
+  %result:1 = affine.for %i = 0 to 2 iter_args(%acc = %c0) -> (i64) {
+    // CHECK: affine.for
+    %sum = arith.addi %acc, %x : i64
+    %fbound = arith.constant 2.0 : f64
+    %ibound = arith.fptosi %fbound : f64 to i64
+    %idx = arith.index_cast %ibound : i64 to index
+    %ubound = arith.addi %idx, %i : index
+    affine.for %j = 0 to %idx {
+    }
+    affine.yield %sum : i64
+  }
+  return %result : i64
+}



More information about the Mlir-commits mailing list