[Mlir-commits] [mlir] Add constant propagation for polynomial ops (PR #91655)

Jeremy Kun llvmlistbot at llvm.org
Tue May 14 12:43:30 PDT 2024


https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/91655

>From 85d9358fbbf766d49f5216dc3660404e279247cd Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Wed, 8 May 2024 13:20:04 -0700
Subject: [PATCH 1/3] implement add for polynomial data structure

- use CRTP for base classes
- Add unit test
---
 .../mlir/Dialect/Polynomial/IR/Polynomial.h   | 69 ++++++++++++++-----
 mlir/unittests/Dialect/CMakeLists.txt         |  1 +
 .../Dialect/Polynomial/CMakeLists.txt         |  8 +++
 .../Dialect/Polynomial/PolynomialMathTest.cpp | 43 ++++++++++++
 4 files changed, 103 insertions(+), 18 deletions(-)
 create mode 100644 mlir/unittests/Dialect/Polynomial/CMakeLists.txt
 create mode 100644 mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 7f44c29a98707..47ca07c1d47c3 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -30,7 +30,7 @@ namespace polynomial {
 /// would want to specify 128-bit polynomials statically in the source code.
 constexpr unsigned apintBitWidth = 64;
 
-template <typename CoefficientType>
+template <class Derived, typename CoefficientType>
 class MonomialBase {
 public:
   MonomialBase(const CoefficientType &coeff, const APInt &expo)
@@ -55,12 +55,21 @@ class MonomialBase {
     return (exponent.ult(other.exponent));
   }
 
+  Derived add(const Derived &other) {
+    assert(exponent == other.exponent);
+    CoefficientType newCoeff = coefficient + other.coefficient;
+    Derived result;
+    result.setCoefficient(newCoeff);
+    result.setExponent(exponent);
+    return result;
+  }
+
   virtual bool isMonic() const = 0;
   virtual void
   coefficientToString(llvm::SmallString<16> &coeffString) const = 0;
 
-  template <typename T>
-  friend ::llvm::hash_code hash_value(const MonomialBase<T> &arg);
+  template <class D, typename T>
+  friend ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg);
 
 protected:
   CoefficientType coefficient;
@@ -69,7 +78,7 @@ class MonomialBase {
 
 /// A class representing a monomial of a single-variable polynomial with integer
 /// coefficients.
-class IntMonomial : public MonomialBase<APInt> {
+class IntMonomial : public MonomialBase<IntMonomial, APInt> {
 public:
   IntMonomial(int64_t coeff, uint64_t expo)
       : MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
@@ -77,7 +86,7 @@ class IntMonomial : public MonomialBase<APInt> {
   IntMonomial()
       : MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {}
 
-  ~IntMonomial() = default;
+  ~IntMonomial() override = default;
 
   bool isMonic() const override { return coefficient == 1; }
 
@@ -88,14 +97,14 @@ class IntMonomial : public MonomialBase<APInt> {
 
 /// A class representing a monomial of a single-variable polynomial with integer
 /// coefficients.
-class FloatMonomial : public MonomialBase<APFloat> {
+class FloatMonomial : public MonomialBase<FloatMonomial, APFloat> {
 public:
   FloatMonomial(double coeff, uint64_t expo)
       : MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
 
   FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {}
 
-  ~FloatMonomial() = default;
+  ~FloatMonomial() override = default;
 
   bool isMonic() const override { return coefficient == APFloat(1.0); }
 
@@ -104,7 +113,7 @@ class FloatMonomial : public MonomialBase<APFloat> {
   }
 };
 
-template <typename Monomial>
+template <class Derived, typename Monomial>
 class PolynomialBase {
 public:
   PolynomialBase() = delete;
@@ -149,6 +158,30 @@ class PolynomialBase {
     }
   }
 
+  Derived add(const Derived &other) {
+    SmallVector<Monomial> newTerms;
+    auto it1 = terms.begin();
+    auto it2 = other.terms.begin();
+    while (it1 != terms.end() || it2 != other.terms.end()) {
+      if (it1 == terms.end()) {
+        newTerms.emplace_back(*it2);
+        it2++;
+        continue;
+      }
+
+      if (it2 == other.terms.end()) {
+        newTerms.emplace_back(*it1);
+        it1++;
+        continue;
+      }
+
+      newTerms.emplace_back(it1->add(*it2));
+      it1++;
+      it2++;
+    }
+    return Derived(newTerms);
+  }
+
   // Prints polynomial to 'os'.
   void print(raw_ostream &os) const { print(os, " + ", "**"); }
 
@@ -168,8 +201,8 @@ class PolynomialBase {
 
   ArrayRef<Monomial> getTerms() const { return terms; }
 
-  template <typename T>
-  friend ::llvm::hash_code hash_value(const PolynomialBase<T> &arg);
+  template <class D, typename T>
+  friend ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg);
 
 private:
   // The monomial terms for this polynomial.
@@ -179,7 +212,7 @@ class PolynomialBase {
 /// A single-variable polynomial with integer coefficients.
 ///
 /// Eg: x^1024 + x + 1
-class IntPolynomial : public PolynomialBase<IntMonomial> {
+class IntPolynomial : public PolynomialBase<IntPolynomial, IntMonomial> {
 public:
   explicit IntPolynomial(ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
 
@@ -196,7 +229,7 @@ class IntPolynomial : public PolynomialBase<IntMonomial> {
 /// A single-variable polynomial with double coefficients.
 ///
 /// Eg: 1.0 x^1024 + 3.5 x + 1e-05
-class FloatPolynomial : public PolynomialBase<FloatMonomial> {
+class FloatPolynomial : public PolynomialBase<FloatPolynomial, FloatMonomial> {
 public:
   explicit FloatPolynomial(ArrayRef<FloatMonomial> terms)
       : PolynomialBase(terms) {}
@@ -212,20 +245,20 @@ class FloatPolynomial : public PolynomialBase<FloatMonomial> {
 };
 
 // Make Polynomials hashable.
-template <typename T>
-inline ::llvm::hash_code hash_value(const PolynomialBase<T> &arg) {
+template <class D, typename T>
+inline ::llvm::hash_code hash_value(const PolynomialBase<D, T> &arg) {
   return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end());
 }
 
-template <typename T>
-inline ::llvm::hash_code hash_value(const MonomialBase<T> &arg) {
+template <class D, typename T>
+inline ::llvm::hash_code hash_value(const MonomialBase<D, T> &arg) {
   return llvm::hash_combine(::llvm::hash_value(arg.coefficient),
                             ::llvm::hash_value(arg.exponent));
 }
 
-template <typename T>
+template <class D, typename T>
 inline raw_ostream &operator<<(raw_ostream &os,
-                               const PolynomialBase<T> &polynomial) {
+                               const PolynomialBase<D, T> &polynomial) {
   polynomial.print(os);
   return os;
 }
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 13393569f36fe..90a75d5a46ad9 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -11,6 +11,7 @@ add_subdirectory(Index)
 add_subdirectory(LLVMIR)
 add_subdirectory(MemRef)
 add_subdirectory(OpenACC)
+add_subdirectory(Polynomial)
 add_subdirectory(SCF)
 add_subdirectory(SparseTensor)
 add_subdirectory(SPIRV)
diff --git a/mlir/unittests/Dialect/Polynomial/CMakeLists.txt b/mlir/unittests/Dialect/Polynomial/CMakeLists.txt
new file mode 100644
index 0000000000000..807deeca41c06
--- /dev/null
+++ b/mlir/unittests/Dialect/Polynomial/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_unittest(MLIRPolynomialTests
+  PolynomialMathTest.cpp
+)
+target_link_libraries(MLIRPolynomialTests
+  PRIVATE
+  MLIRIR
+  MLIRPolynomialDialect
+)
diff --git a/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp
new file mode 100644
index 0000000000000..485c2b64e4f21
--- /dev/null
+++ b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp
@@ -0,0 +1,43 @@
+//===- PolynomialMathTest.cpp - Polynomial math Tests ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::polynomial;
+
+TEST(AddTest, checkSameDegreeAdditionOfIntPolynomial) {
+  IntPolynomial x = IntPolynomial::fromCoefficients({1, 2, 3});
+  IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
+  IntPolynomial expected = IntPolynomial::fromCoefficients({3, 5, 7});
+  EXPECT_EQ(expected, x.add(y));
+}
+
+TEST(AddTest, checkDifferentDegreeAdditionOfIntPolynomial) {
+  IntPolynomial x = IntPolynomial::fromCoefficients({1, 2});
+  IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
+  IntPolynomial expected = IntPolynomial::fromCoefficients({3, 5, 4});
+  EXPECT_EQ(expected, x.add(y));
+  EXPECT_EQ(expected, y.add(x));
+}
+
+TEST(AddTest, checkSameDegreeAdditionOfFloatPolynomial) {
+  FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5, 3.5});
+  FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
+  FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 8});
+  EXPECT_EQ(expected, x.add(y));
+}
+
+TEST(AddTest, checkDifferentDegreeAdditionOfFloatPolynomial) {
+  FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5});
+  FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5});
+  FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 4.5});
+  EXPECT_EQ(expected, x.add(y));
+  EXPECT_EQ(expected, y.add(x));
+}

>From 528778e8605b02f5e64053809972f4fe77f433e7 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Wed, 8 May 2024 18:50:12 -0700
Subject: [PATCH 2/3] fold add via constBinaryFold

---
 .../mlir/Dialect/Polynomial/IR/Polynomial.h   | 14 +++++
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  | 53 ++++++++++++++++++-
 .../Polynomial/IR/PolynomialDialect.cpp       | 14 +++++
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 38 +++++++++++++
 mlir/test/Dialect/Polynomial/folding.mlir     | 23 ++++++++
 .../Dialect/Polynomial/PolynomialMathTest.cpp |  5 +-
 6 files changed, 144 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Dialect/Polynomial/folding.mlir

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 47ca07c1d47c3..e14cef51185e0 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -175,6 +175,20 @@ class PolynomialBase {
         continue;
       }
 
+      while (it1->getExponent().ult(it2->getExponent())) {
+        newTerms.emplace_back(*it1);
+        it1++;
+        if (it1 == terms.end())
+          break;
+      }
+
+      while (it2->getExponent().ult(it1->getExponent())) {
+        newTerms.emplace_back(*it2);
+        it2++;
+        if (it2 == terms.end())
+          break;
+      }
+
       newTerms.emplace_back(it1->add(*it2));
       it1++;
       it2++;
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index ae8484501a50d..9b57a71ea6a80 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -53,6 +53,7 @@ def Polynomial_Dialect : Dialect {
 
   let useDefaultTypePrinterParser = 1;
   let useDefaultAttributePrinterParser = 1;
+  let hasConstantMaterializer = 1;
 }
 
 class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
@@ -83,6 +84,30 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom
   let hasCustomAssemblyFormat = 1;
 }
 
+def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
+    "TypedIntPolynomial", "typed_int_polynomial", [TypedAttrInterface]> {
+  let summary = "A typed variant of int_polynomial for constant folding.";
+  let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::IntPolynomialAttr":$value);
+  let assemblyFormat = "`<` struct(params) `>`";
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "Type":$type,
+                                        "const IntPolynomial &":$value), [{
+      return $_get(
+        type.getContext(),
+        type,
+        IntPolynomialAttr::get(type.getContext(), value));
+    }]>,
+    AttrBuilderWithInferredContext<(ins "Type":$type,
+                                        "const Attribute &":$value), [{
+      return $_get(type.getContext(), type, ::llvm::cast<IntPolynomialAttr>(value));
+    }]>
+  ];
+  let extraClassDeclaration = [{
+    // used for constFoldBinaryOp
+    using ValueType = ::mlir::Attribute;
+  }];
+}
+
 def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
   let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
   let description = [{
@@ -105,6 +130,30 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p
   let hasCustomAssemblyFormat = 1;
 }
 
+def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
+    "TypedFloatPolynomial", "typed_float_polynomial", [TypedAttrInterface]> {
+  let summary = "A typed variant of float_polynomial for constant folding.";
+  let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::FloatPolynomialAttr":$value);
+  let assemblyFormat = "`<` struct(params) `>`";
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "Type":$type,
+                                        "const FloatPolynomial &":$value), [{
+      return $_get(
+        type.getContext(),
+        type,
+        FloatPolynomialAttr::get(type.getContext(), value));
+    }]>,
+    AttrBuilderWithInferredContext<(ins "Type":$type,
+                                        "const Attribute &":$value), [{
+      return $_get(type.getContext(), type, ::llvm::cast<FloatPolynomialAttr>(value));
+    }]>
+  ];
+  let extraClassDeclaration = [{
+    // used for constFoldBinaryOp
+    using ValueType = ::mlir::Attribute;
+  }];
+}
+
 def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
   let summary = "An attribute specifying a polynomial ring.";
   let description = [{
@@ -221,6 +270,7 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
     %2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
+  let hasFolder = 1;
 }
 
 def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
@@ -441,7 +491,7 @@ def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
 ]>;
 
 // Not deriving from Polynomial_Op due to need for custom assembly format
-def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
+def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure, ConstantLike]> {
   let summary = "Define a constant polynomial via an attribute.";
   let description = [{
     Example:
@@ -458,6 +508,7 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
   let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
   let results = (outs Polynomial_PolynomialType:$output);
   let assemblyFormat = "attr-dict `:` type($output)";
+  let hasFolder = 1;
 }
 
 def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp
index 825b80d70f803..05cc9fd8bbc58 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp
@@ -48,3 +48,17 @@ void PolynomialDialect::initialize() {
 #include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc"
       >();
 }
+
+Operation *PolynomialDialect::materializeConstant(OpBuilder &builder,
+                                                  Attribute value, Type type,
+                                                  Location loc) {
+  auto intPoly = dyn_cast<TypedIntPolynomialAttr>(value);
+  auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(value);
+  if (!intPoly && !floatPoly)
+    return nullptr;
+
+  Type ty = intPoly ? intPoly.getType() : floatPoly.getType();
+  Attribute valueAttr =
+      intPoly ? (Attribute)intPoly.getValue() : (Attribute)floatPoly.getValue();
+  return builder.create<ConstantOp>(loc, ty, valueAttr);
+}
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 12010de348237..f63d67d8dae0a 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -7,10 +7,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/Support/LogicalResult.h"
@@ -19,6 +22,41 @@
 using namespace mlir;
 using namespace mlir::polynomial;
 
+OpFoldResult ConstantOp::fold(ConstantOp::FoldAdaptor adaptor) {
+  PolynomialType ty = dyn_cast<PolynomialType>(getOutput().getType());
+
+  if (isa<FloatPolynomialAttr>(ty.getRing().getPolynomialModulus()))
+    return TypedFloatPolynomialAttr::get(
+        ty, cast<FloatPolynomialAttr>(getValue()).getPolynomial());
+
+  assert(isa<IntPolynomialAttr>(ty.getRing().getPolynomialModulus()) &&
+         "expected float or integer polynomial");
+  return TypedIntPolynomialAttr::get(
+      ty, cast<IntPolynomialAttr>(getValue()).getPolynomial());
+}
+
+OpFoldResult AddOp::fold(AddOp::FoldAdaptor adaptor) {
+  auto lhsElements = dyn_cast<ShapedType>(getLhs().getType());
+  PolynomialType elementType = cast<PolynomialType>(
+      lhsElements ? lhsElements.getElementType() : getLhs().getType());
+  MLIRContext *context = getContext();
+
+  if (isa<FloatType>(elementType.getRing().getCoefficientType()))
+    return constFoldBinaryOp<TypedFloatPolynomialAttr>(
+      adaptor.getOperands(), elementType, [&](Attribute a, const Attribute &b) {
+        return FloatPolynomialAttr::get(
+            context, cast<FloatPolynomialAttr>(a).getPolynomial().add(
+                         cast<FloatPolynomialAttr>(b).getPolynomial()));
+      });
+
+  return constFoldBinaryOp<TypedIntPolynomialAttr>(
+      adaptor.getOperands(), elementType, [&](Attribute a, const Attribute &b) {
+        return IntPolynomialAttr::get(
+            context, cast<IntPolynomialAttr>(a).getPolynomial().add(
+                         cast<IntPolynomialAttr>(b).getPolynomial()));
+      });
+}
+
 void FromTensorOp::build(OpBuilder &builder, OperationState &result,
                          Value input, RingAttr ring) {
   TensorType tensorType = dyn_cast<TensorType>(input.getType());
diff --git a/mlir/test/Dialect/Polynomial/folding.mlir b/mlir/test/Dialect/Polynomial/folding.mlir
new file mode 100644
index 0000000000000..c1545a32376e9
--- /dev/null
+++ b/mlir/test/Dialect/Polynomial/folding.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt --sccp --canonicalize %s | FileCheck %s
+
+// Tests for folding
+
+#my_poly = #polynomial.int_polynomial<1 + x**1024>
+#poly_3t = #polynomial.int_polynomial<3t>
+#poly_t3_plus_4t_plus_2 = #polynomial.int_polynomial<t**3 + 4t + 2>
+#modulus = #polynomial.int_polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#modulus, primitiveRoot=193>
+!poly_ty = !polynomial.polynomial<ring=#ring>
+
+// CHECK-LABEL: test_fold_add
+// CHECK-NEXT: polynomial.constant {value = #polynomial.int_polynomial<2 + 7x + x**3>}
+// CHECK-NEXT: return
+func.func @test_fold_add() -> !poly_ty {
+  %0 = polynomial.constant {value=#poly_3t} : !poly_ty
+  %1 = polynomial.constant {value=#poly_t3_plus_4t_plus_2} : !poly_ty
+  %2 = polynomial.add %0, %1 : !poly_ty
+  return %2 : !poly_ty
+}
+
+// Test elementwise folding of add
+// Test float folding of add
diff --git a/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp
index 485c2b64e4f21..95906ad42588e 100644
--- a/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp
+++ b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp
@@ -20,9 +20,10 @@ TEST(AddTest, checkSameDegreeAdditionOfIntPolynomial) {
 }
 
 TEST(AddTest, checkDifferentDegreeAdditionOfIntPolynomial) {
-  IntPolynomial x = IntPolynomial::fromCoefficients({1, 2});
+  IntMonomial term2t = IntMonomial(2, 1);
+  IntPolynomial x = IntPolynomial::fromMonomials({term2t}).value();
   IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4});
-  IntPolynomial expected = IntPolynomial::fromCoefficients({3, 5, 4});
+  IntPolynomial expected = IntPolynomial::fromCoefficients({2, 5, 4});
   EXPECT_EQ(expected, x.add(y));
   EXPECT_EQ(expected, y.add(x));
 }

>From 4c6c10dea5dfb4cdf0d6ca875b11b23bfadbfc36 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Tue, 14 May 2024 10:13:57 -0700
Subject: [PATCH 3/3] try constraints

---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  | 27 +++++++++++--
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   |  1 -
 mlir/test/Dialect/Polynomial/folding.mlir     | 39 +++++++++++++++++--
 3 files changed, 58 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 9b57a71ea6a80..14186c563beb8 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -62,7 +62,7 @@ class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
 }
 
 def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynomial"> {
-  let summary = "An attribute containing a single-variable polynomial with integer coefficients.";
+  let summary = "an attribute containing a single-variable polynomial with integer coefficients";
   let description = [{
     A polynomial attribute represents a single-variable polynomial with integer
     coefficients, which is used to define the modulus of a `RingAttr`, as well
@@ -109,7 +109,7 @@ def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
 }
 
 def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
-  let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
+  let summary = "an attribute containing a single-variable polynomial with double precision floating point coefficients";
   let description = [{
     A polynomial attribute represents a single-variable polynomial with double
     precision floating point coefficients.
@@ -489,6 +489,25 @@ def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
   Polynomial_FloatPolynomialAttr,
   Polynomial_IntPolynomialAttr
 ]>;
+def Polynomial_PolynomialElementsAttr :
+    ElementsAttrBase<And<[//CPred<"::llvm::isa<::mlir::ElementsAttr>($_self)">,
+                          CPred<[{
+                              isa<::mlir::polynomial::PolynomialType>(
+                                ::llvm::cast<::mlir::ElementsAttr>($_self)
+                                                                    .getShapedType()
+                                                                    .getElementType())
+                              }]>]>,
+                     "an elements attribute containing polynomial attributes"> {
+  let storageType = [{ ::mlir::ElementsAttr }];
+  let returnType = [{ ::mlir::ElementsAttr }];
+  let convertFromStorage = "$_self";
+}
+
+def Polynomial_PolynomialOrElementsAttr : AnyAttrOf<[
+  Polynomial_FloatPolynomialAttr,
+  Polynomial_IntPolynomialAttr,
+  Polynomial_PolynomialElementsAttr,
+]>;
 
 // Not deriving from Polynomial_Op due to need for custom assembly format
 def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure, ConstantLike]> {
@@ -505,8 +524,8 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure, ConstantLi
     %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring>
     ```
   }];
-  let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
-  let results = (outs Polynomial_PolynomialType:$output);
+  let arguments = (ins Polynomial_PolynomialOrElementsAttr:$value);
+  let results = (outs PolynomialLike:$output);
   let assemblyFormat = "attr-dict `:` type($output)";
   let hasFolder = 1;
 }
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index f63d67d8dae0a..8cbc3b4615140 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -7,7 +7,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
diff --git a/mlir/test/Dialect/Polynomial/folding.mlir b/mlir/test/Dialect/Polynomial/folding.mlir
index c1545a32376e9..3e52a108644ae 100644
--- a/mlir/test/Dialect/Polynomial/folding.mlir
+++ b/mlir/test/Dialect/Polynomial/folding.mlir
@@ -2,11 +2,9 @@
 
 // Tests for folding
 
-#my_poly = #polynomial.int_polynomial<1 + x**1024>
 #poly_3t = #polynomial.int_polynomial<3t>
 #poly_t3_plus_4t_plus_2 = #polynomial.int_polynomial<t**3 + 4t + 2>
-#modulus = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#modulus, primitiveRoot=193>
+#ring = #polynomial.ring<coefficientType=i32>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-LABEL: test_fold_add
@@ -19,5 +17,38 @@ func.func @test_fold_add() -> !poly_ty {
   return %2 : !poly_ty
 }
 
+// CHECK-LABEL: test_fold_add_elementwise
+// CHECK-NEXT: polynomial.constant {value = dense<
+// CHECK-SAME:  #polynomial.typed_int_polynomial<type=
+// CHECK-SAME:     value = <2 + 7x + x**3>>,
+// CHECK-SAME:  #polynomial.typed_int_polynomial<type=
+// CHECK-SAME:     value = <2 + 7x + x**3>>,
+// CHECK-SAME: ]>}
+// CHECK-NEXT: return
+#typed_poly1 = #polynomial.typed_int_polynomial<type=!poly_ty, value=<3t>>
+#typed_poly2 = #polynomial.typed_int_polynomial<type=!poly_ty, value=<t**3 + 4t + 2>>
+!tensor_ty = tensor<2x!poly_ty>
+func.func @test_fold_add_elementwise() -> !tensor_ty {
+  %0 = polynomial.constant {value=[#typed_poly1, #typed_poly2]} : !tensor_ty
+  %1 = polynomial.constant {value=[#typed_poly2, #typed_poly1]} : !tensor_ty
+  %2 = polynomial.add %0, %1 : !tensor_ty
+  return %2 : !tensor_ty
+}
+
+
+#fpoly_1 = #polynomial.float_polynomial<3.5t>
+#fpoly_2 = #polynomial.float_polynomial<1.0t**3 + 1.25t + 2.0>
+#fring = #polynomial.ring<coefficientType=f32>
+!fpoly_ty = !polynomial.polynomial<ring=#fring>
+
+// CHECK-LABEL: test_fold_add_float
+// CHECK-NEXT: polynomial.constant {value = #polynomial.float_polynomial<2 + 4.75x + x**3>}
+// CHECK-NEXT: return
+func.func @test_fold_add_float() -> !fpoly_ty {
+  %0 = polynomial.constant {value=#fpoly_1} : !fpoly_ty
+  %1 = polynomial.constant {value=#fpoly_2} : !fpoly_ty
+  %2 = polynomial.add %0, %1 : !fpoly_ty
+  return %2 : !fpoly_ty
+}
+
 // Test elementwise folding of add
-// Test float folding of add



More information about the Mlir-commits mailing list