[Mlir-commits] [mlir] [mlir] Handle attempted construction of invalid `AffineExpr` products (PR #103010)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 12 23:31:14 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Felix Schneider (ubfx)

<details>
<summary>Changes</summary>

Binary `AffineExpr`s can be constructed using the overloaded operators as well as respective methods (`AffineExpr::floorDiv()`, `ceilDiv()`). In the case of a `Mul` expression, either LHS or RHS has to be symbolic or constant to constitute a valid `AffineExpr`.

This patch returns nullptr from attempted construction of an `AffineExpr` through the  `*` operator if not at least one operand is symbolic or constant.

Related: https://github.com/llvm/llvm-project/commit/a4b23638d23d603001c19285a7c7535a8ce81317#r144894547

---
Full diff: https://github.com/llvm/llvm-project/pull/103010.diff


2 Files Affected:

- (modified) mlir/lib/IR/AffineExpr.cpp (+9-3) 
- (modified) mlir/unittests/IR/AffineExprTest.cpp (+29-1) 


``````````diff
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index fc7ede279643ed..d0fcfb7d7825ea 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -781,6 +781,10 @@ AffineExpr AffineExpr::operator+(AffineExpr other) const {
 
 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
+  // The caller should have checked that this constitutes a valid `AffineExpr`
+  // in principle.
+  assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
+
   auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
   auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
 
@@ -792,9 +796,6 @@ static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
     return getAffineConstantExpr(product, lhs.getContext());
   }
 
-  if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
-    return nullptr;
-
   // Canonicalize the mul expression so that the constant/symbolic term is the
   // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
   // constant. (Note that a constant is trivially symbolic).
@@ -836,6 +837,11 @@ AffineExpr AffineExpr::operator*(int64_t v) const {
   return *this * getAffineConstantExpr(v, getContext());
 }
 AffineExpr AffineExpr::operator*(AffineExpr other) const {
+  // If neither LHS nor RHS are symbolic or constant, this product will not be a
+  // valid `AffineExpr`.
+  if (!this->isSymbolicOrConstant() && !other.isSymbolicOrConstant())
+    return nullptr;
+
   if (auto simplified = simplifyMul(*this, other))
     return simplified;
 
diff --git a/mlir/unittests/IR/AffineExprTest.cpp b/mlir/unittests/IR/AffineExprTest.cpp
index dc78bbac85f3cf..09f36e42a5c776 100644
--- a/mlir/unittests/IR/AffineExprTest.cpp
+++ b/mlir/unittests/IR/AffineExprTest.cpp
@@ -16,7 +16,7 @@
 using namespace mlir;
 
 // Test creating AffineExprs using the overloaded binary operators.
-TEST(AffineExprTest, constructFromBinaryOperators) {
+TEST(AffineExprTest, constructFromBinaryOperatorsWithDimRHS) {
   MLIRContext ctx;
   OpBuilder b(&ctx);
 
@@ -27,11 +27,39 @@ TEST(AffineExprTest, constructFromBinaryOperators) {
   auto difference = d0 - d1;
   auto product = d0 * d1;
   auto remainder = d0 % d1;
+  auto floorDiv = d0.floorDiv(d1);
+  auto ceilDiv = d0.ceilDiv(d1);
+
+  ASSERT_EQ(sum.getKind(), AffineExprKind::Add);
+  ASSERT_EQ(difference.getKind(), AffineExprKind::Add);
+  ASSERT_EQ(remainder.getKind(), AffineExprKind::Mod);
+  ASSERT_EQ(floorDiv.getKind(), AffineExprKind::FloorDiv);
+  ASSERT_EQ(ceilDiv.getKind(), AffineExprKind::CeilDiv);
+
+  // Invalid (semi-)affine expressions.
+  ASSERT_EQ(product, nullptr);
+}
+
+TEST(AffineExprTest, constructFromBinaryOperatorsWithConstRHS) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+
+  auto d0 = b.getAffineDimExpr(0);
+  auto d1 = b.getAffineConstantExpr(123);
+
+  auto sum = d0 + d1;
+  auto difference = d0 - d1;
+  auto product = d0 * d1;
+  auto remainder = d0 % d1;
+  auto floorDiv = d0.floorDiv(d1);
+  auto ceilDiv = d0.ceilDiv(d1);
 
   ASSERT_EQ(sum.getKind(), AffineExprKind::Add);
   ASSERT_EQ(difference.getKind(), AffineExprKind::Add);
   ASSERT_EQ(product.getKind(), AffineExprKind::Mul);
   ASSERT_EQ(remainder.getKind(), AffineExprKind::Mod);
+  ASSERT_EQ(floorDiv.getKind(), AffineExprKind::FloorDiv);
+  ASSERT_EQ(ceilDiv.getKind(), AffineExprKind::CeilDiv);
 }
 
 TEST(AffineExprTest, constantFolding) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/103010


More information about the Mlir-commits mailing list