[Mlir-commits] [mlir] [MLIR][Presburger] Fix reduce bug in Fraction class and add tests (PR #68298)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 5 03:22:17 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-presburger
<details>
<summary>Changes</summary>
We modify `Fraction.reduce()` to work with negative fractions as well.
We add tests to verify the implementation of arithmetic and relational operators on Fractions.
---
Full diff: https://github.com/llvm/llvm-project/pull/68298.diff
5 Files Affected:
- (added) libcxx/modules/std/mdspan.cppm (+33)
- (added) libcxx/modules/std/print.cppm (+25)
- (modified) mlir/include/mlir/Analysis/Presburger/Fraction.h (+9-7)
- (modified) mlir/unittests/Analysis/Presburger/CMakeLists.txt (+1)
- (added) mlir/unittests/Analysis/Presburger/FractionTest.cpp (+51)
``````````diff
diff --git a/libcxx/modules/std/mdspan.cppm b/libcxx/modules/std/mdspan.cppm
new file mode 100644
index 000000000000000..40426cce3fce8c2
--- /dev/null
+++ b/libcxx/modules/std/mdspan.cppm
@@ -0,0 +1,33 @@
+// -*- C++ -*-
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+module;
+#include <mdspan>
+
+export module std:mdspan;
+export namespace std {
+#if _LIBCPP_STD_VER >= 23
+ // [mdspan.extents], class template extents
+ using std::extents;
+
+ // [mdspan.extents.dextents], alias template dextents
+ using std::dextents;
+
+ // [mdspan.layout], layout mapping
+ using std::layout_left;
+ using std::layout_right;
+ // using std::layout_stride;
+
+ // [mdspan.accessor.default], class template default_accessor
+ using std::default_accessor;
+
+ // [mdspan.mdspan], class template mdspan
+ using std::mdspan;
+#endif // _LIBCPP_STD_VER >= 23
+} // namespace std
diff --git a/libcxx/modules/std/print.cppm b/libcxx/modules/std/print.cppm
new file mode 100644
index 000000000000000..02362633c6d9fbb
--- /dev/null
+++ b/libcxx/modules/std/print.cppm
@@ -0,0 +1,25 @@
+// -*- C++ -*-
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+module;
+#include <print>
+
+export module std:print;
+export namespace std {
+#if _LIBCPP_STD_VER >= 23
+ // [print.fun], print functions
+ using std::print;
+ using std::println;
+
+ using std::vprint_nonunicode;
+# ifndef _LIBCPP_HAS_NO_UNICODE
+ using std::vprint_unicode;
+# endif // _LIBCPP_HAS_NO_UNICODE
+#endif // _LIBCPP_STD_VER >= 23
+} // namespace std
diff --git a/mlir/include/mlir/Analysis/Presburger/Fraction.h b/mlir/include/mlir/Analysis/Presburger/Fraction.h
index 74127a900d53ed2..a410f528e1f8001 100644
--- a/mlir/include/mlir/Analysis/Presburger/Fraction.h
+++ b/mlir/include/mlir/Analysis/Presburger/Fraction.h
@@ -30,7 +30,8 @@ struct Fraction {
Fraction() = default;
/// Construct a Fraction from a numerator and denominator.
- Fraction(const MPInt &oNum, const MPInt &oDen = MPInt(1)) : num(oNum), den(oDen) {
+ Fraction(const MPInt &oNum, const MPInt &oDen = MPInt(1))
+ : num(oNum), den(oDen) {
if (den < 0) {
num = -num;
den = -den;
@@ -38,7 +39,8 @@ struct Fraction {
}
/// Overloads for passing literals.
Fraction(const MPInt &num, int64_t den = 1) : Fraction(num, MPInt(den)) {}
- Fraction(int64_t num, const MPInt &den = MPInt(1)) : Fraction(MPInt(num), den) {}
+ Fraction(int64_t num, const MPInt &den = MPInt(1))
+ : Fraction(MPInt(num), den) {}
Fraction(int64_t num, int64_t den) : Fraction(MPInt(num), MPInt(den)) {}
// Return the value of the fraction as an integer. This should only be called
@@ -102,7 +104,7 @@ inline bool operator>=(const Fraction &x, const Fraction &y) {
inline Fraction reduce(const Fraction &f) {
if (f == Fraction(0))
return Fraction(0, 1);
- MPInt g = gcd(f.num, f.den);
+ MPInt g = gcd(abs(f.num), abs(f.den));
return Fraction(f.num / g, f.den / g);
}
@@ -122,22 +124,22 @@ inline Fraction operator-(const Fraction &x, const Fraction &y) {
return reduce(Fraction(x.num * y.den - x.den * y.num, x.den * y.den));
}
-inline Fraction& operator+=(Fraction &x, const Fraction &y) {
+inline Fraction &operator+=(Fraction &x, const Fraction &y) {
x = x + y;
return x;
}
-inline Fraction& operator-=(Fraction &x, const Fraction &y) {
+inline Fraction &operator-=(Fraction &x, const Fraction &y) {
x = x - y;
return x;
}
-inline Fraction& operator/=(Fraction &x, const Fraction &y) {
+inline Fraction &operator/=(Fraction &x, const Fraction &y) {
x = x / y;
return x;
}
-inline Fraction& operator*=(Fraction &x, const Fraction &y) {
+inline Fraction &operator*=(Fraction &x, const Fraction &y) {
x = x * y;
return x;
}
diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt
index 7b0124ee24c352e..b6ce273e35a0ee7 100644
--- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt
+++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_unittest(MLIRPresburgerTests
+ FractionTest.cpp
IntegerPolyhedronTest.cpp
IntegerRelationTest.cpp
LinearTransformTest.cpp
diff --git a/mlir/unittests/Analysis/Presburger/FractionTest.cpp b/mlir/unittests/Analysis/Presburger/FractionTest.cpp
new file mode 100644
index 000000000000000..aafa689588c921b
--- /dev/null
+++ b/mlir/unittests/Analysis/Presburger/FractionTest.cpp
@@ -0,0 +1,51 @@
+#include "mlir/Analysis/Presburger/Fraction.h"
+#include "./Utils.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+using namespace mlir;
+using namespace presburger;
+
+TEST(FractionTest, getAsInteger) {
+ Fraction f(3, 1);
+ EXPECT_EQ(f.getAsInteger(), MPInt(3));
+}
+
+TEST(FractionTest, nearIntegers) {
+ Fraction f(52, 14);
+
+ EXPECT_EQ(floor(f), 3);
+ EXPECT_EQ(ceil(f), 4);
+}
+
+TEST(FractionTest, reduce) {
+ Fraction f(20, 35), g(-56, 63);
+ EXPECT_EQ(f, Fraction(4, 7));
+ EXPECT_EQ(g, Fraction(-8, 9));
+}
+
+TEST(FractionTest, arithmetic) {
+ Fraction f(3, 4), g(-2, 3);
+
+ EXPECT_EQ(f / g, Fraction(-9, 8));
+ EXPECT_EQ(f * g, Fraction(-1, 2));
+ EXPECT_EQ(f + g, Fraction(1, 12));
+ EXPECT_EQ(f - g, Fraction(17, 12));
+
+ f /= g;
+ EXPECT_EQ(f, Fraction(-9, 8));
+ f *= g;
+ EXPECT_EQ(f, Fraction(3, 4));
+ f += g;
+ EXPECT_EQ(f, Fraction(Fraction(1, 12)));
+ f -= g;
+ EXPECT_EQ(f, Fraction(3, 4));
+}
+
+TEST(FractionTest, relational) {
+ Fraction f(2, 5), g(3, 7);
+ ASSERT_TRUE(f < g);
+ ASSERT_FALSE(g < f);
+
+ EXPECT_EQ(f, Fraction(4, 10));
+}
\ No newline at end of file
``````````
</details>
https://github.com/llvm/llvm-project/pull/68298
More information about the Mlir-commits
mailing list