[Mlir-commits] [mlir] [mlir] Handle attempted construction of invalid `AffineExpr` products (PR #103010)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 12 23:31:14 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Felix Schneider (ubfx)
<details>
<summary>Changes</summary>
Binary `AffineExpr`s can be constructed using the overloaded operators as well as respective methods (`AffineExpr::floorDiv()`, `ceilDiv()`). In the case of a `Mul` expression, either LHS or RHS has to be symbolic or constant to constitute a valid `AffineExpr`.
This patch returns nullptr from attempted construction of an `AffineExpr` through the `*` operator if not at least one operand is symbolic or constant.
Related: https://github.com/llvm/llvm-project/commit/a4b23638d23d603001c19285a7c7535a8ce81317#r144894547
---
Full diff: https://github.com/llvm/llvm-project/pull/103010.diff
2 Files Affected:
- (modified) mlir/lib/IR/AffineExpr.cpp (+9-3)
- (modified) mlir/unittests/IR/AffineExprTest.cpp (+29-1)
``````````diff
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index fc7ede279643ed..d0fcfb7d7825ea 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -781,6 +781,10 @@ AffineExpr AffineExpr::operator+(AffineExpr other) const {
/// Simplify a multiply expression. Return nullptr if it can't be simplified.
static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
+ // The caller should have checked that this constitutes a valid `AffineExpr`
+ // in principle.
+ assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
+
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
@@ -792,9 +796,6 @@ static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
return getAffineConstantExpr(product, lhs.getContext());
}
- if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
- return nullptr;
-
// Canonicalize the mul expression so that the constant/symbolic term is the
// RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
// constant. (Note that a constant is trivially symbolic).
@@ -836,6 +837,11 @@ AffineExpr AffineExpr::operator*(int64_t v) const {
return *this * getAffineConstantExpr(v, getContext());
}
AffineExpr AffineExpr::operator*(AffineExpr other) const {
+ // If neither LHS nor RHS are symbolic or constant, this product will not be a
+ // valid `AffineExpr`.
+ if (!this->isSymbolicOrConstant() && !other.isSymbolicOrConstant())
+ return nullptr;
+
if (auto simplified = simplifyMul(*this, other))
return simplified;
diff --git a/mlir/unittests/IR/AffineExprTest.cpp b/mlir/unittests/IR/AffineExprTest.cpp
index dc78bbac85f3cf..09f36e42a5c776 100644
--- a/mlir/unittests/IR/AffineExprTest.cpp
+++ b/mlir/unittests/IR/AffineExprTest.cpp
@@ -16,7 +16,7 @@
using namespace mlir;
// Test creating AffineExprs using the overloaded binary operators.
-TEST(AffineExprTest, constructFromBinaryOperators) {
+TEST(AffineExprTest, constructFromBinaryOperatorsWithDimRHS) {
MLIRContext ctx;
OpBuilder b(&ctx);
@@ -27,11 +27,39 @@ TEST(AffineExprTest, constructFromBinaryOperators) {
auto difference = d0 - d1;
auto product = d0 * d1;
auto remainder = d0 % d1;
+ auto floorDiv = d0.floorDiv(d1);
+ auto ceilDiv = d0.ceilDiv(d1);
+
+ ASSERT_EQ(sum.getKind(), AffineExprKind::Add);
+ ASSERT_EQ(difference.getKind(), AffineExprKind::Add);
+ ASSERT_EQ(remainder.getKind(), AffineExprKind::Mod);
+ ASSERT_EQ(floorDiv.getKind(), AffineExprKind::FloorDiv);
+ ASSERT_EQ(ceilDiv.getKind(), AffineExprKind::CeilDiv);
+
+ // Invalid (semi-)affine expressions.
+ ASSERT_EQ(product, nullptr);
+}
+
+TEST(AffineExprTest, constructFromBinaryOperatorsWithConstRHS) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+
+ auto d0 = b.getAffineDimExpr(0);
+ auto d1 = b.getAffineConstantExpr(123);
+
+ auto sum = d0 + d1;
+ auto difference = d0 - d1;
+ auto product = d0 * d1;
+ auto remainder = d0 % d1;
+ auto floorDiv = d0.floorDiv(d1);
+ auto ceilDiv = d0.ceilDiv(d1);
ASSERT_EQ(sum.getKind(), AffineExprKind::Add);
ASSERT_EQ(difference.getKind(), AffineExprKind::Add);
ASSERT_EQ(product.getKind(), AffineExprKind::Mul);
ASSERT_EQ(remainder.getKind(), AffineExprKind::Mod);
+ ASSERT_EQ(floorDiv.getKind(), AffineExprKind::FloorDiv);
+ ASSERT_EQ(ceilDiv.getKind(), AffineExprKind::CeilDiv);
}
TEST(AffineExprTest, constantFolding) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/103010
More information about the Mlir-commits
mailing list