[Mlir-commits] [mlir] Fix bug in `visitDivExpr`, `visitModExpr` and `visitMulExpr` (PR #145290)
Arnab Dutta
llvmlistbot at llvm.org
Fri Jun 27 01:17:23 PDT 2025
https://github.com/arnab-polymage updated https://github.com/llvm/llvm-project/pull/145290
>From c2493e69eabdac76ae1dd8850ab469284f0dde5f Mon Sep 17 00:00:00 2001
From: Arnab Dutta <arnab at polymagelabs.com>
Date: Fri, 27 Jun 2025 13:46:55 +0530
Subject: [PATCH] Fix bug in `visitDivExpr`, `visitMulExpr` and `visitModExpr`
Whenever the result of a div or mod affine expression is a
constant expression, place the value in the constant index of the
flattened expression instead of adding it as a local expression.
---
mlir/include/mlir/IR/AffineExprVisitor.h | 8 +++++
mlir/lib/IR/AffineExpr.cpp | 36 +++++++++++++++----
.../Dialect/Affine/simplify-structures.mlir | 31 ++++++++++++----
3 files changed, 61 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index 1826f5fd8ad35..51c66483c90bf 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -418,6 +418,14 @@ class SimpleAffineExprFlattener
AffineExpr localExpr);
private:
+ /// Flatten `expr` and it to `result`. If `expr` is dimension, symbol or
+ /// constant, we add it to appropriate index in `result`. Otherwise we add it
+ /// in local variable section.
+ LogicalResult addExprToFlattenedList(ArrayRef<int64_t> lhs,
+ ArrayRef<int64_t> rhs,
+ SmallVectorImpl<int64_t> &result,
+ AffineExpr expr);
+
/// Adds `localExpr`, which may be mod, ceildiv, floordiv or mod expression
/// representing the affine expression corresponding to the quantifier
/// introduced as the local variable corresponding to `localExpr`. If the
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc81f9d19aca7..b1a1396db2a0e 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -19,6 +19,7 @@
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
#include <numeric>
#include <optional>
@@ -1177,10 +1178,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
if (flatExprs[numDims + numSymbols + it.index()] == 0)
continue;
AffineExpr expr = it.value();
- auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
- if (!binaryExpr)
- continue;
-
+ // A Local expression cannot be a dimension, symbol or a constant -- it
+ // should be a binary op expression.
+ auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
AffineExpr lhs = binaryExpr.getLHS();
AffineExpr rhs = binaryExpr.getRHS();
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1274,6 +1274,27 @@ SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
operandExprStack.reserve(8);
}
+LogicalResult SimpleAffineExprFlattener::addExprToFlattenedList(
+ ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
+ SmallVectorImpl<int64_t> &result, AffineExpr expr) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
+ std::fill(result.begin(), result.end(), 0);
+ result[getConstantIndex()] = constExpr.getValue();
+ return success();
+ }
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+ std::fill(result.begin(), result.end(), 0);
+ result[getDimStartIndex() + dimExpr.getPosition()] = 1;
+ return success();
+ }
+ if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
+ std::fill(result.begin(), result.end(), 0);
+ result[getSymbolStartIndex() + symExpr.getPosition()] = 1;
+ return success();
+ }
+ return addLocalVariableSemiAffine(lhs, rhs, expr, result, result.size());
+}
+
// In pure affine t = expr * c, we multiply each coefficient of lhs with c.
//
// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
@@ -1295,7 +1316,8 @@ LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
localExprs, context);
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
- return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size());
+ AffineExpr mulExpr = a * b;
+ return addExprToFlattenedList(mulLhs, rhs, lhs, mulExpr);
}
// Get the RHS constant.
@@ -1348,7 +1370,7 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
AffineExpr modExpr = dividendExpr % divisorExpr;
- return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
+ return addExprToFlattenedList(modLhs, rhs, lhs, modExpr);
}
int64_t rhsConst = rhs[getConstantIndex()];
@@ -1482,7 +1504,7 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
- return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
+ return addExprToFlattenedList(divLhs, rhs, lhs, divExpr);
}
// This is a pure affine expr; the RHS is a positive constant.
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 6f2737a982752..190589ad86976 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -595,16 +595,33 @@ func.func @semiaffine_modulo_dim(%arg0: index, %arg1: index, %arg2: index) -> in
// -----
-// CHECK-LABEL: func @semiaffine_simplification_floordiv_and_ceildiv_const
-func.func @semiaffine_simplification_floordiv_and_ceildiv_const(%arg0: tensor<?xf32>) -> (index, index) {
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (13 mod s0)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK-LABEL: semiaffine_simplification_local_expr_folded_into_non_binary_expr
+func.func @semiaffine_simplification_local_expr_folded_into_non_binary_expr(%arg0: memref<?x?xf32>) -> (index, index, index, index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
%c13 = arith.constant 13 : index
- %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+ // CHECK: %[[DIM:.*]] = memref.dim
+ %dim = memref.dim %arg0, %c0 : memref<?x?xf32>
+ // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+ // CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
%a = affine.apply affine_map<()[s0, s1, s2] -> (s0 floordiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
%b = affine.apply affine_map<()[s0, s1, s2] -> (s0 ceildiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
- // CHECK: %[[C6:.*]] = arith.constant 6 : index
- // CHECK-NEXT: %[[C7:.*]] = arith.constant 7 : index
- // CHECK-NEXT: return %[[C6]], %[[C7]]
- return %a, %b : index, index
+ // CHECK: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[DIM]]]
+ %c = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 mod (s1 + (-s1 + s3) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1, %dim]
+ %alloc = memref.alloc() : memref<1xindex>
+ affine.for %iv = 0 to 1 {
+ %d = affine.apply affine_map<(d0)[s1, s2] -> ((d0 - s1 + s1 * s2) * (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>(%iv)[%dim, %c1]
+ affine.store %d, %alloc[0] : memref<1xindex>
+ }
+ // CHECK: affine.for %[[IV:.*]] = 0 to 1 {
+ // CHECK-NEXT: %[[VAL:.*]] = affine.apply #[[$MAP1]](%[[IV]])
+ // CHECK-NEXT: affine.store %[[VAL]], %{{.*}}[0] : memref<1xindex>
+ // CHECK-NEXT: }
+ // CHECK: %[[VAL1:.*]] = affine.load %{{.*}}[0]
+ %d = affine.load %alloc[0] : memref<1xindex>
+ // CHECK: return %[[C6]], %[[C7]], %[[VAL0]], %[[VAL1]]
+ return %a, %b, %c, %d : index, index, index, index
}
More information about the Mlir-commits
mailing list