[Mlir-commits] [mlir] Restore #91137 (PR #92003)
Jeremy Kun
llvmlistbot at llvm.org
Mon May 13 10:27:13 PDT 2024
https://github.com/j2kun created https://github.com/llvm/llvm-project/pull/92003
#91137 reverted in #92001
Working on a build fix
>From f085443c71c6e729c972d9001a809193e2e53fea Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 13 May 2024 10:18:33 -0700
Subject: [PATCH 1/2] Revert "Support polynomial attributes with floating point
coefficients (#91137)"
This reverts commit 91a14dbf825b79ff143d1b16124763a4a80facab.
---
.../mlir/Dialect/Polynomial/IR/Polynomial.h | 193 ++++--------------
.../mlir/Dialect/Polynomial/IR/Polynomial.td | 139 +++++--------
mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp | 78 ++++---
.../Polynomial/IR/PolynomialAttributes.cpp | 172 +++++++++-------
mlir/test/Dialect/Polynomial/attributes.mlir | 22 +-
mlir/test/Dialect/Polynomial/ops.mlir | 64 +++---
mlir/test/Dialect/Polynomial/ops_errors.mlir | 66 +++---
mlir/test/Dialect/Polynomial/types.mlir | 65 +++---
8 files changed, 348 insertions(+), 451 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 2b3f0e105c6c5..3325a6fa3f9fc 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -11,13 +11,10 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
-#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Hashing.h"
-#include "llvm/ADT/SmallString.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/Support/raw_ostream.h"
+#include "llvm/ADT/SmallVector.h"
namespace mlir {
@@ -30,202 +27,98 @@ namespace polynomial {
/// would want to specify 128-bit polynomials statically in the source code.
constexpr unsigned apintBitWidth = 64;
-template <typename CoefficientType>
-class MonomialBase {
+/// A class representing a monomial of a single-variable polynomial with integer
+/// coefficients.
+class Monomial {
public:
- MonomialBase(const CoefficientType &coeff, const APInt &expo)
+ Monomial(int64_t coeff, uint64_t expo)
+ : coefficient(apintBitWidth, coeff), exponent(apintBitWidth, expo) {}
+
+ Monomial(const APInt &coeff, const APInt &expo)
: coefficient(coeff), exponent(expo) {}
- virtual ~MonomialBase() = 0;
- const CoefficientType &getCoefficient() const { return coefficient; }
- CoefficientType &getMutableCoefficient() { return coefficient; }
- const APInt &getExponent() const { return exponent; }
- void setCoefficient(const CoefficientType &coeff) { coefficient = coeff; }
- void setExponent(const APInt &exp) { exponent = exp; }
+ Monomial() : coefficient(apintBitWidth, 0), exponent(apintBitWidth, 0) {}
- bool operator==(const MonomialBase &other) const {
+ bool operator==(const Monomial &other) const {
return other.coefficient == coefficient && other.exponent == exponent;
}
- bool operator!=(const MonomialBase &other) const {
+ bool operator!=(const Monomial &other) const {
return other.coefficient != coefficient || other.exponent != exponent;
}
/// Monomials are ordered by exponent.
- bool operator<(const MonomialBase &other) const {
+ bool operator<(const Monomial &other) const {
return (exponent.ult(other.exponent));
}
- virtual bool isMonic() const = 0;
- virtual void
- coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
-
- template <typename T>
- friend ::llvm::hash_code hash_value(const MonomialBase<T> &arg);
+ friend ::llvm::hash_code hash_value(const Monomial &arg);
-protected:
- CoefficientType coefficient;
- APInt exponent;
-};
-
-/// A class representing a monomial of a single-variable polynomial with integer
-/// coefficients.
-class IntMonomial : public MonomialBase<APInt> {
public:
- IntMonomial(int64_t coeff, uint64_t expo)
- : MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
-
- IntMonomial()
- : MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}
-
- ~IntMonomial() = default;
-
- bool isMonic() const override { return coefficient == 1; }
+ APInt coefficient;
- void coefficientToString(llvm::SmallString<16> &coeffString) const override {
- coefficient.toStringSigned(coeffString);
- }
+ // Always unsigned
+ APInt exponent;
};
-/// A class representing a monomial of a single-variable polynomial with integer
-/// coefficients.
-class FloatMonomial : public MonomialBase<APFloat> {
+/// A single-variable polynomial with integer coefficients.
+///
+/// Eg: x^1024 + x + 1
+///
+/// The symbols used as the polynomial's indeterminate don't matter, so long as
+/// it is used consistently throughout the polynomial.
+class Polynomial {
public:
- FloatMonomial(double coeff, uint64_t expo)
- : MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
-
- FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}
+ Polynomial() = delete;
- ~FloatMonomial() = default;
+ explicit Polynomial(ArrayRef<Monomial> terms) : terms(terms){};
- bool isMonic() const override { return coefficient == APFloat(1.0); }
-
- void coefficientToString(llvm::SmallString<16> &coeffString) const override {
- coefficient.toString(coeffString);
- }
-};
-
-template <typename Monomial>
-class PolynomialBase {
-public:
- PolynomialBase() = delete;
+ // Returns a Polynomial from a list of monomials.
+ // Fails if two monomials have the same exponent.
+ static FailureOr<Polynomial> fromMonomials(ArrayRef<Monomial> monomials);
- explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms){};
+ /// Returns a polynomial with coefficients given by `coeffs`. The value
+ /// coeffs[i] is converted to a monomial with exponent i.
+ static Polynomial fromCoefficients(ArrayRef<int64_t> coeffs);
explicit operator bool() const { return !terms.empty(); }
- bool operator==(const PolynomialBase &other) const {
+ bool operator==(const Polynomial &other) const {
return other.terms == terms;
}
- bool operator!=(const PolynomialBase &other) const {
+ bool operator!=(const Polynomial &other) const {
return !(other.terms == terms);
}
- void print(raw_ostream &os, ::llvm::StringRef separator,
- ::llvm::StringRef exponentiation) const {
- bool first = true;
- for (const Monomial &term : getTerms()) {
- if (first) {
- first = false;
- } else {
- os << separator;
- }
- std::string coeffToPrint;
- if (term.isMonic() && term.getExponent().uge(1)) {
- coeffToPrint = "";
- } else {
- llvm::SmallString<16> coeffString;
- term.coefficientToString(coeffString);
- coeffToPrint = coeffString.str();
- }
-
- if (term.getExponent() == 0) {
- os << coeffToPrint;
- } else if (term.getExponent() == 1) {
- os << coeffToPrint << "x";
- } else {
- llvm::SmallString<16> expString;
- term.getExponent().toStringSigned(expString);
- os << coeffToPrint << "x" << exponentiation << expString;
- }
- }
- }
-
// Prints polynomial to 'os'.
- void print(raw_ostream &os) const { print(os, " + ", "**"); }
-
+ void print(raw_ostream &os) const;
+ void print(raw_ostream &os, ::llvm::StringRef separator,
+ ::llvm::StringRef exponentiation) const;
void dump() const;
// Prints polynomial so that it can be used as a valid identifier
- std::string toIdentifier() const {
- std::string result;
- llvm::raw_string_ostream os(result);
- print(os, "_", "");
- return os.str();
- }
+ std::string toIdentifier() const;
- unsigned getDegree() const {
- return terms.back().getExponent().getZExtValue();
- }
+ unsigned getDegree() const;
ArrayRef<Monomial> getTerms() const { return terms; }
- template <typename T>
- friend ::llvm::hash_code hash_value(const PolynomialBase<T> &arg);
+ friend ::llvm::hash_code hash_value(const Polynomial &arg);
private:
// The monomial terms for this polynomial.
SmallVector<Monomial> terms;
};
-/// A single-variable polynomial with integer coefficients.
-///
-/// Eg: x^1024 + x + 1
-class IntPolynomial : public PolynomialBase<IntMonomial> {
-public:
- explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
-
- // Returns a Polynomial from a list of monomials.
- // Fails if two monomials have the same exponent.
- static FailureOr<IntPolynomial>
- fromMonomials(ArrayRef<IntMonomial> monomials);
-
- /// Returns a polynomial with coefficients given by `coeffs`. The value
- /// coeffs[i] is converted to a monomial with exponent i.
- static IntPolynomial fromCoefficients(ArrayRef<int64_t> coeffs);
-};
-
-/// A single-variable polynomial with double coefficients.
-///
-/// Eg: 1.0 x^1024 + 3.5 x + 1e-05
-class FloatPolynomial : public PolynomialBase<FloatMonomial> {
-public:
- explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
- : PolynomialBase(terms) {}
-
- // Returns a Polynomial from a list of monomials.
- // Fails if two monomials have the same exponent.
- static FailureOr<FloatPolynomial>
- fromMonomials(ArrayRef<FloatMonomial> monomials);
-
- /// Returns a polynomial with coefficients given by `coeffs`. The value
- /// coeffs[i] is converted to a monomial with exponent i.
- static FloatPolynomial fromCoefficients(ArrayRef<double> coeffs);
-};
-
-// Make Polynomials hashable.
-template <typename T>
-inline ::llvm::hash_code hash_value(const PolynomialBase<T> &arg) {
+// Make Polynomial hashable.
+inline ::llvm::hash_code hash_value(const Polynomial &arg) {
return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
}
-template <typename T>
-inline ::llvm::hash_code hash_value(const MonomialBase<T> &arg) {
+inline ::llvm::hash_code hash_value(const Monomial &arg) {
return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
::llvm::hash_value(arg.exponent));
}
-template <typename T>
-inline raw_ostream &operator<<(raw_ostream &os,
- const PolynomialBase<T> &polynomial) {
+inline raw_ostream &operator<<(raw_ostream &os, const Polynomial &polynomial) {
polynomial.print(os);
return os;
}
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index ae8484501a50d..ed1f4ce8b7e59 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -39,14 +39,14 @@ def Polynomial_Dialect : Dialect {
%a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
// A constant polynomial in a ring with i32 coefficients, modulo (x^1024 + 1)
- #modulus = #polynomial.int_polynomial<1 + x**1024>
+ #modulus = #polynomial.polynomial<1 + x**1024>
#ring = #polynomial.ring<coefficientType=i32, polynomialModulus=#modulus>
%a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
// A constant polynomial in a ring with i32 coefficients, with a polynomial
// modulus of (x^1024 + 1) and a coefficient modulus of 17.
- #modulus = #polynomial.int_polynomial<1 + x**1024>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=17:i32, polynomialModulus=#modulus>
+ #modulus = #polynomial.polynomial<1 + x**1024>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=17, polynomialModulus=#modulus>
%a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
```
}];
@@ -60,12 +60,12 @@ class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
let mnemonic = attrMnemonic;
}
-def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynomial"> {
- let summary = "An attribute containing a single-variable polynomial with integer coefficients.";
+def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
+ let summary = "An attribute containing a single-variable polynomial.";
let description = [{
- A polynomial attribute represents a single-variable polynomial with integer
- coefficients, which is used to define the modulus of a `RingAttr`, as well
- as to define constants and perform constant folding for `polynomial` ops.
+ A polynomial attribute represents a single-variable polynomial, which
+ is used to define the modulus of a `RingAttr`, as well as to define constants
+ and perform constant folding for `polynomial` ops.
The polynomial must be expressed as a list of monomial terms, with addition
or subtraction between them. The choice of variable name is arbitrary, but
@@ -76,32 +76,10 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom
Example:
```mlir
- #poly = #polynomial.int_polynomial<x**1024 + 1>
+ #poly = #polynomial.polynomial<x**1024 + 1>
```
}];
- let parameters = (ins "::mlir::polynomial::IntPolynomial":$polynomial);
- let hasCustomAssemblyFormat = 1;
-}
-
-def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
- let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
- let description = [{
- A polynomial attribute represents a single-variable polynomial with double
- precision floating point coefficients.
-
- The polynomial must be expressed as a list of monomial terms, with addition
- or subtraction between them. The choice of variable name is arbitrary, but
- must be consistent across all the monomials used to define a single
- attribute. The order of monomial terms is arbitrary, each monomial degree
- must occur at most once.
-
- Example:
-
- ```mlir
- #poly = #polynomial.float_polynomial<0.5 x**7 + 1.5>
- ```
- }];
- let parameters = (ins "FloatPolynomial":$polynomial);
+ let parameters = (ins "::mlir::polynomial::Polynomial":$polynomial);
let hasCustomAssemblyFormat = 1;
}
@@ -126,9 +104,9 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
`x**1024 - 1`.
```mlir
- #poly_mod = #polynomial.int_polynomial<-1 + x**1024>
+ #poly_mod = #polynomial.polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i32,
- coefficientModulus=4294967291:i32,
+ coefficientModulus=4294967291,
polynomialModulus=#poly_mod>
%0 = ... : polynomial.polynomial<#ring>
@@ -145,24 +123,19 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
let parameters = (ins
"Type": $coefficientType,
OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
- OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
+ OptionalParameter<"::mlir::polynomial::PolynomialAttr">: $polynomialModulus,
OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
);
- let assemblyFormat = "`<` struct(params) `>`";
+
let builders = [
- AttrBuilderWithInferredContext<
+ AttrBuilder<
(ins "::mlir::Type":$coefficientTy,
- CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
- CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
- CArg<"::mlir::IntegerAttr", "nullptr"> :$primitiveRootAttr), [{
- return $_get(
- coefficientTy.getContext(),
- coefficientTy,
- coefficientModulusAttr,
- polynomialModulusAttr,
- primitiveRootAttr);
- }]>,
+ "::mlir::IntegerAttr":$coefficientModulusAttr,
+ "::mlir::polynomial::PolynomialAttr":$polynomialModulusAttr), [{
+ return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, nullptr);
+ }]>
];
+ let hasCustomAssemblyFormat = 1;
}
class Polynomial_Type<string name, string typeMnemonic>
@@ -176,7 +149,7 @@ def Polynomial_PolynomialType : Polynomial_Type<"Polynomial", "polynomial"> {
A type for polynomials in a polynomial quotient ring.
}];
let parameters = (ins Polynomial_RingAttr:$ring);
- let assemblyFormat = "`<` struct(params) `>`";
+ let assemblyFormat = "`<` $ring `>`";
}
def PolynomialLike: TypeOrContainer<Polynomial_PolynomialType, "polynomial-like">;
@@ -214,10 +187,10 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
```mlir
// add two polynomials modulo x^1024 - 1
- #poly = #polynomial.int_polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
- %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+ #poly = #polynomial.polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+ %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ %1 = polynomial.constant #polynomial.polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
%2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
```
}];
@@ -238,10 +211,10 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
```mlir
// subtract two polynomials modulo x^1024 - 1
- #poly = #polynomial.int_polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
- %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+ #poly = #polynomial.polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+ %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ %1 = polynomial.constant #polynomial.polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
%2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
```
}];
@@ -262,10 +235,10 @@ def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
```mlir
// multiply two polynomials modulo x^1024 - 1
- #poly = #polynomial.int_polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
- %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+ #poly = #polynomial.polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+ %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ %1 = polynomial.constant #polynomial.polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
%2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring>
```
}];
@@ -287,9 +260,9 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
```mlir
// multiply two polynomials modulo x^1024 - 1
- #poly = #polynomial.int_polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ #poly = #polynomial.polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+ %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
%1 = arith.constant 3 : i32
%2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32
```
@@ -318,9 +291,9 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
Example:
```mlir
- #poly = #polynomial.int_polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ #poly = #polynomial.polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+ %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
%1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32)
```
}];
@@ -341,8 +314,8 @@ def Polynomial_MonomialOp: Polynomial_Op<"monomial"> {
Example:
```mlir
- #poly = #polynomial.int_polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
+ #poly = #polynomial.polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
%deg = arith.constant 1023 : index
%five = arith.constant 5 : i32
%0 = polynomial.monomial %five, %deg : (i32, index) -> !polynomial.polynomial<#ring>
@@ -381,8 +354,8 @@ def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> {
Example:
```mlir
- #poly = #polynomial.int_polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
+ #poly = #polynomial.polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
%two = arith.constant 2 : i32
%five = arith.constant 5 : i32
%coeffs = tensor.from_elements %two, %two, %five : tensor<3xi32>
@@ -420,8 +393,8 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
Example:
```mlir
- #poly = #polynomial.int_polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
+ #poly = #polynomial.polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
%two = arith.constant 2 : i32
%five = arith.constant 5 : i32
%coeffs = tensor.from_elements %two, %two, %five : tensor<3xi32>
@@ -432,32 +405,24 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
let arguments = (ins Polynomial_PolynomialType:$input);
let results = (outs RankedTensorOf<[AnyInteger]>:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
+
let hasVerifier = 1;
}
-def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
- Polynomial_FloatPolynomialAttr,
- Polynomial_IntPolynomialAttr
-]>;
-
-// Not deriving from Polynomial_Op due to need for custom assembly format
-def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
+def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
let summary = "Define a constant polynomial via an attribute.";
let description = [{
Example:
```mlir
- #poly = #polynomial.int_polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
-
- #float_ring = #polynomial.ring<coefficientType=f32>
- %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring>
+ #poly = #polynomial.polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+ %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
```
}];
- let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
+ let arguments = (ins Polynomial_PolynomialAttr:$input);
let results = (outs Polynomial_PolynomialType:$output);
- let assemblyFormat = "attr-dict `:` type($output)";
+ let assemblyFormat = "$input attr-dict `:` type($output)";
}
def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
diff --git a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
index 42e678fad060c..5916ffba78e24 100644
--- a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
@@ -9,63 +9,87 @@
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir {
namespace polynomial {
-template <typename T>
-MonomialBase<T>::~MonomialBase() {}
-
-template <typename PolyT, typename MonomialT>
-FailureOr<PolyT> fromMonomialsImpl(ArrayRef<MonomialT> monomials) {
+FailureOr<Polynomial> Polynomial::fromMonomials(ArrayRef<Monomial> monomials) {
// A polynomial's terms are canonically stored in order of increasing degree.
- auto monomialsCopy = llvm::SmallVector<MonomialT>(monomials);
+ auto monomialsCopy = llvm::SmallVector<Monomial>(monomials);
std::sort(monomialsCopy.begin(), monomialsCopy.end());
// Ensure non-unique exponents are not present. Since we sorted the list by
// exponent, a linear scan of adjancent monomials suffices.
if (std::adjacent_find(monomialsCopy.begin(), monomialsCopy.end(),
- [](const MonomialT &lhs, const MonomialT &rhs) {
- return lhs.getExponent() == rhs.getExponent();
+ [](const Monomial &lhs, const Monomial &rhs) {
+ return lhs.exponent == rhs.exponent;
}) != monomialsCopy.end()) {
return failure();
}
- return PolyT(monomialsCopy);
-}
-
-FailureOr<IntPolynomial>
-IntPolynomial::fromMonomials(ArrayRef<IntMonomial> monomials) {
- return fromMonomialsImpl<IntPolynomial, IntMonomial>(monomials);
-}
-
-FailureOr<FloatPolynomial>
-FloatPolynomial::fromMonomials(ArrayRef<FloatMonomial> monomials) {
- return fromMonomialsImpl<FloatPolynomial, FloatMonomial>(monomials);
+ return Polynomial(monomialsCopy);
}
-template <typename PolyT, typename MonomialT, typename CoeffT>
-PolyT fromCoefficientsImpl(ArrayRef<CoeffT> coeffs) {
- llvm::SmallVector<MonomialT> monomials;
+Polynomial Polynomial::fromCoefficients(ArrayRef<int64_t> coeffs) {
+ llvm::SmallVector<Monomial> monomials;
auto size = coeffs.size();
monomials.reserve(size);
for (size_t i = 0; i < size; i++) {
monomials.emplace_back(coeffs[i], i);
}
- auto result = PolyT::fromMonomials(monomials);
+ auto result = Polynomial::fromMonomials(monomials);
// Construction guarantees unique exponents, so the failure mode of
// fromMonomials can be bypassed.
assert(succeeded(result));
return result.value();
}
-IntPolynomial IntPolynomial::fromCoefficients(ArrayRef<int64_t> coeffs) {
- return fromCoefficientsImpl<IntPolynomial, IntMonomial, int64_t>(coeffs);
+void Polynomial::print(raw_ostream &os, ::llvm::StringRef separator,
+ ::llvm::StringRef exponentiation) const {
+ bool first = true;
+ for (const Monomial &term : terms) {
+ if (first) {
+ first = false;
+ } else {
+ os << separator;
+ }
+ std::string coeffToPrint;
+ if (term.coefficient == 1 && term.exponent.uge(1)) {
+ coeffToPrint = "";
+ } else {
+ llvm::SmallString<16> coeffString;
+ term.coefficient.toStringSigned(coeffString);
+ coeffToPrint = coeffString.str();
+ }
+
+ if (term.exponent == 0) {
+ os << coeffToPrint;
+ } else if (term.exponent == 1) {
+ os << coeffToPrint << "x";
+ } else {
+ llvm::SmallString<16> expString;
+ term.exponent.toStringSigned(expString);
+ os << coeffToPrint << "x" << exponentiation << expString;
+ }
+ }
+}
+
+void Polynomial::print(raw_ostream &os) const { print(os, " + ", "**"); }
+
+std::string Polynomial::toIdentifier() const {
+ std::string result;
+ llvm::raw_string_ostream os(result);
+ print(os, "_", "");
+ return os.str();
}
-FloatPolynomial FloatPolynomial::fromCoefficients(ArrayRef<double> coeffs) {
- return fromCoefficientsImpl<FloatPolynomial, FloatMonomial, double>(coeffs);
+unsigned Polynomial::getDegree() const {
+ return terms.back().exponent.getZExtValue();
}
} // namespace polynomial
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index 890ce5226c30f..236bb78966352 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -10,7 +10,6 @@
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
-#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
@@ -18,31 +17,22 @@
namespace mlir {
namespace polynomial {
-void IntPolynomialAttr::print(AsmPrinter &p) const {
- p << '<' << getPolynomial() << '>';
+void PolynomialAttr::print(AsmPrinter &p) const {
+ p << '<';
+ p << getPolynomial();
+ p << '>';
}
-void FloatPolynomialAttr::print(AsmPrinter &p) const {
- p << '<' << getPolynomial() << '>';
-}
-
-/// A callable that parses the coefficient using the appropriate method for the
-/// given monomial type, and stores the parsed coefficient value on the
-/// monomial.
-template <typename MonomialType>
-using ParseCoefficientFn = std::function<OptionalParseResult(MonomialType &)>;
-
/// Try to parse a monomial. If successful, populate the fields of the outparam
/// `monomial` with the results, and the `variable` outparam with the parsed
/// variable name. Sets shouldParseMore to true if the monomial is followed by
/// a '+'.
-///
-template <typename Monomial>
-ParseResult
-parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
- bool &isConstantTerm, bool &shouldParseMore,
- ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
- OptionalParseResult parsedCoeffResult = parseAndStoreCoefficient(monomial);
+ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
+ llvm::StringRef &variable, bool &isConstantTerm,
+ bool &shouldParseMore) {
+ APInt parsedCoeff(apintBitWidth, 1);
+ auto parsedCoeffResult = parser.parseOptionalInteger(parsedCoeff);
+ monomial.coefficient = parsedCoeff;
isConstantTerm = false;
shouldParseMore = false;
@@ -54,7 +44,7 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
if (!parsedCoeffResult.has_value()) {
return failure();
}
- monomial.setExponent(APInt(apintBitWidth, 0));
+ monomial.exponent = APInt(apintBitWidth, 0);
isConstantTerm = true;
shouldParseMore = true;
return success();
@@ -68,7 +58,7 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
return failure();
}
- monomial.setExponent(APInt(apintBitWidth, 0));
+ monomial.exponent = APInt(apintBitWidth, 0);
isConstantTerm = true;
return success();
}
@@ -90,9 +80,9 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
return failure();
}
- monomial.setExponent(parsedExponent);
+ monomial.exponent = parsedExponent;
} else {
- monomial.setExponent(APInt(apintBitWidth, 1));
+ monomial.exponent = APInt(apintBitWidth, 1);
}
if (succeeded(parser.parseOptionalPlus())) {
@@ -101,21 +91,22 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
return success();
}
-template <typename PolynoimalAttrTy, typename Monomial>
-LogicalResult
-parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
- llvm::StringSet<> &variables,
- ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
+Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
+ if (failed(parser.parseLess()))
+ return {};
+
+ llvm::SmallVector<Monomial> monomials;
+ llvm::StringSet<> variables;
+
while (true) {
Monomial parsedMonomial;
llvm::StringRef parsedVariableRef;
bool isConstantTerm;
bool shouldParseMore;
- if (failed(parseMonomial<Monomial>(
- parser, parsedMonomial, parsedVariableRef, isConstantTerm,
- shouldParseMore, parseAndStoreCoefficient))) {
+ if (failed(parseMonomial(parser, parsedMonomial, parsedVariableRef,
+ isConstantTerm, shouldParseMore))) {
parser.emitError(parser.getCurrentLocation(), "expected a monomial");
- return failure();
+ return {};
}
if (!isConstantTerm) {
@@ -133,7 +124,7 @@ parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
parser.emitError(
parser.getCurrentLocation(),
"expected + and more monomials, or > to end polynomial attribute");
- return failure();
+ return {};
}
if (variables.size() > 1) {
@@ -142,67 +133,96 @@ parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
parser.getCurrentLocation(),
"polynomials must have one indeterminate, but there were multiple: " +
vars);
- return failure();
}
- return success();
+ auto result = Polynomial::fromMonomials(monomials);
+ if (failed(result)) {
+ parser.emitError(parser.getCurrentLocation())
+ << "parsed polynomial must have unique exponents among monomials";
+ return {};
+ }
+ return PolynomialAttr::get(parser.getContext(), result.value());
}
-Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) {
+void RingAttr::print(AsmPrinter &p) const {
+ p << "#polynomial.ring<coefficientType=" << getCoefficientType()
+ << ", coefficientModulus=" << getCoefficientModulus()
+ << ", polynomialModulus=" << getPolynomialModulus() << '>';
+}
+
+Attribute RingAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
- llvm::SmallVector<IntMonomial> monomials;
- llvm::StringSet<> variables;
+ if (failed(parser.parseKeyword("coefficientType")))
+ return {};
- if (failed(parsePolynomialAttr<IntPolynomialAttr, IntMonomial>(
- parser, monomials, variables,
- [&](IntMonomial &monomial) -> OptionalParseResult {
- APInt parsedCoeff(apintBitWidth, 1);
- OptionalParseResult result =
- parser.parseOptionalInteger(parsedCoeff);
- monomial.setCoefficient(parsedCoeff);
- return result;
- }))) {
+ if (failed(parser.parseEqual()))
return {};
- }
- auto result = IntPolynomial::fromMonomials(monomials);
- if (failed(result)) {
- parser.emitError(parser.getCurrentLocation())
- << "parsed polynomial must have unique exponents among monomials";
+ Type ty;
+ if (failed(parser.parseType(ty)))
return {};
- }
- return IntPolynomialAttr::get(parser.getContext(), result.value());
-}
-Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
- if (failed(parser.parseLess()))
+ if (failed(parser.parseComma()))
return {};
- llvm::SmallVector<FloatMonomial> monomials;
- llvm::StringSet<> variables;
+ IntegerAttr coefficientModulusAttr = nullptr;
+ if (succeeded(parser.parseKeyword("coefficientModulus"))) {
+ if (failed(parser.parseEqual()))
+ return {};
- ParseCoefficientFn<FloatMonomial> parseAndStoreCoefficient =
- [&](FloatMonomial &monomial) -> OptionalParseResult {
- double coeffValue = 1.0;
- ParseResult result = parser.parseFloat(coeffValue);
- monomial.setCoefficient(APFloat(coeffValue));
- return OptionalParseResult(result);
- };
+ IntegerType iType = mlir::dyn_cast<IntegerType>(ty);
+ if (!iType) {
+ parser.emitError(parser.getCurrentLocation(),
+ "coefficientType must specify an integer type");
+ return {};
+ }
+ APInt coefficientModulus(iType.getWidth(), 0);
+ auto result = parser.parseInteger(coefficientModulus);
+ if (failed(result)) {
+ parser.emitError(parser.getCurrentLocation(),
+ "invalid coefficient modulus");
+ return {};
+ }
+ coefficientModulusAttr = IntegerAttr::get(iType, coefficientModulus);
- if (failed(parsePolynomialAttr<FloatPolynomialAttr, FloatMonomial>(
- parser, monomials, variables, parseAndStoreCoefficient))) {
- return {};
+ if (failed(parser.parseComma()))
+ return {};
}
- auto result = FloatPolynomial::fromMonomials(monomials);
- if (failed(result)) {
- parser.emitError(parser.getCurrentLocation())
- << "parsed polynomial must have unique exponents among monomials";
- return {};
+ PolynomialAttr polyAttr = nullptr;
+ if (succeeded(parser.parseKeyword("polynomialModulus"))) {
+ if (failed(parser.parseEqual()))
+ return {};
+
+ PolynomialAttr attr;
+ if (failed(parser.parseAttribute<PolynomialAttr>(attr)))
+ return {};
+ polyAttr = attr;
}
- return FloatPolynomialAttr::get(parser.getContext(), result.value());
+
+ Polynomial poly = polyAttr.getPolynomial();
+ APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
+ IntegerAttr rootAttr = nullptr;
+ if (succeeded(parser.parseOptionalComma())) {
+ if (failed(parser.parseKeyword("primitiveRoot")) ||
+ failed(parser.parseEqual()))
+ return {};
+
+ ParseResult result = parser.parseInteger(root);
+ if (failed(result)) {
+ parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot");
+ return {};
+ }
+ rootAttr = IntegerAttr::get(coefficientModulusAttr.getType(), root);
+ }
+
+ if (failed(parser.parseGreater()))
+ return {};
+
+ return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
+ polyAttr, rootAttr);
}
} // namespace polynomial
diff --git a/mlir/test/Dialect/Polynomial/attributes.mlir b/mlir/test/Dialect/Polynomial/attributes.mlir
index 4bdfd44fd4d15..3973ae3944335 100644
--- a/mlir/test/Dialect/Polynomial/attributes.mlir
+++ b/mlir/test/Dialect/Polynomial/attributes.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s --split-input-file --verify-diagnostics
-#my_poly = #polynomial.int_polynomial<y + x**1024>
+#my_poly = #polynomial.polynomial<y + x**1024>
// expected-error at below {{polynomials must have one indeterminate, but there were multiple: x, y}}
#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
@@ -9,31 +9,37 @@
// expected-error at below {{expected integer value}}
// expected-error at below {{expected a monomial}}
// expected-error at below {{found invalid integer exponent}}
-#my_poly = #polynomial.int_polynomial<5 + x**f>
+#my_poly = #polynomial.polynomial<5 + x**f>
#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
// -----
-#my_poly = #polynomial.int_polynomial<5 + x**2 + 3x**2>
+#my_poly = #polynomial.polynomial<5 + x**2 + 3x**2>
// expected-error at below {{parsed polynomial must have unique exponents among monomials}}
#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
// -----
// expected-error at below {{expected + and more monomials, or > to end polynomial attribute}}
-#my_poly = #polynomial.int_polynomial<5 + x**2 7>
+#my_poly = #polynomial.polynomial<5 + x**2 7>
#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
// -----
// expected-error at below {{expected a monomial}}
-#my_poly = #polynomial.int_polynomial<5 + x**2 +>
+#my_poly = #polynomial.polynomial<5 + x**2 +>
#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
// -----
-#my_poly = #polynomial.int_polynomial<5 + x**2>
-// expected-error at below {{failed to parse Polynomial_RingAttr parameter 'coefficientModulus' which is to be a `::mlir::IntegerAttr`}}
-// expected-error at below {{expected attribute value}}
+#my_poly = #polynomial.polynomial<5 + x**2>
+// expected-error at below {{coefficientType must specify an integer type}}
+#ring1 = #polynomial.ring<coefficientType=f64, coefficientModulus=2837465, polynomialModulus=#my_poly>
+
+// -----
+
+#my_poly = #polynomial.polynomial<5 + x**2>
+// expected-error at below {{expected integer value}}
+// expected-error at below {{invalid coefficient modulus}}
#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=x, polynomialModulus=#my_poly>
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index ff709960c50e9..a29cfc2e9cc54 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -2,87 +2,85 @@
// This simply tests for syntax.
-#my_poly = #polynomial.int_polynomial<1 + x**1024>
-#my_poly_2 = #polynomial.int_polynomial<2>
-#my_poly_3 = #polynomial.int_polynomial<3x>
-#my_poly_4 = #polynomial.int_polynomial<t**3 + 4t + 2>
+#my_poly = #polynomial.polynomial<1 + x**1024>
+#my_poly_2 = #polynomial.polynomial<2>
+#my_poly_3 = #polynomial.polynomial<3x>
+#my_poly_4 = #polynomial.polynomial<t**3 + 4t + 2>
#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
-#ring2 = #polynomial.ring<coefficientType=f32>
-#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
+#one_plus_x_squared = #polynomial.polynomial<1 + x**2>
-#ideal = #polynomial.int_polynomial<-1 + x**1024>
+#ideal = #polynomial.polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
-!poly_ty = !polynomial.polynomial<ring=#ring>
+!poly_ty = !polynomial.polynomial<#ring>
-#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
+#ntt_poly = #polynomial.polynomial<-1 + x**8>
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
-!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
+!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
module {
- func.func @test_multiply() -> !polynomial.polynomial<ring=#ring1> {
+ func.func @test_multiply() -> !polynomial.polynomial<#ring1> {
%c0 = arith.constant 0 : index
%two = arith.constant 2 : i16
%five = arith.constant 5 : i16
%coeffs1 = tensor.from_elements %two, %two, %five : tensor<3xi16>
%coeffs2 = tensor.from_elements %five, %five, %two : tensor<3xi16>
- %poly1 = polynomial.from_tensor %coeffs1 : tensor<3xi16> -> !polynomial.polynomial<ring=#ring1>
- %poly2 = polynomial.from_tensor %coeffs2 : tensor<3xi16> -> !polynomial.polynomial<ring=#ring1>
+ %poly1 = polynomial.from_tensor %coeffs1 : tensor<3xi16> -> !polynomial.polynomial<#ring1>
+ %poly2 = polynomial.from_tensor %coeffs2 : tensor<3xi16> -> !polynomial.polynomial<#ring1>
- %3 = polynomial.mul %poly1, %poly2 : !polynomial.polynomial<ring=#ring1>
+ %3 = polynomial.mul %poly1, %poly2 : !polynomial.polynomial<#ring1>
- return %3 : !polynomial.polynomial<ring=#ring1>
+ return %3 : !polynomial.polynomial<#ring1>
}
- func.func @test_elementwise(%p0 : !polynomial.polynomial<ring=#ring1>, %p1: !polynomial.polynomial<ring=#ring1>) {
- %tp0 = tensor.from_elements %p0, %p1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
- %tp1 = tensor.from_elements %p1, %p0 : tensor<2x!polynomial.polynomial<ring=#ring1>>
+ func.func @test_elementwise(%p0 : !polynomial.polynomial<#ring1>, %p1: !polynomial.polynomial<#ring1>) {
+ %tp0 = tensor.from_elements %p0, %p1 : tensor<2x!polynomial.polynomial<#ring1>>
+ %tp1 = tensor.from_elements %p1, %p0 : tensor<2x!polynomial.polynomial<#ring1>>
%c = arith.constant 2 : i32
- %mul_const_sclr = polynomial.mul_scalar %tp0, %c : tensor<2x!polynomial.polynomial<ring=#ring1>>, i32
+ %mul_const_sclr = polynomial.mul_scalar %tp0, %c : tensor<2x!polynomial.polynomial<#ring1>>, i32
- %add = polynomial.add %tp0, %tp1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
- %sub = polynomial.sub %tp0, %tp1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
- %mul = polynomial.mul %tp0, %tp1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
+ %add = polynomial.add %tp0, %tp1 : tensor<2x!polynomial.polynomial<#ring1>>
+ %sub = polynomial.sub %tp0, %tp1 : tensor<2x!polynomial.polynomial<#ring1>>
+ %mul = polynomial.mul %tp0, %tp1 : tensor<2x!polynomial.polynomial<#ring1>>
return
}
- func.func @test_to_from_tensor(%p0 : !polynomial.polynomial<ring=#ring1>) {
+ func.func @test_to_from_tensor(%p0 : !polynomial.polynomial<#ring1>) {
%c0 = arith.constant 0 : index
%two = arith.constant 2 : i16
%coeffs1 = tensor.from_elements %two, %two : tensor<2xi16>
// CHECK: from_tensor
- %poly = polynomial.from_tensor %coeffs1 : tensor<2xi16> -> !polynomial.polynomial<ring=#ring1>
+ %poly = polynomial.from_tensor %coeffs1 : tensor<2xi16> -> !polynomial.polynomial<#ring1>
// CHECK: to_tensor
- %tensor = polynomial.to_tensor %poly : !polynomial.polynomial<ring=#ring1> -> tensor<1024xi16>
+ %tensor = polynomial.to_tensor %poly : !polynomial.polynomial<#ring1> -> tensor<1024xi16>
return
}
- func.func @test_degree(%p0 : !polynomial.polynomial<ring=#ring1>) {
- %0, %1 = polynomial.leading_term %p0 : !polynomial.polynomial<ring=#ring1> -> (index, i32)
+ func.func @test_degree(%p0 : !polynomial.polynomial<#ring1>) {
+ %0, %1 = polynomial.leading_term %p0 : !polynomial.polynomial<#ring1> -> (index, i32)
return
}
func.func @test_monomial() {
%deg = arith.constant 1023 : index
%five = arith.constant 5 : i16
- %0 = polynomial.monomial %five, %deg : (i16, index) -> !polynomial.polynomial<ring=#ring1>
+ %0 = polynomial.monomial %five, %deg : (i16, index) -> !polynomial.polynomial<#ring1>
return
}
func.func @test_monic_monomial_mul() {
%five = arith.constant 5 : index
- %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
- %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<ring=#ring1>, index) -> !polynomial.polynomial<ring=#ring1>
+ %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<#ring1>
+ %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<#ring1>, index) -> !polynomial.polynomial<#ring1>
return
}
func.func @test_constant() {
- %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
- %1 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<ring=#ring1>
- %2 = polynomial.constant {value=#polynomial.float_polynomial<1.5 + 0.5 x**2>} : !polynomial.polynomial<ring=#ring2>
+ %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<#ring1>
+ %1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
return
}
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index af8e4aa5da862..2c20e7bcbf1d6 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -1,8 +1,8 @@
// RUN: mlir-opt --split-input-file --verify-diagnostics %s
-#my_poly = #polynomial.int_polynomial<1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
-!ty = !polynomial.polynomial<ring=#ring>
+#my_poly = #polynomial.polynomial<1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+!ty = !polynomial.polynomial<#ring>
func.func @test_from_tensor_too_large_coeffs() {
%two = arith.constant 2 : i32
@@ -15,13 +15,13 @@ func.func @test_from_tensor_too_large_coeffs() {
// -----
-#my_poly = #polynomial.int_polynomial<1 + x**4>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#my_poly>
-!ty = !polynomial.polynomial<ring=#ring>
+#my_poly = #polynomial.polynomial<1 + x**4>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#my_poly>
+!ty = !polynomial.polynomial<#ring>
func.func @test_from_tensor_wrong_tensor_type() {
%two = arith.constant 2 : i32
%coeffs1 = tensor.from_elements %two, %two, %two, %two, %two : tensor<5xi32>
- // expected-error at below {{input type 'tensor<5xi32>' does not match output type '!polynomial.polynomial<ring = <coefficientType = i32, coefficientModulus = 256 : i32, polynomialModulus = <1 + x**4>>>'}}
+ // expected-error at below {{input type 'tensor<5xi32>' does not match output type '!polynomial.polynomial<#polynomial.ring<coefficientType=i32, coefficientModulus=256 : i32, polynomialModulus=#polynomial.polynomial<1 + x**4>>>'}}
// expected-note at below {{at most the degree of the polynomialModulus of the output type's ring attribute}}
%poly = polynomial.from_tensor %coeffs1 : tensor<5xi32> -> !ty
return
@@ -29,11 +29,11 @@ func.func @test_from_tensor_wrong_tensor_type() {
// -----
-#my_poly = #polynomial.int_polynomial<1 + x**4>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#my_poly>
-!ty = !polynomial.polynomial<ring=#ring>
+#my_poly = #polynomial.polynomial<1 + x**4>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#my_poly>
+!ty = !polynomial.polynomial<#ring>
func.func @test_to_tensor_wrong_output_tensor_type(%arg0 : !ty) {
- // expected-error at below {{input type '!polynomial.polynomial<ring = <coefficientType = i32, coefficientModulus = 256 : i32, polynomialModulus = <1 + x**4>>>' does not match output type 'tensor<5xi32>'}}
+ // expected-error at below {{input type '!polynomial.polynomial<#polynomial.ring<coefficientType=i32, coefficientModulus=256 : i32, polynomialModulus=#polynomial.polynomial<1 + x**4>>>' does not match output type 'tensor<5xi32>'}}
// expected-note at below {{at most the degree of the polynomialModulus of the input type's ring attribute}}
%tensor = polynomial.to_tensor %arg0 : !ty -> tensor<5xi32>
return
@@ -41,9 +41,9 @@ func.func @test_to_tensor_wrong_output_tensor_type(%arg0 : !ty) {
// -----
-#my_poly = #polynomial.int_polynomial<1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
-!ty = !polynomial.polynomial<ring=#ring>
+#my_poly = #polynomial.polynomial<1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+!ty = !polynomial.polynomial<#ring>
func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
%scalar = arith.constant 2 : i32 // should be i16
@@ -54,9 +54,9 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
// -----
-#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
-!poly_ty = !polynomial.polynomial<ring=#ring>
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
@@ -68,9 +68,9 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
// -----
-#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
-!poly_ty = !polynomial.polynomial<ring=#ring>
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
@@ -82,10 +82,10 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
// -----
-#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
-!poly_ty = !polynomial.polynomial<ring=#ring>
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
@@ -97,9 +97,9 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
// -----
-#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
-!poly_ty = !polynomial.polynomial<ring=#ring>
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
@@ -112,9 +112,9 @@ func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
// -----
-#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-!poly_ty = !polynomial.polynomial<ring=#ring>
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+!poly_ty = !polynomial.polynomial<#ring>
// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
@@ -126,10 +126,10 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
// -----
-#my_poly = #polynomial.int_polynomial<-1 + x**8>
+#my_poly = #polynomial.polynomial<-1 + x**8>
// A valid root is 31
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=32:i16>
-!poly_ty = !polynomial.polynomial<ring=#ring>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=32>
+!poly_ty = !polynomial.polynomial<#ring>
// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
diff --git a/mlir/test/Dialect/Polynomial/types.mlir b/mlir/test/Dialect/Polynomial/types.mlir
index dcc5663ceb84c..00296a36e890f 100644
--- a/mlir/test/Dialect/Polynomial/types.mlir
+++ b/mlir/test/Dialect/Polynomial/types.mlir
@@ -2,13 +2,13 @@
// CHECK-LABEL: func @test_types
// CHECK-SAME: !polynomial.polynomial<
-// CHECK-SAME: ring = <
-// CHECK-SAME: coefficientType = i32,
-// CHECK-SAME: coefficientModulus = 2837465 : i32,
-// CHECK-SAME: polynomialModulus = <1 + x**1024>>>
-#my_poly = #polynomial.int_polynomial<1 + x**1024>
-#ring1 = #polynomial.ring<coefficientType = i32, coefficientModulus = 2837465 : i32, polynomialModulus=#my_poly>
-!ty = !polynomial.polynomial<ring=#ring1>
+// CHECK-SAME: #polynomial.ring<
+// CHECK-SAME: coefficientType=i32,
+// CHECK-SAME: coefficientModulus=2837465 : i32,
+// CHECK-SAME: polynomialModulus=#polynomial.polynomial<1 + x**1024>>>
+#my_poly = #polynomial.polynomial<1 + x**1024>
+#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
+!ty = !polynomial.polynomial<#ring1>
func.func @test_types(%0: !ty) -> !ty {
return %0 : !ty
}
@@ -16,13 +16,13 @@ func.func @test_types(%0: !ty) -> !ty {
// CHECK-LABEL: func @test_non_x_variable_64_bit
// CHECK-SAME: !polynomial.polynomial<
-// CHECK-SAME: ring = <
-// CHECK-SAME: coefficientType = i64,
-// CHECK-SAME: coefficientModulus = 2837465 : i64,
-// CHECK-SAME: polynomialModulus = <2 + 4x + x**3>>>
-#my_poly_2 = #polynomial.int_polynomial<t**3 + 4t + 2>
-#ring2 = #polynomial.ring<coefficientType = i64, coefficientModulus = 2837465 : i64, polynomialModulus=#my_poly_2>
-!ty2 = !polynomial.polynomial<ring=#ring2>
+// CHECK-SAME: #polynomial.ring<
+// CHECK-SAME: coefficientType=i64,
+// CHECK-SAME: coefficientModulus=2837465 : i64,
+// CHECK-SAME: polynomialModulus=#polynomial.polynomial<2 + 4x + x**3>>>
+#my_poly_2 = #polynomial.polynomial<t**3 + 4t + 2>
+#ring2 = #polynomial.ring<coefficientType=i64, coefficientModulus=2837465, polynomialModulus=#my_poly_2>
+!ty2 = !polynomial.polynomial<#ring2>
func.func @test_non_x_variable_64_bit(%0: !ty2) -> !ty2 {
return %0 : !ty2
}
@@ -30,36 +30,27 @@ func.func @test_non_x_variable_64_bit(%0: !ty2) -> !ty2 {
// CHECK-LABEL: func @test_linear_poly
// CHECK-SAME: !polynomial.polynomial<
-// CHECK-SAME: ring = <
-// CHECK-SAME: coefficientType = i32,
-// CHECK-SAME: coefficientModulus = 12 : i32,
-// CHECK-SAME: polynomialModulus = <4x>>
-#my_poly_3 = #polynomial.int_polynomial<4x>
-#ring3 = #polynomial.ring<coefficientType = i32, coefficientModulus=12 : i32, polynomialModulus=#my_poly_3>
-!ty3 = !polynomial.polynomial<ring=#ring3>
+// CHECK-SAME: #polynomial.ring<
+// CHECK-SAME: coefficientType=i32,
+// CHECK-SAME: coefficientModulus=12 : i32,
+// CHECK-SAME: polynomialModulus=#polynomial.polynomial<4x>>
+#my_poly_3 = #polynomial.polynomial<4x>
+#ring3 = #polynomial.ring<coefficientType=i32, coefficientModulus=12, polynomialModulus=#my_poly_3>
+!ty3 = !polynomial.polynomial<#ring3>
func.func @test_linear_poly(%0: !ty3) -> !ty3 {
return %0 : !ty3
}
// CHECK-LABEL: func @test_negative_leading_1
// CHECK-SAME: !polynomial.polynomial<
-// CHECK-SAME: ring = <
-// CHECK-SAME: coefficientType = i32,
-// CHECK-SAME: coefficientModulus = 2837465 : i32,
-// CHECK-SAME: polynomialModulus = <-1 + x**1024>>>
-#my_poly_4 = #polynomial.int_polynomial<-1 + x**1024>
-#ring4 = #polynomial.ring<coefficientType = i32, coefficientModulus = 2837465 : i32, polynomialModulus=#my_poly_4>
-!ty4 = !polynomial.polynomial<ring=#ring4>
+// CHECK-SAME: #polynomial.ring<
+// CHECK-SAME: coefficientType=i32,
+// CHECK-SAME: coefficientModulus=2837465 : i32,
+// CHECK-SAME: polynomialModulus=#polynomial.polynomial<-1 + x**1024>>>
+#my_poly_4 = #polynomial.polynomial<-1 + x**1024>
+#ring4 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly_4>
+!ty4 = !polynomial.polynomial<#ring4>
func.func @test_negative_leading_1(%0: !ty4) -> !ty4 {
return %0 : !ty4
}
-// CHECK-LABEL: func @test_float_coefficients
-// CHECK-SAME: !polynomial.polynomial<ring = <coefficientType = f32>>
-#my_poly_5 = #polynomial.float_polynomial<0.5 + 1.6e03 x**1024>
-#ring5 = #polynomial.ring<coefficientType=f32>
-!ty5 = !polynomial.polynomial<ring=#ring5>
-func.func @test_float_coefficients(%0: !ty5) -> !ty5 {
- return %0 : !ty5
-}
-
>From c4815f7ac5158ce19901c44fcd6df7b3c0ba1144 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 13 May 2024 10:25:53 -0700
Subject: [PATCH 2/2] Revert "Revert "Support polynomial attributes with
floating point coefficients (#91137)""
This reverts commit f085443c71c6e729c972d9001a809193e2e53fea.
---
.../mlir/Dialect/Polynomial/IR/Polynomial.h | 193 ++++++++++++++----
.../mlir/Dialect/Polynomial/IR/Polynomial.td | 139 ++++++++-----
mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp | 78 +++----
.../Polynomial/IR/PolynomialAttributes.cpp | 172 +++++++---------
mlir/test/Dialect/Polynomial/attributes.mlir | 22 +-
mlir/test/Dialect/Polynomial/ops.mlir | 64 +++---
mlir/test/Dialect/Polynomial/ops_errors.mlir | 66 +++---
mlir/test/Dialect/Polynomial/types.mlir | 65 +++---
8 files changed, 451 insertions(+), 348 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 3325a6fa3f9fc..2b3f0e105c6c5 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -11,10 +11,13 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Hashing.h"
-#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir {
@@ -27,98 +30,202 @@ namespace polynomial {
/// would want to specify 128-bit polynomials statically in the source code.
constexpr unsigned apintBitWidth = 64;
-/// A class representing a monomial of a single-variable polynomial with integer
-/// coefficients.
-class Monomial {
+template <typename CoefficientType>
+class MonomialBase {
public:
- Monomial(int64_t coeff, uint64_t expo)
- : coefficient(apintBitWidth, coeff), exponent(apintBitWidth, expo) {}
-
- Monomial(const APInt &coeff, const APInt &expo)
+ MonomialBase(const CoefficientType &coeff, const APInt &expo)
: coefficient(coeff), exponent(expo) {}
+ virtual ~MonomialBase() = 0;
- Monomial() : coefficient(apintBitWidth, 0), exponent(apintBitWidth, 0) {}
+ const CoefficientType &getCoefficient() const { return coefficient; }
+ CoefficientType &getMutableCoefficient() { return coefficient; }
+ const APInt &getExponent() const { return exponent; }
+ void setCoefficient(const CoefficientType &coeff) { coefficient = coeff; }
+ void setExponent(const APInt &exp) { exponent = exp; }
- bool operator==(const Monomial &other) const {
+ bool operator==(const MonomialBase &other) const {
return other.coefficient == coefficient && other.exponent == exponent;
}
- bool operator!=(const Monomial &other) const {
+ bool operator!=(const MonomialBase &other) const {
return other.coefficient != coefficient || other.exponent != exponent;
}
/// Monomials are ordered by exponent.
- bool operator<(const Monomial &other) const {
+ bool operator<(const MonomialBase &other) const {
return (exponent.ult(other.exponent));
}
- friend ::llvm::hash_code hash_value(const Monomial &arg);
+ virtual bool isMonic() const = 0;
+ virtual void
+ coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
-public:
- APInt coefficient;
+ template <typename T>
+ friend ::llvm::hash_code hash_value(const MonomialBase<T> &arg);
- // Always unsigned
+protected:
+ CoefficientType coefficient;
APInt exponent;
};
-/// A single-variable polynomial with integer coefficients.
-///
-/// Eg: x^1024 + x + 1
-///
-/// The symbols used as the polynomial's indeterminate don't matter, so long as
-/// it is used consistently throughout the polynomial.
-class Polynomial {
+/// A class representing a monomial of a single-variable polynomial with integer
+/// coefficients.
+class IntMonomial : public MonomialBase<APInt> {
public:
- Polynomial() = delete;
+ IntMonomial(int64_t coeff, uint64_t expo)
+ : MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
- explicit Polynomial(ArrayRef<Monomial> terms) : terms(terms){};
+ IntMonomial()
+ : MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}
- // Returns a Polynomial from a list of monomials.
- // Fails if two monomials have the same exponent.
- static FailureOr<Polynomial> fromMonomials(ArrayRef<Monomial> monomials);
+ ~IntMonomial() = default;
- /// Returns a polynomial with coefficients given by `coeffs`. The value
- /// coeffs[i] is converted to a monomial with exponent i.
- static Polynomial fromCoefficients(ArrayRef<int64_t> coeffs);
+ bool isMonic() const override { return coefficient == 1; }
+
+ void coefficientToString(llvm::SmallString<16> &coeffString) const override {
+ coefficient.toStringSigned(coeffString);
+ }
+};
+
+/// A class representing a monomial of a single-variable polynomial with integer
+/// coefficients.
+class FloatMonomial : public MonomialBase<APFloat> {
+public:
+ FloatMonomial(double coeff, uint64_t expo)
+ : MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
+
+ FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}
+
+ ~FloatMonomial() = default;
+
+ bool isMonic() const override { return coefficient == APFloat(1.0); }
+
+ void coefficientToString(llvm::SmallString<16> &coeffString) const override {
+ coefficient.toString(coeffString);
+ }
+};
+
+template <typename Monomial>
+class PolynomialBase {
+public:
+ PolynomialBase() = delete;
+
+ explicit PolynomialBase(ArrayRef<Monomial> terms) : terms(terms){};
explicit operator bool() const { return !terms.empty(); }
- bool operator==(const Polynomial &other) const {
+ bool operator==(const PolynomialBase &other) const {
return other.terms == terms;
}
- bool operator!=(const Polynomial &other) const {
+ bool operator!=(const PolynomialBase &other) const {
return !(other.terms == terms);
}
- // Prints polynomial to 'os'.
- void print(raw_ostream &os) const;
void print(raw_ostream &os, ::llvm::StringRef separator,
- ::llvm::StringRef exponentiation) const;
+ ::llvm::StringRef exponentiation) const {
+ bool first = true;
+ for (const Monomial &term : getTerms()) {
+ if (first) {
+ first = false;
+ } else {
+ os << separator;
+ }
+ std::string coeffToPrint;
+ if (term.isMonic() && term.getExponent().uge(1)) {
+ coeffToPrint = "";
+ } else {
+ llvm::SmallString<16> coeffString;
+ term.coefficientToString(coeffString);
+ coeffToPrint = coeffString.str();
+ }
+
+ if (term.getExponent() == 0) {
+ os << coeffToPrint;
+ } else if (term.getExponent() == 1) {
+ os << coeffToPrint << "x";
+ } else {
+ llvm::SmallString<16> expString;
+ term.getExponent().toStringSigned(expString);
+ os << coeffToPrint << "x" << exponentiation << expString;
+ }
+ }
+ }
+
+ // Prints polynomial to 'os'.
+ void print(raw_ostream &os) const { print(os, " + ", "**"); }
+
void dump() const;
// Prints polynomial so that it can be used as a valid identifier
- std::string toIdentifier() const;
+ std::string toIdentifier() const {
+ std::string result;
+ llvm::raw_string_ostream os(result);
+ print(os, "_", "");
+ return os.str();
+ }
- unsigned getDegree() const;
+ unsigned getDegree() const {
+ return terms.back().getExponent().getZExtValue();
+ }
ArrayRef<Monomial> getTerms() const { return terms; }
- friend ::llvm::hash_code hash_value(const Polynomial &arg);
+ template <typename T>
+ friend ::llvm::hash_code hash_value(const PolynomialBase<T> &arg);
private:
// The monomial terms for this polynomial.
SmallVector<Monomial> terms;
};
-// Make Polynomial hashable.
-inline ::llvm::hash_code hash_value(const Polynomial &arg) {
+/// A single-variable polynomial with integer coefficients.
+///
+/// Eg: x^1024 + x + 1
+class IntPolynomial : public PolynomialBase<IntMonomial> {
+public:
+ explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
+
+ // Returns a Polynomial from a list of monomials.
+ // Fails if two monomials have the same exponent.
+ static FailureOr<IntPolynomial>
+ fromMonomials(ArrayRef<IntMonomial> monomials);
+
+ /// Returns a polynomial with coefficients given by `coeffs`. The value
+ /// coeffs[i] is converted to a monomial with exponent i.
+ static IntPolynomial fromCoefficients(ArrayRef<int64_t> coeffs);
+};
+
+/// A single-variable polynomial with double coefficients.
+///
+/// Eg: 1.0 x^1024 + 3.5 x + 1e-05
+class FloatPolynomial : public PolynomialBase<FloatMonomial> {
+public:
+ explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
+ : PolynomialBase(terms) {}
+
+ // Returns a Polynomial from a list of monomials.
+ // Fails if two monomials have the same exponent.
+ static FailureOr<FloatPolynomial>
+ fromMonomials(ArrayRef<FloatMonomial> monomials);
+
+ /// Returns a polynomial with coefficients given by `coeffs`. The value
+ /// coeffs[i] is converted to a monomial with exponent i.
+ static FloatPolynomial fromCoefficients(ArrayRef<double> coeffs);
+};
+
+// Make Polynomials hashable.
+template <typename T>
+inline ::llvm::hash_code hash_value(const PolynomialBase<T> &arg) {
return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
}
-inline ::llvm::hash_code hash_value(const Monomial &arg) {
+template <typename T>
+inline ::llvm::hash_code hash_value(const MonomialBase<T> &arg) {
return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
::llvm::hash_value(arg.exponent));
}
-inline raw_ostream &operator<<(raw_ostream &os, const Polynomial &polynomial) {
+template <typename T>
+inline raw_ostream &operator<<(raw_ostream &os,
+ const PolynomialBase<T> &polynomial) {
polynomial.print(os);
return os;
}
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index ed1f4ce8b7e59..ae8484501a50d 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -39,14 +39,14 @@ def Polynomial_Dialect : Dialect {
%a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
// A constant polynomial in a ring with i32 coefficients, modulo (x^1024 + 1)
- #modulus = #polynomial.polynomial<1 + x**1024>
+ #modulus = #polynomial.int_polynomial<1 + x**1024>
#ring = #polynomial.ring<coefficientType=i32, polynomialModulus=#modulus>
%a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
// A constant polynomial in a ring with i32 coefficients, with a polynomial
// modulus of (x^1024 + 1) and a coefficient modulus of 17.
- #modulus = #polynomial.polynomial<1 + x**1024>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=17, polynomialModulus=#modulus>
+ #modulus = #polynomial.int_polynomial<1 + x**1024>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=17:i32, polynomialModulus=#modulus>
%a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
```
}];
@@ -60,12 +60,12 @@ class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
let mnemonic = attrMnemonic;
}
-def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
- let summary = "An attribute containing a single-variable polynomial.";
+def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynomial"> {
+ let summary = "An attribute containing a single-variable polynomial with integer coefficients.";
let description = [{
- A polynomial attribute represents a single-variable polynomial, which
- is used to define the modulus of a `RingAttr`, as well as to define constants
- and perform constant folding for `polynomial` ops.
+ A polynomial attribute represents a single-variable polynomial with integer
+ coefficients, which is used to define the modulus of a `RingAttr`, as well
+ as to define constants and perform constant folding for `polynomial` ops.
The polynomial must be expressed as a list of monomial terms, with addition
or subtraction between them. The choice of variable name is arbitrary, but
@@ -76,10 +76,32 @@ def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
Example:
```mlir
- #poly = #polynomial.polynomial<x**1024 + 1>
+ #poly = #polynomial.int_polynomial<x**1024 + 1>
```
}];
- let parameters = (ins "::mlir::polynomial::Polynomial":$polynomial);
+ let parameters = (ins "::mlir::polynomial::IntPolynomial":$polynomial);
+ let hasCustomAssemblyFormat = 1;
+}
+
+def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
+ let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
+ let description = [{
+ A polynomial attribute represents a single-variable polynomial with double
+ precision floating point coefficients.
+
+ The polynomial must be expressed as a list of monomial terms, with addition
+ or subtraction between them. The choice of variable name is arbitrary, but
+ must be consistent across all the monomials used to define a single
+ attribute. The order of monomial terms is arbitrary, each monomial degree
+ must occur at most once.
+
+ Example:
+
+ ```mlir
+ #poly = #polynomial.float_polynomial<0.5 x**7 + 1.5>
+ ```
+ }];
+ let parameters = (ins "FloatPolynomial":$polynomial);
let hasCustomAssemblyFormat = 1;
}
@@ -104,9 +126,9 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
`x**1024 - 1`.
```mlir
- #poly_mod = #polynomial.polynomial<-1 + x**1024>
+ #poly_mod = #polynomial.int_polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i32,
- coefficientModulus=4294967291,
+ coefficientModulus=4294967291:i32,
polynomialModulus=#poly_mod>
%0 = ... : polynomial.polynomial<#ring>
@@ -123,19 +145,24 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
let parameters = (ins
"Type": $coefficientType,
OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
- OptionalParameter<"::mlir::polynomial::PolynomialAttr">: $polynomialModulus,
+ OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
);
-
+ let assemblyFormat = "`<` struct(params) `>`";
let builders = [
- AttrBuilder<
+ AttrBuilderWithInferredContext<
(ins "::mlir::Type":$coefficientTy,
- "::mlir::IntegerAttr":$coefficientModulusAttr,
- "::mlir::polynomial::PolynomialAttr":$polynomialModulusAttr), [{
- return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, nullptr);
- }]>
+ CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
+ CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
+ CArg<"::mlir::IntegerAttr", "nullptr"> :$primitiveRootAttr), [{
+ return $_get(
+ coefficientTy.getContext(),
+ coefficientTy,
+ coefficientModulusAttr,
+ polynomialModulusAttr,
+ primitiveRootAttr);
+ }]>,
];
- let hasCustomAssemblyFormat = 1;
}
class Polynomial_Type<string name, string typeMnemonic>
@@ -149,7 +176,7 @@ def Polynomial_PolynomialType : Polynomial_Type<"Polynomial", "polynomial"> {
A type for polynomials in a polynomial quotient ring.
}];
let parameters = (ins Polynomial_RingAttr:$ring);
- let assemblyFormat = "`<` $ring `>`";
+ let assemblyFormat = "`<` struct(params) `>`";
}
def PolynomialLike: TypeOrContainer<Polynomial_PolynomialType, "polynomial-like">;
@@ -187,10 +214,10 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
```mlir
// add two polynomials modulo x^1024 - 1
- #poly = #polynomial.polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
- %1 = polynomial.constant #polynomial.polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+ #poly = #polynomial.int_polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
+ %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
%2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
```
}];
@@ -211,10 +238,10 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
```mlir
// subtract two polynomials modulo x^1024 - 1
- #poly = #polynomial.polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
- %1 = polynomial.constant #polynomial.polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+ #poly = #polynomial.int_polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
+ %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
%2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
```
}];
@@ -235,10 +262,10 @@ def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
```mlir
// multiply two polynomials modulo x^1024 - 1
- #poly = #polynomial.polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
- %1 = polynomial.constant #polynomial.polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+ #poly = #polynomial.int_polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
+ %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
%2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring>
```
}];
@@ -260,9 +287,9 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
```mlir
// multiply two polynomials modulo x^1024 - 1
- #poly = #polynomial.polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ #poly = #polynomial.int_polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
+ %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
%1 = arith.constant 3 : i32
%2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32
```
@@ -291,9 +318,9 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
Example:
```mlir
- #poly = #polynomial.polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ #poly = #polynomial.int_polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
+ %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
%1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32)
```
}];
@@ -314,8 +341,8 @@ def Polynomial_MonomialOp: Polynomial_Op<"monomial"> {
Example:
```mlir
- #poly = #polynomial.polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+ #poly = #polynomial.int_polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
%deg = arith.constant 1023 : index
%five = arith.constant 5 : i32
%0 = polynomial.monomial %five, %deg : (i32, index) -> !polynomial.polynomial<#ring>
@@ -354,8 +381,8 @@ def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> {
Example:
```mlir
- #poly = #polynomial.polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+ #poly = #polynomial.int_polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
%two = arith.constant 2 : i32
%five = arith.constant 5 : i32
%coeffs = tensor.from_elements %two, %two, %five : tensor<3xi32>
@@ -393,8 +420,8 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
Example:
```mlir
- #poly = #polynomial.polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+ #poly = #polynomial.int_polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
%two = arith.constant 2 : i32
%five = arith.constant 5 : i32
%coeffs = tensor.from_elements %two, %two, %five : tensor<3xi32>
@@ -405,24 +432,32 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
let arguments = (ins Polynomial_PolynomialType:$input);
let results = (outs RankedTensorOf<[AnyInteger]>:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
-
let hasVerifier = 1;
}
-def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
+def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
+ Polynomial_FloatPolynomialAttr,
+ Polynomial_IntPolynomialAttr
+]>;
+
+// Not deriving from Polynomial_Op due to need for custom assembly format
+def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
let summary = "Define a constant polynomial via an attribute.";
let description = [{
Example:
```mlir
- #poly = #polynomial.polynomial<x**1024 - 1>
- #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
- %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+ #poly = #polynomial.int_polynomial<x**1024 - 1>
+ #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
+ %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+
+ #float_ring = #polynomial.ring<coefficientType=f32>
+ %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring>
```
}];
- let arguments = (ins Polynomial_PolynomialAttr:$input);
+ let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
let results = (outs Polynomial_PolynomialType:$output);
- let assemblyFormat = "$input attr-dict `:` type($output)";
+ let assemblyFormat = "attr-dict `:` type($output)";
}
def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
diff --git a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
index 5916ffba78e24..42e678fad060c 100644
--- a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
@@ -9,87 +9,63 @@
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
#include "mlir/Support/LogicalResult.h"
-#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/Support/raw_ostream.h"
namespace mlir {
namespace polynomial {
-FailureOr<Polynomial> Polynomial::fromMonomials(ArrayRef<Monomial> monomials) {
+template <typename T>
+MonomialBase<T>::~MonomialBase() {}
+
+template <typename PolyT, typename MonomialT>
+FailureOr<PolyT> fromMonomialsImpl(ArrayRef<MonomialT> monomials) {
// A polynomial's terms are canonically stored in order of increasing degree.
- auto monomialsCopy = llvm::SmallVector<Monomial>(monomials);
+ auto monomialsCopy = llvm::SmallVector<MonomialT>(monomials);
std::sort(monomialsCopy.begin(), monomialsCopy.end());
// Ensure non-unique exponents are not present. Since we sorted the list by
// exponent, a linear scan of adjancent monomials suffices.
if (std::adjacent_find(monomialsCopy.begin(), monomialsCopy.end(),
- [](const Monomial &lhs, const Monomial &rhs) {
- return lhs.exponent == rhs.exponent;
+ [](const MonomialT &lhs, const MonomialT &rhs) {
+ return lhs.getExponent() == rhs.getExponent();
}) != monomialsCopy.end()) {
return failure();
}
- return Polynomial(monomialsCopy);
+ return PolyT(monomialsCopy);
+}
+
+FailureOr<IntPolynomial>
+IntPolynomial::fromMonomials(ArrayRef<IntMonomial> monomials) {
+ return fromMonomialsImpl<IntPolynomial, IntMonomial>(monomials);
+}
+
+FailureOr<FloatPolynomial>
+FloatPolynomial::fromMonomials(ArrayRef<FloatMonomial> monomials) {
+ return fromMonomialsImpl<FloatPolynomial, FloatMonomial>(monomials);
}
-Polynomial Polynomial::fromCoefficients(ArrayRef<int64_t> coeffs) {
- llvm::SmallVector<Monomial> monomials;
+template <typename PolyT, typename MonomialT, typename CoeffT>
+PolyT fromCoefficientsImpl(ArrayRef<CoeffT> coeffs) {
+ llvm::SmallVector<MonomialT> monomials;
auto size = coeffs.size();
monomials.reserve(size);
for (size_t i = 0; i < size; i++) {
monomials.emplace_back(coeffs[i], i);
}
- auto result = Polynomial::fromMonomials(monomials);
+ auto result = PolyT::fromMonomials(monomials);
// Construction guarantees unique exponents, so the failure mode of
// fromMonomials can be bypassed.
assert(succeeded(result));
return result.value();
}
-void Polynomial::print(raw_ostream &os, ::llvm::StringRef separator,
- ::llvm::StringRef exponentiation) const {
- bool first = true;
- for (const Monomial &term : terms) {
- if (first) {
- first = false;
- } else {
- os << separator;
- }
- std::string coeffToPrint;
- if (term.coefficient == 1 && term.exponent.uge(1)) {
- coeffToPrint = "";
- } else {
- llvm::SmallString<16> coeffString;
- term.coefficient.toStringSigned(coeffString);
- coeffToPrint = coeffString.str();
- }
-
- if (term.exponent == 0) {
- os << coeffToPrint;
- } else if (term.exponent == 1) {
- os << coeffToPrint << "x";
- } else {
- llvm::SmallString<16> expString;
- term.exponent.toStringSigned(expString);
- os << coeffToPrint << "x" << exponentiation << expString;
- }
- }
-}
-
-void Polynomial::print(raw_ostream &os) const { print(os, " + ", "**"); }
-
-std::string Polynomial::toIdentifier() const {
- std::string result;
- llvm::raw_string_ostream os(result);
- print(os, "_", "");
- return os.str();
+IntPolynomial IntPolynomial::fromCoefficients(ArrayRef<int64_t> coeffs) {
+ return fromCoefficientsImpl<IntPolynomial, IntMonomial, int64_t>(coeffs);
}
-unsigned Polynomial::getDegree() const {
- return terms.back().exponent.getZExtValue();
+FloatPolynomial FloatPolynomial::fromCoefficients(ArrayRef<double> coeffs) {
+ return fromCoefficientsImpl<FloatPolynomial, FloatMonomial, double>(coeffs);
}
} // namespace polynomial
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index 236bb78966352..890ce5226c30f 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
@@ -17,22 +18,31 @@
namespace mlir {
namespace polynomial {
-void PolynomialAttr::print(AsmPrinter &p) const {
- p << '<';
- p << getPolynomial();
- p << '>';
+void IntPolynomialAttr::print(AsmPrinter &p) const {
+ p << '<' << getPolynomial() << '>';
}
+void FloatPolynomialAttr::print(AsmPrinter &p) const {
+ p << '<' << getPolynomial() << '>';
+}
+
+/// A callable that parses the coefficient using the appropriate method for the
+/// given monomial type, and stores the parsed coefficient value on the
+/// monomial.
+template <typename MonomialType>
+using ParseCoefficientFn = std::function<OptionalParseResult(MonomialType &)>;
+
/// Try to parse a monomial. If successful, populate the fields of the outparam
/// `monomial` with the results, and the `variable` outparam with the parsed
/// variable name. Sets shouldParseMore to true if the monomial is followed by
/// a '+'.
-ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
- llvm::StringRef &variable, bool &isConstantTerm,
- bool &shouldParseMore) {
- APInt parsedCoeff(apintBitWidth, 1);
- auto parsedCoeffResult = parser.parseOptionalInteger(parsedCoeff);
- monomial.coefficient = parsedCoeff;
+///
+template <typename Monomial>
+ParseResult
+parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
+ bool &isConstantTerm, bool &shouldParseMore,
+ ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
+ OptionalParseResult parsedCoeffResult = parseAndStoreCoefficient(monomial);
isConstantTerm = false;
shouldParseMore = false;
@@ -44,7 +54,7 @@ ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
if (!parsedCoeffResult.has_value()) {
return failure();
}
- monomial.exponent = APInt(apintBitWidth, 0);
+ monomial.setExponent(APInt(apintBitWidth, 0));
isConstantTerm = true;
shouldParseMore = true;
return success();
@@ -58,7 +68,7 @@ ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
return failure();
}
- monomial.exponent = APInt(apintBitWidth, 0);
+ monomial.setExponent(APInt(apintBitWidth, 0));
isConstantTerm = true;
return success();
}
@@ -80,9 +90,9 @@ ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
return failure();
}
- monomial.exponent = parsedExponent;
+ monomial.setExponent(parsedExponent);
} else {
- monomial.exponent = APInt(apintBitWidth, 1);
+ monomial.setExponent(APInt(apintBitWidth, 1));
}
if (succeeded(parser.parseOptionalPlus())) {
@@ -91,22 +101,21 @@ ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
return success();
}
-Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
- if (failed(parser.parseLess()))
- return {};
-
- llvm::SmallVector<Monomial> monomials;
- llvm::StringSet<> variables;
-
+template <typename PolynoimalAttrTy, typename Monomial>
+LogicalResult
+parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
+ llvm::StringSet<> &variables,
+ ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
while (true) {
Monomial parsedMonomial;
llvm::StringRef parsedVariableRef;
bool isConstantTerm;
bool shouldParseMore;
- if (failed(parseMonomial(parser, parsedMonomial, parsedVariableRef,
- isConstantTerm, shouldParseMore))) {
+ if (failed(parseMonomial<Monomial>(
+ parser, parsedMonomial, parsedVariableRef, isConstantTerm,
+ shouldParseMore, parseAndStoreCoefficient))) {
parser.emitError(parser.getCurrentLocation(), "expected a monomial");
- return {};
+ return failure();
}
if (!isConstantTerm) {
@@ -124,7 +133,7 @@ Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
parser.emitError(
parser.getCurrentLocation(),
"expected + and more monomials, or > to end polynomial attribute");
- return {};
+ return failure();
}
if (variables.size() > 1) {
@@ -133,96 +142,67 @@ Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
parser.getCurrentLocation(),
"polynomials must have one indeterminate, but there were multiple: " +
vars);
+ return failure();
}
- auto result = Polynomial::fromMonomials(monomials);
- if (failed(result)) {
- parser.emitError(parser.getCurrentLocation())
- << "parsed polynomial must have unique exponents among monomials";
- return {};
- }
- return PolynomialAttr::get(parser.getContext(), result.value());
-}
-
-void RingAttr::print(AsmPrinter &p) const {
- p << "#polynomial.ring<coefficientType=" << getCoefficientType()
- << ", coefficientModulus=" << getCoefficientModulus()
- << ", polynomialModulus=" << getPolynomialModulus() << '>';
+ return success();
}
-Attribute RingAttr::parse(AsmParser &parser, Type type) {
+Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
- if (failed(parser.parseKeyword("coefficientType")))
- return {};
+ llvm::SmallVector<IntMonomial> monomials;
+ llvm::StringSet<> variables;
- if (failed(parser.parseEqual()))
+ if (failed(parsePolynomialAttr<IntPolynomialAttr, IntMonomial>(
+ parser, monomials, variables,
+ [&](IntMonomial &monomial) -> OptionalParseResult {
+ APInt parsedCoeff(apintBitWidth, 1);
+ OptionalParseResult result =
+ parser.parseOptionalInteger(parsedCoeff);
+ monomial.setCoefficient(parsedCoeff);
+ return result;
+ }))) {
return {};
+ }
- Type ty;
- if (failed(parser.parseType(ty)))
+ auto result = IntPolynomial::fromMonomials(monomials);
+ if (failed(result)) {
+ parser.emitError(parser.getCurrentLocation())
+ << "parsed polynomial must have unique exponents among monomials";
return {};
+ }
+ return IntPolynomialAttr::get(parser.getContext(), result.value());
+}
- if (failed(parser.parseComma()))
+Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
+ if (failed(parser.parseLess()))
return {};
- IntegerAttr coefficientModulusAttr = nullptr;
- if (succeeded(parser.parseKeyword("coefficientModulus"))) {
- if (failed(parser.parseEqual()))
- return {};
-
- IntegerType iType = mlir::dyn_cast<IntegerType>(ty);
- if (!iType) {
- parser.emitError(parser.getCurrentLocation(),
- "coefficientType must specify an integer type");
- return {};
- }
- APInt coefficientModulus(iType.getWidth(), 0);
- auto result = parser.parseInteger(coefficientModulus);
- if (failed(result)) {
- parser.emitError(parser.getCurrentLocation(),
- "invalid coefficient modulus");
- return {};
- }
- coefficientModulusAttr = IntegerAttr::get(iType, coefficientModulus);
-
- if (failed(parser.parseComma()))
- return {};
- }
-
- PolynomialAttr polyAttr = nullptr;
- if (succeeded(parser.parseKeyword("polynomialModulus"))) {
- if (failed(parser.parseEqual()))
- return {};
+ llvm::SmallVector<FloatMonomial> monomials;
+ llvm::StringSet<> variables;
- PolynomialAttr attr;
- if (failed(parser.parseAttribute<PolynomialAttr>(attr)))
- return {};
- polyAttr = attr;
- }
+ ParseCoefficientFn<FloatMonomial> parseAndStoreCoefficient =
+ [&](FloatMonomial &monomial) -> OptionalParseResult {
+ double coeffValue = 1.0;
+ ParseResult result = parser.parseFloat(coeffValue);
+ monomial.setCoefficient(APFloat(coeffValue));
+ return OptionalParseResult(result);
+ };
- Polynomial poly = polyAttr.getPolynomial();
- APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
- IntegerAttr rootAttr = nullptr;
- if (succeeded(parser.parseOptionalComma())) {
- if (failed(parser.parseKeyword("primitiveRoot")) ||
- failed(parser.parseEqual()))
- return {};
-
- ParseResult result = parser.parseInteger(root);
- if (failed(result)) {
- parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot");
- return {};
- }
- rootAttr = IntegerAttr::get(coefficientModulusAttr.getType(), root);
+ if (failed(parsePolynomialAttr<FloatPolynomialAttr, FloatMonomial>(
+ parser, monomials, variables, parseAndStoreCoefficient))) {
+ return {};
}
- if (failed(parser.parseGreater()))
+ auto result = FloatPolynomial::fromMonomials(monomials);
+ if (failed(result)) {
+ parser.emitError(parser.getCurrentLocation())
+ << "parsed polynomial must have unique exponents among monomials";
return {};
-
- return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
- polyAttr, rootAttr);
+ }
+ return FloatPolynomialAttr::get(parser.getContext(), result.value());
}
} // namespace polynomial
diff --git a/mlir/test/Dialect/Polynomial/attributes.mlir b/mlir/test/Dialect/Polynomial/attributes.mlir
index 3973ae3944335..4bdfd44fd4d15 100644
--- a/mlir/test/Dialect/Polynomial/attributes.mlir
+++ b/mlir/test/Dialect/Polynomial/attributes.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s --split-input-file --verify-diagnostics
-#my_poly = #polynomial.polynomial<y + x**1024>
+#my_poly = #polynomial.int_polynomial<y + x**1024>
// expected-error at below {{polynomials must have one indeterminate, but there were multiple: x, y}}
#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
@@ -9,37 +9,31 @@
// expected-error at below {{expected integer value}}
// expected-error at below {{expected a monomial}}
// expected-error at below {{found invalid integer exponent}}
-#my_poly = #polynomial.polynomial<5 + x**f>
+#my_poly = #polynomial.int_polynomial<5 + x**f>
#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
// -----
-#my_poly = #polynomial.polynomial<5 + x**2 + 3x**2>
+#my_poly = #polynomial.int_polynomial<5 + x**2 + 3x**2>
// expected-error at below {{parsed polynomial must have unique exponents among monomials}}
#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
// -----
// expected-error at below {{expected + and more monomials, or > to end polynomial attribute}}
-#my_poly = #polynomial.polynomial<5 + x**2 7>
+#my_poly = #polynomial.int_polynomial<5 + x**2 7>
#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
// -----
// expected-error at below {{expected a monomial}}
-#my_poly = #polynomial.polynomial<5 + x**2 +>
+#my_poly = #polynomial.int_polynomial<5 + x**2 +>
#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
// -----
-#my_poly = #polynomial.polynomial<5 + x**2>
-// expected-error at below {{coefficientType must specify an integer type}}
-#ring1 = #polynomial.ring<coefficientType=f64, coefficientModulus=2837465, polynomialModulus=#my_poly>
-
-// -----
-
-#my_poly = #polynomial.polynomial<5 + x**2>
-// expected-error at below {{expected integer value}}
-// expected-error at below {{invalid coefficient modulus}}
+#my_poly = #polynomial.int_polynomial<5 + x**2>
+// expected-error at below {{failed to parse Polynomial_RingAttr parameter 'coefficientModulus' which is to be a `::mlir::IntegerAttr`}}
+// expected-error at below {{expected attribute value}}
#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=x, polynomialModulus=#my_poly>
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index a29cfc2e9cc54..ff709960c50e9 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -2,85 +2,87 @@
// This simply tests for syntax.
-#my_poly = #polynomial.polynomial<1 + x**1024>
-#my_poly_2 = #polynomial.polynomial<2>
-#my_poly_3 = #polynomial.polynomial<3x>
-#my_poly_4 = #polynomial.polynomial<t**3 + 4t + 2>
+#my_poly = #polynomial.int_polynomial<1 + x**1024>
+#my_poly_2 = #polynomial.int_polynomial<2>
+#my_poly_3 = #polynomial.int_polynomial<3x>
+#my_poly_4 = #polynomial.int_polynomial<t**3 + 4t + 2>
#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
-#one_plus_x_squared = #polynomial.polynomial<1 + x**2>
+#ring2 = #polynomial.ring<coefficientType=f32>
+#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
-#ideal = #polynomial.polynomial<-1 + x**1024>
+#ideal = #polynomial.int_polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
-!poly_ty = !polynomial.polynomial<#ring>
+!poly_ty = !polynomial.polynomial<ring=#ring>
-#ntt_poly = #polynomial.polynomial<-1 + x**8>
+#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
-!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
+!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
module {
- func.func @test_multiply() -> !polynomial.polynomial<#ring1> {
+ func.func @test_multiply() -> !polynomial.polynomial<ring=#ring1> {
%c0 = arith.constant 0 : index
%two = arith.constant 2 : i16
%five = arith.constant 5 : i16
%coeffs1 = tensor.from_elements %two, %two, %five : tensor<3xi16>
%coeffs2 = tensor.from_elements %five, %five, %two : tensor<3xi16>
- %poly1 = polynomial.from_tensor %coeffs1 : tensor<3xi16> -> !polynomial.polynomial<#ring1>
- %poly2 = polynomial.from_tensor %coeffs2 : tensor<3xi16> -> !polynomial.polynomial<#ring1>
+ %poly1 = polynomial.from_tensor %coeffs1 : tensor<3xi16> -> !polynomial.polynomial<ring=#ring1>
+ %poly2 = polynomial.from_tensor %coeffs2 : tensor<3xi16> -> !polynomial.polynomial<ring=#ring1>
- %3 = polynomial.mul %poly1, %poly2 : !polynomial.polynomial<#ring1>
+ %3 = polynomial.mul %poly1, %poly2 : !polynomial.polynomial<ring=#ring1>
- return %3 : !polynomial.polynomial<#ring1>
+ return %3 : !polynomial.polynomial<ring=#ring1>
}
- func.func @test_elementwise(%p0 : !polynomial.polynomial<#ring1>, %p1: !polynomial.polynomial<#ring1>) {
- %tp0 = tensor.from_elements %p0, %p1 : tensor<2x!polynomial.polynomial<#ring1>>
- %tp1 = tensor.from_elements %p1, %p0 : tensor<2x!polynomial.polynomial<#ring1>>
+ func.func @test_elementwise(%p0 : !polynomial.polynomial<ring=#ring1>, %p1: !polynomial.polynomial<ring=#ring1>) {
+ %tp0 = tensor.from_elements %p0, %p1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
+ %tp1 = tensor.from_elements %p1, %p0 : tensor<2x!polynomial.polynomial<ring=#ring1>>
%c = arith.constant 2 : i32
- %mul_const_sclr = polynomial.mul_scalar %tp0, %c : tensor<2x!polynomial.polynomial<#ring1>>, i32
+ %mul_const_sclr = polynomial.mul_scalar %tp0, %c : tensor<2x!polynomial.polynomial<ring=#ring1>>, i32
- %add = polynomial.add %tp0, %tp1 : tensor<2x!polynomial.polynomial<#ring1>>
- %sub = polynomial.sub %tp0, %tp1 : tensor<2x!polynomial.polynomial<#ring1>>
- %mul = polynomial.mul %tp0, %tp1 : tensor<2x!polynomial.polynomial<#ring1>>
+ %add = polynomial.add %tp0, %tp1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
+ %sub = polynomial.sub %tp0, %tp1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
+ %mul = polynomial.mul %tp0, %tp1 : tensor<2x!polynomial.polynomial<ring=#ring1>>
return
}
- func.func @test_to_from_tensor(%p0 : !polynomial.polynomial<#ring1>) {
+ func.func @test_to_from_tensor(%p0 : !polynomial.polynomial<ring=#ring1>) {
%c0 = arith.constant 0 : index
%two = arith.constant 2 : i16
%coeffs1 = tensor.from_elements %two, %two : tensor<2xi16>
// CHECK: from_tensor
- %poly = polynomial.from_tensor %coeffs1 : tensor<2xi16> -> !polynomial.polynomial<#ring1>
+ %poly = polynomial.from_tensor %coeffs1 : tensor<2xi16> -> !polynomial.polynomial<ring=#ring1>
// CHECK: to_tensor
- %tensor = polynomial.to_tensor %poly : !polynomial.polynomial<#ring1> -> tensor<1024xi16>
+ %tensor = polynomial.to_tensor %poly : !polynomial.polynomial<ring=#ring1> -> tensor<1024xi16>
return
}
- func.func @test_degree(%p0 : !polynomial.polynomial<#ring1>) {
- %0, %1 = polynomial.leading_term %p0 : !polynomial.polynomial<#ring1> -> (index, i32)
+ func.func @test_degree(%p0 : !polynomial.polynomial<ring=#ring1>) {
+ %0, %1 = polynomial.leading_term %p0 : !polynomial.polynomial<ring=#ring1> -> (index, i32)
return
}
func.func @test_monomial() {
%deg = arith.constant 1023 : index
%five = arith.constant 5 : i16
- %0 = polynomial.monomial %five, %deg : (i16, index) -> !polynomial.polynomial<#ring1>
+ %0 = polynomial.monomial %five, %deg : (i16, index) -> !polynomial.polynomial<ring=#ring1>
return
}
func.func @test_monic_monomial_mul() {
%five = arith.constant 5 : index
- %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<#ring1>
- %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<#ring1>, index) -> !polynomial.polynomial<#ring1>
+ %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
+ %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<ring=#ring1>, index) -> !polynomial.polynomial<ring=#ring1>
return
}
func.func @test_constant() {
- %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<#ring1>
- %1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
+ %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
+ %1 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<ring=#ring1>
+ %2 = polynomial.constant {value=#polynomial.float_polynomial<1.5 + 0.5 x**2>} : !polynomial.polynomial<ring=#ring2>
return
}
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index 2c20e7bcbf1d6..af8e4aa5da862 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -1,8 +1,8 @@
// RUN: mlir-opt --split-input-file --verify-diagnostics %s
-#my_poly = #polynomial.polynomial<1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
-!ty = !polynomial.polynomial<#ring>
+#my_poly = #polynomial.int_polynomial<1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
+!ty = !polynomial.polynomial<ring=#ring>
func.func @test_from_tensor_too_large_coeffs() {
%two = arith.constant 2 : i32
@@ -15,13 +15,13 @@ func.func @test_from_tensor_too_large_coeffs() {
// -----
-#my_poly = #polynomial.polynomial<1 + x**4>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#my_poly>
-!ty = !polynomial.polynomial<#ring>
+#my_poly = #polynomial.int_polynomial<1 + x**4>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#my_poly>
+!ty = !polynomial.polynomial<ring=#ring>
func.func @test_from_tensor_wrong_tensor_type() {
%two = arith.constant 2 : i32
%coeffs1 = tensor.from_elements %two, %two, %two, %two, %two : tensor<5xi32>
- // expected-error at below {{input type 'tensor<5xi32>' does not match output type '!polynomial.polynomial<#polynomial.ring<coefficientType=i32, coefficientModulus=256 : i32, polynomialModulus=#polynomial.polynomial<1 + x**4>>>'}}
+ // expected-error at below {{input type 'tensor<5xi32>' does not match output type '!polynomial.polynomial<ring = <coefficientType = i32, coefficientModulus = 256 : i32, polynomialModulus = <1 + x**4>>>'}}
// expected-note at below {{at most the degree of the polynomialModulus of the output type's ring attribute}}
%poly = polynomial.from_tensor %coeffs1 : tensor<5xi32> -> !ty
return
@@ -29,11 +29,11 @@ func.func @test_from_tensor_wrong_tensor_type() {
// -----
-#my_poly = #polynomial.polynomial<1 + x**4>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#my_poly>
-!ty = !polynomial.polynomial<#ring>
+#my_poly = #polynomial.int_polynomial<1 + x**4>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256:i32, polynomialModulus=#my_poly>
+!ty = !polynomial.polynomial<ring=#ring>
func.func @test_to_tensor_wrong_output_tensor_type(%arg0 : !ty) {
- // expected-error at below {{input type '!polynomial.polynomial<#polynomial.ring<coefficientType=i32, coefficientModulus=256 : i32, polynomialModulus=#polynomial.polynomial<1 + x**4>>>' does not match output type 'tensor<5xi32>'}}
+ // expected-error at below {{input type '!polynomial.polynomial<ring = <coefficientType = i32, coefficientModulus = 256 : i32, polynomialModulus = <1 + x**4>>>' does not match output type 'tensor<5xi32>'}}
// expected-note at below {{at most the degree of the polynomialModulus of the input type's ring attribute}}
%tensor = polynomial.to_tensor %arg0 : !ty -> tensor<5xi32>
return
@@ -41,9 +41,9 @@ func.func @test_to_tensor_wrong_output_tensor_type(%arg0 : !ty) {
// -----
-#my_poly = #polynomial.polynomial<1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
-!ty = !polynomial.polynomial<#ring>
+#my_poly = #polynomial.int_polynomial<1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
+!ty = !polynomial.polynomial<ring=#ring>
func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
%scalar = arith.constant 2 : i32 // should be i16
@@ -54,9 +54,9 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
// -----
-#my_poly = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
-!poly_ty = !polynomial.polynomial<#ring>
+#my_poly = #polynomial.int_polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
@@ -68,9 +68,9 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
// -----
-#my_poly = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
-!poly_ty = !polynomial.polynomial<#ring>
+#my_poly = #polynomial.int_polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
@@ -82,10 +82,10 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
// -----
-#my_poly = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
-#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257, polynomialModulus=#my_poly, primitiveRoot=31>
-!poly_ty = !polynomial.polynomial<#ring>
+#my_poly = #polynomial.int_polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
@@ -97,9 +97,9 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
// -----
-#my_poly = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
-!poly_ty = !polynomial.polynomial<#ring>
+#my_poly = #polynomial.int_polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
@@ -112,9 +112,9 @@ func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
// -----
-#my_poly = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
-!poly_ty = !polynomial.polynomial<#ring>
+#my_poly = #polynomial.int_polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
+!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
@@ -126,10 +126,10 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
// -----
-#my_poly = #polynomial.polynomial<-1 + x**8>
+#my_poly = #polynomial.int_polynomial<-1 + x**8>
// A valid root is 31
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=32>
-!poly_ty = !polynomial.polynomial<#ring>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=32:i16>
+!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
diff --git a/mlir/test/Dialect/Polynomial/types.mlir b/mlir/test/Dialect/Polynomial/types.mlir
index 00296a36e890f..dcc5663ceb84c 100644
--- a/mlir/test/Dialect/Polynomial/types.mlir
+++ b/mlir/test/Dialect/Polynomial/types.mlir
@@ -2,13 +2,13 @@
// CHECK-LABEL: func @test_types
// CHECK-SAME: !polynomial.polynomial<
-// CHECK-SAME: #polynomial.ring<
-// CHECK-SAME: coefficientType=i32,
-// CHECK-SAME: coefficientModulus=2837465 : i32,
-// CHECK-SAME: polynomialModulus=#polynomial.polynomial<1 + x**1024>>>
-#my_poly = #polynomial.polynomial<1 + x**1024>
-#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
-!ty = !polynomial.polynomial<#ring1>
+// CHECK-SAME: ring = <
+// CHECK-SAME: coefficientType = i32,
+// CHECK-SAME: coefficientModulus = 2837465 : i32,
+// CHECK-SAME: polynomialModulus = <1 + x**1024>>>
+#my_poly = #polynomial.int_polynomial<1 + x**1024>
+#ring1 = #polynomial.ring<coefficientType = i32, coefficientModulus = 2837465 : i32, polynomialModulus=#my_poly>
+!ty = !polynomial.polynomial<ring=#ring1>
func.func @test_types(%0: !ty) -> !ty {
return %0 : !ty
}
@@ -16,13 +16,13 @@ func.func @test_types(%0: !ty) -> !ty {
// CHECK-LABEL: func @test_non_x_variable_64_bit
// CHECK-SAME: !polynomial.polynomial<
-// CHECK-SAME: #polynomial.ring<
-// CHECK-SAME: coefficientType=i64,
-// CHECK-SAME: coefficientModulus=2837465 : i64,
-// CHECK-SAME: polynomialModulus=#polynomial.polynomial<2 + 4x + x**3>>>
-#my_poly_2 = #polynomial.polynomial<t**3 + 4t + 2>
-#ring2 = #polynomial.ring<coefficientType=i64, coefficientModulus=2837465, polynomialModulus=#my_poly_2>
-!ty2 = !polynomial.polynomial<#ring2>
+// CHECK-SAME: ring = <
+// CHECK-SAME: coefficientType = i64,
+// CHECK-SAME: coefficientModulus = 2837465 : i64,
+// CHECK-SAME: polynomialModulus = <2 + 4x + x**3>>>
+#my_poly_2 = #polynomial.int_polynomial<t**3 + 4t + 2>
+#ring2 = #polynomial.ring<coefficientType = i64, coefficientModulus = 2837465 : i64, polynomialModulus=#my_poly_2>
+!ty2 = !polynomial.polynomial<ring=#ring2>
func.func @test_non_x_variable_64_bit(%0: !ty2) -> !ty2 {
return %0 : !ty2
}
@@ -30,27 +30,36 @@ func.func @test_non_x_variable_64_bit(%0: !ty2) -> !ty2 {
// CHECK-LABEL: func @test_linear_poly
// CHECK-SAME: !polynomial.polynomial<
-// CHECK-SAME: #polynomial.ring<
-// CHECK-SAME: coefficientType=i32,
-// CHECK-SAME: coefficientModulus=12 : i32,
-// CHECK-SAME: polynomialModulus=#polynomial.polynomial<4x>>
-#my_poly_3 = #polynomial.polynomial<4x>
-#ring3 = #polynomial.ring<coefficientType=i32, coefficientModulus=12, polynomialModulus=#my_poly_3>
-!ty3 = !polynomial.polynomial<#ring3>
+// CHECK-SAME: ring = <
+// CHECK-SAME: coefficientType = i32,
+// CHECK-SAME: coefficientModulus = 12 : i32,
+// CHECK-SAME: polynomialModulus = <4x>>
+#my_poly_3 = #polynomial.int_polynomial<4x>
+#ring3 = #polynomial.ring<coefficientType = i32, coefficientModulus=12 : i32, polynomialModulus=#my_poly_3>
+!ty3 = !polynomial.polynomial<ring=#ring3>
func.func @test_linear_poly(%0: !ty3) -> !ty3 {
return %0 : !ty3
}
// CHECK-LABEL: func @test_negative_leading_1
// CHECK-SAME: !polynomial.polynomial<
-// CHECK-SAME: #polynomial.ring<
-// CHECK-SAME: coefficientType=i32,
-// CHECK-SAME: coefficientModulus=2837465 : i32,
-// CHECK-SAME: polynomialModulus=#polynomial.polynomial<-1 + x**1024>>>
-#my_poly_4 = #polynomial.polynomial<-1 + x**1024>
-#ring4 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly_4>
-!ty4 = !polynomial.polynomial<#ring4>
+// CHECK-SAME: ring = <
+// CHECK-SAME: coefficientType = i32,
+// CHECK-SAME: coefficientModulus = 2837465 : i32,
+// CHECK-SAME: polynomialModulus = <-1 + x**1024>>>
+#my_poly_4 = #polynomial.int_polynomial<-1 + x**1024>
+#ring4 = #polynomial.ring<coefficientType = i32, coefficientModulus = 2837465 : i32, polynomialModulus=#my_poly_4>
+!ty4 = !polynomial.polynomial<ring=#ring4>
func.func @test_negative_leading_1(%0: !ty4) -> !ty4 {
return %0 : !ty4
}
+// CHECK-LABEL: func @test_float_coefficients
+// CHECK-SAME: !polynomial.polynomial<ring = <coefficientType = f32>>
+#my_poly_5 = #polynomial.float_polynomial<0.5 + 1.6e03 x**1024>
+#ring5 = #polynomial.ring<coefficientType=f32>
+!ty5 = !polynomial.polynomial<ring=#ring5>
+func.func @test_float_coefficients(%0: !ty5) -> !ty5 {
+ return %0 : !ty5
+}
+
More information about the Mlir-commits
mailing list