[Mlir-commits] [mlir] [mlir][polynomial] implement add for polynomial data structure (PR #92169)
Jeremy Kun
llvmlistbot at llvm.org
Tue May 14 14:31:58 PDT 2024
https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/92169
>From 15bedc7e16a5be442b8af2668be79a52b1671aca Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Wed, 8 May 2024 13:20:04 -0700
Subject: [PATCH 1/2] implement add for polynomial data structure
- use CRTP for base classes
- Add unit test
---
.../mlir/Dialect/Polynomial/IR/Polynomial.h | 85 ++++++++++++++-----
mlir/unittests/Dialect/CMakeLists.txt | 1 +
.../Dialect/Polynomial/CMakeLists.txt | 8 ++
.../Dialect/Polynomial/PolynomialMathTest.cpp | 44 ++++++++++
4 files changed, 119 insertions(+), 19 deletions(-)
create mode 100644 mlir/unittests/Dialect/Polynomial/CMakeLists.txt
create mode 100644 mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 7f44c29a98707..45823275ebb33 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -30,7 +30,7 @@ namespace polynomial {
/// would want to specify 128-bit polynomials statically in the source code.
constexpr unsigned apintBitWidth = 64;
-template <typename CoefficientType>
+template <class Derived, typename CoefficientType>
class MonomialBase {
public:
MonomialBase(const CoefficientType &coeff, const APInt &expo)
@@ -55,12 +55,21 @@ class MonomialBase {
return (exponent.ult(other.exponent));
}
+ Derived add(const Derived &other) {
+ assert(exponent == other.exponent);
+ CoefficientType newCoeff = coefficient + other.coefficient;
+ Derived result;
+ result.setCoefficient(newCoeff);
+ result.setExponent(exponent);
+ return result;
+ }
+
virtual bool isMonic() const = 0;
virtual void
coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
- template <typename T>
- friend ::llvm::hash_code hash_value(const MonomialBase<T> &arg);
+ template <class D, typename T>
+ friend ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg);
protected:
CoefficientType coefficient;
@@ -69,7 +78,7 @@ class MonomialBase {
/// A class representing a monomial of a single-variable polynomial with integer
/// coefficients.
-class IntMonomial : public MonomialBase<APInt> {
+class IntMonomial : public MonomialBase<IntMonomial, APInt> {
public:
IntMonomial(int64_t coeff, uint64_t expo)
: MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
@@ -77,7 +86,7 @@ class IntMonomial : public MonomialBase<APInt> {
IntMonomial()
: MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}
- ~IntMonomial() = default;
+ ~IntMonomial() override = default;
bool isMonic() const override { return coefficient == 1; }
@@ -88,14 +97,14 @@ class IntMonomial : public MonomialBase<APInt> {
/// A class representing a monomial of a single-variable polynomial with integer
/// coefficients.
-class FloatMonomial : public MonomialBase<APFloat> {
+class FloatMonomial : public MonomialBase<FloatMonomial, APFloat> {
public:
FloatMonomial(double coeff, uint64_t expo)
: MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}
- ~FloatMonomial() = default;
+ ~FloatMonomial() override = default;
bool isMonic() const override { return coefficient == APFloat(1.0); }
@@ -104,12 +113,12 @@ class FloatMonomial : public MonomialBase<APFloat> {
}
};
-template <typename Monomial>
+template <class Derived, typename Monomial>
class PolynomialBase {
public:
PolynomialBase() = delete;
- explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms) {};
+ explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms){};
explicit operator bool() const { return !terms.empty(); }
bool operator==(const PolynomialBase &other) const {
@@ -149,6 +158,44 @@ class PolynomialBase {
}
}
+ Derived add(const Derived &other) {
+ SmallVector<Monomial> newTerms;
+ auto it1 = terms.begin();
+ auto it2 = other.terms.begin();
+ while (it1 != terms.end() || it2 != other.terms.end()) {
+ if (it1 == terms.end()) {
+ newTerms.emplace_back(*it2);
+ it2++;
+ continue;
+ }
+
+ if (it2 == other.terms.end()) {
+ newTerms.emplace_back(*it1);
+ it1++;
+ continue;
+ }
+
+ while (it1->getExponent().ult(it2->getExponent())) {
+ newTerms.emplace_back(*it1);
+ it1++;
+ if (it1 == terms.end())
+ break;
+ }
+
+ while (it2->getExponent().ult(it1->getExponent())) {
+ newTerms.emplace_back(*it2);
+ it2++;
+ if (it2 == terms.end())
+ break;
+ }
+
+ newTerms.emplace_back(it1->add(*it2));
+ it1++;
+ it2++;
+ }
+ return Derived(newTerms);
+ }
+
// Prints polynomial to 'os'.
void print(raw_ostream &os) const { print(os, " + ", "**"); }
@@ -168,8 +215,8 @@ class PolynomialBase {
ArrayRef<Monomial> getTerms() const { return terms; }
- template <typename T>
- friend ::llvm::hash_code hash_value(const PolynomialBase<T> &arg);
+ template <class D, typename T>
+ friend ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg);
private:
// The monomial terms for this polynomial.
@@ -179,7 +226,7 @@ class PolynomialBase {
/// A single-variable polynomial with integer coefficients.
///
/// Eg: x^1024 + x + 1
-class IntPolynomial : public PolynomialBase<IntMonomial> {
+class IntPolynomial : public PolynomialBase<IntPolynomial, IntMonomial> {
public:
explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
@@ -196,7 +243,7 @@ class IntPolynomial : public PolynomialBase<IntMonomial> {
/// A single-variable polynomial with double coefficients.
///
/// Eg: 1.0 x^1024 + 3.5 x + 1e-05
-class FloatPolynomial : public PolynomialBase<FloatMonomial> {
+class FloatPolynomial : public PolynomialBase<FloatPolynomial, FloatMonomial> {
public:
explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
: PolynomialBase(terms) {}
@@ -212,20 +259,20 @@ class FloatPolynomial : public PolynomialBase<FloatMonomial> {
};
// Make Polynomials hashable.
-template <typename T>
-inline ::llvm::hash_code hash_value(const PolynomialBase<T> &arg) {
+template <class D, typename T>
+inline ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg) {
return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
}
-template <typename T>
-inline ::llvm::hash_code hash_value(const MonomialBase<T> &arg) {
+template <class D, typename T>
+inline ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg) {
return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
::llvm::hash_value(arg.exponent));
}
-template <typename T>
+template <class D, typename T>
inline raw_ostream &operator<<(raw_ostream &os,
- const PolynomialBase<T> &polynomial) {
+ const PolynomialBase<D, T> &polynomial) {
polynomial.print(os);
return os;
}
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 13393569f36fe..90a75d5a46ad9 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -11,6 +11,7 @@ add_subdirectory(Index)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
add_subdirectory(OpenACC)
+add_subdirectory(Polynomial)
add_subdirectory(SCF)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)
diff --git a/mlir/unittests/Dialect/Polynomial/CMakeLists.txt b/mlir/unittests/Dialect/Polynomial/CMakeLists.txt
new file mode 100644
index 0000000000000..807deeca41c06
--- /dev/null
+++ b/mlir/unittests/Dialect/Polynomial/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_unittest(MLIRPolynomialTests
+ PolynomialMathTest.cpp
+)
+target_link_libraries(MLIRPolynomialTests
+ PRIVATE
+ MLIRIR
+ MLIRPolynomialDialect
+)
diff --git a/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp
new file mode 100644
index 0000000000000..95906ad42588e
--- /dev/null
+++ b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp
@@ -0,0 +1,44 @@
+//===- PolynomialMathTest.cpp - Polynomial math Tests ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::polynomial;
+
+TEST(AddTest, checkSameDegreeAdditionOfIntPolynomial) {
+ IntPolynomial x = IntPolynomial::fromCoefficients({1, 2, 3});
+ IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
+ IntPolynomial expected = IntPolynomial::fromCoefficients({3, 5, 7});
+ EXPECT_EQ(expected, x.add(y));
+}
+
+TEST(AddTest, checkDifferentDegreeAdditionOfIntPolynomial) {
+ IntMonomial term2t = IntMonomial(2, 1);
+ IntPolynomial x = IntPolynomial::fromMonomials({term2t}).value();
+ IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
+ IntPolynomial expected = IntPolynomial::fromCoefficients({2, 5, 4});
+ EXPECT_EQ(expected, x.add(y));
+ EXPECT_EQ(expected, y.add(x));
+}
+
+TEST(AddTest, checkSameDegreeAdditionOfFloatPolynomial) {
+ FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5, 3.5});
+ FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
+ FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 8});
+ EXPECT_EQ(expected, x.add(y));
+}
+
+TEST(AddTest, checkDifferentDegreeAdditionOfFloatPolynomial) {
+ FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5});
+ FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
+ FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 4.5});
+ EXPECT_EQ(expected, x.add(y));
+ EXPECT_EQ(expected, y.add(x));
+}
>From cde7bbff053d242e770af88d7d104844dee346b8 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Tue, 14 May 2024 14:31:49 -0700
Subject: [PATCH 2/2] clang-format
---
mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 45823275ebb33..e14cef51185e0 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -118,7 +118,7 @@ class PolynomialBase {
public:
PolynomialBase() = delete;
- explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms){};
+ explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms) {};
explicit operator bool() const { return !terms.empty(); }
bool operator==(const PolynomialBase &other) const {
More information about the Mlir-commits
mailing list