[Mlir-commits] [mlir] 55b6f17 - Add a polynomial dialect shell, attributes, and types (#72081)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 15 09:12:03 PDT 2024


Author: Jeremy Kun
Date: 2024-04-15T09:12:00-07:00
New Revision: 55b6f17071d25b77fcdc910ca9b15f89305137e0

URL: https://github.com/llvm/llvm-project/commit/55b6f17071d25b77fcdc910ca9b15f89305137e0
DIFF: https://github.com/llvm/llvm-project/commit/55b6f17071d25b77fcdc910ca9b15f89305137e0.diff

LOG: Add a polynomial dialect shell, attributes, and types (#72081)

RFC:
https://discourse.llvm.org/t/rfc-a-poly-dialect-for-polynomial-arithmetic/73891

This PR implements the minimal work needed to represent the polynomial
type such that it can be tested with `lit`.

In this PR:

- Dialect shell
- `Polynomial` data structure needed for folding
- Polynomial attributes (`PolynomialAttr` and `RingAttr` which store a polynomial)
- `polynomial.polynomial` type
- Basic lit tests

---------

Co-authored-by: Jeremy Kun <j2kun at users.noreply.github.com>

Added: 
    mlir/include/mlir/Dialect/Polynomial/CMakeLists.txt
    mlir/include/mlir/Dialect/Polynomial/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
    mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
    mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.h
    mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.h
    mlir/include/mlir/Dialect/Polynomial/IR/PolynomialOps.h
    mlir/include/mlir/Dialect/Polynomial/IR/PolynomialTypes.h
    mlir/lib/Dialect/Polynomial/CMakeLists.txt
    mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt
    mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
    mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
    mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp
    mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
    mlir/test/Dialect/Polynomial/attributes.mlir
    mlir/test/Dialect/Polynomial/types.mlir

Modified: 
    mlir/include/mlir/Dialect/CMakeLists.txt
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 2da79011fa26a3..4bd7f12fabf7ba 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -28,6 +28,7 @@ add_subdirectory(OpenACCMPCommon)
 add_subdirectory(OpenMP)
 add_subdirectory(PDL)
 add_subdirectory(PDLInterp)
+add_subdirectory(Polynomial)
 add_subdirectory(Quant)
 add_subdirectory(SCF)
 add_subdirectory(Shape)

diff  --git a/mlir/include/mlir/Dialect/Polynomial/CMakeLists.txt b/mlir/include/mlir/Dialect/Polynomial/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Polynomial/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)

diff  --git a/mlir/include/mlir/Dialect/Polynomial/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Polynomial/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..d8039deb5ee217
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_mlir_dialect(Polynomial polynomial)
+add_mlir_doc(PolynomialDialect PolynomialDialect Polynomial/ -gen-dialect-doc)
+add_mlir_doc(PolynomialOps PolynomialOps Polynomial/ -gen-op-doc)
+add_mlir_doc(PolynomialAttributes PolynomialAttributes Dialects/ -gen-attrdef-doc)
+add_mlir_doc(PolynomialTypes PolynomialTypes Dialects/ -gen-typedef-doc)
+
+set(LLVM_TARGET_DEFINITIONS Polynomial.td)
+mlir_tablegen(PolynomialAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=polynomial)
+mlir_tablegen(PolynomialAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=polynomial)
+add_public_tablegen_target(MLIRPolynomialAttributesIncGen)

diff  --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
new file mode 100644
index 00000000000000..39b05b9d3ad14b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -0,0 +1,130 @@
+//===- Polynomial.h - A data class for polynomials --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_
+#define MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+
+class MLIRContext;
+
+namespace polynomial {
+
+/// This restricts statically defined polynomials to have at most 64-bit
+/// coefficients. This may be relaxed in the future, but it seems unlikely one
+/// would want to specify 128-bit polynomials statically in the source code.
+constexpr unsigned apintBitWidth = 64;
+
+/// A class representing a monomial of a single-variable polynomial with integer
+/// coefficients.
+class Monomial {
+public:
+  Monomial(int64_t coeff, uint64_t expo)
+      : coefficient(apintBitWidth, coeff), exponent(apintBitWidth, expo) {}
+
+  Monomial(const APInt &coeff, const APInt &expo)
+      : coefficient(coeff), exponent(expo) {}
+
+  Monomial() : coefficient(apintBitWidth, 0), exponent(apintBitWidth, 0) {}
+
+  bool operator==(const Monomial &other) const {
+    return other.coefficient == coefficient && other.exponent == exponent;
+  }
+  bool operator!=(const Monomial &other) const {
+    return other.coefficient != coefficient || other.exponent != exponent;
+  }
+
+  /// Monomials are ordered by exponent.
+  bool operator<(const Monomial &other) const {
+    return (exponent.ult(other.exponent));
+  }
+
+  // Prints polynomial to 'os'.
+  void print(raw_ostream &os) const;
+
+  friend ::llvm::hash_code hash_value(const Monomial &arg);
+
+public:
+  APInt coefficient;
+
+  // Always unsigned
+  APInt exponent;
+};
+
+/// A single-variable polynomial with integer coefficients.
+///
+/// Eg: x^1024 + x + 1
+///
+/// The symbols used as the polynomial's indeterminate don't matter, so long as
+/// it is used consistently throughout the polynomial.
+class Polynomial {
+public:
+  Polynomial() = delete;
+
+  explicit Polynomial(ArrayRef<Monomial> terms) : terms(terms){};
+
+  // Returns a Polynomial from a list of monomials.
+  // Fails if two monomials have the same exponent.
+  static FailureOr<Polynomial> fromMonomials(ArrayRef<Monomial> monomials);
+
+  /// Returns a polynomial with coefficients given by `coeffs`. The value
+  /// coeffs[i] is converted to a monomial with exponent i.
+  static Polynomial fromCoefficients(ArrayRef<int64_t> coeffs);
+
+  explicit operator bool() const { return !terms.empty(); }
+  bool operator==(const Polynomial &other) const {
+    return other.terms == terms;
+  }
+  bool operator!=(const Polynomial &other) const {
+    return !(other.terms == terms);
+  }
+
+  // Prints polynomial to 'os'.
+  void print(raw_ostream &os) const;
+  void print(raw_ostream &os, ::llvm::StringRef separator,
+             ::llvm::StringRef exponentiation) const;
+  void dump() const;
+
+  // Prints polynomial so that it can be used as a valid identifier
+  std::string toIdentifier() const;
+
+  unsigned getDegree() const;
+
+  friend ::llvm::hash_code hash_value(const Polynomial &arg);
+
+private:
+  // The monomial terms for this polynomial.
+  SmallVector<Monomial> terms;
+};
+
+// Make Polynomial hashable.
+inline ::llvm::hash_code hash_value(const Polynomial &arg) {
+  return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
+}
+
+inline ::llvm::hash_code hash_value(const Monomial &arg) {
+  return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
+                            ::llvm::hash_value(arg.exponent));
+}
+
+inline raw_ostream &operator<<(raw_ostream &os, const Polynomial &polynomial) {
+  polynomial.print(os);
+  return os;
+}
+
+} // namespace polynomial
+} // namespace mlir
+
+#endif // MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIAL_H_

diff  --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
new file mode 100644
index 00000000000000..5d8da8399b01b5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -0,0 +1,153 @@
+//===- PolynomialOps.td - Polynomial dialect ---------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef POLYNOMIAL_OPS
+#define POLYNOMIAL_OPS
+
+include "mlir/IR/BuiltinAttributes.td"
+include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+def Polynomial_Dialect : Dialect {
+  let name = "polynomial";
+  let cppNamespace = "::mlir::polynomial";
+  let description = [{
+    The Polynomial dialect defines single-variable polynomial types and
+    operations.
+
+    The simplest use of `polynomial` is to represent mathematical operations in
+    a polynomial ring `R[x]`, where `R` is another MLIR type like `i32`.
+
+    More generally, this dialect supports representing polynomial operations in a
+    quotient ring `R[X]/(f(x))` for some statically fixed polynomial `f(x)`.
+    Two polyomials `p(x), q(x)` are considered equal in this ring if they have the
+    same remainder when dividing by `f(x)`. When a modulus is given, ring operations
+    are performed with reductions modulo `f(x)` and relative to the coefficient ring
+    `R`.
+
+    Examples:
+
+    ```mlir
+    // A constant polynomial in a ring with i32 coefficients and no polynomial modulus
+    #ring = #polynomial.ring<ctype=i32>
+    %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
+
+    // A constant polynomial in a ring with i32 coefficients, modulo (x^1024 + 1)
+    #modulus = #polynomial.polynomial<1 + x**1024>
+    #ring = #polynomial.ring<ctype=i32, ideal=#modulus>
+    %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
+
+    // A constant polynomial in a ring with i32 coefficients, with a polynomial
+    // modulus of (x^1024 + 1) and a coefficient modulus of 17.
+    #modulus = #polynomial.polynomial<1 + x**1024>
+    #ring = #polynomial.ring<ctype=i32, cmod=17, ideal=#modulus>
+    %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
+    ```
+  }];
+
+  let useDefaultTypePrinterParser = 1;
+  let useDefaultAttributePrinterParser = 1;
+}
+
+class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
+    : AttrDef<Polynomial_Dialect, name, traits> {
+  let mnemonic = attrMnemonic;
+}
+
+def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
+  let summary = "An attribute containing a single-variable polynomial.";
+  let description = [{
+     #poly = #polynomial.poly<x**1024 + 1>
+  }];
+  let parameters = (ins "Polynomial":$polynomial);
+  let hasCustomAssemblyFormat = 1;
+}
+
+def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
+  let summary = "An attribute specifying a polynomial ring.";
+  let description = [{
+    A ring describes the domain in which polynomial arithmetic occurs. The ring
+    attribute in `polynomial` represents the more specific case of polynomials
+    with a single indeterminate; whose coefficients can be represented by
+    another MLIR type (`coefficientType`); and, if the coefficient type is
+    integral, whose coefficients are taken modulo some statically known modulus
+    (`coefficientModulus`).
+
+    Additionally, a polynomial ring can specify an _ideal_, which converts
+    polynomial arithmetic to the analogue of modular integer arithmetic, where
+    each polynomial is represented as its remainder when dividing by the
+    modulus. For single-variable polynomials, an "ideal" is always specificed
+    via a single polynomial, which we call `polynomialModulus`.
+
+    An expressive example is polynomials with i32 coefficients, whose
+    coefficients are taken modulo `2**32 - 5`, with a polynomial modulus of
+    `x**1024 - 1`.
+
+    ```mlir
+    #poly_mod = #polynomial.polynomial<-1 + x**1024>
+    #ring = #polynomial.ring<coefficientType=i32,
+                             coefficientModulus=4294967291,
+                             polynomialModulus=#poly_mod>
+
+    %0 = ... : polynomial.polynomial<#ring>
+    ```
+
+    In this case, the value of a polynomial is always "converted" to a
+    canonical form by applying repeated reductions by setting `x**1024 = 1`
+    and simplifying.
+
+    The coefficient and polynomial modulus parameters are optional, and the
+    coefficient modulus is only allowed if the coefficient type is integral.
+  }];
+
+  let parameters = (ins
+    "Type": $coefficientType,
+    OptionalParameter<"IntegerAttr">: $coefficientModulus,
+    OptionalParameter<"PolynomialAttr">: $polynomialModulus
+  );
+
+  let hasCustomAssemblyFormat = 1;
+}
+
+class Polynomial_Type<string name, string typeMnemonic>
+    : TypeDef<Polynomial_Dialect, name> {
+  let mnemonic = typeMnemonic;
+}
+
+def Polynomial_PolynomialType : Polynomial_Type<"Polynomial", "polynomial"> {
+  let summary = "An element of a polynomial ring.";
+
+  let description = [{
+    A type for polynomials in a polynomial quotient ring.
+  }];
+
+  let parameters = (ins Polynomial_RingAttr:$ring);
+  let assemblyFormat = "`<` $ring `>`";
+}
+
+class Polynomial_Op<string mnemonic, list<Trait> traits = []> :
+    Op<Polynomial_Dialect, mnemonic, traits # [Pure]>;
+
+class Polynomial_UnaryOp<string mnemonic, list<Trait> traits = []> :
+    Polynomial_Op<mnemonic, traits # [SameOperandsAndResultType]> {
+  let arguments = (ins Polynomial_PolynomialType:$operand);
+  let results = (outs Polynomial_PolynomialType:$result);
+
+  let assemblyFormat = "$operand attr-dict `:` qualified(type($result))";
+}
+
+class Polynomial_BinaryOp<string mnemonic, list<Trait> traits = []> :
+    Polynomial_Op<mnemonic, traits # [SameOperandsAndResultType]> {
+  let arguments = (ins Polynomial_PolynomialType:$lhs, Polynomial_PolynomialType:$rhs);
+  let results = (outs Polynomial_PolynomialType:$result);
+
+  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($result))";
+}
+
+#endif // POLYNOMIAL_OPS

diff  --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.h b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.h
new file mode 100644
index 00000000000000..b37d17bb89fb2c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.h
@@ -0,0 +1,17 @@
+//===- PolynomialAttributes.h - polynomial dialect attributes ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALATTRIBUTES_H_
+#define MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALATTRIBUTES_H_
+
+#include "Polynomial.h"
+#include "PolynomialDialect.h"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h.inc"
+
+#endif // MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALATTRIBUTES_H_

diff  --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.h b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.h
new file mode 100644
index 00000000000000..7b7acebe7a93bb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.h
@@ -0,0 +1,19 @@
+//===- PolynomialDialect.h - The Polynomial dialect -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALDIALECT_H_
+#define MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALDIALECT_H_
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+
+// Generated headers (block clang-format from messing up order)
+#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h.inc"
+
+#endif // MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALDIALECT_H_

diff  --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialOps.h b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialOps.h
new file mode 100644
index 00000000000000..bacaad81ce8e51
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialOps.h
@@ -0,0 +1,21 @@
+//===- PolynomialOps.h - Ops for the Polynomial dialect ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALOPS_H_
+#define MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALOPS_H_
+
+#include "PolynomialDialect.h"
+#include "PolynomialTypes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Polynomial/IR/Polynomial.h.inc"
+
+#endif // MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALOPS_H_

diff  --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialTypes.h b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialTypes.h
new file mode 100644
index 00000000000000..2fc68774525476
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialTypes.h
@@ -0,0 +1,17 @@
+//===- PolynomialTypes.h - Types for the Polynomial dialect -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALTYPES_H_
+#define MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALTYPES_H_
+
+#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
+#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h.inc"
+
+#endif // MLIR_INCLUDE_MLIR_DIALECT_POLYNOMIAL_IR_POLYNOMIALTYPES_H_

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c558dc53cc7fac..c4d788cf8ed316 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -61,6 +61,7 @@
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.h"
 #include "mlir/Dialect/Quant/QuantOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
@@ -131,6 +132,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   omp::OpenMPDialect,
                   pdl::PDLDialect,
                   pdl_interp::PDLInterpDialect,
+                  polynomial::PolynomialDialect,
                   quant::QuantizationDialect,
                   ROCDL::ROCDLDialect,
                   scf::SCFDialect,

diff  --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index b1ba5a3bc8817d..a324ce7f9b19f7 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -28,6 +28,7 @@ add_subdirectory(OpenACCMPCommon)
 add_subdirectory(OpenMP)
 add_subdirectory(PDL)
 add_subdirectory(PDLInterp)
+add_subdirectory(Polynomial)
 add_subdirectory(Quant)
 add_subdirectory(SCF)
 add_subdirectory(Shape)

diff  --git a/mlir/lib/Dialect/Polynomial/CMakeLists.txt b/mlir/lib/Dialect/Polynomial/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/lib/Dialect/Polynomial/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)

diff  --git a/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt b/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..7f5b3255d5d900
--- /dev/null
+++ b/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_dialect_library(MLIRPolynomialDialect
+  Polynomial.cpp
+  PolynomialAttributes.cpp
+  PolynomialDialect.cpp
+  PolynomialOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Polynomial
+
+  DEPENDS
+  MLIRPolynomialIncGen
+  MLIRPolynomialAttributesIncGen
+  MLIRBuiltinAttributesIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRSupport
+  MLIRDialect
+  MLIRIR
+  )

diff  --git a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
new file mode 100644
index 00000000000000..5916ffba78e246
--- /dev/null
+++ b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
@@ -0,0 +1,96 @@
+//===- Polynomial.cpp - MLIR storage type for static Polynomial -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace polynomial {
+
+FailureOr<Polynomial> Polynomial::fromMonomials(ArrayRef<Monomial> monomials) {
+  // A polynomial's terms are canonically stored in order of increasing degree.
+  auto monomialsCopy = llvm::SmallVector<Monomial>(monomials);
+  std::sort(monomialsCopy.begin(), monomialsCopy.end());
+
+  // Ensure non-unique exponents are not present. Since we sorted the list by
+  // exponent, a linear scan of adjancent monomials suffices.
+  if (std::adjacent_find(monomialsCopy.begin(), monomialsCopy.end(),
+                         [](const Monomial &lhs, const Monomial &rhs) {
+                           return lhs.exponent == rhs.exponent;
+                         }) != monomialsCopy.end()) {
+    return failure();
+  }
+
+  return Polynomial(monomialsCopy);
+}
+
+Polynomial Polynomial::fromCoefficients(ArrayRef<int64_t> coeffs) {
+  llvm::SmallVector<Monomial> monomials;
+  auto size = coeffs.size();
+  monomials.reserve(size);
+  for (size_t i = 0; i < size; i++) {
+    monomials.emplace_back(coeffs[i], i);
+  }
+  auto result = Polynomial::fromMonomials(monomials);
+  // Construction guarantees unique exponents, so the failure mode of
+  // fromMonomials can be bypassed.
+  assert(succeeded(result));
+  return result.value();
+}
+
+void Polynomial::print(raw_ostream &os, ::llvm::StringRef separator,
+                       ::llvm::StringRef exponentiation) const {
+  bool first = true;
+  for (const Monomial &term : terms) {
+    if (first) {
+      first = false;
+    } else {
+      os << separator;
+    }
+    std::string coeffToPrint;
+    if (term.coefficient == 1 && term.exponent.uge(1)) {
+      coeffToPrint = "";
+    } else {
+      llvm::SmallString<16> coeffString;
+      term.coefficient.toStringSigned(coeffString);
+      coeffToPrint = coeffString.str();
+    }
+
+    if (term.exponent == 0) {
+      os << coeffToPrint;
+    } else if (term.exponent == 1) {
+      os << coeffToPrint << "x";
+    } else {
+      llvm::SmallString<16> expString;
+      term.exponent.toStringSigned(expString);
+      os << coeffToPrint << "x" << exponentiation << expString;
+    }
+  }
+}
+
+void Polynomial::print(raw_ostream &os) const { print(os, " + ", "**"); }
+
+std::string Polynomial::toIdentifier() const {
+  std::string result;
+  llvm::raw_string_ostream os(result);
+  print(os, "_", "");
+  return os.str();
+}
+
+unsigned Polynomial::getDegree() const {
+  return terms.back().exponent.getZExtValue();
+}
+
+} // namespace polynomial
+} // namespace mlir

diff  --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
new file mode 100644
index 00000000000000..ee09c73bb3c4ae
--- /dev/null
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -0,0 +1,213 @@
+//===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
+
+#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSet.h"
+
+namespace mlir {
+namespace polynomial {
+
+void PolynomialAttr::print(AsmPrinter &p) const {
+  p << '<';
+  p << getPolynomial();
+  p << '>';
+}
+
+/// Try to parse a monomial. If successful, populate the fields of the outparam
+/// `monomial` with the results, and the `variable` outparam with the parsed
+/// variable name. Sets shouldParseMore to true if the monomial is followed by
+/// a '+'.
+ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
+                          llvm::StringRef &variable, bool &isConstantTerm,
+                          bool &shouldParseMore) {
+  APInt parsedCoeff(apintBitWidth, 1);
+  auto parsedCoeffResult = parser.parseOptionalInteger(parsedCoeff);
+  monomial.coefficient = parsedCoeff;
+
+  isConstantTerm = false;
+  shouldParseMore = false;
+
+  // A + indicates it's a constant term with more to go, as in `1 + x`.
+  if (succeeded(parser.parseOptionalPlus())) {
+    // If no coefficient was parsed, and there's a +, then it's effectively
+    // parsing an empty string.
+    if (!parsedCoeffResult.has_value()) {
+      return failure();
+    }
+    monomial.exponent = APInt(apintBitWidth, 0);
+    isConstantTerm = true;
+    shouldParseMore = true;
+    return success();
+  }
+
+  // A monomial can be a trailing constant term, as in `x + 1`.
+  if (failed(parser.parseOptionalKeyword(&variable))) {
+    // If neither a coefficient nor a variable was found, then it's effectively
+    // parsing an empty string.
+    if (!parsedCoeffResult.has_value()) {
+      return failure();
+    }
+
+    monomial.exponent = APInt(apintBitWidth, 0);
+    isConstantTerm = true;
+    return success();
+  }
+
+  // Parse exponentiation symbol as `**`. We can't use caret because it's
+  // reserved for basic block identifiers If no star is present, it's treated
+  // as a polynomial with exponent 1.
+  if (succeeded(parser.parseOptionalStar())) {
+    // If there's one * there must be two.
+    if (failed(parser.parseStar())) {
+      return failure();
+    }
+
+    // If there's a **, then the integer exponent is required.
+    APInt parsedExponent(apintBitWidth, 0);
+    if (failed(parser.parseInteger(parsedExponent))) {
+      parser.emitError(parser.getCurrentLocation(),
+                       "found invalid integer exponent");
+      return failure();
+    }
+
+    monomial.exponent = parsedExponent;
+  } else {
+    monomial.exponent = APInt(apintBitWidth, 1);
+  }
+
+  if (succeeded(parser.parseOptionalPlus())) {
+    shouldParseMore = true;
+  }
+  return success();
+}
+
+Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
+  if (failed(parser.parseLess()))
+    return {};
+
+  llvm::SmallVector<Monomial> monomials;
+  llvm::StringSet<> variables;
+
+  while (true) {
+    Monomial parsedMonomial;
+    llvm::StringRef parsedVariableRef;
+    bool isConstantTerm;
+    bool shouldParseMore;
+    if (failed(parseMonomial(parser, parsedMonomial, parsedVariableRef,
+                             isConstantTerm, shouldParseMore))) {
+      parser.emitError(parser.getCurrentLocation(), "expected a monomial");
+      return {};
+    }
+
+    if (!isConstantTerm) {
+      std::string parsedVariable = parsedVariableRef.str();
+      variables.insert(parsedVariable);
+    }
+    monomials.push_back(parsedMonomial);
+
+    if (shouldParseMore)
+      continue;
+
+    if (succeeded(parser.parseOptionalGreater())) {
+      break;
+    }
+    parser.emitError(
+        parser.getCurrentLocation(),
+        "expected + and more monomials, or > to end polynomial attribute");
+    return {};
+  }
+
+  if (variables.size() > 1) {
+    std::string vars = llvm::join(variables.keys(), ", ");
+    parser.emitError(
+        parser.getCurrentLocation(),
+        "polynomials must have one indeterminate, but there were multiple: " +
+            vars);
+  }
+
+  auto result = Polynomial::fromMonomials(monomials);
+  if (failed(result)) {
+    parser.emitError(parser.getCurrentLocation())
+        << "parsed polynomial must have unique exponents among monomials";
+    return {};
+  }
+  return PolynomialAttr::get(parser.getContext(), result.value());
+}
+
+void RingAttr::print(AsmPrinter &p) const {
+  p << "#polynomial.ring<coefficientType=" << getCoefficientType()
+    << ", coefficientModulus=" << getCoefficientModulus()
+    << ", polynomialModulus=" << getPolynomialModulus() << '>';
+}
+
+Attribute RingAttr::parse(AsmParser &parser, Type type) {
+  if (failed(parser.parseLess()))
+    return {};
+
+  if (failed(parser.parseKeyword("coefficientType")))
+    return {};
+
+  if (failed(parser.parseEqual()))
+    return {};
+
+  Type ty;
+  if (failed(parser.parseType(ty)))
+    return {};
+
+  if (failed(parser.parseComma()))
+    return {};
+
+  IntegerAttr coefficientModulusAttr = nullptr;
+  if (succeeded(parser.parseKeyword("coefficientModulus"))) {
+    if (failed(parser.parseEqual()))
+      return {};
+
+    IntegerType iType = ty.dyn_cast<IntegerType>();
+    if (!iType) {
+      parser.emitError(parser.getCurrentLocation(),
+                       "coefficientType must specify an integer type");
+      return {};
+    }
+    APInt coefficientModulus(iType.getWidth(), 0);
+    auto result = parser.parseInteger(coefficientModulus);
+    if (failed(result)) {
+      parser.emitError(parser.getCurrentLocation(),
+                       "invalid coefficient modulus");
+      return {};
+    }
+    coefficientModulusAttr = IntegerAttr::get(iType, coefficientModulus);
+
+    if (failed(parser.parseComma()))
+      return {};
+  }
+
+  PolynomialAttr polyAttr = nullptr;
+  if (succeeded(parser.parseKeyword("polynomialModulus"))) {
+    if (failed(parser.parseEqual()))
+      return {};
+
+    PolynomialAttr attr;
+    if (failed(parser.parseAttribute<PolynomialAttr>(attr)))
+      return {};
+    polyAttr = attr;
+  }
+
+  if (failed(parser.parseGreater()))
+    return {};
+
+  return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
+                       polyAttr);
+}
+
+} // namespace polynomial
+} // namespace mlir

diff  --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp
new file mode 100644
index 00000000000000..a672a59b8a465d
--- /dev/null
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp
@@ -0,0 +1,41 @@
+//===- PolynomialDialect.cpp - Polynomial dialect ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+
+#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
+#include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
+#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::polynomial;
+
+#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.cpp.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.cpp.inc"
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.cpp.inc"
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc"
+
+void PolynomialDialect::initialize() {
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.cpp.inc"
+      >();
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.cpp.inc"
+      >();
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc"
+      >();
+}

diff  --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
new file mode 100644
index 00000000000000..96c59a28b8fdce
--- /dev/null
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -0,0 +1,15 @@
+//===- PolynomialOps.cpp - Polynomial dialect ops ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+
+using namespace mlir;
+using namespace mlir::polynomial;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc"

diff  --git a/mlir/test/Dialect/Polynomial/attributes.mlir b/mlir/test/Dialect/Polynomial/attributes.mlir
new file mode 100644
index 00000000000000..3973ae3944335e
--- /dev/null
+++ b/mlir/test/Dialect/Polynomial/attributes.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics
+
+#my_poly = #polynomial.polynomial<y + x**1024>
+// expected-error at below {{polynomials must have one indeterminate, but there were multiple: x, y}}
+#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
+
+// -----
+
+// expected-error at below {{expected integer value}}
+// expected-error at below {{expected a monomial}}
+// expected-error at below {{found invalid integer exponent}}
+#my_poly = #polynomial.polynomial<5 + x**f>
+#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
+
+// -----
+
+#my_poly = #polynomial.polynomial<5 + x**2 + 3x**2>
+// expected-error at below {{parsed polynomial must have unique exponents among monomials}}
+#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
+
+// -----
+
+// expected-error at below {{expected + and more monomials, or > to end polynomial attribute}}
+#my_poly = #polynomial.polynomial<5 + x**2 7>
+#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
+
+// -----
+
+// expected-error at below {{expected a monomial}}
+#my_poly = #polynomial.polynomial<5 + x**2 +>
+#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
+
+
+// -----
+
+#my_poly = #polynomial.polynomial<5 + x**2>
+// expected-error at below {{coefficientType must specify an integer type}}
+#ring1 = #polynomial.ring<coefficientType=f64, coefficientModulus=2837465, polynomialModulus=#my_poly>
+
+// -----
+
+#my_poly = #polynomial.polynomial<5 + x**2>
+// expected-error at below {{expected integer value}}
+// expected-error at below {{invalid coefficient modulus}}
+#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=x, polynomialModulus=#my_poly>

diff  --git a/mlir/test/Dialect/Polynomial/types.mlir b/mlir/test/Dialect/Polynomial/types.mlir
new file mode 100644
index 00000000000000..64b74d9d36bb1c
--- /dev/null
+++ b/mlir/test/Dialect/Polynomial/types.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// CHECK-LABEL: func @test_types
+// CHECK-SAME:  !polynomial.polynomial<
+// CHECK-SAME:    #polynomial.ring<
+// CHECK-SAME:       coefficientType=i32,
+// CHECK-SAME:       coefficientModulus=2837465 : i32,
+// CHECK-SAME:       polynomialModulus=#polynomial.polynomial<1 + x**1024>>>
+#my_poly = #polynomial.polynomial<1 + x**1024>
+#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
+!ty = !polynomial.polynomial<#ring1>
+func.func @test_types(%0: !ty) -> !ty {
+  return %0 : !ty
+}
+
+
+// CHECK-LABEL: func @test_non_x_variable_64_bit
+// CHECK-SAME:  !polynomial.polynomial<
+// CHECK-SAME:    #polynomial.ring<
+// CHECK-SAME:       coefficientType=i64,
+// CHECK-SAME:       coefficientModulus=2837465 : i64,
+// CHECK-SAME:       polynomialModulus=#polynomial.polynomial<2 + 4x + x**3>>>
+#my_poly_2 = #polynomial.polynomial<t**3 + 4t + 2>
+#ring2 = #polynomial.ring<coefficientType=i64, coefficientModulus=2837465, polynomialModulus=#my_poly_2>
+!ty2 = !polynomial.polynomial<#ring2>
+func.func @test_non_x_variable_64_bit(%0: !ty2) -> !ty2 {
+  return %0 : !ty2
+}
+
+
+// CHECK-LABEL: func @test_linear_poly
+// CHECK-SAME:  !polynomial.polynomial<
+// CHECK-SAME:    #polynomial.ring<
+// CHECK-SAME:       coefficientType=i32,
+// CHECK-SAME:       coefficientModulus=12 : i32,
+// CHECK-SAME:       polynomialModulus=#polynomial.polynomial<4x>>
+#my_poly_3 = #polynomial.polynomial<4x>
+#ring3 = #polynomial.ring<coefficientType=i32, coefficientModulus=12, polynomialModulus=#my_poly_3>
+!ty3 = !polynomial.polynomial<#ring3>
+func.func @test_linear_poly(%0: !ty3) -> !ty3 {
+  return %0 : !ty3
+}


        


More information about the Mlir-commits mailing list