[Mlir-commits] [mlir] [mlir][Intrange] Fix materializing ShapedType constant values (PR #158359)
Jeff Niu
llvmlistbot at llvm.org
Fri Sep 12 13:27:05 PDT 2025
https://github.com/Mogball created https://github.com/llvm/llvm-project/pull/158359
When materializing integer ranges of splat tensors or vector as constants, they should use DenseElementsAttr of the shaped type, not IntegerAttrs of the element types, since this can violate the invariants of tensor/vector ops.
>From 2217053dc5591ae0a794d82db3146f08ad6e5aa8 Mon Sep 17 00:00:00 2001
From: Jeff Niu <jeffniu at openai.com>
Date: Fri, 12 Sep 2025 13:24:15 -0700
Subject: [PATCH] [mlir][Intrange] Fix materializing ShapedType constant values
When materializing integer ranges of splat tensors or vector as
constants, they should use DenseElementsAttr of the shaped type, not
IntegerAttrs of the element types, since this can violate the invariants
of tensor/vector ops.
---
.../Analysis/DataFlow/IntegerRangeAnalysis.cpp | 15 ++++++++++++---
.../Arith/Transforms/IntRangeOptimizations.cpp | 2 ++
mlir/test/Dialect/Arith/int-range-opts.mlir | 16 ++++++++++++++++
3 files changed, 30 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index e79f6a8aec1cf..70b56ca77b2da 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -26,6 +26,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Support/DebugStringHelper.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
@@ -76,9 +77,17 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
else
dialect = value.getParentBlock()->getParentOp()->getDialect();
- Type type = getElementTypeOrSelf(value);
- solver->propagateIfChanged(
- cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
+ Attribute cstAttr;
+ if (isa<IntegerType, IndexType>(value.getType())) {
+ cstAttr = IntegerAttr::get(value.getType(), *constant);
+ } else if (auto shapedTy = dyn_cast<ShapedType>(value.getType())) {
+ cstAttr = SplatElementsAttr::get(shapedTy, *constant);
+ } else {
+ llvm::report_fatal_error(
+ Twine("FIXME: Don't know how to create a constant for this type: ") +
+ mlir::debugString(value.getType()));
+ }
+ solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect)));
}
LogicalResult IntegerRangeAnalysis::visitOperation(
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 777ff0ecaa314..2017905587b26 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -8,6 +8,7 @@
#include <utility>
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
@@ -485,6 +486,7 @@ struct IntRangeOptimizationsPass final
MLIRContext *ctx = op->getContext();
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
+ solver.load<SparseConstantPropagation>();
solver.load<IntegerRangeAnalysis>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index ea5969a100258..e6e48d30cece5 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -132,3 +132,19 @@ func.func @wraps() -> i8 {
%mod = arith.remsi %val, %c64 : i8
return %mod : i8
}
+
+// -----
+
+// CHECK-LABEL: @analysis_crash
+func.func @analysis_crash(%arg0: i32, %arg1: tensor<128xi1>) -> tensor<128xi64> {
+ %c0_i32 = arith.constant 0 : i32
+ %cst = arith.constant dense<-1> : tensor<128xi32>
+ %splat = tensor.splat %arg0 : tensor<128xi32>
+ %0 = scf.for %arg2 = %c0_i32 to %arg0 step %arg0 iter_args(%arg3 = %splat) -> (tensor<128xi32>) : i32 {
+ scf.yield %arg3 : tensor<128xi32>
+ }
+ %1 = arith.select %arg1, %0#0, %cst : tensor<128xi1>, tensor<128xi32>
+ // Make sure the analysis doesn't crash when materializing the range as a tensor constant.
+ %2 = arith.extsi %1 : tensor<128xi32> to tensor<128xi64>
+ return %2 : tensor<128xi64>
+}
More information about the Mlir-commits
mailing list