[Mlir-commits] [llvm] [mlir] Do not trigger UB during AffineExpr parsing. (PR #96896)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 27 04:10:54 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Johannes Reifferscheid (jreiffers)

<details>
<summary>Changes</summary>

Currently, parsing expressions that are undefined will trigger UB during compilation 
(e.g. `-9223372036854775808 / -1`). This change instead leaves the expressions as
they were.

This change is an NFC for compilations that did not previously involve UB.

---
Full diff: https://github.com/llvm/llvm-project/pull/96896.diff


3 Files Affected:

- (modified) llvm/include/llvm/Support/MathExtras.h (+4-2) 
- (modified) mlir/lib/IR/AffineExpr.cpp (+34-10) 
- (modified) mlir/unittests/IR/AffineExprTest.cpp (+46) 


``````````diff
diff --git a/llvm/include/llvm/Support/MathExtras.h b/llvm/include/llvm/Support/MathExtras.h
index 3bba999fb00e9..6de754f472635 100644
--- a/llvm/include/llvm/Support/MathExtras.h
+++ b/llvm/include/llvm/Support/MathExtras.h
@@ -435,7 +435,8 @@ inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
 }
 
 /// Returns the integer ceil(Numerator / Denominator). Signed version.
-/// Guaranteed to never overflow.
+/// Guaranteed to never overflow, unless Numerator is INT64_MIN and Denominator
+/// is -1.
 inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
   assert(Denominator && "Division by zero");
   if (!Numerator)
@@ -448,7 +449,8 @@ inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
 }
 
 /// Returns the integer floor(Numerator / Denominator). Signed version.
-/// Guaranteed to never overflow.
+/// Guaranteed to never overflow, unless Numerator is INT64_MIN and Denominator
+/// is -1.
 inline int64_t divideFloorSigned(int64_t Numerator, int64_t Denominator) {
   assert(Denominator && "Division by zero");
   if (!Numerator)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 1fab33327ba76..cf8157cf7bb8c 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -6,6 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <cstdint>
+#include <limits>
 #include <utility>
 
 #include "AffineExprDetail.h"
@@ -645,10 +647,14 @@ mlir::getAffineConstantExprs(ArrayRef<int64_t> constants,
 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
   auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
   auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
-  // Fold if both LHS, RHS are a constant.
-  if (lhsConst && rhsConst)
-    return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
-                                 lhs.getContext());
+  // Fold if both LHS, RHS are a constant and the sum does not overflow.
+  if (lhsConst && rhsConst) {
+    int64_t sum;
+    if (llvm::AddOverflow(lhsConst.getValue(), rhsConst.getValue(), sum)) {
+      return nullptr;
+    }
+    return getAffineConstantExpr(sum, lhs.getContext());
+  }
 
   // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
   // If only one of them is a symbolic expressions, make it the RHS.
@@ -774,9 +780,13 @@ static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
   auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
   auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
 
-  if (lhsConst && rhsConst)
-    return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
-                                 lhs.getContext());
+  if (lhsConst && rhsConst) {
+    int64_t product;
+    if (llvm::MulOverflow(lhsConst.getValue(), rhsConst.getValue(), product)) {
+      return nullptr;
+    }
+    return getAffineConstantExpr(product, lhs.getContext());
+  }
 
   if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
     return nullptr;
@@ -849,10 +859,16 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
   if (!rhsConst || rhsConst.getValue() < 1)
     return nullptr;
 
-  if (lhsConst)
+  if (lhsConst) {
+    // divideFloorSigned can only overflow in this case:
+    if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() &&
+        rhsConst.getValue() == -1) {
+      return nullptr;
+    }
     return getAffineConstantExpr(
         divideFloorSigned(lhsConst.getValue(), rhsConst.getValue()),
         lhs.getContext());
+  }
 
   // Fold floordiv of a multiply with a constant that is a multiple of the
   // divisor. Eg: (i * 128) floordiv 64 = i * 2.
@@ -905,10 +921,16 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
   if (!rhsConst || rhsConst.getValue() < 1)
     return nullptr;
 
-  if (lhsConst)
+  if (lhsConst) {
+    // divideCeilSigned can only overflow in this case:
+    if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() &&
+        rhsConst.getValue() == -1) {
+      return nullptr;
+    }
     return getAffineConstantExpr(
         divideCeilSigned(lhsConst.getValue(), rhsConst.getValue()),
         lhs.getContext());
+  }
 
   // Fold ceildiv of a multiply with a constant that is a multiple of the
   // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
@@ -950,9 +972,11 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
   if (!rhsConst || rhsConst.getValue() < 1)
     return nullptr;
 
-  if (lhsConst)
+  if (lhsConst) {
+    // mod never overflows.
     return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
                                  lhs.getContext());
+  }
 
   // Fold modulo of an expression that is known to be a multiple of a constant
   // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
diff --git a/mlir/unittests/IR/AffineExprTest.cpp b/mlir/unittests/IR/AffineExprTest.cpp
index ff154eb29807c..9740165c6b324 100644
--- a/mlir/unittests/IR/AffineExprTest.cpp
+++ b/mlir/unittests/IR/AffineExprTest.cpp
@@ -6,6 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <cstdint>
+#include <limits>
+
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Builders.h"
 #include "gtest/gtest.h"
@@ -30,3 +33,46 @@ TEST(AffineExprTest, constructFromBinaryOperators) {
   ASSERT_EQ(product.getKind(), AffineExprKind::Mul);
   ASSERT_EQ(remainder.getKind(), AffineExprKind::Mod);
 }
+
+TEST(AffineExprTest, constantFolding) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  auto cn1 = b.getAffineConstantExpr(-1);
+  auto c0 = b.getAffineConstantExpr(0);
+  auto c1 = b.getAffineConstantExpr(1);
+  auto c2 = b.getAffineConstantExpr(2);
+  auto c3 = b.getAffineConstantExpr(3);
+  auto c6 = b.getAffineConstantExpr(6);
+  auto cmax = b.getAffineConstantExpr(std::numeric_limits<int64_t>::max());
+  auto cmin = b.getAffineConstantExpr(std::numeric_limits<int64_t>::min());
+
+  ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Add, c1, c2), c3);
+  ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Mul, c2, c3), c6);
+  ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::FloorDiv, c3, c2), c1);
+  ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::CeilDiv, c3, c2), c2);
+
+  // Test division by zero:
+  auto c3ceildivc0 = getAffineBinaryOpExpr(AffineExprKind::CeilDiv, c3, c0);
+  ASSERT_EQ(c3ceildivc0.getKind(), AffineExprKind::CeilDiv);
+
+  auto c3floordivc0 = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, c3, c0);
+  ASSERT_EQ(c3floordivc0.getKind(), AffineExprKind::FloorDiv);
+
+  auto c3modc0 = getAffineBinaryOpExpr(AffineExprKind::Mod, c3, c0);
+  ASSERT_EQ(c3modc0.getKind(), AffineExprKind::Mod);
+
+  // Test overflow:
+  auto cmaxplusc1 = getAffineBinaryOpExpr(AffineExprKind::Add, cmax, c1);
+  ASSERT_EQ(cmaxplusc1.getKind(), AffineExprKind::Add);
+
+  auto cmaxtimesc2 = getAffineBinaryOpExpr(AffineExprKind::Mul, cmax, c2);
+  ASSERT_EQ(cmaxtimesc2.getKind(), AffineExprKind::Mul);
+
+  auto cminceildivcn1 =
+      getAffineBinaryOpExpr(AffineExprKind::CeilDiv, cmin, cn1);
+  ASSERT_EQ(cminceildivcn1.getKind(), AffineExprKind::CeilDiv);
+
+  auto cminfloordivcn1 =
+      getAffineBinaryOpExpr(AffineExprKind::FloorDiv, cmin, cn1);
+  ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv);
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/96896


More information about the Mlir-commits mailing list