[Mlir-commits] [mlir] polynomial: Add basic ops (PR #89525)
    Jeremy Kun 
    llvmlistbot at llvm.org
       
    Thu Apr 25 10:46:31 PDT 2024
    
    
  
https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/89525
>From a22b4c668e6bf3349ebff853fc3ae266c21ad4cb Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sat, 20 Apr 2024 18:48:38 -0700
Subject: [PATCH 01/16] add basic polynomial ops
---
 .../mlir/Dialect/Polynomial/IR/Polynomial.h   |   2 +
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  | 128 +++++++++++++++++-
 .../Polynomial/IR/PolynomialDialect.cpp       |   9 ++
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   |  84 +++++++++++-
 mlir/test/Dialect/Polynomial/ops.mlir         |  75 ++++++++++
 mlir/test/Dialect/Polynomial/ops_errors.mlir  |  13 ++
 6 files changed, 302 insertions(+), 9 deletions(-)
 create mode 100644 mlir/test/Dialect/Polynomial/ops.mlir
 create mode 100644 mlir/test/Dialect/Polynomial/ops_errors.mlir
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index 39b05b9d3ad14b..fa767649f649b6 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -102,6 +102,8 @@ class Polynomial {
 
   unsigned getDegree() const;
 
+  ArrayRef<Monomial> getTerms() const { return terms; }
+
   friend ::llvm::hash_code hash_value(const Polynomial &arg);
 
 private:
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 5d8da8399b01b5..89a1bd8a5bb30f 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -131,23 +131,137 @@ def Polynomial_PolynomialType : Polynomial_Type<"Polynomial", "polynomial"> {
   let assemblyFormat = "`<` $ring `>`";
 }
 
+def PolynomialLike: TypeOrContainer<Polynomial_PolynomialType, "polynomial-like">;
+
 class Polynomial_Op<string mnemonic, list<Trait> traits = []> :
-    Op<Polynomial_Dialect, mnemonic, traits # [Pure]>;
+    Op<Polynomial_Dialect, mnemonic, traits # [Pure]> {
+  let assemblyFormat = [{
+    operands attr-dict `:` `(` qualified(type(operands)) `)` `->` qualified(type(results))
+  }];
+}
 
 class Polynomial_UnaryOp<string mnemonic, list<Trait> traits = []> :
     Polynomial_Op<mnemonic, traits # [SameOperandsAndResultType]> {
   let arguments = (ins Polynomial_PolynomialType:$operand);
   let results = (outs Polynomial_PolynomialType:$result);
-
-  let assemblyFormat = "$operand attr-dict `:` qualified(type($result))";
 }
 
 class Polynomial_BinaryOp<string mnemonic, list<Trait> traits = []> :
-    Polynomial_Op<mnemonic, traits # [SameOperandsAndResultType]> {
-  let arguments = (ins Polynomial_PolynomialType:$lhs, Polynomial_PolynomialType:$rhs);
-  let results = (outs Polynomial_PolynomialType:$result);
+    Polynomial_Op<mnemonic, !listconcat(traits, [Pure, SameOperandsAndResultType, ElementwiseMappable])> {
+  let arguments = (ins PolynomialLike:$lhs, PolynomialLike:$rhs);
+  let results = (outs PolynomialLike:$result);
+  let assemblyFormat = "operands attr-dict `:` qualified(type($result))";
+}
+
+def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
+  let summary = "Addition operation between polynomials.";
+}
+
+def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
+  let summary = "Subtraction operation between polynomials.";
+}
+
+def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
+  let summary = "Multiplication operation between polynomials.";
+}
+
+def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
+      ElementwiseMappable, AllTypesMatch<["polynomial", "output"]>]> {
+  let summary = "Multiplication by a scalar of the field.";
+
+  let arguments = (ins
+    PolynomialLike:$polynomial,
+    AnyInteger:$scalar
+  );
+
+  let results = (outs
+    PolynomialLike:$output
+  );
+
+  let assemblyFormat = "operands attr-dict `:` qualified(type($polynomial)) `,` type($scalar)";
+}
+
+def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
+  let summary = "Compute the leading term of the polynomial.";
+  let description = [{
+    The degree of a polynomial is the largest $k$ for which the coefficient
+    $a_k$ of $x^k$ is nonzero. The leading term is the term $a_k x^k$, which
+    this op represents as a pair of results.
+  }];
+  let arguments = (ins Polynomial_PolynomialType:$input);
+  let results = (outs Index:$degree, AnyInteger:$coefficient);
+  let assemblyFormat = "operands attr-dict `:` qualified(type($input)) `->` `(` type($degree) `,` type($coefficient) `)`";
+}
+
+def Polynomial_MonomialOp: Polynomial_Op<"monomial"> {
+  let summary = "Create a polynomial that consists of a single monomial.";
+  let arguments = (ins AnyInteger:$coefficient, Index:$degree);
+  let results = (outs Polynomial_PolynomialType:$output);
+}
+
+def Polynomial_MonomialMulOp: Polynomial_Op<"monomial_mul", [AllTypesMatch<["input", "output"]>]> {
+  let summary = "Multiply a polynomial by a monic monomial.";
+  let description = [{
+    In the ring of polynomials mod $x^n - 1$, `monomial_mul` can be interpreted
+    as a cyclic shift of the coefficients of the polynomial. For some rings,
+    this results in optimized lowerings that involve rotations and rescaling
+    of the coefficients of the input.
+  }];
+  let arguments = (ins Polynomial_PolynomialType:$input, Index:$monomialDegree);
+  let results = (outs Polynomial_PolynomialType:$output);
+  let hasVerifier = 1;
+}
+
+def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> {
+  let summary = "Creates a polynomial from integer coefficients stored in a tensor.";
+  let description = [{
+    `polynomial.from_tensor` creates a polynomial value from a tensor of coefficients.
+    The input tensor must list the coefficients in degree-increasing order.
+
+    The input one-dimensional tensor may have size at most the degree of the
+    ring's ideal generator polynomial, with smaller dimension implying that
+    all higher-degree terms have coefficient zero.
+  }];
+  let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
+  let results = (outs Polynomial_PolynomialType:$output);
+
+  let assemblyFormat = "$input attr-dict `:` type($input) `->` qualified(type($output))";
+
+  let builders = [
+    // Builder that infers coefficient modulus from tensor bit width,
+    // and uses whatever input ring is provided by the caller.
+    OpBuilder<(ins "::mlir::Value":$input, "RingAttr":$ring)>
+  ];
+  let hasVerifier = 1;
+}
+
+def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
+  let summary = "Creates a tensor containing the coefficients of a polynomial.";
+  let description = [{
+    `polynomial.to_tensor` creates a tensor value containing the coefficients of the
+    input polynomial. The output tensor contains the coefficients in
+    degree-increasing order.
+
+    Operations that act on the coefficients of a polynomial, such as extracting
+    a specific coefficient or extracting a range of coefficients, should be
+    implemented by composing `to_tensor` with the relevant `tensor` dialect
+    ops.
+
+    The output tensor has shape equal to the degree of the ring's ideal
+    generator polynomial, including zeroes.
+  }];
+  let arguments = (ins Polynomial_PolynomialType:$input);
+  let results = (outs RankedTensorOf<[AnyInteger]>:$output);
+  let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+
+  let hasVerifier = 1;
+}
 
-  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($result))";
+def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
+  let summary = "Define a constant polynomial via an attribute.";
+  let arguments = (ins Polynomial_PolynomialAttr:$input);
+  let results = (outs Polynomial_PolynomialType:$output);
+  let assemblyFormat = "$input attr-dict `:` qualified(type($output))";
 }
 
 #endif // POLYNOMIAL_OPS
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp
index a672a59b8a465d..825b80d70f8032 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp
@@ -8,9 +8,18 @@
 
 #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
 #include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APInt.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 96c59a28b8fdce..077583150177b9 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -6,10 +6,90 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
 #include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
+#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APInt.h"
 
 using namespace mlir;
 using namespace mlir::polynomial;
 
-#define GET_OP_CLASSES
-#include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc"
+void FromTensorOp::build(OpBuilder &builder, OperationState &result,
+                         Value input, RingAttr ring) {
+  TensorType tensorType = dyn_cast<TensorType>(input.getType());
+  auto bitWidth = tensorType.getElementTypeBitWidth();
+  APInt cmod(1 + bitWidth, 1);
+  cmod = cmod << bitWidth;
+  Type resultType = PolynomialType::get(builder.getContext(), ring);
+  build(builder, result, resultType, input);
+}
+
+LogicalResult FromTensorOp::verify() {
+  auto tensorShape = getInput().getType().getShape();
+  auto ring = getOutput().getType().getRing();
+  auto polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
+  bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
+  if (!compatible) {
+    return emitOpError()
+           << "input type " << getInput().getType()
+           << " does not match output type " << getOutput().getType()
+           << ". The input type must be a tensor of shape [d] where d "
+              "is at most the degree of the polynomialModulus of "
+              "the output type's ring attribute.";
+  }
+
+  APInt coefficientModulus = ring.getCoefficientModulus().getValue();
+  unsigned cmodBitWidth = coefficientModulus.ceilLogBase2();
+  unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
+
+  if (inputBitWidth > cmodBitWidth) {
+    return emitOpError() << "input tensor element type "
+                         << getInput().getType().getElementType()
+                         << " is too large to fit in the coefficients of "
+                         << getOutput().getType()
+                         << ". The input tensor's elements must be rescaled"
+                            " to fit before using from_tensor.";
+  }
+
+  return success();
+}
+
+LogicalResult ToTensorOp::verify() {
+  auto tensorShape = getOutput().getType().getShape();
+  auto polyDegree = getInput()
+                        .getType()
+                        .getRing()
+                        .getPolynomialModulus()
+                        .getPolynomial()
+                        .getDegree();
+  bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
+
+  return compatible
+             ? success()
+             : emitOpError()
+                   << "input type " << getInput().getType()
+                   << " does not match output type " << getOutput().getType()
+                   << ". The input type must be a tensor of shape [d] where d "
+                      "is exactly the degree of the polynomialModulus of "
+                      "the output type's ring attribute.";
+}
+
+LogicalResult MonomialMulOp::verify() {
+  auto ring = getInput().getType().getRing();
+  auto idealTerms = ring.getPolynomialModulus().getPolynomial().getTerms();
+  bool compatible =
+      idealTerms.size() == 2 &&
+      (idealTerms[0].coefficient == -1 && idealTerms[0].exponent == 0) &&
+      (idealTerms[1].coefficient == 1);
+
+  return compatible ? success()
+                    : emitOpError()
+                          << "ring type " << ring
+                          << " is not supported yet. The ring "
+                             "must be of the form (x^n - 1) for some n";
+}
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
new file mode 100644
index 00000000000000..dae14471344ec3
--- /dev/null
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// This simply tests for syntax.
+
+#my_poly = #polynomial.polynomial<1 + x**1024>
+#my_poly_2 = #polynomial.polynomial<2>
+#my_poly_3 = #polynomial.polynomial<3x>
+#my_poly_4 = #polynomial.polynomial<t**3 + 4t + 2>
+#ring1 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly>
+#one_plus_x_squared = #polynomial.polynomial<1 + x**2>
+
+#ideal = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=18, polynomialModulus=#ideal>
+!poly_ty = !polynomial.polynomial<#ring>
+
+module {
+  func.func @test_multiply() -> !polynomial.polynomial<#ring1> {
+    %c0 = arith.constant 0 : index
+    %two = arith.constant 2 : i16
+    %five = arith.constant 5 : i16
+    %coeffs1 = tensor.from_elements %two, %two, %five : tensor<3xi16>
+    %coeffs2 = tensor.from_elements %five, %five, %two : tensor<3xi16>
+
+    %poly1 = polynomial.from_tensor %coeffs1 : tensor<3xi16> -> !polynomial.polynomial<#ring1>
+    %poly2 = polynomial.from_tensor %coeffs2 : tensor<3xi16> -> !polynomial.polynomial<#ring1>
+
+    %3 = polynomial.mul %poly1, %poly2 : !polynomial.polynomial<#ring1>
+
+    return %3 : !polynomial.polynomial<#ring1>
+  }
+
+  func.func @test_elementwise(%p0 : !polynomial.polynomial<#ring1>, %p1: !polynomial.polynomial<#ring1>) {
+    %tp0 = tensor.from_elements %p0, %p1 : tensor<2x!polynomial.polynomial<#ring1>>
+    %tp1 = tensor.from_elements %p1, %p0 : tensor<2x!polynomial.polynomial<#ring1>>
+
+    %c = arith.constant 2 : i32
+    %mul_const_sclr = polynomial.mul_scalar %tp0, %c : tensor<2x!polynomial.polynomial<#ring1>>, i32
+
+    %add = polynomial.add %tp0, %tp1 : tensor<2x!polynomial.polynomial<#ring1>>
+    %sub = polynomial.sub %tp0, %tp1 : tensor<2x!polynomial.polynomial<#ring1>>
+    %mul = polynomial.mul %tp0, %tp1 : tensor<2x!polynomial.polynomial<#ring1>>
+
+    return
+  }
+
+  func.func @test_to_from_tensor(%p0 : !polynomial.polynomial<#ring1>) {
+    %c0 = arith.constant 0 : index
+    %two = arith.constant 2 : i16
+    %coeffs1 = tensor.from_elements %two, %two : tensor<2xi16>
+    // CHECK: from_tensor
+    %poly = polynomial.from_tensor %coeffs1 : tensor<2xi16> -> !polynomial.polynomial<#ring1>
+    // CHECK: to_tensor
+    %tensor = polynomial.to_tensor %poly : !polynomial.polynomial<#ring1> -> tensor<1024xi16>
+
+    return
+  }
+
+  func.func @test_degree(%p0 : !polynomial.polynomial<#ring1>) {
+    %0, %1 = polynomial.leading_term %p0 : !polynomial.polynomial<#ring1> -> (index, i32)
+    return
+  }
+
+  func.func @test_monomial() {
+    %deg = arith.constant 1023 : index
+    %five = arith.constant 5 : i16
+    %0 = polynomial.monomial %five, %deg : (i16, index) -> !polynomial.polynomial<#ring1>
+    return
+  }
+
+  func.func @test_constant() {
+    %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<#ring1>
+    %1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
+    return
+  }
+}
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
new file mode 100644
index 00000000000000..e536571eb8f3ee
--- /dev/null
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt --verify-diagnostics %s
+
+#my_poly = #polynomial.polynomial<1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+module {
+  func.func @test_from_tensor_too_large_coeffs() {
+    %two = arith.constant 2 : i32
+    %coeffs1 = tensor.from_elements %two, %two : tensor<2xi32>
+    // expected-error at below {{is too large to fit in the coefficients}}
+    %poly = polynomial.from_tensor %coeffs1 : tensor<2xi32> -> !polynomial.polynomial<#ring>
+    return
+  }
+}
>From dadda267dd1c1152e0bd81215efc06ceb46b384a Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 22 Apr 2024 10:49:32 -0700
Subject: [PATCH 02/16] simplify assembly format
---
 mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 89a1bd8a5bb30f..eeb83c7e4d6964 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -135,9 +135,7 @@ def PolynomialLike: TypeOrContainer<Polynomial_PolynomialType, "polynomial-like"
 
 class Polynomial_Op<string mnemonic, list<Trait> traits = []> :
     Op<Polynomial_Dialect, mnemonic, traits # [Pure]> {
-  let assemblyFormat = [{
-    operands attr-dict `:` `(` qualified(type(operands)) `)` `->` qualified(type(results))
-  }];
+  let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
 class Polynomial_UnaryOp<string mnemonic, list<Trait> traits = []> :
>From 8810618d0b5a5b81f027bffde474d25edb7a69e6 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 22 Apr 2024 10:50:47 -0700
Subject: [PATCH 03/16] use qualified type in builder
---
 mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index eeb83c7e4d6964..d30a210f698303 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -228,7 +228,7 @@ def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> {
   let builders = [
     // Builder that infers coefficient modulus from tensor bit width,
     // and uses whatever input ring is provided by the caller.
-    OpBuilder<(ins "::mlir::Value":$input, "RingAttr":$ring)>
+    OpBuilder<(ins "::mlir::Value":$input, "::mlir::polynomial::RingAttr":$ring)>
   ];
   let hasVerifier = 1;
 }
>From 8c0f83bbcecbff2f20622ccf5892302fd393b1e3 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 22 Apr 2024 10:52:44 -0700
Subject: [PATCH 04/16] expand auto
---
 mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 077583150177b9..2ae815d5ecf147 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -30,9 +30,9 @@ void FromTensorOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult FromTensorOp::verify() {
-  auto tensorShape = getInput().getType().getShape();
-  auto ring = getOutput().getType().getRing();
-  auto polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
+  ArrayRef<int64_t> tensorShape = getInput().getType().getShape();
+  RingAttr ring = getOutput().getType().getRing();
+  unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
   bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
   if (!compatible) {
     return emitOpError()
@@ -60,8 +60,8 @@ LogicalResult FromTensorOp::verify() {
 }
 
 LogicalResult ToTensorOp::verify() {
-  auto tensorShape = getOutput().getType().getShape();
-  auto polyDegree = getInput()
+  ArrayRef<int64_t> tensorShape = getOutput().getType().getShape();
+  unsigned polyDegree = getInput()
                         .getType()
                         .getRing()
                         .getPolynomialModulus()
@@ -80,8 +80,8 @@ LogicalResult ToTensorOp::verify() {
 }
 
 LogicalResult MonomialMulOp::verify() {
-  auto ring = getInput().getType().getRing();
-  auto idealTerms = ring.getPolynomialModulus().getPolynomial().getTerms();
+  RingAttr ring = getInput().getType().getRing();
+  ArrayRef<Monomial> idealTerms = ring.getPolynomialModulus().getPolynomial().getTerms();
   bool compatible =
       idealTerms.size() == 2 &&
       (idealTerms[0].coefficient == -1 && idealTerms[0].exponent == 0) &&
>From a7b525043fab9fc8ba0d7a3e375973472fdedc74 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 22 Apr 2024 10:56:38 -0700
Subject: [PATCH 05/16] attachNotes to error messages
---
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 72 +++++++++++--------
 mlir/test/Dialect/Polynomial/ops_errors.mlir  |  1 +
 2 files changed, 42 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 2ae815d5ecf147..e6d45a7066ba93 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -35,12 +35,14 @@ LogicalResult FromTensorOp::verify() {
   unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
   bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
   if (!compatible) {
-    return emitOpError()
-           << "input type " << getInput().getType()
-           << " does not match output type " << getOutput().getType()
-           << ". The input type must be a tensor of shape [d] where d "
-              "is at most the degree of the polynomialModulus of "
-              "the output type's ring attribute.";
+    InFlightDiagnostic diag = emitOpError()
+                              << "input type " << getInput().getType()
+                              << " does not match output type "
+                              << getOutput().getType();
+    diag.attachNote() << "The input type must be a tensor of shape [d] where d "
+                         "is at most the degree of the polynomialModulus of "
+                         "the output type's ring attribute.";
+    return diag;
   }
 
   APInt coefficientModulus = ring.getCoefficientModulus().getValue();
@@ -48,12 +50,14 @@ LogicalResult FromTensorOp::verify() {
   unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
 
   if (inputBitWidth > cmodBitWidth) {
-    return emitOpError() << "input tensor element type "
-                         << getInput().getType().getElementType()
-                         << " is too large to fit in the coefficients of "
-                         << getOutput().getType()
-                         << ". The input tensor's elements must be rescaled"
-                            " to fit before using from_tensor.";
+    InFlightDiagnostic diag = emitOpError()
+                              << "input tensor element type "
+                              << getInput().getType().getElementType()
+                              << " is too large to fit in the coefficients of "
+                              << getOutput().getType();
+    diag.attachNote() << "The input tensor's elements must be rescaled"
+                         " to fit before using from_tensor.";
+    return diag;
   }
 
   return success();
@@ -62,34 +66,40 @@ LogicalResult FromTensorOp::verify() {
 LogicalResult ToTensorOp::verify() {
   ArrayRef<int64_t> tensorShape = getOutput().getType().getShape();
   unsigned polyDegree = getInput()
-                        .getType()
-                        .getRing()
-                        .getPolynomialModulus()
-                        .getPolynomial()
-                        .getDegree();
+                            .getType()
+                            .getRing()
+                            .getPolynomialModulus()
+                            .getPolynomial()
+                            .getDegree();
   bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
 
-  return compatible
-             ? success()
-             : emitOpError()
-                   << "input type " << getInput().getType()
-                   << " does not match output type " << getOutput().getType()
-                   << ". The input type must be a tensor of shape [d] where d "
-                      "is exactly the degree of the polynomialModulus of "
-                      "the output type's ring attribute.";
+  if (!compatible) {
+    InFlightDiagnostic diag = emitOpError()
+                              << "input type " << getInput().getType()
+                              << " does not match output type "
+                              << getOutput().getType();
+    diag.attachNote() << "The input type must be a tensor of shape [d] where d "
+                         "is at most the degree of the polynomialModulus of "
+                         "the output type's ring attribute.";
+    return diag;
+  }
+  return success();
 }
 
 LogicalResult MonomialMulOp::verify() {
   RingAttr ring = getInput().getType().getRing();
-  ArrayRef<Monomial> idealTerms = ring.getPolynomialModulus().getPolynomial().getTerms();
+  ArrayRef<Monomial> idealTerms =
+      ring.getPolynomialModulus().getPolynomial().getTerms();
   bool compatible =
       idealTerms.size() == 2 &&
       (idealTerms[0].coefficient == -1 && idealTerms[0].exponent == 0) &&
       (idealTerms[1].coefficient == 1);
 
-  return compatible ? success()
-                    : emitOpError()
-                          << "ring type " << ring
-                          << " is not supported yet. The ring "
-                             "must be of the form (x^n - 1) for some n";
+  if (!compatible) {
+    InFlightDiagnostic diag = emitOpError()
+                              << "unsupported ring type: " << ring;
+    diag.attachNote() << "The ring must be of the form (x^n - 1) for some n";
+    return diag;
+  }
+  return success();
 }
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index e536571eb8f3ee..635d54bda45b3d 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -7,6 +7,7 @@ module {
     %two = arith.constant 2 : i32
     %coeffs1 = tensor.from_elements %two, %two : tensor<2xi32>
     // expected-error at below {{is too large to fit in the coefficients}}
+    // expected-note at below {{rescaled to fit}}
     %poly = polynomial.from_tensor %coeffs1 : tensor<2xi32> -> !polynomial.polynomial<#ring>
     return
   }
>From d8b75f9ed59ea244537bc75589f85a5c6abe8a52 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 22 Apr 2024 10:58:48 -0700
Subject: [PATCH 06/16] early return on success where applicable
---
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 34 +++++++++----------
 1 file changed, 16 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index e6d45a7066ba93..07bb6d32836660 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -73,17 +73,16 @@ LogicalResult ToTensorOp::verify() {
                             .getDegree();
   bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
 
-  if (!compatible) {
-    InFlightDiagnostic diag = emitOpError()
-                              << "input type " << getInput().getType()
-                              << " does not match output type "
-                              << getOutput().getType();
-    diag.attachNote() << "The input type must be a tensor of shape [d] where d "
-                         "is at most the degree of the polynomialModulus of "
-                         "the output type's ring attribute.";
-    return diag;
-  }
-  return success();
+  if (compatible)
+    return success();
+
+  InFlightDiagnostic diag =
+      emitOpError() << "input type " << getInput().getType()
+                    << " does not match output type " << getOutput().getType();
+  diag.attachNote() << "The input type must be a tensor of shape [d] where d "
+                       "is at most the degree of the polynomialModulus of "
+                       "the output type's ring attribute.";
+  return diag;
 }
 
 LogicalResult MonomialMulOp::verify() {
@@ -95,11 +94,10 @@ LogicalResult MonomialMulOp::verify() {
       (idealTerms[0].coefficient == -1 && idealTerms[0].exponent == 0) &&
       (idealTerms[1].coefficient == 1);
 
-  if (!compatible) {
-    InFlightDiagnostic diag = emitOpError()
-                              << "unsupported ring type: " << ring;
-    diag.attachNote() << "The ring must be of the form (x^n - 1) for some n";
-    return diag;
-  }
-  return success();
+  if (compatible)
+    return success();
+
+  InFlightDiagnostic diag = emitOpError() << "unsupported ring type: " << ring;
+  diag.attachNote() << "The ring must be of the form (x^n - 1) for some n";
+  return diag;
 }
>From 8ab2154225720fef81d01098caed3ce6b09327bf Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 22 Apr 2024 11:51:50 -0700
Subject: [PATCH 07/16] make a pass over documentation
---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  | 200 ++++++++++++++++--
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   |   2 +-
 mlir/test/Dialect/Polynomial/ops.mlir         |   7 +
 3 files changed, 185 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index d30a210f698303..38866d6879b5b8 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -35,18 +35,18 @@ def Polynomial_Dialect : Dialect {
 
     ```mlir
     // A constant polynomial in a ring with i32 coefficients and no polynomial modulus
-    #ring = #polynomial.ring<ctype=i32>
+    #ring = #polynomial.ring<coefficientType=i32>
     %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
 
     // A constant polynomial in a ring with i32 coefficients, modulo (x^1024 + 1)
     #modulus = #polynomial.polynomial<1 + x**1024>
-    #ring = #polynomial.ring<ctype=i32, ideal=#modulus>
+    #ring = #polynomial.ring<coefficientType=i32, polynomialModulus=#modulus>
     %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
 
     // A constant polynomial in a ring with i32 coefficients, with a polynomial
     // modulus of (x^1024 + 1) and a coefficient modulus of 17.
     #modulus = #polynomial.polynomial<1 + x**1024>
-    #ring = #polynomial.ring<ctype=i32, cmod=17, ideal=#modulus>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=17, polynomialModulus=#modulus>
     %a = polynomial.constant <1 + x**2 - 3x**3> : polynomial.polynomial<#ring>
     ```
   }];
@@ -63,7 +63,21 @@ class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
 def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
   let summary = "An attribute containing a single-variable polynomial.";
   let description = [{
-     #poly = #polynomial.poly<x**1024 + 1>
+    A polynomial attribute represents a single-variable polynomial, which
+    is used to define the modulus of a `RingAttr`, as well as to define constants
+    and perform constant folding for `polynomial` ops.
+
+    The polynomial must be expressed as a list of monomial terms, with addition
+    or subtraction between them. The choice of variable name is arbitrary, but
+    must be consistent across all the monomials used to define a single
+    attribute. The order of monomial terms is arbitrary, each monomial degree
+    must occur at most once.
+
+    Example:
+
+    ```mlir
+    #poly = #polynomial.polynomial<x**1024 + 1>
+    ```
   }];
   let parameters = (ins "Polynomial":$polynomial);
   let hasCustomAssemblyFormat = 1;
@@ -79,10 +93,10 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
     integral, whose coefficients are taken modulo some statically known modulus
     (`coefficientModulus`).
 
-    Additionally, a polynomial ring can specify an _ideal_, which converts
+    Additionally, a polynomial ring can specify an _polynomialModulus_, which converts
     polynomial arithmetic to the analogue of modular integer arithmetic, where
     each polynomial is represented as its remainder when dividing by the
-    modulus. For single-variable polynomials, an "ideal" is always specificed
+    modulus. For single-variable polynomials, an "polynomialModulus" is always specificed
     via a single polynomial, which we call `polynomialModulus`.
 
     An expressive example is polynomials with i32 coefficients, whose
@@ -122,11 +136,9 @@ class Polynomial_Type<string name, string typeMnemonic>
 
 def Polynomial_PolynomialType : Polynomial_Type<"Polynomial", "polynomial"> {
   let summary = "An element of a polynomial ring.";
-
   let description = [{
     A type for polynomials in a polynomial quotient ring.
   }];
-
   let parameters = (ins Polynomial_RingAttr:$ring);
   let assemblyFormat = "`<` $ring `>`";
 }
@@ -153,29 +165,107 @@ class Polynomial_BinaryOp<string mnemonic, list<Trait> traits = []> :
 
 def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
   let summary = "Addition operation between polynomials.";
+  let description = [{
+    Performs polynomial addition on the operands. The operands may be single
+    polynomials or containers of identically-typed polynomials, i.e., polynomials
+    from the same underlying ring with the same coefficient types.
+
+    Addition is defined to occur in the ring defined by the ring attribute of
+    the two operands, meaning the addition is taken modulo the coefficientModulus
+    and the polynomialModulus of the ring.
+
+    Example:
+
+    ```mlir
+    // add two polynomials modulo x^1024 - 1
+    #poly = #polynomial.polynomial<x**1024 - 1>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %1 = polynomial.constant #polynomial.polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+    %2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring>
+    ```
+  }];
 }
 
 def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> {
   let summary = "Subtraction operation between polynomials.";
+  let description = [{
+    Performs polynomial subtraction on the operands. The operands may be single
+    polynomials or containers of identically-typed polynomials, i.e., polynomials
+    from the same underlying ring with the same coefficient types.
+
+    Subtraction is defined to occur in the ring defined by the ring attribute of
+    the two operands, meaning the subtraction is taken modulo the coefficientModulus
+    and the polynomialModulus of the ring.
+
+    Example:
+
+    ```mlir
+    // subtract two polynomials modulo x^1024 - 1
+    #poly = #polynomial.polynomial<x**1024 - 1>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %1 = polynomial.constant #polynomial.polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+    %2 = polynomial.sub %0, %1 : !polynomial.polynomial<#ring>
+    ```
+  }];
 }
 
 def Polynomial_MulOp : Polynomial_BinaryOp<"mul", [Commutative]> {
   let summary = "Multiplication operation between polynomials.";
+  let description = [{
+    Performs polynomial multiplication on the operands. The operands may be single
+    polynomials or containers of identically-typed polynomials, i.e., polynomials
+    from the same underlying ring with the same coefficient types.
+
+    Multiplication is defined to occur in the ring defined by the ring attribute of
+    the two operands, meaning the multiplication is taken modulo the coefficientModulus
+    and the polynomialModulus of the ring.
+
+    Example:
+
+    ```mlir
+    // multiply two polynomials modulo x^1024 - 1
+    #poly = #polynomial.polynomial<x**1024 - 1>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %1 = polynomial.constant #polynomial.polynomial<x**5 - x + 1> : !polynomial.polynomial<#ring>
+    %2 = polynomial.mul %0, %1 : !polynomial.polynomial<#ring>
+    ```
+  }];
 }
 
 def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
       ElementwiseMappable, AllTypesMatch<["polynomial", "output"]>]> {
   let summary = "Multiplication by a scalar of the field.";
+  let description = [{
+    Multiplies the polynomial operand's coefficients by a given scalar value.
+    The operation is defined to occur in the ring defined by the ring attribute
+    of the two operands, meaning the multiplication is taken modulo the
+    coefficientModulus of the ring.
+
+    The `scalar` input must have the same type as the polynomial ring's
+    coefficientType.
+
+    Example:
+
+    ```mlir
+    // multiply two polynomials modulo x^1024 - 1
+    #poly = #polynomial.polynomial<x**1024 - 1>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %1 = arith.constant 3 : i32
+    %2 = polynomial.mul_scalar %0, %1 : (!polynomial.polynomial<#ring>, i32) -> !polynomial.polynomial<#ring>
+    ```
+  }];
 
   let arguments = (ins
     PolynomialLike:$polynomial,
     AnyInteger:$scalar
   );
-
   let results = (outs
     PolynomialLike:$output
   );
-
   let assemblyFormat = "operands attr-dict `:` qualified(type($polynomial)) `,` type($scalar)";
 }
 
@@ -183,8 +273,19 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
   let summary = "Compute the leading term of the polynomial.";
   let description = [{
     The degree of a polynomial is the largest $k$ for which the coefficient
-    $a_k$ of $x^k$ is nonzero. The leading term is the term $a_k x^k$, which
-    this op represents as a pair of results.
+    `a_k` of `x^k` is nonzero. The leading term is the term `a_k * x^k`, which
+    this op represents as a pair of results. The first is the degree `k` as an
+    index, and the second is the coefficient, whose type matches the
+    coefficient type of the polynomial's ring attribute.
+
+    Example:
+
+    ```mlir
+    #poly = #polynomial.polynomial<x**1024 - 1>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    %1, %2 = polynomial.leading_term %0 : !polynomial.polynomial<#ring> -> (index, i32)
+    ```
   }];
   let arguments = (ins Polynomial_PolynomialType:$input);
   let results = (outs Index:$degree, AnyInteger:$coefficient);
@@ -193,17 +294,38 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
 
 def Polynomial_MonomialOp: Polynomial_Op<"monomial"> {
   let summary = "Create a polynomial that consists of a single monomial.";
+  let description = [{
+    Construct a polynomial that consists of a single monomial term, from its
+    degree and coefficient as dynamic inputs.
+
+    The coefficient type of the output polynomial's ring attribute must match
+    the `coefficient` input type.
+
+    Example:
+
+    ```mlir
+    #poly = #polynomial.polynomial<x**1024 - 1>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    %deg = arith.constant 1023 : index
+    %five = arith.constant 5 : i32
+    %0 = polynomial.monomial %five, %deg : (i32, index) -> !polynomial.polynomial<#ring>
+    ```
+  }];
   let arguments = (ins AnyInteger:$coefficient, Index:$degree);
   let results = (outs Polynomial_PolynomialType:$output);
 }
 
-def Polynomial_MonomialMulOp: Polynomial_Op<"monomial_mul", [AllTypesMatch<["input", "output"]>]> {
+def Polynomial_MonomialMulOp: Polynomial_Op<"monic_monomial_mul", [AllTypesMatch<["input", "output"]>]> {
   let summary = "Multiply a polynomial by a monic monomial.";
   let description = [{
-    In the ring of polynomials mod $x^n - 1$, `monomial_mul` can be interpreted
-    as a cyclic shift of the coefficients of the polynomial. For some rings,
-    this results in optimized lowerings that involve rotations and rescaling
-    of the coefficients of the input.
+    Multiply a polynomial by a monic monomial, meaning a polynomial of the form
+    `1 * x^k` for an index operand `k`.
+
+    In some special rings of polynomials, such as a ring of polynomials
+    modulo `x^n - 1`, `monomial_mul` can be interpreted as a cyclic shift of
+    the coefficients of the polynomial. For some rings, this results in
+    optimized lowerings that involve rotations and rescaling of the
+    coefficients of the input.
   }];
   let arguments = (ins Polynomial_PolynomialType:$input, Index:$monomialDegree);
   let results = (outs Polynomial_PolynomialType:$output);
@@ -217,8 +339,19 @@ def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> {
     The input tensor must list the coefficients in degree-increasing order.
 
     The input one-dimensional tensor may have size at most the degree of the
-    ring's ideal generator polynomial, with smaller dimension implying that
+    ring's polynomialModulus generator polynomial, with smaller dimension implying that
     all higher-degree terms have coefficient zero.
+
+    Example:
+
+    ```mlir
+    #poly = #polynomial.polynomial<x**1024 - 1>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    %two = arith.constant 2 : i32
+    %five = arith.constant 5 : i32
+    %coeffs = tensor.from_elements %two, %two, %five : tensor<3xi32>
+    %poly = polynomial.from_tensor %coeffs : tensor<3xi32> -> !polynomial.polynomial<#ring>
+    ```
   }];
   let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
   let results = (outs Polynomial_PolynomialType:$output);
@@ -236,17 +369,29 @@ def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> {
 def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
   let summary = "Creates a tensor containing the coefficients of a polynomial.";
   let description = [{
-    `polynomial.to_tensor` creates a tensor value containing the coefficients of the
-    input polynomial. The output tensor contains the coefficients in
-    degree-increasing order.
+    `polynomial.to_tensor` creates a dense tensor value containing the
+    coefficients of the input polynomial. The output tensor contains the
+    coefficients in degree-increasing order.
 
     Operations that act on the coefficients of a polynomial, such as extracting
     a specific coefficient or extracting a range of coefficients, should be
     implemented by composing `to_tensor` with the relevant `tensor` dialect
     ops.
 
-    The output tensor has shape equal to the degree of the ring's ideal
-    generator polynomial, including zeroes.
+    The output tensor has shape equal to the degree of the polynomial ring
+    attribute's polynomialModulus, including zeroes.
+
+    Example:
+
+    ```mlir
+    #poly = #polynomial.polynomial<x**1024 - 1>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    %two = arith.constant 2 : i32
+    %five = arith.constant 5 : i32
+    %coeffs = tensor.from_elements %two, %two, %five : tensor<3xi32>
+    %poly = polynomial.from_tensor %coeffs : tensor<3xi32> -> !polynomial.polynomial<#ring>
+    %tensor = polynomial.to_tensor %poly : !polynomial.polynomial<#ring> -> tensor<1024xi32>
+    ```
   }];
   let arguments = (ins Polynomial_PolynomialType:$input);
   let results = (outs RankedTensorOf<[AnyInteger]>:$output);
@@ -257,6 +402,15 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
 
 def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
   let summary = "Define a constant polynomial via an attribute.";
+  let description = [{
+    Example:
+
+    ```mlir
+    #poly = #polynomial.polynomial<x**1024 - 1>
+    #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
+    %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
+    ```
+  }];
   let arguments = (ins Polynomial_PolynomialAttr:$input);
   let results = (outs Polynomial_PolynomialType:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($output))";
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 07bb6d32836660..1387b69ab50772 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -85,7 +85,7 @@ LogicalResult ToTensorOp::verify() {
   return diag;
 }
 
-LogicalResult MonomialMulOp::verify() {
+LogicalResult MonicMonomialMulOp::verify() {
   RingAttr ring = getInput().getType().getRing();
   ArrayRef<Monomial> idealTerms =
       ring.getPolynomialModulus().getPolynomial().getTerms();
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index dae14471344ec3..ea1b279fa1ff96 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -67,6 +67,13 @@ module {
     return
   }
 
+  func.func @test_monic_monomial_mul() {
+    %five = arith.constant 5 : index
+    %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<#ring1>
+    %1 = polynomial.monic_monomial_mul %0, %five : (!polynomial.polynomial<#ring1>, index) -> !polynomial.polynomial<#ring1>
+    return
+  }
+
   func.func @test_constant() {
     %0 = polynomial.constant #one_plus_x_squared : !polynomial.polynomial<#ring1>
     %1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
>From ae3c4ece366265e6ed5672a994827891289b82c7 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 22 Apr 2024 13:57:02 -0700
Subject: [PATCH 08/16] fix monic_monomial_mul C++ op name and remove verifier
---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td    |  7 +++----
 .../lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 17 -----------------
 2 files changed, 3 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 38866d6879b5b8..5f042b0ae1f2b1 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -315,7 +315,7 @@ def Polynomial_MonomialOp: Polynomial_Op<"monomial"> {
   let results = (outs Polynomial_PolynomialType:$output);
 }
 
-def Polynomial_MonomialMulOp: Polynomial_Op<"monic_monomial_mul", [AllTypesMatch<["input", "output"]>]> {
+def Polynomial_MonicMonomialMulOp: Polynomial_Op<"monic_monomial_mul", [AllTypesMatch<["input", "output"]>]> {
   let summary = "Multiply a polynomial by a monic monomial.";
   let description = [{
     Multiply a polynomial by a monic monomial, meaning a polynomial of the form
@@ -327,9 +327,8 @@ def Polynomial_MonomialMulOp: Polynomial_Op<"monic_monomial_mul", [AllTypesMatch
     optimized lowerings that involve rotations and rescaling of the
     coefficients of the input.
   }];
-  let arguments = (ins Polynomial_PolynomialType:$input, Index:$monomialDegree);
-  let results = (outs Polynomial_PolynomialType:$output);
-  let hasVerifier = 1;
+  let arguments = (ins PolynomialLike:$input, Index:$monomialDegree);
+  let results = (outs PolynomialLike:$output);
 }
 
 def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> {
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 1387b69ab50772..e804ed9c999e14 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -84,20 +84,3 @@ LogicalResult ToTensorOp::verify() {
                        "the output type's ring attribute.";
   return diag;
 }
-
-LogicalResult MonicMonomialMulOp::verify() {
-  RingAttr ring = getInput().getType().getRing();
-  ArrayRef<Monomial> idealTerms =
-      ring.getPolynomialModulus().getPolynomial().getTerms();
-  bool compatible =
-      idealTerms.size() == 2 &&
-      (idealTerms[0].coefficient == -1 && idealTerms[0].exponent == 0) &&
-      (idealTerms[1].coefficient == 1);
-
-  if (compatible)
-    return success();
-
-  InFlightDiagnostic diag = emitOpError() << "unsupported ring type: " << ring;
-  diag.attachNote() << "The ring must be of the form (x^n - 1) for some n";
-  return diag;
-}
>From cfbb9b99425ec9e7f5860b250b4b4e2daecf97f7 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 22 Apr 2024 13:57:19 -0700
Subject: [PATCH 09/16] remove unimplemented monomial print method
---
 mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h | 3 ---
 1 file changed, 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
index fa767649f649b6..3325a6fa3f9fcf 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h
@@ -51,9 +51,6 @@ class Monomial {
     return (exponent.ult(other.exponent));
   }
 
-  // Prints polynomial to 'os'.
-  void print(raw_ostream &os) const;
-
   friend ::llvm::hash_code hash_value(const Monomial &arg);
 
 public:
>From 2d0de65f78bd6c6a8db36f5ac90351d485bc24c8 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 22 Apr 2024 13:58:32 -0700
Subject: [PATCH 10/16] add additional test for polynomial type parser
---
 mlir/test/Dialect/Polynomial/types.mlir | 14 ++++++++++++++
 1 file changed, 14 insertions(+)
diff --git a/mlir/test/Dialect/Polynomial/types.mlir b/mlir/test/Dialect/Polynomial/types.mlir
index 64b74d9d36bb1c..00296a36e890f9 100644
--- a/mlir/test/Dialect/Polynomial/types.mlir
+++ b/mlir/test/Dialect/Polynomial/types.mlir
@@ -40,3 +40,17 @@ func.func @test_non_x_variable_64_bit(%0: !ty2) -> !ty2 {
 func.func @test_linear_poly(%0: !ty3) -> !ty3 {
   return %0 : !ty3
 }
+
+// CHECK-LABEL: func @test_negative_leading_1
+// CHECK-SAME:  !polynomial.polynomial<
+// CHECK-SAME:    #polynomial.ring<
+// CHECK-SAME:       coefficientType=i32,
+// CHECK-SAME:       coefficientModulus=2837465 : i32,
+// CHECK-SAME:       polynomialModulus=#polynomial.polynomial<-1 + x**1024>>>
+#my_poly_4 = #polynomial.polynomial<-1 + x**1024>
+#ring4 = #polynomial.ring<coefficientType=i32, coefficientModulus=2837465, polynomialModulus=#my_poly_4>
+!ty4 = !polynomial.polynomial<#ring4>
+func.func @test_negative_leading_1(%0: !ty4) -> !ty4 {
+  return %0 : !ty4
+}
+
>From d92ef4df69457954cfb91522e7a82f92b0211d71 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 22 Apr 2024 14:46:31 -0700
Subject: [PATCH 11/16] add verifier for mul_scalar op
---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  |  3 +-
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 20 +++++++++++
 mlir/test/Dialect/Polynomial/ops_errors.mlir  | 33 +++++++++++++------
 3 files changed, 45 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 5f042b0ae1f2b1..81c9abc29c1599 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -255,7 +255,7 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
     #ring = #polynomial.ring<coefficientType=i32, coefficientModulus=65536, polynomialModulus=#poly>
     %0 = polynomial.constant #polynomial.polynomial<1 + x**2> : !polynomial.polynomial<#ring>
     %1 = arith.constant 3 : i32
-    %2 = polynomial.mul_scalar %0, %1 : (!polynomial.polynomial<#ring>, i32) -> !polynomial.polynomial<#ring>
+    %2 = polynomial.mul_scalar %0, %1 : !polynomial.polynomial<#ring>, i32
     ```
   }];
 
@@ -267,6 +267,7 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
     PolynomialLike:$output
   );
   let assemblyFormat = "operands attr-dict `:` qualified(type($polynomial)) `,` type($scalar)";
+  let hasVerifier = 1;
 }
 
 def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index e804ed9c999e14..427abefebccbac 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -84,3 +84,23 @@ LogicalResult ToTensorOp::verify() {
                        "the output type's ring attribute.";
   return diag;
 }
+
+LogicalResult MulScalarOp::verify() {
+  Type argType = getPolynomial().getType();
+  PolynomialType polyType;
+
+  if (auto shapedPolyType = dyn_cast<ShapedType>(argType)) {
+    polyType = dyn_cast<PolynomialType>(shapedPolyType.getElementType());
+  } else {
+    polyType = cast<PolynomialType>(argType);
+  }
+
+  Type coefficientType = polyType.getRing().getCoefficientType();
+
+  if (coefficientType != getScalar().getType())
+    return emitOpError() << "polynomial coefficient type " << coefficientType
+                         << " does not match scalar type "
+                         << getScalar().getType();
+
+  return success();
+}
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index 635d54bda45b3d..b0cf612e78d8d9 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -1,14 +1,27 @@
-// RUN: mlir-opt --verify-diagnostics %s
+// RUN: mlir-opt --split-input-file --verify-diagnostics %s
 
 #my_poly = #polynomial.polynomial<1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
-module {
-  func.func @test_from_tensor_too_large_coeffs() {
-    %two = arith.constant 2 : i32
-    %coeffs1 = tensor.from_elements %two, %two : tensor<2xi32>
-    // expected-error at below {{is too large to fit in the coefficients}}
-    // expected-note at below {{rescaled to fit}}
-    %poly = polynomial.from_tensor %coeffs1 : tensor<2xi32> -> !polynomial.polynomial<#ring>
-    return
-  }
+!ty = !polynomial.polynomial<#ring>
+
+func.func @test_from_tensor_too_large_coeffs() {
+  %two = arith.constant 2 : i32
+  %coeffs1 = tensor.from_elements %two, %two : tensor<2xi32>
+  // expected-error at below {{is too large to fit in the coefficients}}
+  // expected-note at below {{rescaled to fit}}
+  %poly = polynomial.from_tensor %coeffs1 : tensor<2xi32> -> !ty
+  return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+!ty = !polynomial.polynomial<#ring>
+
+func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
+  %scalar = arith.constant 2 : i32  // should be i16
+  // expected-error at below {{polynomial coefficient type 'i16' does not match scalar type 'i32'}}
+  %poly = polynomial.mul_scalar %arg0, %scalar : !ty, i32
+  return %poly : !ty
 }
>From 942d3e569f6ac6153417a04200f189765f841846 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Mon, 22 Apr 2024 14:57:35 -0700
Subject: [PATCH 12/16] remove qualified
---
 .../include/mlir/Dialect/Polynomial/IR/Polynomial.td | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 81c9abc29c1599..a7311d6d68d31e 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -160,7 +160,7 @@ class Polynomial_BinaryOp<string mnemonic, list<Trait> traits = []> :
     Polynomial_Op<mnemonic, !listconcat(traits, [Pure, SameOperandsAndResultType, ElementwiseMappable])> {
   let arguments = (ins PolynomialLike:$lhs, PolynomialLike:$rhs);
   let results = (outs PolynomialLike:$result);
-  let assemblyFormat = "operands attr-dict `:` qualified(type($result))";
+  let assemblyFormat = "operands attr-dict `:` type($result)";
 }
 
 def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> {
@@ -266,7 +266,7 @@ def Polynomial_MulScalarOp : Polynomial_Op<"mul_scalar", [
   let results = (outs
     PolynomialLike:$output
   );
-  let assemblyFormat = "operands attr-dict `:` qualified(type($polynomial)) `,` type($scalar)";
+  let assemblyFormat = "operands attr-dict `:` type($polynomial) `,` type($scalar)";
   let hasVerifier = 1;
 }
 
@@ -290,7 +290,7 @@ def Polynomial_LeadingTermOp: Polynomial_Op<"leading_term"> {
   }];
   let arguments = (ins Polynomial_PolynomialType:$input);
   let results = (outs Index:$degree, AnyInteger:$coefficient);
-  let assemblyFormat = "operands attr-dict `:` qualified(type($input)) `->` `(` type($degree) `,` type($coefficient) `)`";
+  let assemblyFormat = "operands attr-dict `:` type($input) `->` `(` type($degree) `,` type($coefficient) `)`";
 }
 
 def Polynomial_MonomialOp: Polynomial_Op<"monomial"> {
@@ -356,7 +356,7 @@ def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> {
   let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
   let results = (outs Polynomial_PolynomialType:$output);
 
-  let assemblyFormat = "$input attr-dict `:` type($input) `->` qualified(type($output))";
+  let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
 
   let builders = [
     // Builder that infers coefficient modulus from tensor bit width,
@@ -395,7 +395,7 @@ def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
   }];
   let arguments = (ins Polynomial_PolynomialType:$input);
   let results = (outs RankedTensorOf<[AnyInteger]>:$output);
-  let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+  let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
 
   let hasVerifier = 1;
 }
@@ -413,7 +413,7 @@ def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
   }];
   let arguments = (ins Polynomial_PolynomialAttr:$input);
   let results = (outs Polynomial_PolynomialType:$output);
-  let assemblyFormat = "$input attr-dict `:` qualified(type($output))";
+  let assemblyFormat = "$input attr-dict `:` type($output)";
 }
 
 #endif // POLYNOMIAL_OPS
>From 9d7a333452265bd61501733082313cac28af3cc9 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <kun.jeremy at gmail.com>
Date: Tue, 23 Apr 2024 09:52:44 -0700
Subject: [PATCH 13/16] an -> a
Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>
---
 mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index a7311d6d68d31e..d3e3ac55677f86 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -93,7 +93,7 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
     integral, whose coefficients are taken modulo some statically known modulus
     (`coefficientModulus`).
 
-    Additionally, a polynomial ring can specify an _polynomialModulus_, which converts
+    Additionally, a polynomial ring can specify a _polynomialModulus_, which converts
     polynomial arithmetic to the analogue of modular integer arithmetic, where
     each polynomial is represented as its remainder when dividing by the
     modulus. For single-variable polynomials, an "polynomialModulus" is always specificed
>From 17a775d4a035a1ce49bb30d13de9ea99cd23d42e Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Thu, 25 Apr 2024 10:11:45 -0700
Subject: [PATCH 14/16] decapitalize and remove period
---
 mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 427abefebccbac..4e2ffae6148a4d 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -39,9 +39,9 @@ LogicalResult FromTensorOp::verify() {
                               << "input type " << getInput().getType()
                               << " does not match output type "
                               << getOutput().getType();
-    diag.attachNote() << "The input type must be a tensor of shape [d] where d "
+    diag.attachNote() << "the input type must be a tensor of shape [d] where d "
                          "is at most the degree of the polynomialModulus of "
-                         "the output type's ring attribute.";
+                         "the output type's ring attribute";
     return diag;
   }
 
@@ -55,8 +55,8 @@ LogicalResult FromTensorOp::verify() {
                               << getInput().getType().getElementType()
                               << " is too large to fit in the coefficients of "
                               << getOutput().getType();
-    diag.attachNote() << "The input tensor's elements must be rescaled"
-                         " to fit before using from_tensor.";
+    diag.attachNote() << "the input tensor's elements must be rescaled"
+                         " to fit before using from_tensor";
     return diag;
   }
 
@@ -79,9 +79,9 @@ LogicalResult ToTensorOp::verify() {
   InFlightDiagnostic diag =
       emitOpError() << "input type " << getInput().getType()
                     << " does not match output type " << getOutput().getType();
-  diag.attachNote() << "The input type must be a tensor of shape [d] where d "
+  diag.attachNote() << "he input type must be a tensor of shape [d] where d "
                        "is at most the degree of the polynomialModulus of "
-                       "the output type's ring attribute.";
+                       "the output type's ring attribute";
   return diag;
 }
 
>From e9ac29d267c782adce6474e36ec510bef9dc7833 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Thu, 25 Apr 2024 10:12:34 -0700
Subject: [PATCH 15/16] dyn_cast -> cast
---
 mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 4e2ffae6148a4d..d86e977a273248 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -90,7 +90,7 @@ LogicalResult MulScalarOp::verify() {
   PolynomialType polyType;
 
   if (auto shapedPolyType = dyn_cast<ShapedType>(argType)) {
-    polyType = dyn_cast<PolynomialType>(shapedPolyType.getElementType());
+    polyType = cast<PolynomialType>(shapedPolyType.getElementType());
   } else {
     polyType = cast<PolynomialType>(argType);
   }
>From 937917eab5c8f0cd5c5c97893ce17ca00c224c4f Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Thu, 25 Apr 2024 10:45:43 -0700
Subject: [PATCH 16/16] test remaining diagnostics
---
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   |  4 +--
 mlir/test/Dialect/Polynomial/ops_errors.mlir  | 26 +++++++++++++++++++
 2 files changed, 28 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index d86e977a273248..8e2bb5f27dc6cc 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -79,9 +79,9 @@ LogicalResult ToTensorOp::verify() {
   InFlightDiagnostic diag =
       emitOpError() << "input type " << getInput().getType()
                     << " does not match output type " << getOutput().getType();
-  diag.attachNote() << "he input type must be a tensor of shape [d] where d "
+  diag.attachNote() << "the output type must be a tensor of shape [d] where d "
                        "is at most the degree of the polynomialModulus of "
-                       "the output type's ring attribute";
+                       "the input type's ring attribute";
   return diag;
 }
 
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index b0cf612e78d8d9..2b04486024d254 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -15,6 +15,32 @@ func.func @test_from_tensor_too_large_coeffs() {
 
 // -----
 
+#my_poly = #polynomial.polynomial<1 + x**4>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#my_poly>
+!ty = !polynomial.polynomial<#ring>
+func.func @test_from_tensor_wrong_tensor_type() {
+  %two = arith.constant 2 : i32
+  %coeffs1 = tensor.from_elements %two, %two, %two, %two, %two : tensor<5xi32>
+  // expected-error at below {{input type 'tensor<5xi32>' does not match output type 'polynomial.polynomial<#polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#polynomial.polynomial<1 + x**4>>>'}}
+  // expected-note at below {{at most the degree of the polynomialModulus of the output type's ring attribute}}
+  %poly = polynomial.from_tensor %coeffs1 : tensor<2xi32> -> !ty
+  return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<1 + x**4>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#my_poly>
+!ty = !polynomial.polynomial<#ring>
+func.func @test_to_tensor_wrong_output_tensor_type(%arg0 : !ty) {
+  // expected-error at below {{input type 'polynomial.polynomial<#polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#polynomial.polynomial<1 + x**4>>>' does not match output type 'tensor<5xi32>'}}
+  // expected-note at below {{at most the degree of the polynomialModulus of the input type's ring attribute}}
+  %tensor = polynomial.to_tensor %arg0 : !ty -> tensor<5xi32>
+  return
+}
+
+// -----
+
 #my_poly = #polynomial.polynomial<1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
 !ty = !polynomial.polynomial<#ring>
    
    
More information about the Mlir-commits
mailing list