[Mlir-commits] [llvm] [mlir] Add constant propagation for polynomial ops (PR #91655)

Jeremy Kun llvmlistbot at llvm.org
Thu May 9 14:07:55 PDT 2024


https://github.com/j2kun created https://github.com/llvm/llvm-project/pull/91655

Rebased over 

- https://github.com/llvm/llvm-project/pull/91410
- https://github.com/llvm/llvm-project/pull/91137 

>From 10cfbcf6107d12ed084fb446af32fa430e653cee Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Fri, 3 May 2024 17:27:07 -0700
Subject: [PATCH 1/9] Support polynomials with float coefficients

Also involves significant refactoring of the attribute parser/printer
---
 .../mlir/Dialect/Polynomial/IR/Polynomial.h   | 192 ++++++++++++++----
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  |  88 +++++---
 mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp |  79 +++----
 .../Polynomial/IR/PolynomialAttributes.cpp    | 172 +++++++---------
 mlir/test/Dialect/Polynomial/attributes.mlir  |  22 +-
 mlir/test/Dialect/Polynomial/ops.mlir         |  22 +-
 mlir/test/Dialect/Polynomial/ops_errors.mlir  |  46 ++---
 mlir/test/Dialect/Polynomial/types.mlir       |  49 +++--
 8 files changed, 383 insertions(+), 287 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 3325a6fa3f9fc..5705deeadf730 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -11,10 +11,13 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/Hashing.h"
-#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
 
@@ -27,98 +30,201 @@ namespace polynomial {
 /// would want to specify 128-bit polynomials statically in the source code.
 constexpr unsigned apintBitWidth = 64;
 
-/// A class representing a monomial of a single-variable polynomial with integer
-/// coefficients.
-class Monomial {
+template <typename CoefficientType>
+class MonomialBase {
 public:
-  Monomial(int64_t coeff, uint64_t expo)
-      : coefficient(apintBitWidth, coeff), exponent(apintBitWidth, expo) {}
-
-  Monomial(const APInt &coeff, const APInt &expo)
+  MonomialBase(const CoefficientType &coeff, const APInt &expo)
       : coefficient(coeff), exponent(expo) {}
+  virtual ~MonomialBase() = 0;
 
-  Monomial() : coefficient(apintBitWidth, 0), exponent(apintBitWidth, 0) {}
+  const CoefficientType &getCoefficient() const { return coefficient; }
+  CoefficientType &getMutableCoefficient() { return coefficient; }
+  const APInt &getExponent() const { return exponent; }
+  void setCoefficient(const CoefficientType &coeff) { coefficient = coeff; }
+  void setExponent(const APInt &exp) { exponent = exp; }
 
-  bool operator==(const Monomial &other) const {
+  bool operator==(const MonomialBase &other) const {
     return other.coefficient == coefficient && other.exponent == exponent;
   }
-  bool operator!=(const Monomial &other) const {
+  bool operator!=(const MonomialBase &other) const {
     return other.coefficient != coefficient || other.exponent != exponent;
   }
 
   /// Monomials are ordered by exponent.
-  bool operator<(const Monomial &other) const {
+  bool operator<(const MonomialBase &other) const {
     return (exponent.ult(other.exponent));
   }
 
-  friend ::llvm::hash_code hash_value(const Monomial &arg);
+  virtual bool isMonic() const = 0;
+  virtual void coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
 
-public:
-  APInt coefficient;
+  template <typename T>
+  friend ::llvm::hash_code hash_value(const MonomialBase<T> &arg);
 
-  // Always unsigned
+protected:
+  CoefficientType coefficient;
   APInt exponent;
 };
 
-/// A single-variable polynomial with integer coefficients.
-///
-/// Eg: x^1024 + x + 1
-///
-/// The symbols used as the polynomial's indeterminate don't matter, so long as
-/// it is used consistently throughout the polynomial.
-class Polynomial {
+/// A class representing a monomial of a single-variable polynomial with integer
+/// coefficients.
+class IntMonomial : public MonomialBase<APInt> {
 public:
-  Polynomial() = delete;
+  IntMonomial(int64_t coeff, uint64_t expo)
+      : MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
 
-  explicit Polynomial(ArrayRef<Monomial> terms) : terms(terms){};
+  IntMonomial()
+      : MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}
 
-  // Returns a Polynomial from a list of monomials.
-  // Fails if two monomials have the same exponent.
-  static FailureOr<Polynomial> fromMonomials(ArrayRef<Monomial> monomials);
+  ~IntMonomial() = default;
 
-  /// Returns a polynomial with coefficients given by `coeffs`. The value
-  /// coeffs[i] is converted to a monomial with exponent i.
-  static Polynomial fromCoefficients(ArrayRef<int64_t> coeffs);
+  bool isMonic() const override { return coefficient == 1; }
+
+  void coefficientToString(llvm::SmallString<16> &coeffString) const override {
+    coefficient.toStringSigned(coeffString);
+  }
+};
+
+/// A class representing a monomial of a single-variable polynomial with integer
+/// coefficients.
+class FloatMonomial : public MonomialBase<APFloat> {
+public:
+  FloatMonomial(double coeff, uint64_t expo)
+      : MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
+
+  FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}
+
+  ~FloatMonomial() = default;
+
+  bool isMonic() const override { return coefficient == APFloat(1.0); }
+
+  void coefficientToString(llvm::SmallString<16> &coeffString) const override {
+    coefficient.toString(coeffString);
+  }
+};
+
+template <typename Monomial>
+class PolynomialBase {
+public:
+  PolynomialBase() = delete;
+
+  explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms){};
 
   explicit operator bool() const { return !terms.empty(); }
-  bool operator==(const Polynomial &other) const {
+  bool operator==(const PolynomialBase &other) const {
     return other.terms == terms;
   }
-  bool operator!=(const Polynomial &other) const {
+  bool operator!=(const PolynomialBase &other) const {
     return !(other.terms == terms);
   }
 
-  // Prints polynomial to 'os'.
-  void print(raw_ostream &os) const;
   void print(raw_ostream &os, ::llvm::StringRef separator,
-             ::llvm::StringRef exponentiation) const;
+             ::llvm::StringRef exponentiation) const {
+    bool first = true;
+    for (const Monomial &term : getTerms()) {
+      if (first) {
+        first = false;
+      } else {
+        os << separator;
+      }
+      std::string coeffToPrint;
+      if (term.isMonic() && term.getExponent().uge(1)) {
+        coeffToPrint = "";
+      } else {
+        llvm::SmallString<16> coeffString;
+        term.coefficientToString(coeffString);
+        coeffToPrint = coeffString.str();
+      }
+
+      if (term.getExponent() == 0) {
+        os << coeffToPrint;
+      } else if (term.getExponent() == 1) {
+        os << coeffToPrint << "x";
+      } else {
+        llvm::SmallString<16> expString;
+        term.getExponent().toStringSigned(expString);
+        os << coeffToPrint << "x" << exponentiation << expString;
+      }
+    }
+  }
+
+  // Prints polynomial to 'os'.
+  void print(raw_ostream &os) const { print(os, " + ", "**"); }
+
   void dump() const;
 
   // Prints polynomial so that it can be used as a valid identifier
-  std::string toIdentifier() const;
+  std::string toIdentifier() const {
+    std::string result;
+    llvm::raw_string_ostream os(result);
+    print(os, "_", "");
+    return os.str();
+  }
 
-  unsigned getDegree() const;
+  unsigned getDegree() const {
+    return terms.back().getExponent().getZExtValue();
+  }
 
   ArrayRef<Monomial> getTerms() const { return terms; }
 
-  friend ::llvm::hash_code hash_value(const Polynomial &arg);
+  template <typename T>
+  friend ::llvm::hash_code hash_value(const PolynomialBase<T> &arg);
 
 private:
   // The monomial terms for this polynomial.
   SmallVector<Monomial> terms;
 };
 
-// Make Polynomial hashable.
-inline ::llvm::hash_code hash_value(const Polynomial &arg) {
+/// A single-variable polynomial with integer coefficients.
+///
+/// Eg: x^1024 + x + 1
+class IntPolynomial : public PolynomialBase<IntMonomial> {
+public:
+  explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
+
+  // Returns a Polynomial from a list of monomials.
+  // Fails if two monomials have the same exponent.
+  static FailureOr<IntPolynomial>
+  fromMonomials(ArrayRef<IntMonomial> monomials);
+
+  /// Returns a polynomial with coefficients given by `coeffs`. The value
+  /// coeffs[i] is converted to a monomial with exponent i.
+  static IntPolynomial fromCoefficients(ArrayRef<int64_t> coeffs);
+};
+
+/// A single-variable polynomial with double coefficients.
+///
+/// Eg: 1.0 x^1024 + 3.5 x + 1e-05
+class FloatPolynomial : public PolynomialBase<FloatMonomial> {
+public:
+  explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
+      : PolynomialBase(terms) {}
+
+  // Returns a Polynomial from a list of monomials.
+  // Fails if two monomials have the same exponent.
+  static FailureOr<FloatPolynomial>
+  fromMonomials(ArrayRef<FloatMonomial> monomials);
+
+  /// Returns a polynomial with coefficients given by `coeffs`. The value
+  /// coeffs[i] is converted to a monomial with exponent i.
+  static FloatPolynomial fromCoefficients(ArrayRef<double> coeffs);
+};
+
+// Make Polynomials hashable.
+template <typename T>
+inline ::llvm::hash_code hash_value(const PolynomialBase<T> &arg) {
   return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
 }
 
-inline ::llvm::hash_code hash_value(const Monomial &arg) {
+template <typename T>
+inline ::llvm::hash_code hash_value(const MonomialBase<T> &arg) {
   return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
                             ::llvm::hash_value(arg.exponent));
 }
 
-inline raw_ostream &operator<<(raw_ostream &os, const Polynomial &polynomial) {
+template <typename T>
+inline raw_ostream &operator<<(raw_ostream &os,
+                               const PolynomialBase<T> &polynomial) {
   polynomial.print(os);
   return os;
 }
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index ed1f4ce8b7e59..08ab1f1811eed 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -46,7 +46,7 @@ def Polynomial_Dialect : Dialect {
     // A constant polynomial in a ring with i32 coefficients, with a polynomial
     // modulus of (x^1024 + 1) and a coefficient modulus of 17.
     #modulus = #polynomial.polynomial<1 + x**1024>
-    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=17, polynomialModulus=#modulus>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=17:i32, polynomialModulus=#modulus>
     %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
     ```
   }];
@@ -60,12 +60,12 @@ class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
   let mnemonic = attrMnemonic;
 }
 
-def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
-  let summary = "An attribute containing a single-variable polynomial.";
+def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynomial"> {
+  let summary = "An attribute containing a single-variable polynomial with integer coefficients.";
   let description = [{
-    A polynomial attribute represents a single-variable polynomial, which
-    is used to define the modulus of a `RingAttr`, as well as to define constants
-    and perform constant folding for `polynomial` ops.
+    A polynomial attribute represents a single-variable polynomial with integer
+    coefficients, which is used to define the modulus of a `RingAttr`, as well
+    as to define constants and perform constant folding for `polynomial` ops.
 
     The polynomial must be expressed as a list of monomial terms, with addition
     or subtraction between them. The choice of variable name is arbitrary, but
@@ -76,10 +76,32 @@ def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
     Example:
 
     ```mlir
-    #poly = #polynomial.polynomial<x**1024 + 1>
+    #poly = #polynomial.int_polynomial<x**1024 + 1>
     ```
   }];
-  let parameters = (ins "::mlir::polynomial::Polynomial":$polynomial);
+  let parameters = (ins "::mlir::polynomial::IntPolynomial":$polynomial);
+  let hasCustomAssemblyFormat = 1;
+}
+
+def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
+  let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
+  let description = [{
+    A polynomial attribute represents a single-variable polynomial with double
+    precision floating point coefficients.
+
+    The polynomial must be expressed as a list of monomial terms, with addition
+    or subtraction between them. The choice of variable name is arbitrary, but
+    must be consistent across all the monomials used to define a single
+    attribute. The order of monomial terms is arbitrary, each monomial degree
+    must occur at most once.
+
+    Example:
+
+    ```mlir
+    #poly = #polynomial.float_polynomial<0.5 x**7 + 1.5>
+    ```
+  }];
+  let parameters = (ins "FloatPolynomial":$polynomial);
   let hasCustomAssemblyFormat = 1;
 }
 
@@ -106,7 +128,7 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
     ```mlir
     #poly_mod = #polynomial.polynomial<-1 + x**1024>
     #ring = #polynomial.ring<coefficientType=i32,
-                             coefficientModulus=4294967291,
+                             coefficientModulus=4294967291:i32,
                              polynomialModulus=#poly_mod>
 
     %0 = ... : polynomial.polynomial<#ring>
@@ -123,19 +145,20 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
   let parameters = (ins
     "Type": $coefficientType,
     OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
-    OptionalParameter<"::mlir::polynomial::PolynomialAttr">: $polynomialModulus,
+    OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
     OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
   );
-
+  let assemblyFormat = "`<` struct(params) `>`";
   let builders = [
     AttrBuilder<
         (ins "::mlir::Type":$coefficientTy,
-             "::mlir::IntegerAttr":$coefficientModulusAttr,
-             "::mlir::polynomial::PolynomialAttr":$polynomialModulusAttr), [{
-      return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, nullptr);
-    }]>
+              CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
+              CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
+              CArg<"::mlir::IntegerAttr", "nullptr"> :$primitiveRootAttr), [{
+      return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, primitiveRootAttr);
+    }]>,
   ];
-  let hasCustomAssemblyFormat = 1;
+  let skipDefaultBuilders = 1;
 }
 
 class Polynomial_Type<string name, string typeMnemonic>
@@ -149,7 +172,7 @@ def Polynomial_PolynomialType : Polynomial_Type<"Polynomial", "polynomial"> {
     A type for polynomials in a polynomial quotient ring.
   }];
   let parameters = (ins Polynomial_RingAttr:$ring);
-  let assemblyFormat = "`<` $ring `>`";
+  let assemblyFormat = "`<` qualified($ring) `>`";
 }
 
 def PolynomialLike: TypeOrContainer<Polynomial_PolynomialType, "polynomial-like">;
@@ -188,7 +211,7 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
     ```mlir
     // add two polynomials modulo x^1024 - 1
     #poly = #polynomial.polynomial<x**1024 - 1>
-    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
     %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
     %1 = polynomial.constant #polynomial.polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
     %2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
@@ -212,7 +235,7 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
     ```mlir
     // subtract two polynomials modulo x^1024 - 1
     #poly = #polynomial.polynomial<x**1024 - 1>
-    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
     %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
     %1 = polynomial.constant #polynomial.polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
     %2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
@@ -236,7 +259,7 @@ def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
     ```mlir
     // multiply two polynomials modulo x^1024 - 1
     #poly = #polynomial.polynomial<x**1024 - 1>
-    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
     %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
     %1 = polynomial.constant #polynomial.polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
     %2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring>
@@ -261,7 +284,7 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
     ```mlir
     // multiply two polynomials modulo x^1024 - 1
     #poly = #polynomial.polynomial<x**1024 - 1>
-    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
     %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
     %1 = arith.constant 3 : i32
     %2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32
@@ -292,7 +315,7 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
 
     ```mlir
     #poly = #polynomial.polynomial<x**1024 - 1>
-    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
     %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
     %1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32)
     ```
@@ -315,7 +338,7 @@ def Polynomial_MonomialOp: Polynomial_Op<"monomial"> {
 
     ```mlir
     #poly = #polynomial.polynomial<x**1024 - 1>
-    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
     %deg = arith.constant 1023 : index
     %five = arith.constant 5 : i32
     %0 = polynomial.monomial %five, %deg : (i32, index) -> !polynomial.polynomial<#ring>
@@ -355,7 +378,7 @@ def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> {
 
     ```mlir
     #poly = #polynomial.polynomial<x**1024 - 1>
-    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
     %two = arith.constant 2 : i32
     %five = arith.constant 5 : i32
     %coeffs = tensor.from_elements %two, %two, %five : tensor<3xi32>
@@ -394,7 +417,7 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
 
     ```mlir
     #poly = #polynomial.polynomial<x**1024 - 1>
-    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
     %two = arith.constant 2 : i32
     %five = arith.constant 5 : i32
     %coeffs = tensor.from_elements %two, %two, %five : tensor<3xi32>
@@ -405,24 +428,29 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
   let arguments = (ins Polynomial_PolynomialType:$input);
   let results = (outs RankedTensorOf<[AnyInteger]>:$output);
   let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
-
   let hasVerifier = 1;
 }
 
-def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
+def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
+  Polynomial_FloatPolynomialAttr,
+  Polynomial_IntPolynomialAttr
+]>;
+
+// Not deriving from Polynomial_Op due to need for custom assembly format
+def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
   let summary = "Define a constant polynomial via an attribute.";
   let description = [{
     Example:
 
     ```mlir
     #poly = #polynomial.polynomial<x**1024 - 1>
-    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
     %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
     ```
   }];
-  let arguments = (ins Polynomial_PolynomialAttr:$input);
+  let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
   let results = (outs Polynomial_PolynomialType:$output);
-  let assemblyFormat = "$input attr-dict `:` type($output)";
+  let assemblyFormat = "attr-dict `:` type($output)";
 }
 
 def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
diff --git a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
index 5916ffba78e24..9d0d38ba927e2 100644
--- a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
@@ -9,87 +9,64 @@
 #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
 
 #include "mlir/Support/LogicalResult.h"
-#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
 namespace polynomial {
 
-FailureOr<Polynomial> Polynomial::fromMonomials(ArrayRef<Monomial> monomials) {
+template <typename T>
+MonomialBase<T>::~MonomialBase() {}
+
+template <typename PolyT, typename MonomialT>
+FailureOr<PolyT> fromMonomialsImpl(ArrayRef<MonomialT> monomials) {
   // A polynomial's terms are canonically stored in order of increasing degree.
-  auto monomialsCopy = llvm::SmallVector<Monomial>(monomials);
+  auto monomialsCopy = llvm::SmallVector<MonomialT>(monomials);
   std::sort(monomialsCopy.begin(), monomialsCopy.end());
 
   // Ensure non-unique exponents are not present. Since we sorted the list by
   // exponent, a linear scan of adjancent monomials suffices.
   if (std::adjacent_find(monomialsCopy.begin(), monomialsCopy.end(),
-                         [](const Monomial &lhs, const Monomial &rhs) {
-                           return lhs.exponent == rhs.exponent;
+                         [](const MonomialT &lhs, const MonomialT &rhs) {
+                           return lhs.getExponent() == rhs.getExponent();
                          }) != monomialsCopy.end()) {
     return failure();
   }
 
-  return Polynomial(monomialsCopy);
+  return PolyT(monomialsCopy);
+}
+
+
+FailureOr<IntPolynomial>
+IntPolynomial::fromMonomials(ArrayRef<IntMonomial> monomials) {
+  return fromMonomialsImpl<IntPolynomial, IntMonomial>(monomials);
+}
+
+FailureOr<FloatPolynomial>
+FloatPolynomial::fromMonomials(ArrayRef<FloatMonomial> monomials) {
+  return fromMonomialsImpl<FloatPolynomial, FloatMonomial>(monomials);
 }
 
-Polynomial Polynomial::fromCoefficients(ArrayRef<int64_t> coeffs) {
-  llvm::SmallVector<Monomial> monomials;
+template <typename PolyT, typename MonomialT, typename CoeffT>
+PolyT fromCoefficientsImpl(ArrayRef<CoeffT> coeffs) {
+  llvm::SmallVector<MonomialT> monomials;
   auto size = coeffs.size();
   monomials.reserve(size);
   for (size_t i = 0; i < size; i++) {
     monomials.emplace_back(coeffs[i], i);
   }
-  auto result = Polynomial::fromMonomials(monomials);
+  auto result = PolyT::fromMonomials(monomials);
   // Construction guarantees unique exponents, so the failure mode of
   // fromMonomials can be bypassed.
   assert(succeeded(result));
   return result.value();
 }
 
-void Polynomial::print(raw_ostream &os, ::llvm::StringRef separator,
-                       ::llvm::StringRef exponentiation) const {
-  bool first = true;
-  for (const Monomial &term : terms) {
-    if (first) {
-      first = false;
-    } else {
-      os << separator;
-    }
-    std::string coeffToPrint;
-    if (term.coefficient == 1 && term.exponent.uge(1)) {
-      coeffToPrint = "";
-    } else {
-      llvm::SmallString<16> coeffString;
-      term.coefficient.toStringSigned(coeffString);
-      coeffToPrint = coeffString.str();
-    }
-
-    if (term.exponent == 0) {
-      os << coeffToPrint;
-    } else if (term.exponent == 1) {
-      os << coeffToPrint << "x";
-    } else {
-      llvm::SmallString<16> expString;
-      term.exponent.toStringSigned(expString);
-      os << coeffToPrint << "x" << exponentiation << expString;
-    }
-  }
-}
-
-void Polynomial::print(raw_ostream &os) const { print(os, " + ", "**"); }
-
-std::string Polynomial::toIdentifier() const {
-  std::string result;
-  llvm::raw_string_ostream os(result);
-  print(os, "_", "");
-  return os.str();
+IntPolynomial IntPolynomial::fromCoefficients(ArrayRef<int64_t> coeffs) {
+  return fromCoefficientsImpl<IntPolynomial, IntMonomial, int64_t>(coeffs);
 }
 
-unsigned Polynomial::getDegree() const {
-  return terms.back().exponent.getZExtValue();
+FloatPolynomial FloatPolynomial::fromCoefficients(ArrayRef<double> coeffs) {
+  return fromCoefficientsImpl<FloatPolynomial, FloatMonomial, double>(coeffs);
 }
 
 } // namespace polynomial
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index 236bb78966352..b5f674c98d835 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSet.h"
@@ -17,22 +18,31 @@
 namespace mlir {
 namespace polynomial {
 
-void PolynomialAttr::print(AsmPrinter &p) const {
-  p << '<';
-  p << getPolynomial();
-  p << '>';
+void IntPolynomialAttr::print(AsmPrinter &p) const {
+  p << '<' << getPolynomial() << '>';
 }
 
+void FloatPolynomialAttr::print(AsmPrinter &p) const {
+  p << '<' << getPolynomial() << '>';
+}
+
+/// A callable that parses the coefficient using the appropriate method for the
+/// given monomial type, and stores the parsed coefficient value on the
+/// monomial.
+template <typename MonomialType>
+using ParseCoefficientFn = std::function<OptionalParseResult(MonomialType &)>;
+
 /// Try to parse a monomial. If successful, populate the fields of the outparam
 /// `monomial` with the results, and the `variable` outparam with the parsed
 /// variable name. Sets shouldParseMore to true if the monomial is followed by
 /// a '+'.
-ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
-                          llvm::StringRef &variable, bool &isConstantTerm,
-                          bool &shouldParseMore) {
-  APInt parsedCoeff(apintBitWidth, 1);
-  auto parsedCoeffResult = parser.parseOptionalInteger(parsedCoeff);
-  monomial.coefficient = parsedCoeff;
+///
+template <typename Monomial>
+ParseResult
+parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
+              bool &isConstantTerm, bool &shouldParseMore,
+              ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
+  OptionalParseResult parsedCoeffResult = parseAndStoreCoefficient(monomial);
 
   isConstantTerm = false;
   shouldParseMore = false;
@@ -44,7 +54,7 @@ ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
     if (!parsedCoeffResult.has_value()) {
       return failure();
     }
-    monomial.exponent = APInt(apintBitWidth, 0);
+    monomial.setExponent(APInt(apintBitWidth, 0));
     isConstantTerm = true;
     shouldParseMore = true;
     return success();
@@ -58,7 +68,7 @@ ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
       return failure();
     }
 
-    monomial.exponent = APInt(apintBitWidth, 0);
+    monomial.setExponent(APInt(apintBitWidth, 0));
     isConstantTerm = true;
     return success();
   }
@@ -80,9 +90,9 @@ ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
       return failure();
     }
 
-    monomial.exponent = parsedExponent;
+    monomial.setExponent(parsedExponent);
   } else {
-    monomial.exponent = APInt(apintBitWidth, 1);
+    monomial.setExponent(APInt(apintBitWidth, 1));
   }
 
   if (succeeded(parser.parseOptionalPlus())) {
@@ -91,22 +101,21 @@ ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
   return success();
 }
 
-Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
-  if (failed(parser.parseLess()))
-    return {};
-
-  llvm::SmallVector<Monomial> monomials;
-  llvm::StringSet<> variables;
-
+template <typename PolynoimalAttrTy, typename Monomial>
+LogicalResult parsePolynomialAttr(
+    AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
+    llvm::StringSet<> &variables,
+    ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
   while (true) {
     Monomial parsedMonomial;
     llvm::StringRef parsedVariableRef;
     bool isConstantTerm;
     bool shouldParseMore;
-    if (failed(parseMonomial(parser, parsedMonomial, parsedVariableRef,
-                             isConstantTerm, shouldParseMore))) {
+    if (failed(parseMonomial<Monomial>(
+            parser, parsedMonomial, parsedVariableRef, isConstantTerm,
+            shouldParseMore, parseAndStoreCoefficient))) {
       parser.emitError(parser.getCurrentLocation(), "expected a monomial");
-      return {};
+      return failure();
     }
 
     if (!isConstantTerm) {
@@ -124,7 +133,7 @@ Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
     parser.emitError(
         parser.getCurrentLocation(),
         "expected + and more monomials, or > to end polynomial attribute");
-    return {};
+    return failure();
   }
 
   if (variables.size() > 1) {
@@ -133,96 +142,67 @@ Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
         parser.getCurrentLocation(),
         "polynomials must have one indeterminate, but there were multiple: " +
             vars);
+    return failure();
   }
 
-  auto result = Polynomial::fromMonomials(monomials);
-  if (failed(result)) {
-    parser.emitError(parser.getCurrentLocation())
-        << "parsed polynomial must have unique exponents among monomials";
-    return {};
-  }
-  return PolynomialAttr::get(parser.getContext(), result.value());
-}
-
-void RingAttr::print(AsmPrinter &p) const {
-  p << "#polynomial.ring<coefficientType=" << getCoefficientType()
-    << ", coefficientModulus=" << getCoefficientModulus()
-    << ", polynomialModulus=" << getPolynomialModulus() << '>';
+  return success();
 }
 
-Attribute RingAttr::parse(AsmParser &parser, Type type) {
+Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) {
   if (failed(parser.parseLess()))
     return {};
 
-  if (failed(parser.parseKeyword("coefficientType")))
-    return {};
+  llvm::SmallVector<IntMonomial> monomials;
+  llvm::StringSet<> variables;
 
-  if (failed(parser.parseEqual()))
+  if (failed(parsePolynomialAttr<IntPolynomialAttr, IntMonomial>(
+          parser, monomials, variables,
+          [&](IntMonomial &monomial) -> OptionalParseResult {
+            APInt parsedCoeff(apintBitWidth, 1);
+            OptionalParseResult result =
+                parser.parseOptionalInteger(parsedCoeff);
+            monomial.setCoefficient(parsedCoeff);
+            return result;
+          }))) {
     return {};
+  }
 
-  Type ty;
-  if (failed(parser.parseType(ty)))
+  auto result = IntPolynomial::fromMonomials(monomials);
+  if (failed(result)) {
+    parser.emitError(parser.getCurrentLocation())
+        << "parsed polynomial must have unique exponents among monomials";
     return {};
+  }
+  return IntPolynomialAttr::get(parser.getContext(), result.value());
+}
 
-  if (failed(parser.parseComma()))
+Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
+  if (failed(parser.parseLess()))
     return {};
 
-  IntegerAttr coefficientModulusAttr = nullptr;
-  if (succeeded(parser.parseKeyword("coefficientModulus"))) {
-    if (failed(parser.parseEqual()))
-      return {};
-
-    IntegerType iType = mlir::dyn_cast<IntegerType>(ty);
-    if (!iType) {
-      parser.emitError(parser.getCurrentLocation(),
-                       "coefficientType must specify an integer type");
-      return {};
-    }
-    APInt coefficientModulus(iType.getWidth(), 0);
-    auto result = parser.parseInteger(coefficientModulus);
-    if (failed(result)) {
-      parser.emitError(parser.getCurrentLocation(),
-                       "invalid coefficient modulus");
-      return {};
-    }
-    coefficientModulusAttr = IntegerAttr::get(iType, coefficientModulus);
-
-    if (failed(parser.parseComma()))
-      return {};
-  }
-
-  PolynomialAttr polyAttr = nullptr;
-  if (succeeded(parser.parseKeyword("polynomialModulus"))) {
-    if (failed(parser.parseEqual()))
-      return {};
+  llvm::SmallVector<FloatMonomial> monomials;
+  llvm::StringSet<> variables;
 
-    PolynomialAttr attr;
-    if (failed(parser.parseAttribute<PolynomialAttr>(attr)))
-      return {};
-    polyAttr = attr;
-  }
+  ParseCoefficientFn<FloatMonomial> parseAndStoreCoefficient =
+      [&](FloatMonomial &monomial) -> OptionalParseResult {
+    double coeffValue = 1.0;
+    ParseResult result = parser.parseFloat(coeffValue);
+    monomial.setCoefficient(APFloat(coeffValue));
+    return OptionalParseResult(result);
+  };
 
-  Polynomial poly = polyAttr.getPolynomial();
-  APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
-  IntegerAttr rootAttr = nullptr;
-  if (succeeded(parser.parseOptionalComma())) {
-    if (failed(parser.parseKeyword("primitiveRoot")) ||
-        failed(parser.parseEqual()))
-      return {};
-
-    ParseResult result = parser.parseInteger(root);
-    if (failed(result)) {
-      parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot");
-      return {};
-    }
-    rootAttr = IntegerAttr::get(coefficientModulusAttr.getType(), root);
+  if (failed(parsePolynomialAttr<FloatPolynomialAttr, FloatMonomial>(
+          parser, monomials, variables, parseAndStoreCoefficient))) {
+    return {};
   }
 
-  if (failed(parser.parseGreater()))
+  auto result = FloatPolynomial::fromMonomials(monomials);
+  if (failed(result)) {
+    parser.emitError(parser.getCurrentLocation())
+        << "parsed polynomial must have unique exponents among monomials";
     return {};
-
-  return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
-                       polyAttr, rootAttr);
+  }
+  return FloatPolynomialAttr::get(parser.getContext(), result.value());
 }
 
 } // namespace polynomial
diff --git a/mlir/test/Dialect/Polynomial/attributes.mlir b/mlir/test/Dialect/Polynomial/attributes.mlir
index 3973ae3944335..4bdfd44fd4d15 100644
--- a/mlir/test/Dialect/Polynomial/attributes.mlir
+++ b/mlir/test/Dialect/Polynomial/attributes.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s --split-input-file --verify-diagnostics
 
-#my_poly = #polynomial.polynomial<y + x**1024>
+#my_poly = #polynomial.int_polynomial<y + x**1024>
 // expected-error at below {{polynomials must have one indeterminate, but there were multiple: x, y}}
 #ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
 
@@ -9,37 +9,31 @@
 // expected-error at below {{expected integer value}}
 // expected-error at below {{expected a monomial}}
 // expected-error at below {{found invalid integer exponent}}
-#my_poly = #polynomial.polynomial<5 + x**f>
+#my_poly = #polynomial.int_polynomial<5 + x**f>
 #ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
 
 // -----
 
-#my_poly = #polynomial.polynomial<5 + x**2 + 3x**2>
+#my_poly = #polynomial.int_polynomial<5 + x**2 + 3x**2>
 // expected-error at below {{parsed polynomial must have unique exponents among monomials}}
 #ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
 
 // -----
 
 // expected-error at below {{expected + and more monomials, or > to end polynomial attribute}}
-#my_poly = #polynomial.polynomial<5 + x**2 7>
+#my_poly = #polynomial.int_polynomial<5 + x**2 7>
 #ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
 
 // -----
 
 // expected-error at below {{expected a monomial}}
-#my_poly = #polynomial.polynomial<5 + x**2 +>
+#my_poly = #polynomial.int_polynomial<5 + x**2 +>
 #ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
 
 
 // -----
 
-#my_poly = #polynomial.polynomial<5 + x**2>
-// expected-error at below {{coefficientType must specify an integer type}}
-#ring1 = #polynomial.ring<coefficientType=f64, coefficientModulus=2837465, polynomialModulus=#my_poly>
-
-// -----
-
-#my_poly = #polynomial.polynomial<5 + x**2>
-// expected-error at below {{expected integer value}}
-// expected-error at below {{invalid coefficient modulus}}
+#my_poly = #polynomial.int_polynomial<5 + x**2>
+// expected-error at below {{failed to parse Polynomial_RingAttr parameter 'coefficientModulus' which is to be a `::mlir::IntegerAttr`}}
+// expected-error at below {{expected attribute value}}
 #ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=x, polynomialModulus=#my_poly>
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index a29cfc2e9cc54..76dbc9156a282 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -2,18 +2,19 @@
 
 // This simply tests for syntax.
 
-#my_poly = #polynomial.polynomial<1 + x**1024>
-#my_poly_2 = #polynomial.polynomial<2>
-#my_poly_3 = #polynomial.polynomial<3x>
-#my_poly_4 = #polynomial.polynomial<t**3 + 4t + 2>
+#my_poly = #polynomial.int_polynomial<1 + x**1024>
+#my_poly_2 = #polynomial.int_polynomial<2>
+#my_poly_3 = #polynomial.int_polynomial<3x>
+#my_poly_4 = #polynomial.int_polynomial<t**3 + 4t + 2>
 #ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
-#one_plus_x_squared = #polynomial.polynomial<1 + x**2>
+#ring2 = #polynomial.ring<coefficientType=f32>
+#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
 
-#ideal = #polynomial.polynomial<-1 + x**1024>
+#ideal = #polynomial.int_polynomial<-1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
 !poly_ty = !polynomial.polynomial<#ring>
 
-#ntt_poly = #polynomial.polynomial<-1 + x**8>
+#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
 #ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
 !ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
 
@@ -73,14 +74,15 @@ module {
 
   func.func @test_monic_monomial_mul() {
     %five = arith.constant 5 : index
-    %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<#ring1>
+    %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<#ring1>
     %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<#ring1>, index) -> !polynomial.polynomial<#ring1>
     return
   }
 
   func.func @test_constant() {
-    %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<#ring1>
-    %1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
+    %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<#ring1>
+    %1 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring1>
+    %2 = polynomial.constant {value=#polynomial.float_polynomial<1.5 + 0.5 x**2>} : !polynomial.polynomial<#ring2>
     return
   }
 
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index 2c20e7bcbf1d6..2bdf4f9d4c19a 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt --split-input-file --verify-diagnostics %s
 
-#my_poly = #polynomial.polynomial<1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+#my_poly = #polynomial.int_polynomial<1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
 !ty = !polynomial.polynomial<#ring>
 
 func.func @test_from_tensor_too_large_coeffs() {
@@ -15,13 +15,13 @@ func.func @test_from_tensor_too_large_coeffs() {
 
 // -----
 
-#my_poly = #polynomial.polynomial<1 + x**4>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#my_poly>
+#my_poly = #polynomial.int_polynomial<1 + x**4>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#my_poly>
 !ty = !polynomial.polynomial<#ring>
 func.func @test_from_tensor_wrong_tensor_type() {
   %two = arith.constant 2 : i32
   %coeffs1 = tensor.from_elements %two, %two, %two, %two, %two : tensor<5xi32>
-  // expected-error at below {{input type 'tensor<5xi32>' does not match output type '!polynomial.polynomial<#polynomial.ring<coefficientType=i32, coefficientModulus=256 : i32, polynomialModulus=#polynomial.polynomial<1 + x**4>>>'}}
+  // expected-error at below {{input type 'tensor<5xi32>' does not match output type '!polynomial.polynomial<#polynomial.ring<coefficientType = i32, coefficientModulus = 256 : i32, polynomialModulus = <1 + x**4>>>'}}
   // expected-note at below {{at most the degree of the polynomialModulus of the output type's ring attribute}}
   %poly = polynomial.from_tensor %coeffs1 : tensor<5xi32> -> !ty
   return
@@ -29,11 +29,11 @@ func.func @test_from_tensor_wrong_tensor_type() {
 
 // -----
 
-#my_poly = #polynomial.polynomial<1 + x**4>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#my_poly>
+#my_poly = #polynomial.int_polynomial<1 + x**4>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#my_poly>
 !ty = !polynomial.polynomial<#ring>
 func.func @test_to_tensor_wrong_output_tensor_type(%arg0 : !ty) {
-  // expected-error at below {{input type '!polynomial.polynomial<#polynomial.ring<coefficientType=i32, coefficientModulus=256 : i32, polynomialModulus=#polynomial.polynomial<1 + x**4>>>' does not match output type 'tensor<5xi32>'}}
+  // expected-error at below {{input type '!polynomial.polynomial<#polynomial.ring<coefficientType = i32, coefficientModulus = 256 : i32, polynomialModulus = <1 + x**4>>>' does not match output type 'tensor<5xi32>'}}
   // expected-note at below {{at most the degree of the polynomialModulus of the input type's ring attribute}}
   %tensor = polynomial.to_tensor %arg0 : !ty -> tensor<5xi32>
   return
@@ -41,8 +41,8 @@ func.func @test_to_tensor_wrong_output_tensor_type(%arg0 : !ty) {
 
 // -----
 
-#my_poly = #polynomial.polynomial<1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+#my_poly = #polynomial.int_polynomial<1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
 !ty = !polynomial.polynomial<#ring>
 
 func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
@@ -54,8 +54,8 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
 
 // -----
 
-#my_poly = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+#my_poly = #polynomial.int_polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
 !poly_ty = !polynomial.polynomial<#ring>
 
 // CHECK-NOT: @test_invalid_ntt
@@ -68,8 +68,8 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
 
 // -----
 
-#my_poly = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+#my_poly = #polynomial.int_polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
 !poly_ty = !polynomial.polynomial<#ring>
 
 // CHECK-NOT: @test_invalid_ntt
@@ -82,9 +82,9 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
 
 // -----
 
-#my_poly = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
-#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257, polynomialModulus=#my_poly, primitiveRoot=31>
+#my_poly = #polynomial.int_polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
 !poly_ty = !polynomial.polynomial<#ring>
 
 // CHECK-NOT: @test_invalid_intt
@@ -97,8 +97,8 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
 
 // -----
 
-#my_poly = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+#my_poly = #polynomial.int_polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
 !poly_ty = !polynomial.polynomial<#ring>
 
 // CHECK-NOT: @test_invalid_intt
@@ -112,8 +112,8 @@ func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
 
 // -----
 
-#my_poly = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+#my_poly = #polynomial.int_polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<#ring>
 
 // CHECK-NOT: @test_invalid_ntt
@@ -126,9 +126,9 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
 
 // -----
 
-#my_poly = #polynomial.polynomial<-1 + x**8>
+#my_poly = #polynomial.int_polynomial<-1 + x**8>
 // A valid root is 31
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=32>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=32:i16>
 !poly_ty = !polynomial.polynomial<#ring>
 
 // CHECK-NOT: @test_invalid_intt
diff --git a/mlir/test/Dialect/Polynomial/types.mlir b/mlir/test/Dialect/Polynomial/types.mlir
index 00296a36e890f..2b46ec9209eab 100644
--- a/mlir/test/Dialect/Polynomial/types.mlir
+++ b/mlir/test/Dialect/Polynomial/types.mlir
@@ -3,11 +3,11 @@
 // CHECK-LABEL: func @test_types
 // CHECK-SAME:  !polynomial.polynomial<
 // CHECK-SAME:    #polynomial.ring<
-// CHECK-SAME:       coefficientType=i32,
-// CHECK-SAME:       coefficientModulus=2837465 : i32,
-// CHECK-SAME:       polynomialModulus=#polynomial.polynomial<1 + x**1024>>>
-#my_poly = #polynomial.polynomial<1 + x**1024>
-#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
+// CHECK-SAME:       coefficientType = i32,
+// CHECK-SAME:       coefficientModulus = 2837465 : i32,
+// CHECK-SAME:       polynomialModulus = <1 + x**1024>>>
+#my_poly = #polynomial.int_polynomial<1 + x**1024>
+#ring1 = #polynomial.ring<coefficientType = i32, coefficientModulus = 2837465 : i32, polynomialModulus=#my_poly>
 !ty = !polynomial.polynomial<#ring1>
 func.func @test_types(%0: !ty) -> !ty {
   return %0 : !ty
@@ -17,11 +17,11 @@ func.func @test_types(%0: !ty) -> !ty {
 // CHECK-LABEL: func @test_non_x_variable_64_bit
 // CHECK-SAME:  !polynomial.polynomial<
 // CHECK-SAME:    #polynomial.ring<
-// CHECK-SAME:       coefficientType=i64,
-// CHECK-SAME:       coefficientModulus=2837465 : i64,
-// CHECK-SAME:       polynomialModulus=#polynomial.polynomial<2 + 4x + x**3>>>
-#my_poly_2 = #polynomial.polynomial<t**3 + 4t + 2>
-#ring2 = #polynomial.ring<coefficientType=i64, coefficientModulus=2837465, polynomialModulus=#my_poly_2>
+// CHECK-SAME:       coefficientType = i64,
+// CHECK-SAME:       coefficientModulus = 2837465 : i64,
+// CHECK-SAME:       polynomialModulus = <2 + 4x + x**3>>>
+#my_poly_2 = #polynomial.int_polynomial<t**3 + 4t + 2>
+#ring2 = #polynomial.ring<coefficientType = i64, coefficientModulus = 2837465 : i64, polynomialModulus=#my_poly_2>
 !ty2 = !polynomial.polynomial<#ring2>
 func.func @test_non_x_variable_64_bit(%0: !ty2) -> !ty2 {
   return %0 : !ty2
@@ -31,11 +31,11 @@ func.func @test_non_x_variable_64_bit(%0: !ty2) -> !ty2 {
 // CHECK-LABEL: func @test_linear_poly
 // CHECK-SAME:  !polynomial.polynomial<
 // CHECK-SAME:    #polynomial.ring<
-// CHECK-SAME:       coefficientType=i32,
-// CHECK-SAME:       coefficientModulus=12 : i32,
-// CHECK-SAME:       polynomialModulus=#polynomial.polynomial<4x>>
-#my_poly_3 = #polynomial.polynomial<4x>
-#ring3 = #polynomial.ring<coefficientType=i32, coefficientModulus=12, polynomialModulus=#my_poly_3>
+// CHECK-SAME:       coefficientType = i32,
+// CHECK-SAME:       coefficientModulus = 12 : i32,
+// CHECK-SAME:       polynomialModulus = <4x>>
+#my_poly_3 = #polynomial.int_polynomial<4x>
+#ring3 = #polynomial.ring<coefficientType = i32, coefficientModulus=12 : i32, polynomialModulus=#my_poly_3>
 !ty3 = !polynomial.polynomial<#ring3>
 func.func @test_linear_poly(%0: !ty3) -> !ty3 {
   return %0 : !ty3
@@ -44,13 +44,22 @@ func.func @test_linear_poly(%0: !ty3) -> !ty3 {
 // CHECK-LABEL: func @test_negative_leading_1
 // CHECK-SAME:  !polynomial.polynomial<
 // CHECK-SAME:    #polynomial.ring<
-// CHECK-SAME:       coefficientType=i32,
-// CHECK-SAME:       coefficientModulus=2837465 : i32,
-// CHECK-SAME:       polynomialModulus=#polynomial.polynomial<-1 + x**1024>>>
-#my_poly_4 = #polynomial.polynomial<-1 + x**1024>
-#ring4 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly_4>
+// CHECK-SAME:       coefficientType = i32,
+// CHECK-SAME:       coefficientModulus = 2837465 : i32,
+// CHECK-SAME:       polynomialModulus = <-1 + x**1024>>>
+#my_poly_4 = #polynomial.int_polynomial<-1 + x**1024>
+#ring4 = #polynomial.ring<coefficientType = i32, coefficientModulus = 2837465 : i32, polynomialModulus=#my_poly_4>
 !ty4 = !polynomial.polynomial<#ring4>
 func.func @test_negative_leading_1(%0: !ty4) -> !ty4 {
   return %0 : !ty4
 }
 
+// CHECK-LABEL: func @test_float_coefficients
+// CHECK-SAME:  !polynomial.polynomial<#polynomial.ring<coefficientType = f32>>
+#my_poly_5 = #polynomial.float_polynomial<0.5 + 1.6e03 x**1024>
+#ring5 = #polynomial.ring<coefficientType=f32>
+!ty5 = !polynomial.polynomial<#ring5>
+func.func @test_float_coefficients(%0: !ty5) -> !ty5 {
+  return %0 : !ty5
+}
+

>From ff11eea3dcd57ed0075d85a5eb3b6d5e39c9f212 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sun, 5 May 2024 15:31:30 -0700
Subject: [PATCH 2/9] update docs

---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  | 45 ++++++++++---------
 1 file changed, 24 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 08ab1f1811eed..d8dafba6a2473 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -39,13 +39,13 @@ def Polynomial_Dialect : Dialect {
     %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
 
     // A constant polynomial in a ring with i32 coefficients, modulo (x^1024 + 1)
-    #modulus = #polynomial.polynomial<1 + x**1024>
+    #modulus = #polynomial.int_polynomial<1 + x**1024>
     #ring = #polynomial.ring<coefficientType=i32, polynomialModulus=#modulus>
     %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
 
     // A constant polynomial in a ring with i32 coefficients, with a polynomial
     // modulus of (x^1024 + 1) and a coefficient modulus of 17.
-    #modulus = #polynomial.polynomial<1 + x**1024>
+    #modulus = #polynomial.int_polynomial<1 + x**1024>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=17:i32, polynomialModulus=#modulus>
     %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
     ```
@@ -126,7 +126,7 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
     `x**1024 - 1`.
 
     ```mlir
-    #poly_mod = #polynomial.polynomial<-1 + x**1024>
+    #poly_mod = #polynomial.int_polynomial<-1 + x**1024>
     #ring = #polynomial.ring<coefficientType=i32,
                              coefficientModulus=4294967291:i32,
                              polynomialModulus=#poly_mod>
@@ -210,10 +210,10 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
 
     ```mlir
     // add two polynomials modulo x^1024 - 1
-    #poly = #polynomial.polynomial<x**1024 - 1>
+    #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
-    %1 = polynomial.constant #polynomial.polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
     %2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
@@ -234,10 +234,10 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
 
     ```mlir
     // subtract two polynomials modulo x^1024 - 1
-    #poly = #polynomial.polynomial<x**1024 - 1>
+    #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
-    %1 = polynomial.constant #polynomial.polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
     %2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
@@ -258,10 +258,10 @@ def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
 
     ```mlir
     // multiply two polynomials modulo x^1024 - 1
-    #poly = #polynomial.polynomial<x**1024 - 1>
+    #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
-    %1 = polynomial.constant #polynomial.polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
     %2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
@@ -283,9 +283,9 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
 
     ```mlir
     // multiply two polynomials modulo x^1024 - 1
-    #poly = #polynomial.polynomial<x**1024 - 1>
+    #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
     %1 = arith.constant 3 : i32
     %2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32
     ```
@@ -314,9 +314,9 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
     Example:
 
     ```mlir
-    #poly = #polynomial.polynomial<x**1024 - 1>
+    #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
     %1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32)
     ```
   }];
@@ -337,7 +337,7 @@ def Polynomial_MonomialOp: Polynomial_Op<"monomial"> {
     Example:
 
     ```mlir
-    #poly = #polynomial.polynomial<x**1024 - 1>
+    #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
     %deg = arith.constant 1023 : index
     %five = arith.constant 5 : i32
@@ -377,7 +377,7 @@ def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> {
     Example:
 
     ```mlir
-    #poly = #polynomial.polynomial<x**1024 - 1>
+    #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
     %two = arith.constant 2 : i32
     %five = arith.constant 5 : i32
@@ -416,7 +416,7 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
     Example:
 
     ```mlir
-    #poly = #polynomial.polynomial<x**1024 - 1>
+    #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
     %two = arith.constant 2 : i32
     %five = arith.constant 5 : i32
@@ -443,9 +443,12 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
     Example:
 
     ```mlir
-    #poly = #polynomial.polynomial<x**1024 - 1>
+    #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+
+    #float_ring = #polynomial.ring<coefficientType=f32>
+    %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring>
     ```
   }];
   let arguments = (ins Polynomial_AnyPolynomialAttr:$value);

>From aaf668ea6186886a91b34a5b492a29bf646c1557 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sun, 5 May 2024 15:32:01 -0700
Subject: [PATCH 3/9] clang-format

---
 mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h    | 3 ++-
 mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp           | 1 -
 mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp | 8 ++++----
 3 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 5705deeadf730..2b3f0e105c6c5 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -56,7 +56,8 @@ class MonomialBase {
   }
 
   virtual bool isMonic() const = 0;
-  virtual void coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
+  virtual void
+  coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
 
   template <typename T>
   friend ::llvm::hash_code hash_value(const MonomialBase<T> &arg);
diff --git a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
index 9d0d38ba927e2..42e678fad060c 100644
--- a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
@@ -35,7 +35,6 @@ FailureOr<PolyT> fromMonomialsImpl(ArrayRef<MonomialT> monomials) {
   return PolyT(monomialsCopy);
 }
 
-
 FailureOr<IntPolynomial>
 IntPolynomial::fromMonomials(ArrayRef<IntMonomial> monomials) {
   return fromMonomialsImpl<IntPolynomial, IntMonomial>(monomials);
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index b5f674c98d835..890ce5226c30f 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -102,10 +102,10 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
 }
 
 template <typename PolynoimalAttrTy, typename Monomial>
-LogicalResult parsePolynomialAttr(
-    AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
-    llvm::StringSet<> &variables,
-    ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
+LogicalResult
+parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
+                    llvm::StringSet<> &variables,
+                    ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
   while (true) {
     Monomial parsedMonomial;
     llvm::StringRef parsedVariableRef;

>From 1230b731e68c95b4b1e7f3d039c5c227d1684be8 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Tue, 7 May 2024 13:39:43 -0700
Subject: [PATCH 4/9] infer context in attr builder

---
 mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td | 10 +++++++---
 1 file changed, 7 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index d8dafba6a2473..387b0998a76b9 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -150,15 +150,19 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
   );
   let assemblyFormat = "`<` struct(params) `>`";
   let builders = [
-    AttrBuilder<
+    AttrBuilderWithInferredContext<
         (ins "::mlir::Type":$coefficientTy,
               CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
               CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
               CArg<"::mlir::IntegerAttr", "nullptr"> :$primitiveRootAttr), [{
-      return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, primitiveRootAttr);
+      return $_get(
+        coefficientTy.getContext(),
+        coefficientTy,
+        coefficientModulusAttr,
+        polynomialModulusAttr,
+        primitiveRootAttr);
     }]>,
   ];
-  let skipDefaultBuilders = 1;
 }
 
 class Polynomial_Type<string name, string typeMnemonic>

>From 43911fc8df750b49e0eec3fc49f58df2241f7452 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Wed, 8 May 2024 09:41:31 -0700
Subject: [PATCH 5/9] use struct(params)

---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  |  2 +-
 mlir/test/Dialect/Polynomial/ops.mlir         | 50 +++++++++----------
 mlir/test/Dialect/Polynomial/ops_errors.mlir  | 24 ++++-----
 mlir/test/Dialect/Polynomial/types.mlir       | 20 ++++----
 4 files changed, 48 insertions(+), 48 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 387b0998a76b9..ae8484501a50d 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -176,7 +176,7 @@ def Polynomial_PolynomialType : Polynomial_Type<"Polynomial", "polynomial"> {
     A type for polynomials in a polynomial quotient ring.
   }];
   let parameters = (ins Polynomial_RingAttr:$ring);
-  let assemblyFormat = "`<` qualified($ring) `>`";
+  let assemblyFormat = "`<` struct(params) `>`";
 }
 
 def PolynomialLike: TypeOrContainer<Polynomial_PolynomialType, "polynomial-like">;
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index 76dbc9156a282..ff709960c50e9 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -12,77 +12,77 @@
 
 #ideal = #polynomial.int_polynomial<-1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
-!poly_ty = !polynomial.polynomial<#ring>
+!poly_ty = !polynomial.polynomial<ring=#ring>
 
 #ntt_poly = #polynomial.int_polynomial<-1 + x**8>
 #ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
-!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
+!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 
 module {
-  func.func @test_multiply() -> !polynomial.polynomial<#ring1> {
+  func.func @test_multiply() -> !polynomial.polynomial<ring=#ring1> {
     %c0 = arith.constant 0 : index
     %two = arith.constant 2 : i16
     %five = arith.constant 5 : i16
     %coeffs1 = tensor.from_elements %two, %two, %five : tensor<3xi16>
     %coeffs2 = tensor.from_elements %five, %five, %two : tensor<3xi16>
 
-    %poly1 = polynomial.from_tensor %coeffs1 : tensor<3xi16> -> !polynomial.polynomial<#ring1>
-    %poly2 = polynomial.from_tensor %coeffs2 : tensor<3xi16> -> !polynomial.polynomial<#ring1>
+    %poly1 = polynomial.from_tensor %coeffs1 : tensor<3xi16> -> !polynomial.polynomial<ring=#ring1>
+    %poly2 = polynomial.from_tensor %coeffs2 : tensor<3xi16> -> !polynomial.polynomial<ring=#ring1>
 
-    %3 = polynomial.mul %poly1, %poly2 : !polynomial.polynomial<#ring1>
+    %3 = polynomial.mul %poly1, %poly2 : !polynomial.polynomial<ring=#ring1>
 
-    return %3 : !polynomial.polynomial<#ring1>
+    return %3 : !polynomial.polynomial<ring=#ring1>
   }
 
-  func.func @test_elementwise(%p0 : !polynomial.polynomial<#ring1>, %p1: !polynomial.polynomial<#ring1>) {
-    %tp0 = tensor.from_elements %p0, %p1 : tensor<2x!polynomial.polynomial<#ring1>>
-    %tp1 = tensor.from_elements %p1, %p0 : tensor<2x!polynomial.polynomial<#ring1>>
+  func.func @test_elementwise(%p0 : !polynomial.polynomial<ring=#ring1>, %p1: !polynomial.polynomial<ring=#ring1>) {
+    %tp0 = tensor.from_elements %p0, %p1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
+    %tp1 = tensor.from_elements %p1, %p0 : tensor<2x!polynomial.polynomial<ring=#ring1>>
 
     %c = arith.constant 2 : i32
-    %mul_const_sclr = polynomial.mul_scalar %tp0, %c : tensor<2x!polynomial.polynomial<#ring1>>, i32
+    %mul_const_sclr = polynomial.mul_scalar %tp0, %c : tensor<2x!polynomial.polynomial<ring=#ring1>>, i32
 
-    %add = polynomial.add %tp0, %tp1 : tensor<2x!polynomial.polynomial<#ring1>>
-    %sub = polynomial.sub %tp0, %tp1 : tensor<2x!polynomial.polynomial<#ring1>>
-    %mul = polynomial.mul %tp0, %tp1 : tensor<2x!polynomial.polynomial<#ring1>>
+    %add = polynomial.add %tp0, %tp1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
+    %sub = polynomial.sub %tp0, %tp1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
+    %mul = polynomial.mul %tp0, %tp1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
 
     return
   }
 
-  func.func @test_to_from_tensor(%p0 : !polynomial.polynomial<#ring1>) {
+  func.func @test_to_from_tensor(%p0 : !polynomial.polynomial<ring=#ring1>) {
     %c0 = arith.constant 0 : index
     %two = arith.constant 2 : i16
     %coeffs1 = tensor.from_elements %two, %two : tensor<2xi16>
     // CHECK: from_tensor
-    %poly = polynomial.from_tensor %coeffs1 : tensor<2xi16> -> !polynomial.polynomial<#ring1>
+    %poly = polynomial.from_tensor %coeffs1 : tensor<2xi16> -> !polynomial.polynomial<ring=#ring1>
     // CHECK: to_tensor
-    %tensor = polynomial.to_tensor %poly : !polynomial.polynomial<#ring1> -> tensor<1024xi16>
+    %tensor = polynomial.to_tensor %poly : !polynomial.polynomial<ring=#ring1> -> tensor<1024xi16>
 
     return
   }
 
-  func.func @test_degree(%p0 : !polynomial.polynomial<#ring1>) {
-    %0, %1 = polynomial.leading_term %p0 : !polynomial.polynomial<#ring1> -> (index, i32)
+  func.func @test_degree(%p0 : !polynomial.polynomial<ring=#ring1>) {
+    %0, %1 = polynomial.leading_term %p0 : !polynomial.polynomial<ring=#ring1> -> (index, i32)
     return
   }
 
   func.func @test_monomial() {
     %deg = arith.constant 1023 : index
     %five = arith.constant 5 : i16
-    %0 = polynomial.monomial %five, %deg : (i16, index) -> !polynomial.polynomial<#ring1>
+    %0 = polynomial.monomial %five, %deg : (i16, index) -> !polynomial.polynomial<ring=#ring1>
     return
   }
 
   func.func @test_monic_monomial_mul() {
     %five = arith.constant 5 : index
-    %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<#ring1>
-    %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<#ring1>, index) -> !polynomial.polynomial<#ring1>
+    %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
+    %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<ring=#ring1>, index) -> !polynomial.polynomial<ring=#ring1>
     return
   }
 
   func.func @test_constant() {
-    %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<#ring1>
-    %1 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring1>
-    %2 = polynomial.constant {value=#polynomial.float_polynomial<1.5 + 0.5 x**2>} : !polynomial.polynomial<#ring2>
+    %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
+    %1 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<ring=#ring1>
+    %2 = polynomial.constant {value=#polynomial.float_polynomial<1.5 + 0.5 x**2>} : !polynomial.polynomial<ring=#ring2>
     return
   }
 
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index 2bdf4f9d4c19a..af8e4aa5da862 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -2,7 +2,7 @@
 
 #my_poly = #polynomial.int_polynomial<1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
-!ty = !polynomial.polynomial<#ring>
+!ty = !polynomial.polynomial<ring=#ring>
 
 func.func @test_from_tensor_too_large_coeffs() {
   %two = arith.constant 2 : i32
@@ -17,11 +17,11 @@ func.func @test_from_tensor_too_large_coeffs() {
 
 #my_poly = #polynomial.int_polynomial<1 + x**4>
 #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#my_poly>
-!ty = !polynomial.polynomial<#ring>
+!ty = !polynomial.polynomial<ring=#ring>
 func.func @test_from_tensor_wrong_tensor_type() {
   %two = arith.constant 2 : i32
   %coeffs1 = tensor.from_elements %two, %two, %two, %two, %two : tensor<5xi32>
-  // expected-error at below {{input type 'tensor<5xi32>' does not match output type '!polynomial.polynomial<#polynomial.ring<coefficientType = i32, coefficientModulus = 256 : i32, polynomialModulus = <1 + x**4>>>'}}
+  // expected-error at below {{input type 'tensor<5xi32>' does not match output type '!polynomial.polynomial<ring = <coefficientType = i32, coefficientModulus = 256 : i32, polynomialModulus = <1 + x**4>>>'}}
   // expected-note at below {{at most the degree of the polynomialModulus of the output type's ring attribute}}
   %poly = polynomial.from_tensor %coeffs1 : tensor<5xi32> -> !ty
   return
@@ -31,9 +31,9 @@ func.func @test_from_tensor_wrong_tensor_type() {
 
 #my_poly = #polynomial.int_polynomial<1 + x**4>
 #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#my_poly>
-!ty = !polynomial.polynomial<#ring>
+!ty = !polynomial.polynomial<ring=#ring>
 func.func @test_to_tensor_wrong_output_tensor_type(%arg0 : !ty) {
-  // expected-error at below {{input type '!polynomial.polynomial<#polynomial.ring<coefficientType = i32, coefficientModulus = 256 : i32, polynomialModulus = <1 + x**4>>>' does not match output type 'tensor<5xi32>'}}
+  // expected-error at below {{input type '!polynomial.polynomial<ring = <coefficientType = i32, coefficientModulus = 256 : i32, polynomialModulus = <1 + x**4>>>' does not match output type 'tensor<5xi32>'}}
   // expected-note at below {{at most the degree of the polynomialModulus of the input type's ring attribute}}
   %tensor = polynomial.to_tensor %arg0 : !ty -> tensor<5xi32>
   return
@@ -43,7 +43,7 @@ func.func @test_to_tensor_wrong_output_tensor_type(%arg0 : !ty) {
 
 #my_poly = #polynomial.int_polynomial<1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
-!ty = !polynomial.polynomial<#ring>
+!ty = !polynomial.polynomial<ring=#ring>
 
 func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
   %scalar = arith.constant 2 : i32  // should be i16
@@ -56,7 +56,7 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
-!poly_ty = !polynomial.polynomial<#ring>
+!poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_ntt
 // CHECK-NOT: polynomial.ntt
@@ -70,7 +70,7 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
-!poly_ty = !polynomial.polynomial<#ring>
+!poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_ntt
 // CHECK-NOT: polynomial.ntt
@@ -85,7 +85,7 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
 #ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
-!poly_ty = !polynomial.polynomial<#ring>
+!poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
 // CHECK-NOT: polynomial.intt
@@ -99,7 +99,7 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
-!poly_ty = !polynomial.polynomial<#ring>
+!poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
 // CHECK-NOT: polynomial.intt
@@ -114,7 +114,7 @@ func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-!poly_ty = !polynomial.polynomial<#ring>
+!poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_ntt
 // CHECK-NOT: polynomial.ntt
@@ -129,7 +129,7 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
 #my_poly = #polynomial.int_polynomial<-1 + x**8>
 // A valid root is 31
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=32:i16>
-!poly_ty = !polynomial.polynomial<#ring>
+!poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
 // CHECK-NOT: polynomial.intt
diff --git a/mlir/test/Dialect/Polynomial/types.mlir b/mlir/test/Dialect/Polynomial/types.mlir
index 2b46ec9209eab..dcc5663ceb84c 100644
--- a/mlir/test/Dialect/Polynomial/types.mlir
+++ b/mlir/test/Dialect/Polynomial/types.mlir
@@ -2,13 +2,13 @@
 
 // CHECK-LABEL: func @test_types
 // CHECK-SAME:  !polynomial.polynomial<
-// CHECK-SAME:    #polynomial.ring<
+// CHECK-SAME:    ring = <
 // CHECK-SAME:       coefficientType = i32,
 // CHECK-SAME:       coefficientModulus = 2837465 : i32,
 // CHECK-SAME:       polynomialModulus = <1 + x**1024>>>
 #my_poly = #polynomial.int_polynomial<1 + x**1024>
 #ring1 = #polynomial.ring<coefficientType = i32, coefficientModulus = 2837465 : i32, polynomialModulus=#my_poly>
-!ty = !polynomial.polynomial<#ring1>
+!ty = !polynomial.polynomial<ring=#ring1>
 func.func @test_types(%0: !ty) -> !ty {
   return %0 : !ty
 }
@@ -16,13 +16,13 @@ func.func @test_types(%0: !ty) -> !ty {
 
 // CHECK-LABEL: func @test_non_x_variable_64_bit
 // CHECK-SAME:  !polynomial.polynomial<
-// CHECK-SAME:    #polynomial.ring<
+// CHECK-SAME:    ring = <
 // CHECK-SAME:       coefficientType = i64,
 // CHECK-SAME:       coefficientModulus = 2837465 : i64,
 // CHECK-SAME:       polynomialModulus = <2 + 4x + x**3>>>
 #my_poly_2 = #polynomial.int_polynomial<t**3 + 4t + 2>
 #ring2 = #polynomial.ring<coefficientType = i64, coefficientModulus = 2837465 : i64, polynomialModulus=#my_poly_2>
-!ty2 = !polynomial.polynomial<#ring2>
+!ty2 = !polynomial.polynomial<ring=#ring2>
 func.func @test_non_x_variable_64_bit(%0: !ty2) -> !ty2 {
   return %0 : !ty2
 }
@@ -30,35 +30,35 @@ func.func @test_non_x_variable_64_bit(%0: !ty2) -> !ty2 {
 
 // CHECK-LABEL: func @test_linear_poly
 // CHECK-SAME:  !polynomial.polynomial<
-// CHECK-SAME:    #polynomial.ring<
+// CHECK-SAME:    ring = <
 // CHECK-SAME:       coefficientType = i32,
 // CHECK-SAME:       coefficientModulus = 12 : i32,
 // CHECK-SAME:       polynomialModulus = <4x>>
 #my_poly_3 = #polynomial.int_polynomial<4x>
 #ring3 = #polynomial.ring<coefficientType = i32, coefficientModulus=12 : i32, polynomialModulus=#my_poly_3>
-!ty3 = !polynomial.polynomial<#ring3>
+!ty3 = !polynomial.polynomial<ring=#ring3>
 func.func @test_linear_poly(%0: !ty3) -> !ty3 {
   return %0 : !ty3
 }
 
 // CHECK-LABEL: func @test_negative_leading_1
 // CHECK-SAME:  !polynomial.polynomial<
-// CHECK-SAME:    #polynomial.ring<
+// CHECK-SAME:    ring = <
 // CHECK-SAME:       coefficientType = i32,
 // CHECK-SAME:       coefficientModulus = 2837465 : i32,
 // CHECK-SAME:       polynomialModulus = <-1 + x**1024>>>
 #my_poly_4 = #polynomial.int_polynomial<-1 + x**1024>
 #ring4 = #polynomial.ring<coefficientType = i32, coefficientModulus = 2837465 : i32, polynomialModulus=#my_poly_4>
-!ty4 = !polynomial.polynomial<#ring4>
+!ty4 = !polynomial.polynomial<ring=#ring4>
 func.func @test_negative_leading_1(%0: !ty4) -> !ty4 {
   return %0 : !ty4
 }
 
 // CHECK-LABEL: func @test_float_coefficients
-// CHECK-SAME:  !polynomial.polynomial<#polynomial.ring<coefficientType = f32>>
+// CHECK-SAME:  !polynomial.polynomial<ring = <coefficientType = f32>>
 #my_poly_5 = #polynomial.float_polynomial<0.5 + 1.6e03 x**1024>
 #ring5 = #polynomial.ring<coefficientType=f32>
-!ty5 = !polynomial.polynomial<#ring5>
+!ty5 = !polynomial.polynomial<ring=#ring5>
 func.func @test_float_coefficients(%0: !ty5) -> !ty5 {
   return %0 : !ty5
 }

>From 4aee96c7580e1b0a6991ad63bbd6391c9d69813f Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Tue, 7 May 2024 15:07:38 -0700
Subject: [PATCH 6/9] add basic polynomial canonicalization patterns

---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  |  3 ++
 mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt |  4 ++
 .../IR/PolynomialCanonicalization.td          | 37 ++++++++++++++
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 25 ++++++++++
 .../Dialect/Polynomial/canonicalization.mlir  | 49 +++++++++++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     | 18 +++++++
 6 files changed, 136 insertions(+)
 create mode 100644 mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
 create mode 100644 mlir/test/Dialect/Polynomial/canonicalization.mlir

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index ae8484501a50d..537be4832e8f8 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -245,6 +245,7 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
     %2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
@@ -480,6 +481,7 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
   let arguments = (ins Polynomial_PolynomialType:$input);
   let results = (outs RankedTensorOf<[AnyInteger]>:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+  let hasCanonicalizer = 1;
   let hasVerifier = 1;
 }
 
@@ -498,6 +500,7 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
   let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
   let results = (outs Polynomial_PolynomialType:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+  let hasCanonicalizer = 1;
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt b/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt
index d6e703b8b3591..6dcdcb257674f 100644
--- a/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt
@@ -1,3 +1,7 @@
+set(LLVM_TARGET_DEFINITIONS PolynomialCanonicalization.td)
+mlir_tablegen(PolynomialCanonicalization.inc -gen-rewriters)
+add_public_tablegen_target(MLIRPolynomialCanonicalizationIncGen)
+
 add_mlir_dialect_library(MLIRPolynomialDialect
   Polynomial.cpp
   PolynomialAttributes.cpp
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
new file mode 100644
index 0000000000000..1292ececa2309
--- /dev/null
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -0,0 +1,37 @@
+//===- PolynomialCanonicalization.td - Polynomial patterns -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef POLYNOMIAL_CANONICALIZATION
+#define POLYNOMIAL_CANONICALIZATION
+
+include "mlir/Dialect/Polynomial/IR/Polynomial.td"
+include "mlir/Dialect/Arith/IR/ArithOps.td"
+include "mlir/IR/OpBase.td"
+include "mlir/IR/PatternBase.td"
+
+// TODO: get the proper scalar type from the operand polynomial ring attribute
+def SubAsAdd : Pat<
+  (Polynomial_SubOp $f, $g),
+  (Polynomial_AddOp $f,
+    (Polynomial_MulScalarOp $g,
+      (Arith_ConstantOp
+        ConstantAttr<I32Attr, "-1">)))>;
+
+def INTTAfterNTT : Pat<
+  (Polynomial_INTTOp (Polynomial_NTTOp $poly)),
+  (replaceWithValue $poly),
+  []
+>;
+
+def NTTAfterINTT : Pat<
+  (Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
+  (replaceWithValue $tensor),
+  []
+>;
+
+#endif  // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 12010de348237..329d5fec9b7c6 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -7,12 +7,14 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/APInt.h"
 
@@ -183,3 +185,26 @@ LogicalResult INTTOp::verify() {
   auto ring = getOutput().getType().getRing();
   return verifyNTTOp(this->getOperation(), ring, tensorType);
 }
+
+//===----------------------------------------------------------------------===//
+// TableGen'd canonicalization patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+#include "PolynomialCanonicalization.inc"
+} // namespace
+
+void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                        MLIRContext *context) {
+  populateWithGenerated(results);
+}
+
+void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                        MLIRContext *context) {
+  populateWithGenerated(results);
+}
+
+void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                         MLIRContext *context) {
+  populateWithGenerated(results);
+}
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
new file mode 100644
index 0000000000000..54759ff00c966
--- /dev/null
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt -canonicalize %s | FileCheck %s
+#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
+!tensor_ty = tensor<8xi32, #ntt_ring>
+
+// CHECK-LABEL: @test_canonicalize_intt_after_ntt
+// CHECK: (%[[P:.*]]: [[T:.*]]) -> [[T]]
+func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty {
+  // CHECK-NOT: polynomial.ntt
+  // CHECK-NOT: polynomial.intt
+  // CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]]  : [[T]]
+  %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
+  %p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty
+  %p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
+  // CHECK: return %[[RESULT]] : [[T]]
+  return %p2 : !ntt_poly_ty
+}
+
+// CHECK-LABEL: @test_canonicalize_ntt_after_intt
+// CHECK: (%[[X:.*]]: [[T:.*]]) -> [[T]]
+func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
+  // CHECK-NOT: polynomial.intt
+  // CHECK-NOT: polynomial.ntt
+  // CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
+  %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
+  %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
+  %t2 = arith.addi %t1, %t1 : !tensor_ty
+  // CHECK: return %[[RESULT]] : [[T]]
+  return %t2 : !tensor_ty
+}
+
+#cycl_2048 = #polynomial.int_polynomial<1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#cycl_2048>
+#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
+#one_minus_x_squared = #polynomial.int_polynomial<1 + -1x**2>
+
+// CHECK-LABEL: test_canonicalize_sub_power_of_two_cmod
+func.func @test_canonicalize_sub_power_of_two_cmod() -> !polynomial.polynomial<#ring> {
+  %poly0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<#ring>
+  %poly1 = polynomial.constant {value=#one_minus_x_squared} : !polynomial.polynomial<#ring>
+  %0 = polynomial.sub %poly0, %poly1  : !polynomial.polynomial<#ring>
+  // CHECK: %[[minus_one:.+]] = arith.constant -1 : i32
+  // CHECK: %[[p1:.+]] = polynomial.constant
+  // CHECK: %[[p2:.+]] = polynomial.constant
+  // CHECK: %[[p2neg:.+]] = polynomial.mul_scalar %[[p2]], %[[minus_one]]
+  // CHECK: [[ADD:%.+]] = polynomial.add %[[p1]], %[[p2neg]]
+  return %0 : !polynomial.polynomial<#ring>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index b75ed99df5b2c..eff2e787c977e 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6703,6 +6703,7 @@ cc_library(
         ":IR",
         ":InferTypeOpInterface",
         ":PolynomialAttributesIncGen",
+        ":PolynomialCanonicalizationIncGen",
         ":PolynomialIncGen",
         ":Support",
         "//llvm:Support",
@@ -6793,6 +6794,23 @@ gentbl_cc_library(
     deps = [":PolynomialTdFiles"],
 )
 
+gentbl_cc_library(
+    name = "PolynomialCanonicalizationIncGen",
+    strip_include_prefix = "include/mlir/Dialect/Polynomial/IR",
+    tbl_outs = [
+        (
+            ["-gen-rewriters"],
+            "include/mlir/Dialect/Polynomial/IR/PolynomialCanonicalization.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td",
+    deps = [
+        ":ArithOpsTdFiles",
+        ":PolynomialTdFiles",
+    ],
+)
+
 td_library(
     name = "SPIRVOpsTdFiles",
     srcs = glob(["include/mlir/Dialect/SPIRV/IR/*.td"]),

>From 24cbcd72c51462206cbe475db8e3bfd5b2cb6aae Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Wed, 8 May 2024 10:04:20 -0700
Subject: [PATCH 7/9] use struct for type

---
 mlir/test/Dialect/Polynomial/canonicalization.mlir | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index 54759ff00c966..2b4cf6aa8997f 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -canonicalize %s | FileCheck %s
 #ntt_poly = #polynomial.int_polynomial<-1 + x**8>
 #ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
-!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
+!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 !tensor_ty = tensor<8xi32, #ntt_ring>
 
 // CHECK-LABEL: @test_canonicalize_intt_after_ntt
@@ -34,16 +34,17 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
 #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#cycl_2048>
 #one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
 #one_minus_x_squared = #polynomial.int_polynomial<1 + -1x**2>
+!sub_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-LABEL: test_canonicalize_sub_power_of_two_cmod
-func.func @test_canonicalize_sub_power_of_two_cmod() -> !polynomial.polynomial<#ring> {
-  %poly0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<#ring>
-  %poly1 = polynomial.constant {value=#one_minus_x_squared} : !polynomial.polynomial<#ring>
-  %0 = polynomial.sub %poly0, %poly1  : !polynomial.polynomial<#ring>
+func.func @test_canonicalize_sub_power_of_two_cmod() -> !sub_ty {
+  %poly0 = polynomial.constant {value=#one_plus_x_squared} : !sub_ty
+  %poly1 = polynomial.constant {value=#one_minus_x_squared} : !sub_ty
+  %0 = polynomial.sub %poly0, %poly1  : !sub_ty
   // CHECK: %[[minus_one:.+]] = arith.constant -1 : i32
   // CHECK: %[[p1:.+]] = polynomial.constant
   // CHECK: %[[p2:.+]] = polynomial.constant
   // CHECK: %[[p2neg:.+]] = polynomial.mul_scalar %[[p2]], %[[minus_one]]
   // CHECK: [[ADD:%.+]] = polynomial.add %[[p1]], %[[p2neg]]
-  return %0 : !polynomial.polynomial<#ring>
+  return %0 : !sub_ty
 }

>From b340e5ccea189bb4083a2cc3263bc07c0e90c311 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Wed, 8 May 2024 13:20:04 -0700
Subject: [PATCH 8/9] implement add for polynomial data structure

- use CRTP for base classes
- Add unit test
---
 .../mlir/Dialect/Polynomial/IR/Polynomial.h   | 69 ++++++++++++++-----
 mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp |  4 +-
 mlir/unittests/Dialect/CMakeLists.txt         |  1 +
 .../Dialect/Polynomial/CMakeLists.txt         |  8 +++
 .../Dialect/Polynomial/PolynomialMathTest.cpp | 43 ++++++++++++
 5 files changed, 105 insertions(+), 20 deletions(-)
 create mode 100644 mlir/unittests/Dialect/Polynomial/CMakeLists.txt
 create mode 100644 mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 2b3f0e105c6c5..f0e5bdf16036c 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -30,7 +30,7 @@ namespace polynomial {
 /// would want to specify 128-bit polynomials statically in the source code.
 constexpr unsigned apintBitWidth = 64;
 
-template <typename CoefficientType>
+template <class Derived, typename CoefficientType>
 class MonomialBase {
 public:
   MonomialBase(const CoefficientType &coeff, const APInt &expo)
@@ -55,12 +55,21 @@ class MonomialBase {
     return (exponent.ult(other.exponent));
   }
 
+  Derived add(const Derived &other) {
+    assert(exponent == other.exponent);
+    CoefficientType newCoeff = coefficient + other.coefficient;
+    Derived result;
+    result.setCoefficient(newCoeff);
+    result.setExponent(exponent);
+    return result;
+  }
+
   virtual bool isMonic() const = 0;
   virtual void
   coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
 
-  template <typename T>
-  friend ::llvm::hash_code hash_value(const MonomialBase<T> &arg);
+  template <class D, typename T>
+  friend ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg);
 
 protected:
   CoefficientType coefficient;
@@ -69,7 +78,7 @@ class MonomialBase {
 
 /// A class representing a monomial of a single-variable polynomial with integer
 /// coefficients.
-class IntMonomial : public MonomialBase<APInt> {
+class IntMonomial : public MonomialBase<IntMonomial, APInt> {
 public:
   IntMonomial(int64_t coeff, uint64_t expo)
       : MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
@@ -77,7 +86,7 @@ class IntMonomial : public MonomialBase<APInt> {
   IntMonomial()
       : MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}
 
-  ~IntMonomial() = default;
+  ~IntMonomial() override = default;
 
   bool isMonic() const override { return coefficient == 1; }
 
@@ -88,14 +97,14 @@ class IntMonomial : public MonomialBase<APInt> {
 
 /// A class representing a monomial of a single-variable polynomial with integer
 /// coefficients.
-class FloatMonomial : public MonomialBase<APFloat> {
+class FloatMonomial : public MonomialBase<FloatMonomial, APFloat> {
 public:
   FloatMonomial(double coeff, uint64_t expo)
       : MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
 
   FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}
 
-  ~FloatMonomial() = default;
+  ~FloatMonomial() override = default;
 
   bool isMonic() const override { return coefficient == APFloat(1.0); }
 
@@ -104,7 +113,7 @@ class FloatMonomial : public MonomialBase<APFloat> {
   }
 };
 
-template <typename Monomial>
+template <class Derived, typename Monomial>
 class PolynomialBase {
 public:
   PolynomialBase() = delete;
@@ -149,6 +158,30 @@ class PolynomialBase {
     }
   }
 
+  Derived add(const Derived &other) {
+    SmallVector<Monomial> newTerms;
+    auto it1 = terms.begin();
+    auto it2 = other.terms.begin();
+    while (it1 != terms.end() || it2 != other.terms.end()) {
+      if (it1 == terms.end()) {
+        newTerms.emplace_back(*it2);
+        it2++;
+        continue;
+      }
+
+      if (it2 == other.terms.end()) {
+        newTerms.emplace_back(*it1);
+        it1++;
+        continue;
+      }
+
+      newTerms.emplace_back(it1->add(*it2));
+      it1++;
+      it2++;
+    }
+    return Derived(newTerms);
+  }
+
   // Prints polynomial to 'os'.
   void print(raw_ostream &os) const { print(os, " + ", "**"); }
 
@@ -168,8 +201,8 @@ class PolynomialBase {
 
   ArrayRef<Monomial> getTerms() const { return terms; }
 
-  template <typename T>
-  friend ::llvm::hash_code hash_value(const PolynomialBase<T> &arg);
+  template <class D, typename T>
+  friend ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg);
 
 private:
   // The monomial terms for this polynomial.
@@ -179,7 +212,7 @@ class PolynomialBase {
 /// A single-variable polynomial with integer coefficients.
 ///
 /// Eg: x^1024 + x + 1
-class IntPolynomial : public PolynomialBase<IntMonomial> {
+class IntPolynomial : public PolynomialBase<IntPolynomial, IntMonomial> {
 public:
   explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
 
@@ -196,7 +229,7 @@ class IntPolynomial : public PolynomialBase<IntMonomial> {
 /// A single-variable polynomial with double coefficients.
 ///
 /// Eg: 1.0 x^1024 + 3.5 x + 1e-05
-class FloatPolynomial : public PolynomialBase<FloatMonomial> {
+class FloatPolynomial : public PolynomialBase<FloatPolynomial, FloatMonomial> {
 public:
   explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
       : PolynomialBase(terms) {}
@@ -212,20 +245,20 @@ class FloatPolynomial : public PolynomialBase<FloatMonomial> {
 };
 
 // Make Polynomials hashable.
-template <typename T>
-inline ::llvm::hash_code hash_value(const PolynomialBase<T> &arg) {
+template <class D, typename T>
+inline ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg) {
   return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
 }
 
-template <typename T>
-inline ::llvm::hash_code hash_value(const MonomialBase<T> &arg) {
+template <class D, typename T>
+inline ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg) {
   return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
                             ::llvm::hash_value(arg.exponent));
 }
 
-template <typename T>
+template <class D, typename T>
 inline raw_ostream &operator<<(raw_ostream &os,
-                               const PolynomialBase<T> &polynomial) {
+                               const PolynomialBase<D, T> &polynomial) {
   polynomial.print(os);
   return os;
 }
diff --git a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
index 42e678fad060c..0fb2d6b81992f 100644
--- a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
@@ -14,8 +14,8 @@
 namespace mlir {
 namespace polynomial {
 
-template <typename T>
-MonomialBase<T>::~MonomialBase() {}
+template <class D, typename T>
+MonomialBase<D, T>::~MonomialBase() = default;
 
 template <typename PolyT, typename MonomialT>
 FailureOr<PolyT> fromMonomialsImpl(ArrayRef<MonomialT> monomials) {
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 13393569f36fe..90a75d5a46ad9 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -11,6 +11,7 @@ add_subdirectory(Index)
 add_subdirectory(LLVMIR)
 add_subdirectory(MemRef)
 add_subdirectory(OpenACC)
+add_subdirectory(Polynomial)
 add_subdirectory(SCF)
 add_subdirectory(SparseTensor)
 add_subdirectory(SPIRV)
diff --git a/mlir/unittests/Dialect/Polynomial/CMakeLists.txt b/mlir/unittests/Dialect/Polynomial/CMakeLists.txt
new file mode 100644
index 0000000000000..807deeca41c06
--- /dev/null
+++ b/mlir/unittests/Dialect/Polynomial/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_unittest(MLIRPolynomialTests
+  PolynomialMathTest.cpp
+)
+target_link_libraries(MLIRPolynomialTests
+  PRIVATE
+  MLIRIR
+  MLIRPolynomialDialect
+)
diff --git a/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp
new file mode 100644
index 0000000000000..485c2b64e4f21
--- /dev/null
+++ b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp
@@ -0,0 +1,43 @@
+//===- PolynomialMathTest.cpp - Polynomial math Tests ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::polynomial;
+
+TEST(AddTest, checkSameDegreeAdditionOfIntPolynomial) {
+  IntPolynomial x = IntPolynomial::fromCoefficients({1, 2, 3});
+  IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
+  IntPolynomial expected = IntPolynomial::fromCoefficients({3, 5, 7});
+  EXPECT_EQ(expected, x.add(y));
+}
+
+TEST(AddTest, checkDifferentDegreeAdditionOfIntPolynomial) {
+  IntPolynomial x = IntPolynomial::fromCoefficients({1, 2});
+  IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
+  IntPolynomial expected = IntPolynomial::fromCoefficients({3, 5, 4});
+  EXPECT_EQ(expected, x.add(y));
+  EXPECT_EQ(expected, y.add(x));
+}
+
+TEST(AddTest, checkSameDegreeAdditionOfFloatPolynomial) {
+  FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5, 3.5});
+  FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
+  FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 8});
+  EXPECT_EQ(expected, x.add(y));
+}
+
+TEST(AddTest, checkDifferentDegreeAdditionOfFloatPolynomial) {
+  FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5});
+  FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
+  FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 4.5});
+  EXPECT_EQ(expected, x.add(y));
+  EXPECT_EQ(expected, y.add(x));
+}

>From 04614be503d94506df1b722d9444911337008968 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Wed, 8 May 2024 18:50:12 -0700
Subject: [PATCH 9/9] try constBinaryFold

---
 mlir/include/mlir/Dialect/CommonFolders.h     |  3 +-
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  | 38 ++++++++++++++++++-
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 34 +++++++++++++++++
 3 files changed, 73 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 7dabc781cd595..29e6fccdf2553 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -82,7 +82,8 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
     if (!elementResult)
       return {};
 
-    return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
+    return DenseElementsAttr::get(cast<ShapedType>(resultType),
+                                  llvm::ArrayRef(*elementResult));
   }
 
   if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1])) {
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 537be4832e8f8..82b7d0c82952e 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -83,6 +83,23 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom
   let hasCustomAssemblyFormat = 1;
 }
 
+def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
+    "TypedIntPolynomial", "typed_int_polynomial", [TypedAttrInterface]> {
+  let summary = "A typed variant of int_polynomial for constant folding.";
+  let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::IntPolynomial":$value);
+  let assemblyFormat = "`<` struct(params) `>`";
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "Type":$type,
+                                        "const IntPolynomial &":$value), [{
+      return $_get(type.getContext(), type, value);
+    }]>
+  ];
+  let extraClassDeclaration = [{
+    // used for constFoldBinaryOp
+    using ValueType = ::mlir::polynomial::IntPolynomial;
+  }];
+}
+
 def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
   let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
   let description = [{
@@ -105,6 +122,23 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p
   let hasCustomAssemblyFormat = 1;
 }
 
+def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
+    "TypedFloatPolynomial", "typed_float_polynomial", [TypedAttrInterface]> {
+  let summary = "A typed variant of float_polynomial for constant folding.";
+  let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::FloatPolynomial":$value);
+  let assemblyFormat = "`<` struct(params) `>`";
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "Type":$type,
+                                        "const FloatPolynomial &":$value), [{
+      return $_get(type.getContext(), type, value);
+    }]>
+  ];
+  let extraClassDeclaration = [{
+    // used for constFoldBinaryOp
+    using ValueType = ::mlir::polynomial::FloatPolynomial;
+  }];
+}
+
 def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
   let summary = "An attribute specifying a polynomial ring.";
   let description = [{
@@ -221,6 +255,7 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
     %2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
+  let hasFolder = 1;
 }
 
 def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
@@ -442,7 +477,7 @@ def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
 ]>;
 
 // Not deriving from Polynomial_Op due to need for custom assembly format
-def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
+def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure, ConstantLike]> {
   let summary = "Define a constant polynomial via an attribute.";
   let description = [{
     Example:
@@ -459,6 +494,7 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
   let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
   let results = (outs Polynomial_PolynomialType:$output);
   let assemblyFormat = "attr-dict `:` type($output)";
+  let hasFolder = 1;
 }
 
 def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 329d5fec9b7c6..db71e88c86d36 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -8,10 +8,12 @@
 
 #include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/PatternMatch.h"
@@ -21,6 +23,38 @@
 using namespace mlir;
 using namespace mlir::polynomial;
 
+OpFoldResult ConstantOp::fold(ConstantOp::FoldAdaptor adaptor) {
+  PolynomialType ty = dyn_cast<PolynomialType>(getOutput().getType());
+
+  if (isa<FloatPolynomialAttr>(ty.getRing().getPolynomialModulus()))
+    return TypedFloatPolynomialAttr::get(ty, cast<FloatPolynomialAttr>(getValue()).getPolynomial());
+
+  assert(isa<IntPolynomialAttr>(ty.getRing().getPolynomialModulus()) &&
+         "expected float or integer polynomial");
+  return TypedIntPolynomialAttr::get(ty,cast<IntPolynomialAttr>(getValue()).getPolynomial());
+}
+
+OpFoldResult AddOp::fold(AddOp::FoldAdaptor adaptor) {
+  // Folded input attributes can either be typed_int_polynomial or
+  // typed_float_polynomial, and those require different invocations of
+  // constFoldBinaryOp.
+  PolynomialType ty = dyn_cast<PolynomialType>(getLhs().getType());
+  if (!ty) {
+    ShapedType shapedTy = dyn_cast<ShapedType>(getLhs().getType());
+    assert(shapedTy && "lhs must be a polynomial or a shaped type");
+    ty = cast<PolynomialType>(shapedTy.getElementType());
+  }
+
+  if (isa<FloatPolynomialAttr>(ty.getRing().getPolynomialModulus()))
+    return constFoldBinaryOp<TypedFloatPolynomialAttr>(
+        adaptor.getOperands(), getLhs().getType(),
+        [](FloatPolynomial a, const FloatPolynomial &b) { return a.add(b); });
+
+  return constFoldBinaryOp<TypedIntPolynomialAttr>(
+      adaptor.getOperands(), getLhs().getType(),
+      [](IntPolynomial a, const IntPolynomial &b) { return a.add(b); });
+}
+
 void FromTensorOp::build(OpBuilder &builder, OperationState &result,
                          Value input, RingAttr ring) {
   TensorType tensorType = dyn_cast<TensorType>(input.getType());



More information about the Mlir-commits mailing list