[Mlir-commits] [mlir] [mlir][polynomial] implement add for polynomial data structure (PR #92169)

Jeremy Kun llvmlistbot at llvm.org
Tue May 14 14:31:58 PDT 2024


https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/92169

>From 15bedc7e16a5be442b8af2668be79a52b1671aca Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Wed, 8 May 2024 13:20:04 -0700
Subject: [PATCH 1/2] implement add for polynomial data structure

- use CRTP for base classes
- Add unit test
---
 .../mlir/Dialect/Polynomial/IR/Polynomial.h   | 85 ++++++++++++++-----
 mlir/unittests/Dialect/CMakeLists.txt         |  1 +
 .../Dialect/Polynomial/CMakeLists.txt         |  8 ++
 .../Dialect/Polynomial/PolynomialMathTest.cpp | 44 ++++++++++
 4 files changed, 119 insertions(+), 19 deletions(-)
 create mode 100644 mlir/unittests/Dialect/Polynomial/CMakeLists.txt
 create mode 100644 mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 7f44c29a98707..45823275ebb33 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -30,7 +30,7 @@ namespace polynomial {
 /// would want to specify 128-bit polynomials statically in the source code.
 constexpr unsigned apintBitWidth = 64;
 
-template <typename CoefficientType>
+template <class Derived, typename CoefficientType>
 class MonomialBase {
 public:
   MonomialBase(const CoefficientType &coeff, const APInt &expo)
@@ -55,12 +55,21 @@ class MonomialBase {
     return (exponent.ult(other.exponent));
   }
 
+  Derived add(const Derived &other) {
+    assert(exponent == other.exponent);
+    CoefficientType newCoeff = coefficient + other.coefficient;
+    Derived result;
+    result.setCoefficient(newCoeff);
+    result.setExponent(exponent);
+    return result;
+  }
+
   virtual bool isMonic() const = 0;
   virtual void
   coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
 
-  template <typename T>
-  friend ::llvm::hash_code hash_value(const MonomialBase<T> &arg);
+  template <class D, typename T>
+  friend ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg);
 
 protected:
   CoefficientType coefficient;
@@ -69,7 +78,7 @@ class MonomialBase {
 
 /// A class representing a monomial of a single-variable polynomial with integer
 /// coefficients.
-class IntMonomial : public MonomialBase<APInt> {
+class IntMonomial : public MonomialBase<IntMonomial, APInt> {
 public:
   IntMonomial(int64_t coeff, uint64_t expo)
       : MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
@@ -77,7 +86,7 @@ class IntMonomial : public MonomialBase<APInt> {
   IntMonomial()
       : MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}
 
-  ~IntMonomial() = default;
+  ~IntMonomial() override = default;
 
   bool isMonic() const override { return coefficient == 1; }
 
@@ -88,14 +97,14 @@ class IntMonomial : public MonomialBase<APInt> {
 
 /// A class representing a monomial of a single-variable polynomial with integer
 /// coefficients.
-class FloatMonomial : public MonomialBase<APFloat> {
+class FloatMonomial : public MonomialBase<FloatMonomial, APFloat> {
 public:
   FloatMonomial(double coeff, uint64_t expo)
       : MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
 
   FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}
 
-  ~FloatMonomial() = default;
+  ~FloatMonomial() override = default;
 
   bool isMonic() const override { return coefficient == APFloat(1.0); }
 
@@ -104,12 +113,12 @@ class FloatMonomial : public MonomialBase<APFloat> {
   }
 };
 
-template <typename Monomial>
+template <class Derived, typename Monomial>
 class PolynomialBase {
 public:
   PolynomialBase() = delete;
 
-  explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms) {};
+  explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms){};
 
   explicit operator bool() const { return !terms.empty(); }
   bool operator==(const PolynomialBase &other) const {
@@ -149,6 +158,44 @@ class PolynomialBase {
     }
   }
 
+  Derived add(const Derived &other) {
+    SmallVector<Monomial> newTerms;
+    auto it1 = terms.begin();
+    auto it2 = other.terms.begin();
+    while (it1 != terms.end() || it2 != other.terms.end()) {
+      if (it1 == terms.end()) {
+        newTerms.emplace_back(*it2);
+        it2++;
+        continue;
+      }
+
+      if (it2 == other.terms.end()) {
+        newTerms.emplace_back(*it1);
+        it1++;
+        continue;
+      }
+
+      while (it1->getExponent().ult(it2->getExponent())) {
+        newTerms.emplace_back(*it1);
+        it1++;
+        if (it1 == terms.end())
+          break;
+      }
+
+      while (it2->getExponent().ult(it1->getExponent())) {
+        newTerms.emplace_back(*it2);
+        it2++;
+        if (it2 == terms.end())
+          break;
+      }
+
+      newTerms.emplace_back(it1->add(*it2));
+      it1++;
+      it2++;
+    }
+    return Derived(newTerms);
+  }
+
   // Prints polynomial to 'os'.
   void print(raw_ostream &os) const { print(os, " + ", "**"); }
 
@@ -168,8 +215,8 @@ class PolynomialBase {
 
   ArrayRef<Monomial> getTerms() const { return terms; }
 
-  template <typename T>
-  friend ::llvm::hash_code hash_value(const PolynomialBase<T> &arg);
+  template <class D, typename T>
+  friend ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg);
 
 private:
   // The monomial terms for this polynomial.
@@ -179,7 +226,7 @@ class PolynomialBase {
 /// A single-variable polynomial with integer coefficients.
 ///
 /// Eg: x^1024 + x + 1
-class IntPolynomial : public PolynomialBase<IntMonomial> {
+class IntPolynomial : public PolynomialBase<IntPolynomial, IntMonomial> {
 public:
   explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
 
@@ -196,7 +243,7 @@ class IntPolynomial : public PolynomialBase<IntMonomial> {
 /// A single-variable polynomial with double coefficients.
 ///
 /// Eg: 1.0 x^1024 + 3.5 x + 1e-05
-class FloatPolynomial : public PolynomialBase<FloatMonomial> {
+class FloatPolynomial : public PolynomialBase<FloatPolynomial, FloatMonomial> {
 public:
   explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
       : PolynomialBase(terms) {}
@@ -212,20 +259,20 @@ class FloatPolynomial : public PolynomialBase<FloatMonomial> {
 };
 
 // Make Polynomials hashable.
-template <typename T>
-inline ::llvm::hash_code hash_value(const PolynomialBase<T> &arg) {
+template <class D, typename T>
+inline ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg) {
   return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
 }
 
-template <typename T>
-inline ::llvm::hash_code hash_value(const MonomialBase<T> &arg) {
+template <class D, typename T>
+inline ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg) {
   return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
                             ::llvm::hash_value(arg.exponent));
 }
 
-template <typename T>
+template <class D, typename T>
 inline raw_ostream &operator<<(raw_ostream &os,
-                               const PolynomialBase<T> &polynomial) {
+                               const PolynomialBase<D, T> &polynomial) {
   polynomial.print(os);
   return os;
 }
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 13393569f36fe..90a75d5a46ad9 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -11,6 +11,7 @@ add_subdirectory(Index)
 add_subdirectory(LLVMIR)
 add_subdirectory(MemRef)
 add_subdirectory(OpenACC)
+add_subdirectory(Polynomial)
 add_subdirectory(SCF)
 add_subdirectory(SparseTensor)
 add_subdirectory(SPIRV)
diff --git a/mlir/unittests/Dialect/Polynomial/CMakeLists.txt b/mlir/unittests/Dialect/Polynomial/CMakeLists.txt
new file mode 100644
index 0000000000000..807deeca41c06
--- /dev/null
+++ b/mlir/unittests/Dialect/Polynomial/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_unittest(MLIRPolynomialTests
+  PolynomialMathTest.cpp
+)
+target_link_libraries(MLIRPolynomialTests
+  PRIVATE
+  MLIRIR
+  MLIRPolynomialDialect
+)
diff --git a/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp
new file mode 100644
index 0000000000000..95906ad42588e
--- /dev/null
+++ b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp
@@ -0,0 +1,44 @@
+//===- PolynomialMathTest.cpp - Polynomial math Tests ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::polynomial;
+
+TEST(AddTest, checkSameDegreeAdditionOfIntPolynomial) {
+  IntPolynomial x = IntPolynomial::fromCoefficients({1, 2, 3});
+  IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
+  IntPolynomial expected = IntPolynomial::fromCoefficients({3, 5, 7});
+  EXPECT_EQ(expected, x.add(y));
+}
+
+TEST(AddTest, checkDifferentDegreeAdditionOfIntPolynomial) {
+  IntMonomial term2t = IntMonomial(2, 1);
+  IntPolynomial x = IntPolynomial::fromMonomials({term2t}).value();
+  IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
+  IntPolynomial expected = IntPolynomial::fromCoefficients({2, 5, 4});
+  EXPECT_EQ(expected, x.add(y));
+  EXPECT_EQ(expected, y.add(x));
+}
+
+TEST(AddTest, checkSameDegreeAdditionOfFloatPolynomial) {
+  FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5, 3.5});
+  FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
+  FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 8});
+  EXPECT_EQ(expected, x.add(y));
+}
+
+TEST(AddTest, checkDifferentDegreeAdditionOfFloatPolynomial) {
+  FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5});
+  FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
+  FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 4.5});
+  EXPECT_EQ(expected, x.add(y));
+  EXPECT_EQ(expected, y.add(x));
+}

>From cde7bbff053d242e770af88d7d104844dee346b8 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Tue, 14 May 2024 14:31:49 -0700
Subject: [PATCH 2/2] clang-format

---
 mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 45823275ebb33..e14cef51185e0 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -118,7 +118,7 @@ class PolynomialBase {
 public:
   PolynomialBase() = delete;
 
-  explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms){};
+  explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms) {};
 
   explicit operator bool() const { return !terms.empty(); }
   bool operator==(const PolynomialBase &other) const {



More information about the Mlir-commits mailing list