[Mlir-commits] [mlir] Support polynomial attributes with floating point coefficients (PR #91137)

Jeremy Kun llvmlistbot at llvm.org
Sun May 5 11:07:12 PDT 2024


https://github.com/j2kun created https://github.com/llvm/llvm-project/pull/91137

In summary:

- `Monomial` -> `MonomialBase` with two inheriting `IntMonomial` and `FloatMonomial` for the different coefficient types
- `Polynomial` -> `PolynomialBase` with `IntPolynomial` and `FloatPolynomial` inheriting
- `PolynomialAttr` -> `IntPolynomialAttr`, and new `FloatPolynomialAttr` attribute, both of which may be input to `polynomial.constant`
- Refactoring common parts of attribute parsers.

>From bf5da2ae01e7efa0415db06625809b8f96aee77c Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Fri, 3 May 2024 17:27:07 -0700
Subject: [PATCH] refactor and support Float polynomials

---
 .../mlir/Dialect/Polynomial/IR/Polynomial.h   | 192 ++++++++++++++----
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  |  51 +++--
 mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp |  79 +++----
 .../Polynomial/IR/PolynomialAttributes.cpp    | 123 ++++++++---
 4 files changed, 308 insertions(+), 137 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 3325a6fa3f9fcf..5705deeadf7307 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -11,10 +11,13 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/Hashing.h"
-#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
 
@@ -27,98 +30,201 @@ namespace polynomial {
 /// would want to specify 128-bit polynomials statically in the source code.
 constexpr unsigned apintBitWidth = 64;
 
-/// A class representing a monomial of a single-variable polynomial with integer
-/// coefficients.
-class Monomial {
+template <typename CoefficientType>
+class MonomialBase {
 public:
-  Monomial(int64_t coeff, uint64_t expo)
-      : coefficient(apintBitWidth, coeff), exponent(apintBitWidth, expo) {}
-
-  Monomial(const APInt &coeff, const APInt &expo)
+  MonomialBase(const CoefficientType &coeff, const APInt &expo)
       : coefficient(coeff), exponent(expo) {}
+  virtual ~MonomialBase() = 0;
 
-  Monomial() : coefficient(apintBitWidth, 0), exponent(apintBitWidth, 0) {}
+  const CoefficientType &getCoefficient() const { return coefficient; }
+  CoefficientType &getMutableCoefficient() { return coefficient; }
+  const APInt &getExponent() const { return exponent; }
+  void setCoefficient(const CoefficientType &coeff) { coefficient = coeff; }
+  void setExponent(const APInt &exp) { exponent = exp; }
 
-  bool operator==(const Monomial &other) const {
+  bool operator==(const MonomialBase &other) const {
     return other.coefficient == coefficient && other.exponent == exponent;
   }
-  bool operator!=(const Monomial &other) const {
+  bool operator!=(const MonomialBase &other) const {
     return other.coefficient != coefficient || other.exponent != exponent;
   }
 
   /// Monomials are ordered by exponent.
-  bool operator<(const Monomial &other) const {
+  bool operator<(const MonomialBase &other) const {
     return (exponent.ult(other.exponent));
   }
 
-  friend ::llvm::hash_code hash_value(const Monomial &arg);
+  virtual bool isMonic() const = 0;
+  virtual void coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
 
-public:
-  APInt coefficient;
+  template <typename T>
+  friend ::llvm::hash_code hash_value(const MonomialBase<T> &arg);
 
-  // Always unsigned
+protected:
+  CoefficientType coefficient;
   APInt exponent;
 };
 
-/// A single-variable polynomial with integer coefficients.
-///
-/// Eg: x^1024 + x + 1
-///
-/// The symbols used as the polynomial's indeterminate don't matter, so long as
-/// it is used consistently throughout the polynomial.
-class Polynomial {
+/// A class representing a monomial of a single-variable polynomial with integer
+/// coefficients.
+class IntMonomial : public MonomialBase<APInt> {
 public:
-  Polynomial() = delete;
+  IntMonomial(int64_t coeff, uint64_t expo)
+      : MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
 
-  explicit Polynomial(ArrayRef<Monomial> terms) : terms(terms){};
+  IntMonomial()
+      : MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}
 
-  // Returns a Polynomial from a list of monomials.
-  // Fails if two monomials have the same exponent.
-  static FailureOr<Polynomial> fromMonomials(ArrayRef<Monomial> monomials);
+  ~IntMonomial() = default;
 
-  /// Returns a polynomial with coefficients given by `coeffs`. The value
-  /// coeffs[i] is converted to a monomial with exponent i.
-  static Polynomial fromCoefficients(ArrayRef<int64_t> coeffs);
+  bool isMonic() const override { return coefficient == 1; }
+
+  void coefficientToString(llvm::SmallString<16> &coeffString) const override {
+    coefficient.toStringSigned(coeffString);
+  }
+};
+
+/// A class representing a monomial of a single-variable polynomial with integer
+/// coefficients.
+class FloatMonomial : public MonomialBase<APFloat> {
+public:
+  FloatMonomial(double coeff, uint64_t expo)
+      : MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
+
+  FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}
+
+  ~FloatMonomial() = default;
+
+  bool isMonic() const override { return coefficient == APFloat(1.0); }
+
+  void coefficientToString(llvm::SmallString<16> &coeffString) const override {
+    coefficient.toString(coeffString);
+  }
+};
+
+template <typename Monomial>
+class PolynomialBase {
+public:
+  PolynomialBase() = delete;
+
+  explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms){};
 
   explicit operator bool() const { return !terms.empty(); }
-  bool operator==(const Polynomial &other) const {
+  bool operator==(const PolynomialBase &other) const {
     return other.terms == terms;
   }
-  bool operator!=(const Polynomial &other) const {
+  bool operator!=(const PolynomialBase &other) const {
     return !(other.terms == terms);
   }
 
-  // Prints polynomial to 'os'.
-  void print(raw_ostream &os) const;
   void print(raw_ostream &os, ::llvm::StringRef separator,
-             ::llvm::StringRef exponentiation) const;
+             ::llvm::StringRef exponentiation) const {
+    bool first = true;
+    for (const Monomial &term : getTerms()) {
+      if (first) {
+        first = false;
+      } else {
+        os << separator;
+      }
+      std::string coeffToPrint;
+      if (term.isMonic() && term.getExponent().uge(1)) {
+        coeffToPrint = "";
+      } else {
+        llvm::SmallString<16> coeffString;
+        term.coefficientToString(coeffString);
+        coeffToPrint = coeffString.str();
+      }
+
+      if (term.getExponent() == 0) {
+        os << coeffToPrint;
+      } else if (term.getExponent() == 1) {
+        os << coeffToPrint << "x";
+      } else {
+        llvm::SmallString<16> expString;
+        term.getExponent().toStringSigned(expString);
+        os << coeffToPrint << "x" << exponentiation << expString;
+      }
+    }
+  }
+
+  // Prints polynomial to 'os'.
+  void print(raw_ostream &os) const { print(os, " + ", "**"); }
+
   void dump() const;
 
   // Prints polynomial so that it can be used as a valid identifier
-  std::string toIdentifier() const;
+  std::string toIdentifier() const {
+    std::string result;
+    llvm::raw_string_ostream os(result);
+    print(os, "_", "");
+    return os.str();
+  }
 
-  unsigned getDegree() const;
+  unsigned getDegree() const {
+    return terms.back().getExponent().getZExtValue();
+  }
 
   ArrayRef<Monomial> getTerms() const { return terms; }
 
-  friend ::llvm::hash_code hash_value(const Polynomial &arg);
+  template <typename T>
+  friend ::llvm::hash_code hash_value(const PolynomialBase<T> &arg);
 
 private:
   // The monomial terms for this polynomial.
   SmallVector<Monomial> terms;
 };
 
-// Make Polynomial hashable.
-inline ::llvm::hash_code hash_value(const Polynomial &arg) {
+/// A single-variable polynomial with integer coefficients.
+///
+/// Eg: x^1024 + x + 1
+class IntPolynomial : public PolynomialBase<IntMonomial> {
+public:
+  explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
+
+  // Returns a Polynomial from a list of monomials.
+  // Fails if two monomials have the same exponent.
+  static FailureOr<IntPolynomial>
+  fromMonomials(ArrayRef<IntMonomial> monomials);
+
+  /// Returns a polynomial with coefficients given by `coeffs`. The value
+  /// coeffs[i] is converted to a monomial with exponent i.
+  static IntPolynomial fromCoefficients(ArrayRef<int64_t> coeffs);
+};
+
+/// A single-variable polynomial with double coefficients.
+///
+/// Eg: 1.0 x^1024 + 3.5 x + 1e-05
+class FloatPolynomial : public PolynomialBase<FloatMonomial> {
+public:
+  explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
+      : PolynomialBase(terms) {}
+
+  // Returns a Polynomial from a list of monomials.
+  // Fails if two monomials have the same exponent.
+  static FailureOr<FloatPolynomial>
+  fromMonomials(ArrayRef<FloatMonomial> monomials);
+
+  /// Returns a polynomial with coefficients given by `coeffs`. The value
+  /// coeffs[i] is converted to a monomial with exponent i.
+  static FloatPolynomial fromCoefficients(ArrayRef<double> coeffs);
+};
+
+// Make Polynomials hashable.
+template <typename T>
+inline ::llvm::hash_code hash_value(const PolynomialBase<T> &arg) {
   return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
 }
 
-inline ::llvm::hash_code hash_value(const Monomial &arg) {
+template <typename T>
+inline ::llvm::hash_code hash_value(const MonomialBase<T> &arg) {
   return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
                             ::llvm::hash_value(arg.exponent));
 }
 
-inline raw_ostream &operator<<(raw_ostream &os, const Polynomial &polynomial) {
+template <typename T>
+inline raw_ostream &operator<<(raw_ostream &os,
+                               const PolynomialBase<T> &polynomial) {
   polynomial.print(os);
   return os;
 }
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index ed1f4ce8b7e599..c55ede2a41af20 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -60,12 +60,12 @@ class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
   let mnemonic = attrMnemonic;
 }
 
-def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
-  let summary = "An attribute containing a single-variable polynomial.";
+def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynomial"> {
+  let summary = "An attribute containing a single-variable polynomial with integer coefficients.";
   let description = [{
-    A polynomial attribute represents a single-variable polynomial, which
-    is used to define the modulus of a `RingAttr`, as well as to define constants
-    and perform constant folding for `polynomial` ops.
+    A polynomial attribute represents a single-variable polynomial with integer
+    coefficients, which is used to define the modulus of a `RingAttr`, as well
+    as to define constants and perform constant folding for `polynomial` ops.
 
     The polynomial must be expressed as a list of monomial terms, with addition
     or subtraction between them. The choice of variable name is arbitrary, but
@@ -76,10 +76,33 @@ def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
     Example:
 
     ```mlir
-    #poly = #polynomial.polynomial<x**1024 + 1>
+    #poly = #polynomial.int_polynomial<x**1024 + 1>
     ```
   }];
-  let parameters = (ins "::mlir::polynomial::Polynomial":$polynomial);
+  let parameters = (ins "::mlir::polynomial::IntPolynomial":$polynomial);
+  let hasCustomAssemblyFormat = 1;
+}
+
+def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
+  let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
+  let description = [{
+    A polynomial attribute represents a single-variable polynomial with double
+    precision floating point coefficients.
+
+    The polynomial must be expressed as a list of monomial terms, with addition
+    or subtraction between them. The choice of variable name is arbitrary, but
+    must be consistent across all the monomials used to define a single
+    attribute. The order of monomial terms is arbitrary, each monomial degree
+    must occur at most once.
+
+    Example:
+
+    ```mlir
+    #poly = #polynomial.float_polynomial<0.5 x**7 + 1.5>
+    ```
+  }];
+  let parameters = (ins "FloatPolynomial":$polynomial);
+>>>>>>> 7b132b93b70 (refactor and support Float polynomials)
   let hasCustomAssemblyFormat = 1;
 }
 
@@ -123,7 +146,7 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
   let parameters = (ins
     "Type": $coefficientType,
     OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
-    OptionalParameter<"::mlir::polynomial::PolynomialAttr">: $polynomialModulus,
+    OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
     OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
   );
 
@@ -131,7 +154,7 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
     AttrBuilder<
         (ins "::mlir::Type":$coefficientTy,
              "::mlir::IntegerAttr":$coefficientModulusAttr,
-             "::mlir::polynomial::PolynomialAttr":$polynomialModulusAttr), [{
+             "::mlir::polynomial::IntPolynomialAttr":$polynomialModulusAttr), [{
       return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, nullptr);
     }]>
   ];
@@ -405,10 +428,14 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
   let arguments = (ins Polynomial_PolynomialType:$input);
   let results = (outs RankedTensorOf<[AnyInteger]>:$output);
   let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
-
   let hasVerifier = 1;
 }
 
+def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
+  Polynomial_FloatPolynomialAttr,
+  Polynomial_IntPolynomialAttr
+]>;
+
 def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
   let summary = "Define a constant polynomial via an attribute.";
   let description = [{
@@ -420,9 +447,9 @@ def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
     %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
     ```
   }];
-  let arguments = (ins Polynomial_PolynomialAttr:$input);
+  let arguments = (ins Polynomial_AnyPolynomialAttr:$input);
   let results = (outs Polynomial_PolynomialType:$output);
-  let assemblyFormat = "$input attr-dict `:` type($output)";
+  let assemblyFormat = "operands attr-dict `:` type($output)";
 }
 
 def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
diff --git a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
index 5916ffba78e246..9d0d38ba927e25 100644
--- a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
@@ -9,87 +9,64 @@
 #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
 
 #include "mlir/Support/LogicalResult.h"
-#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
 namespace polynomial {
 
-FailureOr<Polynomial> Polynomial::fromMonomials(ArrayRef<Monomial> monomials) {
+template <typename T>
+MonomialBase<T>::~MonomialBase() {}
+
+template <typename PolyT, typename MonomialT>
+FailureOr<PolyT> fromMonomialsImpl(ArrayRef<MonomialT> monomials) {
   // A polynomial's terms are canonically stored in order of increasing degree.
-  auto monomialsCopy = llvm::SmallVector<Monomial>(monomials);
+  auto monomialsCopy = llvm::SmallVector<MonomialT>(monomials);
   std::sort(monomialsCopy.begin(), monomialsCopy.end());
 
   // Ensure non-unique exponents are not present. Since we sorted the list by
   // exponent, a linear scan of adjancent monomials suffices.
   if (std::adjacent_find(monomialsCopy.begin(), monomialsCopy.end(),
-                         [](const Monomial &lhs, const Monomial &rhs) {
-                           return lhs.exponent == rhs.exponent;
+                         [](const MonomialT &lhs, const MonomialT &rhs) {
+                           return lhs.getExponent() == rhs.getExponent();
                          }) != monomialsCopy.end()) {
     return failure();
   }
 
-  return Polynomial(monomialsCopy);
+  return PolyT(monomialsCopy);
+}
+
+
+FailureOr<IntPolynomial>
+IntPolynomial::fromMonomials(ArrayRef<IntMonomial> monomials) {
+  return fromMonomialsImpl<IntPolynomial, IntMonomial>(monomials);
+}
+
+FailureOr<FloatPolynomial>
+FloatPolynomial::fromMonomials(ArrayRef<FloatMonomial> monomials) {
+  return fromMonomialsImpl<FloatPolynomial, FloatMonomial>(monomials);
 }
 
-Polynomial Polynomial::fromCoefficients(ArrayRef<int64_t> coeffs) {
-  llvm::SmallVector<Monomial> monomials;
+template <typename PolyT, typename MonomialT, typename CoeffT>
+PolyT fromCoefficientsImpl(ArrayRef<CoeffT> coeffs) {
+  llvm::SmallVector<MonomialT> monomials;
   auto size = coeffs.size();
   monomials.reserve(size);
   for (size_t i = 0; i < size; i++) {
     monomials.emplace_back(coeffs[i], i);
   }
-  auto result = Polynomial::fromMonomials(monomials);
+  auto result = PolyT::fromMonomials(monomials);
   // Construction guarantees unique exponents, so the failure mode of
   // fromMonomials can be bypassed.
   assert(succeeded(result));
   return result.value();
 }
 
-void Polynomial::print(raw_ostream &os, ::llvm::StringRef separator,
-                       ::llvm::StringRef exponentiation) const {
-  bool first = true;
-  for (const Monomial &term : terms) {
-    if (first) {
-      first = false;
-    } else {
-      os << separator;
-    }
-    std::string coeffToPrint;
-    if (term.coefficient == 1 && term.exponent.uge(1)) {
-      coeffToPrint = "";
-    } else {
-      llvm::SmallString<16> coeffString;
-      term.coefficient.toStringSigned(coeffString);
-      coeffToPrint = coeffString.str();
-    }
-
-    if (term.exponent == 0) {
-      os << coeffToPrint;
-    } else if (term.exponent == 1) {
-      os << coeffToPrint << "x";
-    } else {
-      llvm::SmallString<16> expString;
-      term.exponent.toStringSigned(expString);
-      os << coeffToPrint << "x" << exponentiation << expString;
-    }
-  }
-}
-
-void Polynomial::print(raw_ostream &os) const { print(os, " + ", "**"); }
-
-std::string Polynomial::toIdentifier() const {
-  std::string result;
-  llvm::raw_string_ostream os(result);
-  print(os, "_", "");
-  return os.str();
+IntPolynomial IntPolynomial::fromCoefficients(ArrayRef<int64_t> coeffs) {
+  return fromCoefficientsImpl<IntPolynomial, IntMonomial, int64_t>(coeffs);
 }
 
-unsigned Polynomial::getDegree() const {
-  return terms.back().exponent.getZExtValue();
+FloatPolynomial FloatPolynomial::fromCoefficients(ArrayRef<double> coeffs) {
+  return fromCoefficientsImpl<FloatPolynomial, FloatMonomial, double>(coeffs);
 }
 
 } // namespace polynomial
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index 236bb789663529..c91ad9b979879c 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSet.h"
@@ -17,22 +18,33 @@
 namespace mlir {
 namespace polynomial {
 
-void PolynomialAttr::print(AsmPrinter &p) const {
-  p << '<';
-  p << getPolynomial();
-  p << '>';
+void IntPolynomialAttr::print(AsmPrinter &p) const {
+  p << '<' << getPolynomial() << '>';
 }
 
+void FloatPolynomialAttr::print(AsmPrinter &p) const {
+  p << '<' << getPolynomial() << '>';
+}
+
+/// A callable that parses the coefficient using the appropriate method for the
+/// given monomial type, and stores the parsed coefficient value on the
+/// monomial.
+template <typename CoefficientType>
+using ParseCoefficientFn =
+    std::function<OptionalParseResult(CoefficientType &)>;
+
 /// Try to parse a monomial. If successful, populate the fields of the outparam
 /// `monomial` with the results, and the `variable` outparam with the parsed
 /// variable name. Sets shouldParseMore to true if the monomial is followed by
 /// a '+'.
-ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
-                          llvm::StringRef &variable, bool &isConstantTerm,
-                          bool &shouldParseMore) {
-  APInt parsedCoeff(apintBitWidth, 1);
-  auto parsedCoeffResult = parser.parseOptionalInteger(parsedCoeff);
-  monomial.coefficient = parsedCoeff;
+///
+template <typename Monomial, typename CoefficientType>
+ParseResult
+parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
+              bool &isConstantTerm, bool &shouldParseMore,
+              ParseCoefficientFn<CoefficientType> parseAndStoreCoefficient) {
+  OptionalParseResult parsedCoeffResult =
+      parseAndStoreCoefficient(monomial.getMutableCoefficient());
 
   isConstantTerm = false;
   shouldParseMore = false;
@@ -44,7 +56,7 @@ ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
     if (!parsedCoeffResult.has_value()) {
       return failure();
     }
-    monomial.exponent = APInt(apintBitWidth, 0);
+    monomial.setExponent(APInt(apintBitWidth, 0));
     isConstantTerm = true;
     shouldParseMore = true;
     return success();
@@ -58,7 +70,7 @@ ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
       return failure();
     }
 
-    monomial.exponent = APInt(apintBitWidth, 0);
+    monomial.setExponent(APInt(apintBitWidth, 0));
     isConstantTerm = true;
     return success();
   }
@@ -80,9 +92,9 @@ ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
       return failure();
     }
 
-    monomial.exponent = parsedExponent;
+    monomial.setExponent(parsedExponent);
   } else {
-    monomial.exponent = APInt(apintBitWidth, 1);
+    monomial.setExponent(APInt(apintBitWidth, 1));
   }
 
   if (succeeded(parser.parseOptionalPlus())) {
@@ -91,22 +103,21 @@ ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
   return success();
 }
 
-Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
-  if (failed(parser.parseLess()))
-    return {};
-
-  llvm::SmallVector<Monomial> monomials;
-  llvm::StringSet<> variables;
-
+template <typename PolynoimalAttrTy, typename Monomial, typename CoefficientTy>
+LogicalResult parsePolynomialAttr(
+    AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
+    llvm::StringSet<> &variables,
+    ParseCoefficientFn<CoefficientTy> parseAndStoreCoefficient) {
   while (true) {
     Monomial parsedMonomial;
     llvm::StringRef parsedVariableRef;
     bool isConstantTerm;
     bool shouldParseMore;
-    if (failed(parseMonomial(parser, parsedMonomial, parsedVariableRef,
-                             isConstantTerm, shouldParseMore))) {
+    if (failed(parseMonomial<Monomial, CoefficientTy>(
+            parser, parsedMonomial, parsedVariableRef, isConstantTerm,
+            shouldParseMore, parseAndStoreCoefficient))) {
       parser.emitError(parser.getCurrentLocation(), "expected a monomial");
-      return {};
+      return failure();
     }
 
     if (!isConstantTerm) {
@@ -124,7 +135,7 @@ Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
     parser.emitError(
         parser.getCurrentLocation(),
         "expected + and more monomials, or > to end polynomial attribute");
-    return {};
+    return failure();
   }
 
   if (variables.size() > 1) {
@@ -133,15 +144,65 @@ Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
         parser.getCurrentLocation(),
         "polynomials must have one indeterminate, but there were multiple: " +
             vars);
+    return failure();
+  }
+
+  return success();
+}
+
+Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) {
+  if (failed(parser.parseLess()))
+    return {};
+
+  llvm::SmallVector<IntMonomial> monomials;
+  llvm::StringSet<> variables;
+
+  if (failed(parsePolynomialAttr<IntPolynomialAttr, IntMonomial, APInt>(
+          parser, monomials, variables,
+          [&](APInt &coeff) -> OptionalParseResult {
+            return parser.parseOptionalInteger(coeff);
+          }))) {
+    return {};
+  }
+
+  auto result = IntPolynomial::fromMonomials(monomials);
+  if (failed(result)) {
+    parser.emitError(parser.getCurrentLocation())
+        << "parsed polynomial must have unique exponents among monomials";
+    return {};
+  }
+  return IntPolynomialAttr::get(parser.getContext(), result.value());
+}
+
+Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
+  if (failed(parser.parseLess()))
+    return {};
+
+  llvm::SmallVector<FloatMonomial> monomials;
+  llvm::StringSet<> variables;
+
+  ParseCoefficientFn<APFloat> parseAndStoreCoefficient =
+      [&](APFloat &coeff) -> OptionalParseResult {
+    double coeffValue;
+    ParseResult result = parser.parseFloat(coeffValue);
+    if (succeeded(result)) {
+      coeff = APFloat(coeffValue);
+    }
+    return OptionalParseResult(result);
+  };
+
+  if (failed(parsePolynomialAttr<FloatPolynomialAttr, FloatMonomial, APFloat>(
+          parser, monomials, variables, parseAndStoreCoefficient))) {
+    return {};
   }
 
-  auto result = Polynomial::fromMonomials(monomials);
+  auto result = FloatPolynomial::fromMonomials(monomials);
   if (failed(result)) {
     parser.emitError(parser.getCurrentLocation())
         << "parsed polynomial must have unique exponents among monomials";
     return {};
   }
-  return PolynomialAttr::get(parser.getContext(), result.value());
+  return FloatPolynomialAttr::get(parser.getContext(), result.value());
 }
 
 void RingAttr::print(AsmPrinter &p) const {
@@ -191,18 +252,18 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
       return {};
   }
 
-  PolynomialAttr polyAttr = nullptr;
+  IntPolynomialAttr polyAttr = nullptr;
   if (succeeded(parser.parseKeyword("polynomialModulus"))) {
     if (failed(parser.parseEqual()))
       return {};
 
-    PolynomialAttr attr;
-    if (failed(parser.parseAttribute<PolynomialAttr>(attr)))
+    IntPolynomialAttr attr;
+    if (failed(parser.parseAttribute<IntPolynomialAttr>(attr)))
       return {};
     polyAttr = attr;
   }
 
-  Polynomial poly = polyAttr.getPolynomial();
+  IntPolynomial poly = polyAttr.getPolynomial();
   APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
   IntegerAttr rootAttr = nullptr;
   if (succeeded(parser.parseOptionalComma())) {



More information about the Mlir-commits mailing list