[Mlir-commits] [mlir] [mlir][AffineExpr] Order arguments in the commutative affine exprs (PR #146895)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 3 07:09:41 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Ivan Butygin (Hardcode84)
<details>
<summary>Changes</summary>
Order symbol/dim arguments by position, put dims before symbols and put constants to the right. This is to help affine simplifications.
---
Full diff: https://github.com/llvm/llvm-project/pull/146895.diff
4 Files Affected:
- (modified) mlir/lib/IR/AffineExpr.cpp (+36-2)
- (modified) mlir/test/Dialect/Affine/simplify-structures.mlir (+2-2)
- (modified) mlir/test/IR/affine-map.mlir (+1-1)
- (modified) mlir/unittests/IR/AffineExprTest.cpp (+23)
``````````diff
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc81f9d19aca7..856b29125602d 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -784,6 +784,36 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
return nullptr;
}
+static std::pair<AffineExpr, AffineExpr>
+orderCommutativeArgs(AffineExpr expr1, AffineExpr expr2) {
+ auto sym1 = dyn_cast<AffineSymbolExpr>(expr1);
+ auto sym2 = dyn_cast<AffineSymbolExpr>(expr2);
+ // Try to order by symbol/dim position first
+ if (sym1 && sym2)
+ return sym1.getPosition() < sym2.getPosition() ? std::pair{expr1, expr2}
+ : std::pair{expr2, expr1};
+
+ auto dim1 = dyn_cast<AffineDimExpr>(expr1);
+ auto dim2 = dyn_cast<AffineDimExpr>(expr2);
+ if (dim1 && dim2)
+ return dim1.getPosition() < dim2.getPosition() ? std::pair{expr1, expr2}
+ : std::pair{expr2, expr1};
+
+ // Put dims before symbols
+ if (dim1 && sym2)
+ return {dim1, sym2};
+
+ if (sym1 && dim2)
+ return {dim2, sym1};
+
+ // Move constants to the right
+ if (isa<AffineConstantExpr>(expr1) && !isa<AffineConstantExpr>(expr2))
+ return {expr2, expr1};
+
+ // Otherwise, keep original order
+ return {expr1, expr2};
+}
+
AffineExpr AffineExpr::operator+(int64_t v) const {
return *this + getAffineConstantExpr(v, getContext());
}
@@ -791,9 +821,11 @@ AffineExpr AffineExpr::operator+(AffineExpr other) const {
if (auto simplified = simplifyAdd(*this, other))
return simplified;
+ auto [lhs, rhs] = orderCommutativeArgs(*this, other);
+
StorageUniquer &uniquer = getContext()->getAffineUniquer();
return uniquer.get<AffineBinaryOpExprStorage>(
- /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
+ /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), lhs, rhs);
}
/// Simplify a multiply expression. Return nullptr if it can't be simplified.
@@ -856,9 +888,11 @@ AffineExpr AffineExpr::operator*(AffineExpr other) const {
if (auto simplified = simplifyMul(*this, other))
return simplified;
+ auto [lhs, rhs] = orderCommutativeArgs(*this, other);
+
StorageUniquer &uniquer = getContext()->getAffineUniquer();
return uniquer.get<AffineBinaryOpExprStorage>(
- /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
+ /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), lhs, rhs);
}
// Unary minus, delegate to operator*.
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 6f2737a982752..653c2cb521637 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -508,7 +508,7 @@ func.func @test_not_trivially_true_or_false_returning_three_results() -> (index,
// -----
// Test simplification of mod expressions.
-// CHECK-DAG: #[[$MOD:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 * s1 + (s0 - s1) mod s2)>
+// CHECK-DAG: #[[$MOD:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s4 + s3 + (s0 - s1) mod s2)>
// CHECK-DAG: #[[$SIMPLIFIED_MOD_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 mod (s2 - s0 * s1))>
// CHECK-DAG: #[[$MODULO_AND_PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s3 - (-s0 + s3) mod s2)>
// CHECK-LABEL: func @semiaffine_simplification_mod
@@ -547,7 +547,7 @@ func.func @semiaffine_simplification_floordiv_and_ceildiv(%arg0: index, %arg1: i
// Test simplification of product expressions.
// CHECK-DAG: #[[$PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 + (s0 - s1) * s2)>
-// CHECK-DAG: #[[$SUM_OF_PRODUCTS:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s2 + s2 * s0 + s3 + s3 * s0 + s3 * s1 + s4 + s4 * s1)>
+// CHECK-DAG: #[[$SUM_OF_PRODUCTS:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s2 + s0 * s3 + s1 * s3 + s1 * s4 + s2 + s3 + s4)>
// CHECK-LABEL: func @semiaffine_simplification_product
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index)
func.func @semiaffine_simplification_product(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index) {
diff --git a/mlir/test/IR/affine-map.mlir b/mlir/test/IR/affine-map.mlir
index 977aec2536b1e..6277b28561f36 100644
--- a/mlir/test/IR/affine-map.mlir
+++ b/mlir/test/IR/affine-map.mlir
@@ -139,7 +139,7 @@
#map44 = affine_map<(i, j) -> (i - 2*j, j * 6 floordiv 4)>
// Simplifications
-// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2)[s0] -> (d0 + d1 + d2 + 1, d2 + d1, (d0 * s0) * 8)>
+// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2)[s0] -> (d0 + d1 + d2 + 1, d1 + d2, (d0 * s0) * 8)>
#map45 = affine_map<(i, j, k) [N] -> (1 + i + 3 + j - 3 + k, k + 5 + j - 5, 2*i*4*N)>
// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (0, d1, d0 * 2, 0)>
diff --git a/mlir/unittests/IR/AffineExprTest.cpp b/mlir/unittests/IR/AffineExprTest.cpp
index 8a2d697540d5c..f8494ecb971c2 100644
--- a/mlir/unittests/IR/AffineExprTest.cpp
+++ b/mlir/unittests/IR/AffineExprTest.cpp
@@ -84,6 +84,20 @@ TEST(AffineExprTest, constantFolding) {
ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv);
}
+TEST(AffineExprTest, commutative) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+ auto c2 = b.getAffineConstantExpr(1);
+ auto d0 = b.getAffineDimExpr(0);
+ auto d1 = b.getAffineDimExpr(1);
+ auto s0 = b.getAffineSymbolExpr(0);
+ auto s1 = b.getAffineSymbolExpr(1);
+
+ ASSERT_EQ(d0 * d1, d1 * d0);
+ ASSERT_EQ(s0 + s1, s1 + s0);
+ ASSERT_EQ(s0 * c2, c2 * s0);
+}
+
TEST(AffineExprTest, divisionSimplification) {
MLIRContext ctx;
OpBuilder b(&ctx);
@@ -147,3 +161,12 @@ TEST(AffineExprTest, simpleAffineExprFlattenerRegression) {
ASSERT_TRUE(isa<AffineConstantExpr>(result));
ASSERT_EQ(cast<AffineConstantExpr>(result).getValue(), 7);
}
+
+TEST(AffineExprTest, simplifyCommutative) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+ auto s0 = b.getAffineSymbolExpr(0);
+ auto s1 = b.getAffineSymbolExpr(1);
+
+ ASSERT_EQ(s0 * s1 - s1 * s0 + 1, 1);
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/146895
More information about the Mlir-commits
mailing list