[Mlir-commits] [mlir] [mlir][affine] Set overflow flags when lowering [de]linearize_index (PR #139612)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 12 12:53:39 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-affine
Author: Krzysztof Drewniak (krzysz00)
<details>
<summary>Changes</summary>
By analogy to some changess to the affine.apply lowering which put `nsw`s on various multiplications, add appropritae overflow flags to the multiplications and additions that're emitted when lowering affine.delinearize_index and affine.linearize_index to arith ops.
---
Full diff: https://github.com/llvm/llvm-project/pull/139612.diff
2 Files Affected:
- (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+8)
- (modified) mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp (+28-10)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dfe2a57df587..19fbcf64b2360 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1113,6 +1113,10 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
Due to the constraints of affine maps, all the basis elements must
be strictly positive. A dynamic basis element being 0 or negative causes
undefined behavior.
+
+ As with other affine operations, lowerings of delinearize_index may assume
+ that the underlying computations do not overflow the index type in a signed sense
+ - that is, the product of all basis elements is positive as an `index` as well.
}];
let arguments = (ins Index:$linear_index,
@@ -1195,9 +1199,13 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
If the `disjoint` property is present, this is an optimization hint that,
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
except that `%idx_0` may be negative to make the index as a whole negative.
+ In addition, `disjoint` is an assertion that all bases elements are non-negative.
Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`.
+ As with other affine ops, undefined behavior occurs if the linearization
+ computation overflows in the signed sense.
+
Example:
```mlir
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 7e335ea929c4f..35205a6ca2eee 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -35,10 +35,13 @@ using namespace mlir::affine;
///
/// If excess dynamic values are provided, the values at the beginning
/// will be ignored. This allows for dropping the outer bound without
-/// needing to manipulate the dynamic value array.
+/// needing to manipulate the dynamic value array. `knownPositive`
+/// indicases that the values being used to compute the strides are known
+/// to be non-negative.
static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
ValueRange dynamicBasis,
- ArrayRef<int64_t> staticBasis) {
+ ArrayRef<int64_t> staticBasis,
+ bool knownNonNegative) {
if (staticBasis.empty())
return {};
@@ -47,11 +50,18 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
size_t dynamicIndex = dynamicBasis.size();
Value dynamicPart = nullptr;
int64_t staticPart = 1;
+ // The products of the strides can't have overflow by definition of
+ // affine.*_index.
+ arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
+ if (knownNonNegative)
+ ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
for (int64_t elem : llvm::reverse(staticBasis)) {
if (ShapedType::isDynamic(elem)) {
+ // Note: basis elements and their products are, definitionally,
+ // non-negative, so `nuw` is justified.
if (dynamicPart)
dynamicPart = rewriter.create<arith::MulIOp>(
- loc, dynamicPart, dynamicBasis[dynamicIndex - 1]);
+ loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags);
else
dynamicPart = dynamicBasis[dynamicIndex - 1];
--dynamicIndex;
@@ -65,7 +75,8 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
Value stride =
rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
if (dynamicPart)
- stride = rewriter.create<arith::MulIOp>(loc, dynamicPart, stride);
+ stride =
+ rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags);
result.push_back(stride);
}
}
@@ -96,7 +107,8 @@ struct LowerDelinearizeIndexOps
SmallVector<Value> results;
results.reserve(numResults);
SmallVector<Value> strides =
- computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
+ computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+ /*knownNonNegative=*/true);
Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
@@ -108,7 +120,11 @@ struct LowerDelinearizeIndexOps
Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
Value remainderNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, remainder, zero);
- Value corrected = rewriter.create<arith::AddIOp>(loc, remainder, stride);
+ // If the correction is relevant, this term is <= stride, which is known
+ // to be positive in `index`. Otherwise, while 2 * stride might overflow,
+ // this branch won't be taken, so the risk of `poison` is fine.
+ Value corrected = rewriter.create<arith::AddIOp>(
+ loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
corrected, remainder);
return mod;
@@ -155,7 +171,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
staticBasis = staticBasis.drop_front();
SmallVector<Value> strides =
- computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
+ computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+ /*knownNonNegative=*/op.getDisjoint());
SmallVector<std::pair<Value, int64_t>> scaledValues;
scaledValues.reserve(numIndexes);
@@ -164,8 +181,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
// our hands on an `OpOperand&` for the loop invariant counting function.
for (auto [stride, idxOp] :
llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
- Value scaledIdx =
- rewriter.create<arith::MulIOp>(loc, idxOp.get(), stride);
+ Value scaledIdx = rewriter.create<arith::MulIOp>(
+ loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
scaledValues.emplace_back(scaledIdx, numHoistableLoops);
}
@@ -182,7 +199,8 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
for (auto [scaledValue, numHoistableLoops] :
llvm::drop_begin(scaledValues)) {
std::ignore = numHoistableLoops;
- result = rewriter.create<arith::AddIOp>(loc, result, scaledValue);
+ result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
+ arith::IntegerOverflowFlags::nsw);
}
rewriter.replaceOp(op, result);
return success();
``````````
</details>
https://github.com/llvm/llvm-project/pull/139612
More information about the Mlir-commits
mailing list