[Mlir-commits] [mlir] [mlir][AffineExpr] Order arguments in the commutative affine exprs (PR #146895)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 3 07:09:41 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>

Order symbol/dim arguments by position, put dims before symbols and put constants to the right. This is to help affine simplifications.

---
Full diff: https://github.com/llvm/llvm-project/pull/146895.diff


4 Files Affected:

- (modified) mlir/lib/IR/AffineExpr.cpp (+36-2) 
- (modified) mlir/test/Dialect/Affine/simplify-structures.mlir (+2-2) 
- (modified) mlir/test/IR/affine-map.mlir (+1-1) 
- (modified) mlir/unittests/IR/AffineExprTest.cpp (+23) 


``````````diff
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc81f9d19aca7..856b29125602d 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -784,6 +784,36 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
   return nullptr;
 }
 
+static std::pair<AffineExpr, AffineExpr>
+orderCommutativeArgs(AffineExpr expr1, AffineExpr expr2) {
+  auto sym1 = dyn_cast<AffineSymbolExpr>(expr1);
+  auto sym2 = dyn_cast<AffineSymbolExpr>(expr2);
+  // Try to order by symbol/dim position first
+  if (sym1 && sym2)
+    return sym1.getPosition() < sym2.getPosition() ? std::pair{expr1, expr2}
+                                                   : std::pair{expr2, expr1};
+
+  auto dim1 = dyn_cast<AffineDimExpr>(expr1);
+  auto dim2 = dyn_cast<AffineDimExpr>(expr2);
+  if (dim1 && dim2)
+    return dim1.getPosition() < dim2.getPosition() ? std::pair{expr1, expr2}
+                                                   : std::pair{expr2, expr1};
+
+  // Put dims before symbols
+  if (dim1 && sym2)
+    return {dim1, sym2};
+
+  if (sym1 && dim2)
+    return {dim2, sym1};
+
+  // Move constants to the right
+  if (isa<AffineConstantExpr>(expr1) && !isa<AffineConstantExpr>(expr2))
+    return {expr2, expr1};
+
+  // Otherwise, keep original order
+  return {expr1, expr2};
+}
+
 AffineExpr AffineExpr::operator+(int64_t v) const {
   return *this + getAffineConstantExpr(v, getContext());
 }
@@ -791,9 +821,11 @@ AffineExpr AffineExpr::operator+(AffineExpr other) const {
   if (auto simplified = simplifyAdd(*this, other))
     return simplified;
 
+  auto [lhs, rhs] = orderCommutativeArgs(*this, other);
+
   StorageUniquer &uniquer = getContext()->getAffineUniquer();
   return uniquer.get<AffineBinaryOpExprStorage>(
-      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
+      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), lhs, rhs);
 }
 
 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
@@ -856,9 +888,11 @@ AffineExpr AffineExpr::operator*(AffineExpr other) const {
   if (auto simplified = simplifyMul(*this, other))
     return simplified;
 
+  auto [lhs, rhs] = orderCommutativeArgs(*this, other);
+
   StorageUniquer &uniquer = getContext()->getAffineUniquer();
   return uniquer.get<AffineBinaryOpExprStorage>(
-      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
+      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), lhs, rhs);
 }
 
 // Unary minus, delegate to operator*.
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 6f2737a982752..653c2cb521637 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -508,7 +508,7 @@ func.func @test_not_trivially_true_or_false_returning_three_results() -> (index,
 // -----
 
 // Test simplification of mod expressions.
-// CHECK-DAG:   #[[$MOD:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 * s1 + (s0 - s1) mod s2)>
+// CHECK-DAG:   #[[$MOD:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s4 + s3 + (s0 - s1) mod s2)>
 // CHECK-DAG:   #[[$SIMPLIFIED_MOD_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 mod (s2 - s0 * s1))>
 // CHECK-DAG:   #[[$MODULO_AND_PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s3 - (-s0 + s3) mod s2)>
 // CHECK-LABEL: func @semiaffine_simplification_mod
@@ -547,7 +547,7 @@ func.func @semiaffine_simplification_floordiv_and_ceildiv(%arg0: index, %arg1: i
 
 // Test simplification of product expressions.
 // CHECK-DAG:   #[[$PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 + (s0 - s1) * s2)>
-// CHECK-DAG:   #[[$SUM_OF_PRODUCTS:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s2 + s2 * s0 + s3 + s3 * s0 + s3 * s1 + s4 + s4 * s1)>
+// CHECK-DAG:   #[[$SUM_OF_PRODUCTS:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s2 + s0 * s3 + s1 * s3 + s1 * s4 + s2 + s3 + s4)>
 // CHECK-LABEL: func @semiaffine_simplification_product
 // CHECK-SAME:  (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index)
 func.func @semiaffine_simplification_product(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index) {
diff --git a/mlir/test/IR/affine-map.mlir b/mlir/test/IR/affine-map.mlir
index 977aec2536b1e..6277b28561f36 100644
--- a/mlir/test/IR/affine-map.mlir
+++ b/mlir/test/IR/affine-map.mlir
@@ -139,7 +139,7 @@
 #map44 = affine_map<(i, j) -> (i - 2*j, j * 6 floordiv 4)>
 
 // Simplifications
-// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2)[s0] -> (d0 + d1 + d2 + 1, d2 + d1, (d0 * s0) * 8)>
+// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2)[s0] -> (d0 + d1 + d2 + 1, d1 + d2, (d0 * s0) * 8)>
 #map45 = affine_map<(i, j, k) [N] -> (1 + i + 3 + j - 3 + k, k + 5 + j - 5, 2*i*4*N)>
 
 // CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (0, d1, d0 * 2, 0)>
diff --git a/mlir/unittests/IR/AffineExprTest.cpp b/mlir/unittests/IR/AffineExprTest.cpp
index 8a2d697540d5c..f8494ecb971c2 100644
--- a/mlir/unittests/IR/AffineExprTest.cpp
+++ b/mlir/unittests/IR/AffineExprTest.cpp
@@ -84,6 +84,20 @@ TEST(AffineExprTest, constantFolding) {
   ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv);
 }
 
+TEST(AffineExprTest, commutative) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  auto c2 = b.getAffineConstantExpr(1);
+  auto d0 = b.getAffineDimExpr(0);
+  auto d1 = b.getAffineDimExpr(1);
+  auto s0 = b.getAffineSymbolExpr(0);
+  auto s1 = b.getAffineSymbolExpr(1);
+
+  ASSERT_EQ(d0 * d1, d1 * d0);
+  ASSERT_EQ(s0 + s1, s1 + s0);
+  ASSERT_EQ(s0 * c2, c2 * s0);
+}
+
 TEST(AffineExprTest, divisionSimplification) {
   MLIRContext ctx;
   OpBuilder b(&ctx);
@@ -147,3 +161,12 @@ TEST(AffineExprTest, simpleAffineExprFlattenerRegression) {
   ASSERT_TRUE(isa<AffineConstantExpr>(result));
   ASSERT_EQ(cast<AffineConstantExpr>(result).getValue(), 7);
 }
+
+TEST(AffineExprTest, simplifyCommutative) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  auto s0 = b.getAffineSymbolExpr(0);
+  auto s1 = b.getAffineSymbolExpr(1);
+
+  ASSERT_EQ(s0 * s1 - s1 * s0 + 1, 1);
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/146895


More information about the Mlir-commits mailing list