[Mlir-commits] [mlir] Fix bug in `visitDivExpr` and `visitModExpr` (PR #145290)

Arnab Dutta llvmlistbot at llvm.org
Tue Jun 24 05:01:28 PDT 2025


https://github.com/arnab-polymage updated https://github.com/llvm/llvm-project/pull/145290

>From 75ab97da101a5e42b02ea834b9706fe3ecb5f06a Mon Sep 17 00:00:00 2001
From: Arnab Dutta <arnab at polymagelabs.com>
Date: Tue, 24 Jun 2025 17:30:52 +0530
Subject: [PATCH] Fix bug in `visitDivExpr` and `visitModExpr`

Whenever the result of a div or mod affine expression is a
constant expression, place the value in the constant index of the
flattened expression instead of adding it as a local expression.
---
 mlir/lib/IR/AffineExpr.cpp | 37 +++++++++++++++++++++++++++++++++----
 1 file changed, 33 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc81f9d19aca7..3dcf3ac91d760 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1177,10 +1177,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
     if (flatExprs[numDims + numSymbols + it.index()] == 0)
       continue;
     AffineExpr expr = it.value();
-    auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
-    if (!binaryExpr)
-      continue;
-
+    // Local expression cannot be a dimension, symbol or a constant -- it
+    // should be a binary op expression.
+    auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
     AffineExpr lhs = binaryExpr.getLHS();
     AffineExpr rhs = binaryExpr.getRHS();
     if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1348,6 +1347,21 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
     AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                                        localExprs, context);
     AffineExpr modExpr = dividendExpr % divisorExpr;
+    if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getConstantIndex()] = constModExpr.getValue();
+      return success();
+    }
+    if (auto dimModExpr = dyn_cast<AffineDimExpr>(modExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getDimStartIndex() + dimModExpr.getPosition()] = 1;
+      return success();
+    }
+    if (auto symbolModExpr = dyn_cast<AffineSymbolExpr>(modExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getSymbolStartIndex() + symbolModExpr.getPosition()] = 1;
+      return success();
+    }
     return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
   }
 
@@ -1482,6 +1496,21 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                              localExprs, context);
     AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
+    if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getConstantIndex()] = constDivExpr.getValue();
+      return success();
+    }
+    if (auto dimDivExpr = dyn_cast<AffineDimExpr>(divExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getDimStartIndex() + dimDivExpr.getPosition()] = 1;
+      return success();
+    }
+    if (auto symbolDivExpr = dyn_cast<AffineSymbolExpr>(divExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getSymbolStartIndex() + symbolDivExpr.getPosition()] = 1;
+      return success();
+    }
     return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
   }
 



More information about the Mlir-commits mailing list