[Mlir-commits] [mlir] [mlir][polynomial] Move primitive root attr to ring attr (PR #111931)

Hongren Zheng llvmlistbot at llvm.org
Thu Oct 10 19:06:22 PDT 2024


https://github.com/ZenithalHourlyRate created https://github.com/llvm/llvm-project/pull/111931

Related to https://github.com/llvm/llvm-project/pull/93227
and https://github.com/google/heir/issues/993

When ntt/intt ops are emitted as a result of pattern rewrite,
the primitive root attr must be provided in some way, and it
is convenient for it to be provided in ring attr.

As for using different primitive root for the same polynomial,
to_tensor/tensor.cast/from_tensor should be enough for changing
primitiveRoot attribute in RingAttr.

Cc @j2kun

>From 7e6e545d3b9c69d9c3bbd2eeb0d107b087a85c4b Mon Sep 17 00:00:00 2001
From: Zenithal <i at zenithal.me>
Date: Fri, 11 Oct 2024 02:01:45 +0000
Subject: [PATCH] [mlir][polynomial] Move primitive root attr to ring attr

Related to https://github.com/llvm/llvm-project/pull/93227
and https://github.com/google/heir/issues/993

When ntt/intt ops are emitted as a result of pattern rewrite,
the primitive root attr must be provided in some way, and it
is convenient for it to be provided in ring attr.

As for using different primitive root for the same polynomial,
to_tensor/tensor.cast/from_tensor should be enough for changing
primitiveRoot attribute in RingAttr.
---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  | 20 +++----
 .../Polynomial/IR/PolynomialAttributes.td     | 56 ++++++++++---------
 .../Polynomial/IR/PolynomialAttributes.cpp    |  3 +-
 .../IR/PolynomialCanonicalization.td          | 12 ++--
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 17 +++---
 .../Dialect/Polynomial/canonicalization.mlir  | 10 ++--
 mlir/test/Dialect/Polynomial/ops.mlir         | 11 ++--
 mlir/test/Dialect/Polynomial/ops_errors.mlir  | 37 +++++++++---
 8 files changed, 94 insertions(+), 72 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 755396c8b90235..63f9ff1def4e19 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -311,12 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
 
       `f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
 
-    The choice of primitive root may be optionally specified.
+    The choice of primitive root is specified in the primitiveRootAttr of RingAttr.
+    Its degree affects the behavior of ntt performed, with n-th primitive root
+    performing cyclic convolution and 2n-th primitive root performing negacyclic
+    convolution.
   }];
-  let arguments = (ins
-    Polynomial_PolynomialType:$input,
-    OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
-  );
+  let arguments = (ins Polynomial_PolynomialType:$input);
   let results = (outs RankedTensorOf<[AnyInteger]>:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
   let hasCanonicalizer = 1;
@@ -335,12 +335,12 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
     `polynomial.ntt`). The ring of the polynomial is taken from the required
     encoding attribute of the tensor.
 
-    The choice of primitive root may be optionally specified.
+    The choice of primitive root is specified in the primitiveRootAttr of RingAttr.
+    Its degree affects the behavior of ntt performed, with n-th primitive root
+    performing cyclic convolution and 2n-th primitive root performing negacyclic
+    convolution.
   }];
-  let arguments = (
-    ins RankedTensorOf<[AnyInteger]>:$input,
-    OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
-  );
+  let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
   let results = (outs Polynomial_PolynomialType:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
   let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
index 7d59add3d37c2b..00c9239fc6369d 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
@@ -126,6 +126,26 @@ def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
   }];
 }
 
+def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
+  let summary = "an attribute containing an integer and its degree as a root of unity";
+  let description = [{
+    A primitive root attribute stores an integer root `value` and an integer
+    `degree`, corresponding to a primitive root of unity of the given degree in
+    an unspecified ring.
+
+    Example:
+
+    ```mlir
+    #poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
+    ```
+  }];
+  let parameters = (ins
+    "::mlir::IntegerAttr":$value,
+    "::mlir::IntegerAttr":$degree
+  );
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
 def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
   let summary = "an attribute specifying a polynomial ring";
   let description = [{
@@ -142,6 +162,9 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
     modulus. For single-variable polynomials, an "polynomialModulus" is always specificed
     via a single polynomial, which we call `polynomialModulus`.
 
+    For ntt/intt and mul to ntt/intt optimization to work, an n-th or 2n-th
+    _primitiveRoot_ should be specified.
+
     An expressive example is polynomials with i32 coefficients, whose
     coefficients are taken modulo `2**32 - 5`, with a polynomial modulus of
     `x**1024 - 1`.
@@ -177,7 +200,8 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
   let parameters = (ins
     "Type": $coefficientType,
     OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
-    OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus
+    OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
+    OptionalParameter<"::mlir::polynomial::PrimitiveRootAttr">: $primitiveRoot
   );
   let genVerifyDecl = 1;
   let assemblyFormat = "`<` struct(params) `>`";
@@ -185,38 +209,16 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
     AttrBuilderWithInferredContext<
         (ins "::mlir::Type":$coefficientTy,
               CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
-              CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{
+              CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
+              CArg<"::mlir::polynomial::PrimitiveRootAttr", "nullptr"> :$primitiveRootAttr), [{
       return $_get(
         coefficientTy.getContext(),
         coefficientTy,
         coefficientModulusAttr,
-        polynomialModulusAttr);
+        polynomialModulusAttr,
+        primitiveRootAttr);
     }]>,
   ];
 }
 
-def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
-  let summary = "an attribute containing an integer and its degree as a root of unity";
-  let description = [{
-    A primitive root attribute stores an integer root `value` and an integer
-    `degree`, corresponding to a primitive root of unity of the given degree in
-    an unspecified ring.
-
-    This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops
-    to specify the root of unity used in lowering the transform.
-
-    Example:
-
-    ```mlir
-    #poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
-    ```
-  }];
-  let parameters = (ins
-    "::mlir::IntegerAttr":$value,
-    "::mlir::IntegerAttr":$degree
-  );
-  let assemblyFormat = "`<` struct(params) `>`";
-}
-
-
 #endif // POLYNOMIAL_ATTRIBUTES
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index cd7789a2e9531c..f3f6afdee9950c 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -206,7 +206,8 @@ Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
 LogicalResult
 RingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
                  Type coefficientType, IntegerAttr coefficientModulus,
-                 IntPolynomialAttr polynomialModulus) {
+                 IntPolynomialAttr polynomialModulus,
+                 PrimitiveRootAttr primitiveRoot) {
   if (coefficientModulus) {
     auto coeffIntType = llvm::dyn_cast<IntegerType>(coefficientType);
     if (!coeffIntType) {
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index 28c45e6846380c..a26b34e29d561f 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -14,8 +14,6 @@ include "mlir/Dialect/Polynomial/IR/Polynomial.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/PatternBase.td"
 
-def Equal : Constraint<CPred<"$0 == $1">>;
-
 // Get a -1 integer attribute of the same type as the polynomial SSA value's
 // ring coefficient type.
 def getMinusOne
@@ -30,15 +28,13 @@ def SubAsAdd : Pat<
       (Arith_ConstantOp (getMinusOne $g))))>;
 
 def INTTAfterNTT : Pat<
-  (Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
-  (replaceWithValue $poly),
-  [(Equal $r1, $r2)]
+  (Polynomial_INTTOp (Polynomial_NTTOp $poly)),
+  (replaceWithValue $poly)
 >;
 
 def NTTAfterINTT : Pat<
-  (Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
-  (replaceWithValue $tensor),
-  [(Equal $r1, $r2)]
+  (Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
+  (replaceWithValue $tensor)
 >;
 
 #endif  // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 460ef17167e801..30a6a004c50aff 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -134,8 +134,7 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
 /// Verify that the types involved in an NTT or INTT operation are
 /// compatible.
 static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
-                                 RankedTensorType tensorType,
-                                 std::optional<PrimitiveRootAttr> root) {
+                                 RankedTensorType tensorType) {
   Attribute encoding = tensorType.getEncoding();
   if (!encoding) {
     return op->emitOpError()
@@ -166,9 +165,10 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
     return diag;
   }
 
-  if (root.has_value()) {
-    APInt rootValue = root.value().getValue().getValue();
-    APInt rootDegree = root.value().getDegree().getValue();
+  auto root = ring.getPrimitiveRoot();
+  if (root) {
+    APInt rootValue = root.getValue().getValue();
+    APInt rootDegree = root.getDegree().getValue();
     APInt cmod = ring.getCoefficientModulus().getValue();
     if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
       return op->emitOpError()
@@ -177,6 +177,9 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
              << "of unity mod " << cmod.getZExtValue()
              << ", with the specified degree " << rootDegree.getZExtValue();
     }
+  } else {
+    return op->emitOpError()
+           << "primitive root not provided but ntt/intt op called";
   }
 
   return success();
@@ -184,12 +187,12 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
 
 LogicalResult NTTOp::verify() {
   return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
-                     getOutput().getType(), getRoot());
+                     getOutput().getType());
 }
 
 LogicalResult INTTOp::verify() {
   return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
-                     getInput().getType(), getRoot());
+                     getInput().getType());
 }
 
 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index c0ee514daab645..5a517a5e1ed9b4 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -canonicalize %s | FileCheck %s
 #ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
 #root = #polynomial.primitive_root<value=31:i32, degree=8:index>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=#root>
 !ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 !tensor_ty = tensor<8xi32, #ntt_ring>
 
@@ -11,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty
   // CHECK-NOT: polynomial.ntt
   // CHECK-NOT: polynomial.intt
   // CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]]  : [[T]]
-  %t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
-  %p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
+  %p1 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
   %p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
   // CHECK: return %[[RESULT]] : [[T]]
   return %p2 : !ntt_poly_ty
@@ -24,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
   // CHECK-NOT: polynomial.intt
   // CHECK-NOT: polynomial.ntt
   // CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
-  %p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
-  %t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
+  %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
   %t2 = arith.addi %t1, %t1 : !tensor_ty
   // CHECK: return %[[RESULT]] : [[T]]
   return %t2 : !tensor_ty
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index faeb68a8b2c093..4998730c80c7ea 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -15,12 +15,13 @@
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 #ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
+#ntt_ring_root = #polynomial.primitive_root<value=31:i32, degree=8:index>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=#ntt_ring_root>
 !ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 
 #ntt_poly_2 = #polynomial.int_polynomial<1 + x**65536>
-#ntt_ring_2 = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#ntt_poly_2>
 #ntt_ring_2_root = #polynomial.primitive_root<value=283965:i32, degree=131072:i32>
+#ntt_ring_2 = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#ntt_poly_2, primitiveRoot=#ntt_ring_2_root>
 !ntt_poly_ty_2 = !polynomial.polynomial<ring=#ntt_ring_2>
 
 module {
@@ -96,17 +97,17 @@ module {
   }
 
   func.func @test_ntt(%0 : !ntt_poly_ty) {
-    %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+    %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
     return
   }
 
   func.func @test_ntt_with_overflowing_root(%0 : !ntt_poly_ty_2) {
-    %1 = polynomial.ntt %0 {root=#ntt_ring_2_root} : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2>
+    %1 = polynomial.ntt %0 : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2>
     return
   }
 
   func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
-    %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
+    %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
     return
   }
 }
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index 4937e17027afaa..003967e3f4228c 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -55,36 +55,39 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
+#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_ntt
 // CHECK-NOT: polynomial.ntt
 func.func @test_invalid_ntt(%0 : !poly_ty) {
   // expected-error at below {{expects a ring encoding to be provided to the tensor}}
-  %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32>
+  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
   return
 }
 
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
+#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_ntt
 // CHECK-NOT: polynomial.ntt
 func.func @test_invalid_ntt(%0 : !poly_ty) {
   // expected-error at below {{tensor encoding is not a ring attribute}}
-  %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32, #my_poly>
+  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
   return
 }
 
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
+#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
@@ -98,7 +101,8 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
+#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
@@ -106,7 +110,7 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
 func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
   // expected-error at below {{does not match output type}}
   // expected-note at below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
-  %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1025xi32, #ring> -> !poly_ty
+  %1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
   return
 }
 
@@ -114,13 +118,28 @@ func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
 
 #my_poly = #polynomial.int_polynomial<-1 + x**8>
 // A valid root is 31
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
+#root = #polynomial.primitive_root<value=32:i32, degree=8:index>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
 // CHECK-NOT: polynomial.intt
 func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
   // expected-error at below {{provided root 32 is not a primitive root of unity mod 256, with the specified degree 8}}
-  %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=32:i16, degree=8:index>} : tensor<8xi32, #ring> -> !poly_ty
+  %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
+  return
+}
+
+// -----
+
+#my_poly = #polynomial.int_polynomial<-1 + x**8>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
+!poly_ty = !polynomial.polynomial<ring=#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
+  // expected-error at below {{primitive root not provided but ntt/intt op called}}
+  %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
   return
 }



More information about the Mlir-commits mailing list