[Mlir-commits] [mlir] Fix bug in `visitDivExpr` and `visitModExpr` (PR #145290)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 23 01:37:07 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Arnab Dutta (arnab-polymage)
<details>
<summary>Changes</summary>
Whenever the result of a div or mod affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression.
---
Full diff: https://github.com/llvm/llvm-project/pull/145290.diff
1 Files Affected:
- (modified) mlir/lib/IR/AffineExpr.cpp (+14-4)
``````````diff
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc81f9d19aca7..feedef46c66b8 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1177,10 +1177,10 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
if (flatExprs[numDims + numSymbols + it.index()] == 0)
continue;
AffineExpr expr = it.value();
- auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
- if (!binaryExpr)
- continue;
-
+ assert(isa<AffineBinaryOpExpr>(expr) &&
+ "local expression cannot be a dimension, symbol or a constant -- it "
+ "should be a binary op expression");
+ auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
AffineExpr lhs = binaryExpr.getLHS();
AffineExpr rhs = binaryExpr.getRHS();
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1348,6 +1348,11 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
AffineExpr modExpr = dividendExpr % divisorExpr;
+ if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
+ std::fill(lhs.begin(), lhs.end(), 0);
+ lhs[getConstantIndex()] = constModExpr.getValue();
+ return success();
+ }
return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
}
@@ -1482,6 +1487,11 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
+ if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
+ std::fill(lhs.begin(), lhs.end(), 0);
+ lhs[getConstantIndex()] = constDivExpr.getValue();
+ return success();
+ }
return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/145290
More information about the Mlir-commits
mailing list