[Mlir-commits] [mlir] [mlir][polynomial] use typed attributes for polynomial.constant op (PR #92818)

Jeremy Kun llvmlistbot at llvm.org
Tue May 21 15:59:54 PDT 2024


https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/92818

>From 1203b90c4ba7bfa79ab2fefe81ae7f05e96bde1c Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 20 May 2024 13:48:08 -0700
Subject: [PATCH 1/8] [mlir][polynomial] fix polynomial.constant syntax in
 docstrings

---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  | 20 +++++++++----------
 .../Polynomial/IR/PolynomialDialect.td        |  6 +++---
 2 files changed, 13 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 3ef899d3376b1..e03d2ec81e9c8 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -52,8 +52,8 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
     // add two polynomials modulo x^1024 - 1
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
-    %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
+    %1 = polynomial.constant {value=#polynomial.int_polynomial<x**5 - x + 1>} : !polynomial.polynomial<#ring>
     %2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
@@ -76,8 +76,8 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
     // subtract two polynomials modulo x^1024 - 1
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
-    %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
+    %1 = polynomial.constant {value=#polynomial.int_polynomial<x**5 - x + 1>} : !polynomial.polynomial<#ring>
     %2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
@@ -101,8 +101,8 @@ def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
     // multiply two polynomials modulo x^1024 - 1
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
-    %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
+    %1 = polynomial.constant {value=#polynomial.int_polynomial<x**5 - x + 1>} : !polynomial.polynomial<#ring>
     %2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
@@ -126,7 +126,7 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
     // multiply two polynomials modulo x^1024 - 1
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
     %1 = arith.constant 3 : i32
     %2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32
     ```
@@ -157,7 +157,7 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
     ```mlir
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
     %1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32)
     ```
   }];
@@ -286,10 +286,10 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
     ```mlir
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
 
     #float_ring = #polynomial.ring<coefficientType=f32>
-    %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring>
+    %0 = polynomial.constant {value=#polynomial.float_polynomial<0.5 + 1.3e06 x**2>} : !polynomial.polynomial<#float_ring>
     ```
   }];
   let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td
index b0573b3715f78..73783815781cf 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td
@@ -33,18 +33,18 @@ def Polynomial_Dialect : Dialect {
     ```mlir
     // A constant polynomial in a ring with i32 coefficients and no polynomial modulus
     #ring = #polynomial.ring<coefficientType=i32>
-    %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
+    %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring>
 
     // A constant polynomial in a ring with i32 coefficients, modulo (x^1024 + 1)
     #modulus = #polynomial.int_polynomial<1 + x**1024>
     #ring = #polynomial.ring<coefficientType=i32, polynomialModulus=#modulus>
-    %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
+    %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring>
 
     // A constant polynomial in a ring with i32 coefficients, with a polynomial
     // modulus of (x^1024 + 1) and a coefficient modulus of 17.
     #modulus = #polynomial.int_polynomial<1 + x**1024>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=17:i32, polynomialModulus=#modulus>
-    %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
+    %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring>
     ```
   }];
 

>From 29b42317e7b3800ffe4a3b4b0f89f499ce4e51e7 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 20 May 2024 15:53:26 -0700
Subject: [PATCH 2/8] Revert "[mlir][polynomial] fix polynomial.constant syntax
 in docstrings"

This reverts commit 1203b90c4ba7bfa79ab2fefe81ae7f05e96bde1c.
---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  | 20 +++++++++----------
 .../Polynomial/IR/PolynomialDialect.td        |  6 +++---
 2 files changed, 13 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index e03d2ec81e9c8..3ef899d3376b1 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -52,8 +52,8 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
     // add two polynomials modulo x^1024 - 1
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
-    %1 = polynomial.constant {value=#polynomial.int_polynomial<x**5 - x + 1>} : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
     %2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
@@ -76,8 +76,8 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
     // subtract two polynomials modulo x^1024 - 1
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
-    %1 = polynomial.constant {value=#polynomial.int_polynomial<x**5 - x + 1>} : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
     %2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
@@ -101,8 +101,8 @@ def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
     // multiply two polynomials modulo x^1024 - 1
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
-    %1 = polynomial.constant {value=#polynomial.int_polynomial<x**5 - x + 1>} : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
     %2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
@@ -126,7 +126,7 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
     // multiply two polynomials modulo x^1024 - 1
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
     %1 = arith.constant 3 : i32
     %2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32
     ```
@@ -157,7 +157,7 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
     ```mlir
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
     %1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32)
     ```
   }];
@@ -286,10 +286,10 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
     ```mlir
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
 
     #float_ring = #polynomial.ring<coefficientType=f32>
-    %0 = polynomial.constant {value=#polynomial.float_polynomial<0.5 + 1.3e06 x**2>} : !polynomial.polynomial<#float_ring>
+    %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring>
     ```
   }];
   let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td
index 73783815781cf..b0573b3715f78 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.td
@@ -33,18 +33,18 @@ def Polynomial_Dialect : Dialect {
     ```mlir
     // A constant polynomial in a ring with i32 coefficients and no polynomial modulus
     #ring = #polynomial.ring<coefficientType=i32>
-    %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring>
+    %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
 
     // A constant polynomial in a ring with i32 coefficients, modulo (x^1024 + 1)
     #modulus = #polynomial.int_polynomial<1 + x**1024>
     #ring = #polynomial.ring<coefficientType=i32, polynomialModulus=#modulus>
-    %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring>
+    %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
 
     // A constant polynomial in a ring with i32 coefficients, with a polynomial
     // modulus of (x^1024 + 1) and a coefficient modulus of 17.
     #modulus = #polynomial.int_polynomial<1 + x**1024>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=17:i32, polynomialModulus=#modulus>
-    %a = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2 - 3x**3>} : polynomial.polynomial<#ring>
+    %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
     ```
   }];
 

>From f6276fe2d81a883676cfc24de152c0f16afdc16d Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 20 May 2024 17:01:42 -0700
Subject: [PATCH 3/8] add typed variants for polynomial.constant op

---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  | 13 ++---
 .../Polynomial/IR/PolynomialAttributes.td     | 54 +++++++++++++++++--
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 15 ++++++
 mlir/test/Dialect/Polynomial/ops.mlir         |  8 +--
 4 files changed, 77 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 3ef899d3376b1..85a9dd6b935d2 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -272,13 +272,14 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
   let hasVerifier = 1;
 }
 
-def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
-  Polynomial_FloatPolynomialAttr,
-  Polynomial_IntPolynomialAttr
+def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
+  Polynomial_TypedFloatPolynomialAttr,
+  Polynomial_TypedIntPolynomialAttr
 ]>;
 
 // Not deriving from Polynomial_Op due to need for custom assembly format
-def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
+def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
+    [Pure, InferTypeOpAdaptor]> {
   let summary = "Define a constant polynomial via an attribute.";
   let description = [{
     Example:
@@ -292,9 +293,9 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
     %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring>
     ```
   }];
-  let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
+  let arguments = (ins Polynomial_AnyTypedPolynomialAttr:$value);
   let results = (outs Polynomial_PolynomialType:$output);
-  let assemblyFormat = "attr-dict `:` type($output)";
+  let assemblyFormat = "attr-dict $value";
 }
 
 def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
index e5dbfa7fa21ee..1ea07e21e0076 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
@@ -18,7 +18,7 @@ class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
 }
 
 def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynomial"> {
-  let summary = "An attribute containing a single-variable polynomial with integer coefficients.";
+  let summary = "an attribute containing a single-variable polynomial with integer coefficients";
   let description = [{
     A polynomial attribute represents a single-variable polynomial with integer
     coefficients, which is used to define the modulus of a `RingAttr`, as well
@@ -41,7 +41,7 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom
 }
 
 def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
-  let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
+  let summary = "an attribute containing a single-variable polynomial with double precision floating point coefficients";
   let description = [{
     A polynomial attribute represents a single-variable polynomial with double
     precision floating point coefficients.
@@ -62,8 +62,56 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p
   let hasCustomAssemblyFormat = 1;
 }
 
+def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
+    "TypedIntPolynomial", "typed_int_polynomial", [TypedAttrInterface]> {
+  let summary = "a typed int_polynomial";
+  let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::IntPolynomialAttr":$value);
+  let assemblyFormat = "$value `:` $type";
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "Type":$type,
+                                        "const IntPolynomial &":$value), [{
+      return $_get(
+        type.getContext(),
+        type,
+        IntPolynomialAttr::get(type.getContext(), value));
+    }]>,
+    AttrBuilderWithInferredContext<(ins "Type":$type,
+                                        "const Attribute &":$value), [{
+      return $_get(type.getContext(), type, ::llvm::cast<IntPolynomialAttr>(value));
+    }]>
+  ];
+  let extraClassDeclaration = [{
+    // used for constFoldBinaryOp
+    using ValueType = ::mlir::Attribute;
+  }];
+}
+
+def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
+    "TypedFloatPolynomial", "typed_float_polynomial", [TypedAttrInterface]> {
+  let summary = "a typed float_polynomial";
+  let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::FloatPolynomialAttr":$value);
+  let assemblyFormat = "$value `:` $type";
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "Type":$type,
+                                        "const FloatPolynomial &":$value), [{
+      return $_get(
+        type.getContext(),
+        type,
+        FloatPolynomialAttr::get(type.getContext(), value));
+    }]>,
+    AttrBuilderWithInferredContext<(ins "Type":$type,
+                                        "const Attribute &":$value), [{
+      return $_get(type.getContext(), type, ::llvm::cast<FloatPolynomialAttr>(value));
+    }]>
+  ];
+  let extraClassDeclaration = [{
+    // used for constFoldBinaryOp
+    using ValueType = ::mlir::Attribute;
+  }];
+}
+
 def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
-  let summary = "An attribute specifying a polynomial ring.";
+  let summary = "an attribute specifying a polynomial ring";
   let description = [{
     A ring describes the domain in which polynomial arithmetic occurs. The ring
     attribute in `polynomial` represents the more specific case of polynomials
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 1a2439fe810b5..4c2fed6bab312 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -186,6 +186,21 @@ LogicalResult INTTOp::verify() {
   return verifyNTTOp(this->getOperation(), ring, tensorType);
 }
 
+LogicalResult ConstantOp::inferReturnTypes(
+    MLIRContext *context, std::optional<mlir::Location> location,
+    ConstantOp::Adaptor adaptor,
+    llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
+  Attribute operand = adaptor.getValue();
+  if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(operand)) {
+    inferredReturnTypes.push_back(intPoly.getType());
+  } else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(operand)) {
+    inferredReturnTypes.push_back(floatPoly.getType());
+  } else {
+    return failure();
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd canonicalization patterns
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index ff709960c50e9..695b1acf18bd7 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -74,15 +74,15 @@ module {
 
   func.func @test_monic_monomial_mul() {
     %five = arith.constant 5 : index
-    %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
+    %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<ring=#ring1>
     %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<ring=#ring1>, index) -> !polynomial.polynomial<ring=#ring1>
     return
   }
 
   func.func @test_constant() {
-    %0 = polynomial.constant {value=#one_plus_x_squared} : !polynomial.polynomial<ring=#ring1>
-    %1 = polynomial.constant {value=#polynomial.int_polynomial<1 + x**2>} : !polynomial.polynomial<ring=#ring1>
-    %2 = polynomial.constant {value=#polynomial.float_polynomial<1.5 + 0.5 x**2>} : !polynomial.polynomial<ring=#ring2>
+    %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<ring=#ring1>
+    %1 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<ring=#ring1>
+    %2 = polynomial.constant #polynomial.float_polynomial<1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
     return
   }
 

>From ef17f2af7f75bd2c98c05720eb8d3f77d652ef43 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 20 May 2024 22:25:14 -0700
Subject: [PATCH 4/8] show broken attempt

---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  |  2 +-
 .../Polynomial/IR/PolynomialAttributes.td     | 12 +++-
 .../Polynomial/IR/PolynomialAttributes.cpp    | 64 ++++++++++++-------
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 55 ++++++++++++++++
 mlir/test/Dialect/Polynomial/ops.mlir         |  8 +--
 5 files changed, 111 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 85a9dd6b935d2..a0bd0bb0861bd 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -295,7 +295,7 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
   }];
   let arguments = (ins Polynomial_AnyTypedPolynomialAttr:$value);
   let results = (outs Polynomial_PolynomialType:$output);
-  let assemblyFormat = "attr-dict $value";
+  let hasCustomAssemblyFormat = 1;
 }
 
 def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
index 1ea07e21e0076..3bae6204299d1 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
@@ -38,6 +38,11 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom
   }];
   let parameters = (ins "::mlir::polynomial::IntPolynomial":$polynomial);
   let hasCustomAssemblyFormat = 1;
+  let extraClassDeclaration = [{
+    /// A parser which, upon failure to parse, does not emit errors and just returns
+    /// a null attribute.
+    static Attribute parse(AsmParser &parser, Type type, bool optional);
+  }];
 }
 
 def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
@@ -60,6 +65,11 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p
   }];
   let parameters = (ins "FloatPolynomial":$polynomial);
   let hasCustomAssemblyFormat = 1;
+  let extraClassDeclaration = [{
+    /// A parser which, upon failure to parse, does not emit errors and just returns
+    /// a null attribute.
+    static Attribute parse(AsmParser &parser, Type type, bool optional);
+  }];
 }
 
 def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
@@ -81,7 +91,6 @@ def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
     }]>
   ];
   let extraClassDeclaration = [{
-    // used for constFoldBinaryOp
     using ValueType = ::mlir::Attribute;
   }];
 }
@@ -105,7 +114,6 @@ def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
     }]>
   ];
   let extraClassDeclaration = [{
-    // used for constFoldBinaryOp
     using ValueType = ::mlir::Attribute;
   }];
 }
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index 890ce5226c30f..94169b5e93cf8 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -38,10 +38,11 @@ using ParseCoefficientFn = std::function<OptionalParseResult(MonomialType &)>;
 /// a '+'.
 ///
 template <typename Monomial>
-ParseResult
-parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
-              bool &isConstantTerm, bool &shouldParseMore,
-              ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
+ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
+                          llvm::StringRef &variable, bool &isConstantTerm,
+                          bool &shouldParseMore,
+                          ParseCoefficientFn<Monomial> parseAndStoreCoefficient,
+                          bool optional) {
   OptionalParseResult parsedCoeffResult = parseAndStoreCoefficient(monomial);
 
   isConstantTerm = false;
@@ -85,8 +86,9 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
     // If there's a **, then the integer exponent is required.
     APInt parsedExponent(apintBitWidth, 0);
     if (failed(parser.parseInteger(parsedExponent))) {
-      parser.emitError(parser.getCurrentLocation(),
-                       "found invalid integer exponent");
+      if (!optional)
+        parser.emitError(parser.getCurrentLocation(),
+                         "found invalid integer exponent");
       return failure();
     }
 
@@ -101,11 +103,12 @@ parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
   return success();
 }
 
-template <typename PolynoimalAttrTy, typename Monomial>
+template <typename Monomial>
 LogicalResult
 parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
                     llvm::StringSet<> &variables,
-                    ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
+                    ParseCoefficientFn<Monomial> parseAndStoreCoefficient,
+                    bool optional) {
   while (true) {
     Monomial parsedMonomial;
     llvm::StringRef parsedVariableRef;
@@ -113,8 +116,9 @@ parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
     bool shouldParseMore;
     if (failed(parseMonomial<Monomial>(
             parser, parsedMonomial, parsedVariableRef, isConstantTerm,
-            shouldParseMore, parseAndStoreCoefficient))) {
-      parser.emitError(parser.getCurrentLocation(), "expected a monomial");
+            shouldParseMore, parseAndStoreCoefficient, optional))) {
+      if (!optional)
+        parser.emitError(parser.getCurrentLocation(), "expected a monomial");
       return failure();
     }
 
@@ -130,18 +134,20 @@ parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
     if (succeeded(parser.parseOptionalGreater())) {
       break;
     }
-    parser.emitError(
-        parser.getCurrentLocation(),
-        "expected + and more monomials, or > to end polynomial attribute");
+    if (!optional)
+      parser.emitError(
+          parser.getCurrentLocation(),
+          "expected + and more monomials, or > to end polynomial attribute");
     return failure();
   }
 
   if (variables.size() > 1) {
     std::string vars = llvm::join(variables.keys(), ", ");
-    parser.emitError(
-        parser.getCurrentLocation(),
-        "polynomials must have one indeterminate, but there were multiple: " +
-            vars);
+    if (!optional)
+      parser.emitError(
+          parser.getCurrentLocation(),
+          "polynomials must have one indeterminate, but there were multiple: " +
+              vars);
     return failure();
   }
 
@@ -149,13 +155,18 @@ parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
 }
 
 Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) {
+  return IntPolynomialAttr::parse(parser, type, /*optional=*/false);
+}
+
+Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type,
+                                   bool optional) {
   if (failed(parser.parseLess()))
     return {};
 
   llvm::SmallVector<IntMonomial> monomials;
   llvm::StringSet<> variables;
 
-  if (failed(parsePolynomialAttr<IntPolynomialAttr, IntMonomial>(
+  if (failed(parsePolynomialAttr<IntMonomial>(
           parser, monomials, variables,
           [&](IntMonomial &monomial) -> OptionalParseResult {
             APInt parsedCoeff(apintBitWidth, 1);
@@ -163,20 +174,27 @@ Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) {
                 parser.parseOptionalInteger(parsedCoeff);
             monomial.setCoefficient(parsedCoeff);
             return result;
-          }))) {
+          },
+          optional))) {
     return {};
   }
 
   auto result = IntPolynomial::fromMonomials(monomials);
   if (failed(result)) {
-    parser.emitError(parser.getCurrentLocation())
-        << "parsed polynomial must have unique exponents among monomials";
+    if (!optional)
+      parser.emitError(parser.getCurrentLocation())
+          << "parsed polynomial must have unique exponents among monomials";
     return {};
   }
   return IntPolynomialAttr::get(parser.getContext(), result.value());
 }
 
 Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
+  return FloatPolynomialAttr::parse(parser, type, /*optional=*/false);
+}
+
+Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type,
+                                     bool optional) {
   if (failed(parser.parseLess()))
     return {};
 
@@ -191,8 +209,8 @@ Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
     return OptionalParseResult(result);
   };
 
-  if (failed(parsePolynomialAttr<FloatPolynomialAttr, FloatMonomial>(
-          parser, monomials, variables, parseAndStoreCoefficient))) {
+  if (failed(parsePolynomialAttr<FloatMonomial>(
+          parser, monomials, variables, parseAndStoreCoefficient, optional))) {
     return {};
   }
 
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 4c2fed6bab312..c7c61d2ad8190 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -186,6 +186,60 @@ LogicalResult INTTOp::verify() {
   return verifyNTTOp(this->getOperation(), ring, tensorType);
 }
 
+ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
+  // Using the built-in parser.parseAttribute requires the full
+  // #polynomial.typed_int_polynomial syntax, which is excessive.
+  // Instead we manually parse the components.
+  Type type;
+  parser.parseOptionalAttribute();
+
+  IntPolynomialAttr intPolyAttr;
+  parser.parseOptionalAttribute(intPolyAttr);
+  if (intPolyAttr) {
+    if (parser.parseColon() || parser.parseType(type))
+      return failure();
+
+    result.addAttribute("value",
+                        TypedIntPolynomialAttr::get(type, intPolyAttr));
+    result.addTypes(type);
+    return success();
+  }
+
+  Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr, /*optional=*/true);
+  if (floatPolyAttr) {
+    if (parser.parseColon() || parser.parseType(type))
+      return failure();
+    result.addAttribute("value",
+                        TypedFloatPolynomialAttr::get(type, intPolyAttr));
+    result.addTypes(type);
+    return success();
+  }
+
+  // In the worst case, still accept the verbose versions.
+  TypedIntPolynomialAttr typedIntPolyAttr;
+  ParseResult res = parser.parseAttribute<TypedIntPolynomialAttr>(
+      typedIntPolyAttr, "value", result.attributes);
+  if (succeeded(res)) {
+    result.addTypes(typedIntPolyAttr.getType());
+    return success();
+  }
+
+  TypedFloatPolynomialAttr typedFloatPolyAttr;
+  res = parser.parseAttribute<TypedFloatPolynomialAttr>(
+      typedFloatPolyAttr, "value", result.attributes);
+  if (succeeded(res)) {
+    result.addTypes(typedFloatPolyAttr.getType());
+    return success();
+  }
+
+  return failure();
+}
+
+void ConstantOp::print(OpAsmPrinter &p) {
+  p << " ";
+  p.printAttribute(getValue());
+}
+
 LogicalResult ConstantOp::inferReturnTypes(
     MLIRContext *context, std::optional<mlir::Location> location,
     ConstantOp::Adaptor adaptor,
@@ -196,6 +250,7 @@ LogicalResult ConstantOp::inferReturnTypes(
   } else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(operand)) {
     inferredReturnTypes.push_back(floatPoly.getType());
   } else {
+    assert(false && "unexpected attribute type");
     return failure();
   }
   return success();
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index 695b1acf18bd7..cfe3446a1dccf 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -74,15 +74,15 @@ module {
 
   func.func @test_monic_monomial_mul() {
     %five = arith.constant 5 : index
-    %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<ring=#ring1>
+    %0 = polynomial.constant <1 + x**2> : !polynomial.polynomial<ring=#ring1>
     %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<ring=#ring1>, index) -> !polynomial.polynomial<ring=#ring1>
     return
   }
 
   func.func @test_constant() {
-    %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<ring=#ring1>
-    %1 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<ring=#ring1>
-    %2 = polynomial.constant #polynomial.float_polynomial<1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
+    %0 = polynomial.constant <1 + x**2> : !polynomial.polynomial<ring=#ring1>
+    %1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<ring=#ring1>
+    %2 = polynomial.constant <1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
     return
   }
 

>From 74470e38e3b380c0ae1fe6951ada2fe8a189d532 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Tue, 21 May 2024 15:42:17 -0700
Subject: [PATCH 5/8] use int/float keywords

---
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 61 +++++++++++--------
 mlir/test/Dialect/Polynomial/ops.mlir         | 12 ++--
 2 files changed, 44 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index c7c61d2ad8190..38e7db85a1e97 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -189,37 +189,38 @@ LogicalResult INTTOp::verify() {
 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
   // Using the built-in parser.parseAttribute requires the full
   // #polynomial.typed_int_polynomial syntax, which is excessive.
-  // Instead we manually parse the components.
+  // Instead we parse a keyword int to signal it's an integer polynomial
   Type type;
-  parser.parseOptionalAttribute();
-
-  IntPolynomialAttr intPolyAttr;
-  parser.parseOptionalAttribute(intPolyAttr);
-  if (intPolyAttr) {
-    if (parser.parseColon() || parser.parseType(type))
-      return failure();
-
-    result.addAttribute("value",
-                        TypedIntPolynomialAttr::get(type, intPolyAttr));
-    result.addTypes(type);
-    return success();
+  if (succeeded(parser.parseOptionalKeyword("float"))) {
+    Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr);
+    if (floatPolyAttr) {
+      if (parser.parseColon() || parser.parseType(type))
+        return failure();
+      result.addAttribute("value",
+                          TypedFloatPolynomialAttr::get(type, floatPolyAttr));
+      result.addTypes(type);
+      return success();
+    }
   }
 
-  Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr, /*optional=*/true);
-  if (floatPolyAttr) {
-    if (parser.parseColon() || parser.parseType(type))
-      return failure();
-    result.addAttribute("value",
-                        TypedFloatPolynomialAttr::get(type, intPolyAttr));
-    result.addTypes(type);
-    return success();
+  if (succeeded(parser.parseOptionalKeyword("int"))) {
+    Attribute intPolyAttr = IntPolynomialAttr::parse(parser, nullptr);
+    if (intPolyAttr) {
+      if (parser.parseColon() || parser.parseType(type))
+        return failure();
+
+      result.addAttribute("value",
+                          TypedIntPolynomialAttr::get(type, intPolyAttr));
+      result.addTypes(type);
+      return success();
+    }
   }
 
   // In the worst case, still accept the verbose versions.
   TypedIntPolynomialAttr typedIntPolyAttr;
-  ParseResult res = parser.parseAttribute<TypedIntPolynomialAttr>(
+  OptionalParseResult res = parser.parseOptionalAttribute<TypedIntPolynomialAttr>(
       typedIntPolyAttr, "value", result.attributes);
-  if (succeeded(res)) {
+  if (res.has_value() && succeeded(res.value())) {
     result.addTypes(typedIntPolyAttr.getType());
     return success();
   }
@@ -227,7 +228,7 @@ ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
   TypedFloatPolynomialAttr typedFloatPolyAttr;
   res = parser.parseAttribute<TypedFloatPolynomialAttr>(
       typedFloatPolyAttr, "value", result.attributes);
-  if (succeeded(res)) {
+  if (res.has_value() && succeeded(res.value())) {
     result.addTypes(typedFloatPolyAttr.getType());
     return success();
   }
@@ -237,7 +238,17 @@ ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
 
 void ConstantOp::print(OpAsmPrinter &p) {
   p << " ";
-  p.printAttribute(getValue());
+  if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(getValue())) {
+    p << "int";
+    intPoly.getValue().print(p);
+  } else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(getValue())) {
+    p << "float";
+    floatPoly.getValue().print(p);
+  } else {
+    assert(false && "unexpected attribute type");
+  }
+  p << " : ";
+  p.printType(getOutput().getType());
 }
 
 LogicalResult ConstantOp::inferReturnTypes(
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index cfe3446a1dccf..4716e37ff8852 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -74,15 +74,19 @@ module {
 
   func.func @test_monic_monomial_mul() {
     %five = arith.constant 5 : index
-    %0 = polynomial.constant <1 + x**2> : !polynomial.polynomial<ring=#ring1>
+    %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<ring=#ring1>
     %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<ring=#ring1>, index) -> !polynomial.polynomial<ring=#ring1>
     return
   }
 
   func.func @test_constant() {
-    %0 = polynomial.constant <1 + x**2> : !polynomial.polynomial<ring=#ring1>
-    %1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<ring=#ring1>
-    %2 = polynomial.constant <1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
+    %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<ring=#ring1>
+    %1 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<ring=#ring1>
+    %2 = polynomial.constant float<1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
+
+    // Test verbose fallbacks
+    %verb0 = polynomial.constant #polynomial.typed_int_polynomial<1 + x**2> : !polynomial.polynomial<ring=#ring1>
+    %verb2 = polynomial.constant #polynomial.typed_float_polynomial<1.5 + 0.5 x**2> : !polynomial.polynomial<ring=#ring2>
     return
   }
 

>From 431bf8af1e851d8261e50a06413bf9955dafa97d Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Tue, 21 May 2024 15:50:32 -0700
Subject: [PATCH 6/8] update docs one last time

---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  | 25 +++++++++----------
 .../Polynomial/IR/PolynomialAttributes.td     | 18 +++++++++++++
 2 files changed, 30 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index a0bd0bb0861bd..f99cbccd243ec 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -52,8 +52,8 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
     // add two polynomials modulo x^1024 - 1
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
-    %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring>
+    %1 = polynomial.constant int<x**5 - x + 1> : !polynomial.polynomial<#ring>
     %2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
@@ -76,8 +76,8 @@ def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
     // subtract two polynomials modulo x^1024 - 1
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
-    %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring>
+    %1 = polynomial.constant int<x**5 - x + 1> : !polynomial.polynomial<#ring>
     %2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
@@ -101,8 +101,8 @@ def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
     // multiply two polynomials modulo x^1024 - 1
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
-    %1 = polynomial.constant #polynomial.int_polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring>
+    %1 = polynomial.constant int<x**5 - x + 1> : !polynomial.polynomial<#ring>
     %2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring>
     ```
   }];
@@ -126,7 +126,7 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
     // multiply two polynomials modulo x^1024 - 1
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring>
     %1 = arith.constant 3 : i32
     %2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32
     ```
@@ -157,7 +157,7 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
     ```mlir
     #poly = #polynomial.int_polynomial<x**1024 - 1>
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %0 = polynomial.constant int<1 + x**2> : !polynomial.polynomial<#ring>
     %1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32)
     ```
   }];
@@ -285,12 +285,11 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
     Example:
 
     ```mlir
-    #poly = #polynomial.int_polynomial<x**1024 - 1>
-    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536:i32, polynomialModulus=#poly>
-    %0 = polynomial.constant #polynomial.int_polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    !int_poly_ty = !polynomial.polynomial<ring=<coefficientType=i32>>
+    %0 = polynomial.constant int<1 + x**2> : !int_poly_ty
 
-    #float_ring = #polynomial.ring<coefficientType=f32>
-    %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring>
+    !float_poly_ty = !polynomial.polynomial<ring=<coefficientType=f32>>
+    %1 = polynomial.constant float<0.5 + 1.3e06 x**2> : !float_poly_ty
     ```
   }];
   let arguments = (ins Polynomial_AnyTypedPolynomialAttr:$value);
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
index 3bae6204299d1..5298542faac9a 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
@@ -75,6 +75,15 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p
 def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
     "TypedIntPolynomial", "typed_int_polynomial", [TypedAttrInterface]> {
   let summary = "a typed int_polynomial";
+  let description = [{
+    Example:
+
+    ```mlir
+    !poly_ty = !polynomial.polynomial<ring=<coefficientType=i32>>
+    #poly = int<1 x**7 + 4> : !poly_ty
+    #poly_verbose = #polynomial.typed_int_polynomial<1 x**7 + 4> : !poly_ty
+    ```
+  }];
   let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::IntPolynomialAttr":$value);
   let assemblyFormat = "$value `:` $type";
   let builders = [
@@ -98,6 +107,15 @@ def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
 def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
     "TypedFloatPolynomial", "typed_float_polynomial", [TypedAttrInterface]> {
   let summary = "a typed float_polynomial";
+  let description = [{
+    Example:
+
+    ```mlir
+    !poly_ty = !polynomial.polynomial<ring=<coefficientType=f32>>
+    #poly = float<1.4 x**7 + 4.5> : !poly_ty
+    #poly_verbose = #polynomial.typed_float_polynomial<1.4 x**7 + 4.5> : !poly_ty
+    ```
+  }];
   let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::FloatPolynomialAttr":$value);
   let assemblyFormat = "$value `:` $type";
   let builders = [

>From f9019bcddeeae49c77f2d82250c2e765a32ae716 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Tue, 21 May 2024 15:53:47 -0700
Subject: [PATCH 7/8] remove optional parse option

---
 .../Polynomial/IR/PolynomialAttributes.td     | 10 ---
 .../Polynomial/IR/PolynomialAttributes.cpp    | 61 +++++++------------
 2 files changed, 21 insertions(+), 50 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
index 5298542faac9a..655020adf808b 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
@@ -38,11 +38,6 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom
   }];
   let parameters = (ins "::mlir::polynomial::IntPolynomial":$polynomial);
   let hasCustomAssemblyFormat = 1;
-  let extraClassDeclaration = [{
-    /// A parser which, upon failure to parse, does not emit errors and just returns
-    /// a null attribute.
-    static Attribute parse(AsmParser &parser, Type type, bool optional);
-  }];
 }
 
 def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
@@ -65,11 +60,6 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p
   }];
   let parameters = (ins "FloatPolynomial":$polynomial);
   let hasCustomAssemblyFormat = 1;
-  let extraClassDeclaration = [{
-    /// A parser which, upon failure to parse, does not emit errors and just returns
-    /// a null attribute.
-    static Attribute parse(AsmParser &parser, Type type, bool optional);
-  }];
 }
 
 def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index 94169b5e93cf8..cc7d3172b1a1d 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -38,11 +38,10 @@ using ParseCoefficientFn = std::function<OptionalParseResult(MonomialType &)>;
 /// a '+'.
 ///
 template <typename Monomial>
-ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
-                          llvm::StringRef &variable, bool &isConstantTerm,
-                          bool &shouldParseMore,
-                          ParseCoefficientFn<Monomial> parseAndStoreCoefficient,
-                          bool optional) {
+ParseResult
+parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable,
+              bool &isConstantTerm, bool &shouldParseMore,
+              ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
   OptionalParseResult parsedCoeffResult = parseAndStoreCoefficient(monomial);
 
   isConstantTerm = false;
@@ -86,9 +85,8 @@ ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
     // If there's a **, then the integer exponent is required.
     APInt parsedExponent(apintBitWidth, 0);
     if (failed(parser.parseInteger(parsedExponent))) {
-      if (!optional)
-        parser.emitError(parser.getCurrentLocation(),
-                         "found invalid integer exponent");
+      parser.emitError(parser.getCurrentLocation(),
+                       "found invalid integer exponent");
       return failure();
     }
 
@@ -107,8 +105,7 @@ template <typename Monomial>
 LogicalResult
 parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
                     llvm::StringSet<> &variables,
-                    ParseCoefficientFn<Monomial> parseAndStoreCoefficient,
-                    bool optional) {
+                    ParseCoefficientFn<Monomial> parseAndStoreCoefficient) {
   while (true) {
     Monomial parsedMonomial;
     llvm::StringRef parsedVariableRef;
@@ -116,9 +113,8 @@ parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
     bool shouldParseMore;
     if (failed(parseMonomial<Monomial>(
             parser, parsedMonomial, parsedVariableRef, isConstantTerm,
-            shouldParseMore, parseAndStoreCoefficient, optional))) {
-      if (!optional)
-        parser.emitError(parser.getCurrentLocation(), "expected a monomial");
+            shouldParseMore, parseAndStoreCoefficient))) {
+      parser.emitError(parser.getCurrentLocation(), "expected a monomial");
       return failure();
     }
 
@@ -134,20 +130,18 @@ parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
     if (succeeded(parser.parseOptionalGreater())) {
       break;
     }
-    if (!optional)
-      parser.emitError(
-          parser.getCurrentLocation(),
-          "expected + and more monomials, or > to end polynomial attribute");
+    parser.emitError(
+        parser.getCurrentLocation(),
+        "expected + and more monomials, or > to end polynomial attribute");
     return failure();
   }
 
   if (variables.size() > 1) {
     std::string vars = llvm::join(variables.keys(), ", ");
-    if (!optional)
-      parser.emitError(
-          parser.getCurrentLocation(),
-          "polynomials must have one indeterminate, but there were multiple: " +
-              vars);
+    parser.emitError(
+        parser.getCurrentLocation(),
+        "polynomials must have one indeterminate, but there were multiple: " +
+            vars);
     return failure();
   }
 
@@ -155,11 +149,6 @@ parsePolynomialAttr(AsmParser &parser, llvm::SmallVector<Monomial> &monomials,
 }
 
 Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type) {
-  return IntPolynomialAttr::parse(parser, type, /*optional=*/false);
-}
-
-Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type,
-                                   bool optional) {
   if (failed(parser.parseLess()))
     return {};
 
@@ -174,27 +163,19 @@ Attribute IntPolynomialAttr::parse(AsmParser &parser, Type type,
                 parser.parseOptionalInteger(parsedCoeff);
             monomial.setCoefficient(parsedCoeff);
             return result;
-          },
-          optional))) {
+          }))) {
     return {};
   }
 
   auto result = IntPolynomial::fromMonomials(monomials);
   if (failed(result)) {
-    if (!optional)
-      parser.emitError(parser.getCurrentLocation())
-          << "parsed polynomial must have unique exponents among monomials";
+    parser.emitError(parser.getCurrentLocation())
+        << "parsed polynomial must have unique exponents among monomials";
     return {};
   }
   return IntPolynomialAttr::get(parser.getContext(), result.value());
 }
-
 Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
-  return FloatPolynomialAttr::parse(parser, type, /*optional=*/false);
-}
-
-Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type,
-                                     bool optional) {
   if (failed(parser.parseLess()))
     return {};
 
@@ -209,8 +190,8 @@ Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type,
     return OptionalParseResult(result);
   };
 
-  if (failed(parsePolynomialAttr<FloatMonomial>(
-          parser, monomials, variables, parseAndStoreCoefficient, optional))) {
+  if (failed(parsePolynomialAttr<FloatMonomial>(parser, monomials, variables,
+                                                parseAndStoreCoefficient))) {
     return {};
   }
 

>From 0da56f3f6270069c589cadd6e920a927de536f35 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Tue, 21 May 2024 15:59:41 -0700
Subject: [PATCH 8/8] clang-format

---
 mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 38e7db85a1e97..d0a25fd9288b9 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -218,8 +218,9 @@ ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
 
   // In the worst case, still accept the verbose versions.
   TypedIntPolynomialAttr typedIntPolyAttr;
-  OptionalParseResult res = parser.parseOptionalAttribute<TypedIntPolynomialAttr>(
-      typedIntPolyAttr, "value", result.attributes);
+  OptionalParseResult res =
+      parser.parseOptionalAttribute<TypedIntPolynomialAttr>(
+          typedIntPolyAttr, "value", result.attributes);
   if (res.has_value() && succeeded(res.value())) {
     result.addTypes(typedIntPolyAttr.getType());
     return success();



More information about the Mlir-commits mailing list