[Mlir-commits] [mlir] Fix bug in `visitDivExpr`, `visitModExpr` and `visitMulExpr` (PR #145290)

Arnab Dutta llvmlistbot at llvm.org
Thu Jun 26 22:53:26 PDT 2025


https://github.com/arnab-polymage updated https://github.com/llvm/llvm-project/pull/145290

>From 5fc2915568fb31985dbed95587fd1042303b7a39 Mon Sep 17 00:00:00 2001
From: Arnab Dutta <arnab at polymagelabs.com>
Date: Fri, 27 Jun 2025 11:22:59 +0530
Subject: [PATCH] Fix bug in `visitDivExpr`, `visitMulExpr`  and `visitModExpr`

Whenever the result of a div or mod affine expression is a
constant expression, place the value in the constant index of the
flattened expression instead of adding it as a local expression.
---
 mlir/lib/IR/AffineExpr.cpp                    | 57 +++++++++++++++++--
 .../Dialect/Affine/simplify-structures.mlir   | 31 +++++++---
 2 files changed, 76 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc81f9d19aca7..18fd2b72c6f9b 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Support/TypeID.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
 #include <numeric>
 #include <optional>
 
@@ -1177,10 +1178,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
     if (flatExprs[numDims + numSymbols + it.index()] == 0)
       continue;
     AffineExpr expr = it.value();
-    auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
-    if (!binaryExpr)
-      continue;
-
+    // A Local expression cannot be a dimension, symbol or a constant -- it
+    // should be a binary op expression.
+    auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
     AffineExpr lhs = binaryExpr.getLHS();
     AffineExpr rhs = binaryExpr.getRHS();
     if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1295,7 +1295,23 @@ LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
                                              localExprs, context);
     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                              localExprs, context);
-    return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size());
+    AffineExpr mulExpr = a * b;
+    if (auto constMulExpr = dyn_cast<AffineConstantExpr>(mulExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getConstantIndex()] = constMulExpr.getValue();
+      return success();
+    }
+    if (auto dimMulExpr = dyn_cast<AffineDimExpr>(mulExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getDimStartIndex() + dimMulExpr.getPosition()] = 1;
+      return success();
+    }
+    if (auto symbolMulExpr = dyn_cast<AffineSymbolExpr>(mulExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getSymbolStartIndex() + symbolMulExpr.getPosition()] = 1;
+      return success();
+    }
+    return addLocalVariableSemiAffine(mulLhs, rhs, mulExpr, lhs, lhs.size());
   }
 
   // Get the RHS constant.
@@ -1348,6 +1364,21 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
     AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                                        localExprs, context);
     AffineExpr modExpr = dividendExpr % divisorExpr;
+    if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getConstantIndex()] = constModExpr.getValue();
+      return success();
+    }
+    if (auto dimModExpr = dyn_cast<AffineDimExpr>(modExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getDimStartIndex() + dimModExpr.getPosition()] = 1;
+      return success();
+    }
+    if (auto symbolModExpr = dyn_cast<AffineSymbolExpr>(modExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getSymbolStartIndex() + symbolModExpr.getPosition()] = 1;
+      return success();
+    }
     return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
   }
 
@@ -1482,6 +1513,21 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                              localExprs, context);
     AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
+    if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getConstantIndex()] = constDivExpr.getValue();
+      return success();
+    }
+    if (auto dimDivExpr = dyn_cast<AffineDimExpr>(divExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getDimStartIndex() + dimDivExpr.getPosition()] = 1;
+      return success();
+    }
+    if (auto symbolDivExpr = dyn_cast<AffineSymbolExpr>(divExpr)) {
+      std::fill(lhs.begin(), lhs.end(), 0);
+      lhs[getSymbolStartIndex() + symbolDivExpr.getPosition()] = 1;
+      return success();
+    }
     return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
   }
 
@@ -1574,6 +1620,7 @@ int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
                                     unsigned numSymbols) {
   // Simplify semi-affine expressions separately.
+  expr.dump();
   if (!expr.isPureAffine())
     expr = simplifySemiAffine(expr, numDims, numSymbols);
 
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 6f2737a982752..4ebf054a4dff3 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -595,16 +595,33 @@ func.func @semiaffine_modulo_dim(%arg0: index, %arg1: index, %arg2: index) -> in
 
 // -----
 
-// CHECK-LABEL: func @semiaffine_simplification_floordiv_and_ceildiv_const
-func.func @semiaffine_simplification_floordiv_and_ceildiv_const(%arg0: tensor<?xf32>) -> (index, index) {
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (13 mod s0)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK-LABEL: semiaffine_simplification_local_expr_folded_into_non_binary_expr
+func.func @semiaffine_simplification_local_expr_folded_into_non_binary_expr(%arg0: memref<?x?xf32>) -> (index, index, index, index) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
   %c13 = arith.constant 13 : index
-  %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+  // CHECK: %[[DIM:.*]] = memref.dim
+  %dim = memref.dim %arg0, %c0 : memref<?x?xf32>
+  // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+  // CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
   %a = affine.apply affine_map<()[s0, s1, s2] -> (s0 floordiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
   %b = affine.apply affine_map<()[s0, s1, s2] -> (s0 ceildiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
-  // CHECK:      %[[C6:.*]] = arith.constant 6 : index
-  // CHECK-NEXT: %[[C7:.*]] = arith.constant 7 : index
-  // CHECK-NEXT: return %[[C6]], %[[C7]]
-  return %a, %b : index, index
+  // CHECK: %0 = affine.apply #[[$MAP]]()[%[[DIM]]]
+  %c = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 mod (s1 + (-s1 + s3) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1, %dim]
+  %alloc = memref.alloc() : memref<1xindex>
+  affine.for %iv = 0 to 1 {
+    %d = affine.apply affine_map<(d0)[s1, s2] -> ((d0 - s1 + s1 * s2) * (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>(%iv)[%dim, %c1]
+    affine.store %d, %alloc[0] : memref<1xindex>
+  }
+  // CHECK:      affine.for %[[IV:.*]] = 0 to 1 {
+  // CHECK-NEXT:   %[[VAL:.*]] = affine.apply #[[$MAP1]](%[[IV]])
+  // CHECK-NEXT:   affine.store %[[VAL]], %{{.*}}[0] : memref<1xindex>
+  // CHECK-NEXT: }
+  // CHECK: %[[VAL1:.*]] = affine.load %{{.*}}[0]
+  %d = affine.load %alloc[0] : memref<1xindex>
+
+  return %a, %b, %c, %d : index, index, index, index
 }



More information about the Mlir-commits mailing list