[Mlir-commits] [mlir] [polynomial] Move primitive root attribute to ntt/intt ops. (PR #93227)

Jeremy Kun llvmlistbot at llvm.org
Thu May 23 11:45:31 PDT 2024


https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/93227

>From 0a7431a678d8bf7ef5ed6c729ff9e1e4f2d60f69 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 23 May 2024 11:20:49 -0700
Subject: [PATCH 1/3] move primitive root attr to ntt/intt ops

---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  | 17 +++++---
 .../Polynomial/IR/PolynomialAttributes.td     | 33 ++++++++++++---
 .../IR/PolynomialCanonicalization.td          | 42 ++++++++++---------
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 40 +++++++++---------
 4 files changed, 80 insertions(+), 52 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index f99cbccd243ec..7c5365af5e06f 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -277,8 +277,7 @@ def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
   Polynomial_TypedIntPolynomialAttr
 ]>;
 
-// Not deriving from Polynomial_Op due to need for custom assembly format
-def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
+def Polynomial_ConstantOp : Polynomial_Op<"constant",
     [Pure, InferTypeOpAdaptor]> {
   let summary = "Define a constant polynomial via an attribute.";
   let description = [{
@@ -312,9 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
 
       `f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
 
-    The choice of primitive root is determined by subsequent lowerings.
+    The choice of primitive root may be optionally specified.
   }];
-  let arguments = (ins Polynomial_PolynomialType:$input);
+  let arguments = (ins
+    Polynomial_PolynomialType:$input,
+    OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
+  );
   let results = (outs RankedTensorOf<[AnyInteger]>:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
   let hasCanonicalizer = 1;
@@ -332,8 +334,13 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
     output polynomial at powers of a primitive `n`-th root of unity (see
     `polynomial.ntt`). The ring of the polynomial is taken from the required
     encoding attribute of the tensor.
+
+    The choice of primitive root may be optionally specified.
   }];
-  let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
+  let arguments = (
+    ins RankedTensorOf<[AnyInteger]>:$input,
+    OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
+  );
   let results = (outs Polynomial_PolynomialType:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
   let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
index 655020adf808b..2d3ed60a35fd9 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
@@ -166,24 +166,45 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
   let parameters = (ins
     "Type": $coefficientType,
     OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
-    OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
-    OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
+    OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus
   );
   let assemblyFormat = "`<` struct(params) `>`";
   let builders = [
     AttrBuilderWithInferredContext<
         (ins "::mlir::Type":$coefficientTy,
               CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
-              CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
-              CArg<"::mlir::IntegerAttr", "nullptr"> :$primitiveRootAttr), [{
+              CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{
       return $_get(
         coefficientTy.getContext(),
         coefficientTy,
         coefficientModulusAttr,
-        polynomialModulusAttr,
-        primitiveRootAttr);
+        polynomialModulusAttr);
     }]>,
   ];
 }
 
+def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
+  let summary = "an attribute containing an integer and its degree as a root of unity";
+  let description = [{
+    A primitive root attribute stores an integer root `value` and an integer
+    `degree`, corresponding to a primitive root of unity of the given degree in
+    an unspecified ring.
+
+    This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops
+    to specify the root of unity used in lowering the transform.
+
+    Example:
+
+    ```mlir
+    #poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
+    ```
+  }];
+  let parameters = (ins
+    "::mlir::IntegerAttr":$value,
+    "::mlir::IntegerAttr":$degree
+  );
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+
 #endif // POLYNOMIAL_ATTRIBUTES
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index e37bcf76a20f2..93ea6e4e43698 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -17,6 +17,8 @@ include "mlir/IR/PatternBase.td"
 
 defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
 
+def Equal : Constraint<CPred<"$0 == $1">>;
+
 // Get a -1 integer attribute of the same type as the polynomial SSA value's
 // ring coefficient type.
 def getMinusOne
@@ -31,15 +33,15 @@ def SubAsAdd : Pat<
       (Arith_ConstantOp (getMinusOne $g))))>;
 
 def INTTAfterNTT : Pat<
-  (Polynomial_INTTOp (Polynomial_NTTOp $poly)),
+  (Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
   (replaceWithValue $poly),
-  []
+  [(Equal $r1, $r2)]
 >;
 
 def NTTAfterINTT : Pat<
-  (Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
+  (Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
   (replaceWithValue $tensor),
-  []
+  [(Equal $r1, $r2)]
 >;
 
 // NTTs are expensive, and addition in coefficient or NTT domain should be
@@ -47,35 +49,35 @@ def NTTAfterINTT : Pat<
 // ntt(a) + ntt(b) -> ntt(a + b)
 def NTTOfAdd : Pat<
   (Arith_AddIOp
-    (Polynomial_NTTOp $p1),
-    (Polynomial_NTTOp $p2),
+    (Polynomial_NTTOp $p1, $r1),
+    (Polynomial_NTTOp $p2, $r2),
     $overflow),
-  (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2)),
-  []
+  (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2), $r1),
+  [(Equal $r1, $r2)]
 >;
 // intt(a) + intt(b) -> intt(a + b)
 def INTTOfAdd : Pat<
   (Polynomial_AddOp
-    (Polynomial_INTTOp $t1),
-    (Polynomial_INTTOp $t2)),
-  (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow)),
-  []
+    (Polynomial_INTTOp $t1, $r1),
+    (Polynomial_INTTOp $t2, $r2)),
+  (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow), $r1),
+  [(Equal $r1, $r2)]
 >;
 // repeated for sub
 def NTTOfSub : Pat<
   (Arith_SubIOp
-    (Polynomial_NTTOp $p1),
-    (Polynomial_NTTOp $p2),
+    (Polynomial_NTTOp $p1, $r1),
+    (Polynomial_NTTOp $p2, $r2),
     $overflow),
-  (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2)),
-  []
+  (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2), $r1),
+  [(Equal $r1, $r2)]
 >;
 def INTTOfSub : Pat<
   (Polynomial_SubOp
-    (Polynomial_INTTOp $t1),
-    (Polynomial_INTTOp $t2)),
-  (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow)),
-  []
+    (Polynomial_INTTOp $t1, $r1),
+    (Polynomial_INTTOp $t2, $r2)),
+  (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow), $r1),
+  [(Equal $r1, $r2)]
 >;
 
 #endif  // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 3d302797ce513..a39c3872ad6d5 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -108,14 +108,15 @@ LogicalResult MulScalarOp::verify() {
 }
 
 /// Test if a value is a primitive nth root of unity modulo cmod.
-bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
+bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
                                const APInt &cmod) {
   // Root bitwidth may be 1 less then cmod.
   APInt r = APInt(root).zext(cmod.getBitWidth());
   assert(r.ule(cmod) && "root must be less than cmod");
+  unsigned upperBound = n.getZExtValue();
 
   APInt a = r;
-  for (size_t k = 1; k < n; k++) {
+  for (size_t k = 1; k < upperBound; k++) {
     if (a.isOne())
       return false;
     a = (a * r).urem(cmod);
@@ -126,7 +127,8 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
 /// Verify that the types involved in an NTT or INTT operation are
 /// compatible.
 static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
-                                 RankedTensorType tensorType) {
+                                 RankedTensorType tensorType,
+                                 std::optional<PrimitiveRootAttr> root) {
   Attribute encoding = tensorType.getEncoding();
   if (!encoding) {
     return op->emitOpError()
@@ -157,33 +159,29 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
     return diag;
   }
 
-  if (!ring.getPrimitiveRoot()) {
-    return op->emitOpError()
-           << "ring type " << ring << " does not provide a primitive root "
-           << "of unity, which is required to express an NTT";
-  }
-
-  if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
-                                 ring.getCoefficientModulus().getValue())) {
-    return op->emitOpError()
-           << "ring type " << ring << " has a primitiveRoot attribute '"
-           << ring.getPrimitiveRoot()
-           << "' that is not a primitive root of the coefficient ring";
+  if (root.has_value()) {
+    APInt rootValue = root.value().getValue().getValue();
+    APInt rootDegree = root.value().getDegree().getValue();
+    APInt cmod = ring.getCoefficientModulus().getValue();
+    if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
+      return op->emitOpError()
+             << "provided root " << rootValue.getZExtValue() << " is not a primitive root "
+             << "of unity mod " << cmod.getZExtValue() << ", with the specified degree "
+             << rootDegree.getZExtValue();
+    }
   }
 
   return success();
 }
 
 LogicalResult NTTOp::verify() {
-  auto ring = getInput().getType().getRing();
-  auto tensorType = getOutput().getType();
-  return verifyNTTOp(this->getOperation(), ring, tensorType);
+  return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
+                     getOutput().getType(), getRoot());
 }
 
 LogicalResult INTTOp::verify() {
-  auto tensorType = getInput().getType();
-  auto ring = getOutput().getType().getRing();
-  return verifyNTTOp(this->getOperation(), ring, tensorType);
+  return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
+                     getInput().getType(), getRoot());
 }
 
 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {

>From cd7eedb41343a5c4afff1ccfd473868e0cc3b829 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 23 May 2024 11:29:56 -0700
Subject: [PATCH 2/3] update tests

---
 .../Dialect/Polynomial/canonicalization.mlir  | 35 +++++++++---------
 mlir/test/Dialect/Polynomial/ops.mlir         |  8 ++---
 mlir/test/Dialect/Polynomial/ops_errors.mlir  | 36 ++++++-------------
 3 files changed, 33 insertions(+), 46 deletions(-)

diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index 489d9ec2720d6..354b76e3d9669 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt -canonicalize %s | FileCheck %s
 #ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
+#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
 !ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 !tensor_ty = tensor<8xi32, #ntt_ring>
 
@@ -10,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty
   // CHECK-NOT: polynomial.ntt
   // CHECK-NOT: polynomial.intt
   // CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]]  : [[T]]
-  %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
-  %p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty
+  %t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
   %p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
   // CHECK: return %[[RESULT]] : [[T]]
   return %p2 : !ntt_poly_ty
@@ -23,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
   // CHECK-NOT: polynomial.intt
   // CHECK-NOT: polynomial.ntt
   // CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
-  %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
-  %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
+  %p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
   %t2 = arith.addi %t1, %t1 : !tensor_ty
   // CHECK: return %[[RESULT]] : [[T]]
   return %t2 : !tensor_ty
@@ -51,10 +52,10 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
 func.func @test_canonicalize_fold_add_through_ntt(
     %poly0 : !ntt_poly_ty,
     %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
-  %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
-  %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+  %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
   %a_plus_b = arith.addi %0, %1 : !tensor_ty
-  %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
   return %out : !ntt_poly_ty
 }
 
@@ -65,10 +66,10 @@ func.func @test_canonicalize_fold_add_through_ntt(
 func.func @test_canonicalize_fold_add_through_intt(
     %tensor0 : !tensor_ty,
     %tensor1 : !tensor_ty) -> !tensor_ty {
-  %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
-  %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+  %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
   %a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
-  %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+  %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
   return %out : !tensor_ty
 }
 
@@ -80,10 +81,10 @@ func.func @test_canonicalize_fold_add_through_intt(
 func.func @test_canonicalize_fold_sub_through_ntt(
     %poly0 : !ntt_poly_ty,
     %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
-  %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
-  %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+  %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
   %a_plus_b = arith.subi %0, %1 : !tensor_ty
-  %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
   return %out : !ntt_poly_ty
 }
 
@@ -94,9 +95,9 @@ func.func @test_canonicalize_fold_sub_through_ntt(
 func.func @test_canonicalize_fold_sub_through_intt(
     %tensor0 : !tensor_ty,
     %tensor1 : !tensor_ty) -> !tensor_ty {
-  %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
-  %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+  %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
   %a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
-  %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+  %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
   return %out : !tensor_ty
 }
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index 4716e37ff8852..8c134ab789d60 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -11,11 +11,11 @@
 #one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
 
 #ideal = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 #ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
 !ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 
 module {
@@ -91,12 +91,12 @@ module {
   }
 
   func.func @test_ntt(%0 : !ntt_poly_ty) {
-    %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+    %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
     return
   }
 
   func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
-    %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
+    %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
     return
   }
 }
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index af8e4aa5da862..f22b14897e98a 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -55,28 +55,28 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_ntt
 // CHECK-NOT: polynomial.ntt
 func.func @test_invalid_ntt(%0 : !poly_ty) {
   // expected-error at below {{expects a ring encoding to be provided to the tensor}}
-  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
+  %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32>
   return
 }
 
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_ntt
 // CHECK-NOT: polynomial.ntt
 func.func @test_invalid_ntt(%0 : !poly_ty) {
   // expected-error at below {{tensor encoding is not a ring attribute}}
-  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
+  %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32, #my_poly>
   return
 }
 
@@ -84,21 +84,21 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
 // CHECK-NOT: polynomial.intt
 func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
   // expected-error at below {{not equivalent to the polynomial ring}}
-  %1 = polynomial.intt %0 : tensor<1024xi32, #ring1> -> !poly_ty
+  %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1024xi32, #ring1> -> !poly_ty
   return
 }
 
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
@@ -106,21 +106,7 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
 func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
   // expected-error at below {{does not match output type}}
   // expected-note at below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
-  %1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
-  return
-}
-
-// -----
-
-#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-!poly_ty = !polynomial.polynomial<ring=#ring>
-
-// CHECK-NOT: @test_invalid_ntt
-// CHECK-NOT: polynomial.ntt
-func.func @test_invalid_ntt(%0 : !poly_ty) {
-  // expected-error at below {{does not provide a primitive root of unity, which is required to express an NTT}}
-  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
+  %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1025xi32, #ring> -> !poly_ty
   return
 }
 
@@ -128,13 +114,13 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
 
 #my_poly = #polynomial.int_polynomial<-1 + x**8>
 // A valid root is 31
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=32:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
 // CHECK-NOT: polynomial.intt
 func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
-  // expected-error at below {{has a primitiveRoot attribute '32 : i16' that is not a primitive root of the coefficient ring}}
-  %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
+  // expected-error at below {{provided root 32 is not a primitive root of unity mod 256, with the specified degree 8}}
+  %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=32:i16, degree=8:index>} : tensor<8xi32, #ring> -> !poly_ty
   return
 }

>From 9f289ea479641ae08f709ae0bccafed9b14cfbd5 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 23 May 2024 11:42:39 -0700
Subject: [PATCH 3/3] add test for root equality

---
 mlir/test/Dialect/Polynomial/canonicalization.mlir | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index 354b76e3d9669..b79938627e415 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -101,3 +101,17 @@ func.func @test_canonicalize_fold_sub_through_intt(
   %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
   return %out : !tensor_ty
 }
+
+
+// CHECK-LABEL: test_canonicalize_do_not_fold_different_roots
+// CHECK: arith.addi
+func.func @test_canonicalize_do_not_fold_different_roots(
+    %poly0 : !ntt_poly_ty,
+    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+  %0 = polynomial.ntt %poly0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 {root=#polynomial.primitive_root<value=33:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
+  %a_plus_b = arith.addi %0, %1 : !tensor_ty
+  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
+  return %out : !ntt_poly_ty
+}
+



More information about the Mlir-commits mailing list