[Mlir-commits] [mlir] [mlir][polynomial] ensure primitive root calculation doesn't overflow (PR #93368)

Jeremy Kun llvmlistbot at llvm.org
Tue May 28 09:23:41 PDT 2024


https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/93368

>From 168b8d83753e0d2ec78e3be7aaa68df11e7a8ac5 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 23 May 2024 11:20:49 -0700
Subject: [PATCH 1/6] move primitive root attr to ntt/intt ops

---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  | 15 +++++--
 .../Polynomial/IR/PolynomialAttributes.td     | 33 ++++++++++++---
 .../IR/PolynomialCanonicalization.td          | 42 ++++++++++---------
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 40 +++++++++---------
 4 files changed, 79 insertions(+), 51 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index f99cbccd243ec..755396c8b9023 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -277,7 +277,6 @@ def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
   Polynomial_TypedIntPolynomialAttr
 ]>;
 
-// Not deriving from Polynomial_Op due to need for custom assembly format
 def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
     [Pure, InferTypeOpAdaptor]> {
   let summary = "Define a constant polynomial via an attribute.";
@@ -312,9 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
 
       `f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
 
-    The choice of primitive root is determined by subsequent lowerings.
+    The choice of primitive root may be optionally specified.
   }];
-  let arguments = (ins Polynomial_PolynomialType:$input);
+  let arguments = (ins
+    Polynomial_PolynomialType:$input,
+    OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
+  );
   let results = (outs RankedTensorOf<[AnyInteger]>:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
   let hasCanonicalizer = 1;
@@ -332,8 +334,13 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
     output polynomial at powers of a primitive `n`-th root of unity (see
     `polynomial.ntt`). The ring of the polynomial is taken from the required
     encoding attribute of the tensor.
+
+    The choice of primitive root may be optionally specified.
   }];
-  let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
+  let arguments = (
+    ins RankedTensorOf<[AnyInteger]>:$input,
+    OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
+  );
   let results = (outs Polynomial_PolynomialType:$output);
   let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
   let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
index 655020adf808b..2d3ed60a35fd9 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
@@ -166,24 +166,45 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
   let parameters = (ins
     "Type": $coefficientType,
     OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
-    OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
-    OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
+    OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus
   );
   let assemblyFormat = "`<` struct(params) `>`";
   let builders = [
     AttrBuilderWithInferredContext<
         (ins "::mlir::Type":$coefficientTy,
               CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
-              CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
-              CArg<"::mlir::IntegerAttr", "nullptr"> :$primitiveRootAttr), [{
+              CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{
       return $_get(
         coefficientTy.getContext(),
         coefficientTy,
         coefficientModulusAttr,
-        polynomialModulusAttr,
-        primitiveRootAttr);
+        polynomialModulusAttr);
     }]>,
   ];
 }
 
+def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
+  let summary = "an attribute containing an integer and its degree as a root of unity";
+  let description = [{
+    A primitive root attribute stores an integer root `value` and an integer
+    `degree`, corresponding to a primitive root of unity of the given degree in
+    an unspecified ring.
+
+    This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops
+    to specify the root of unity used in lowering the transform.
+
+    Example:
+
+    ```mlir
+    #poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
+    ```
+  }];
+  let parameters = (ins
+    "::mlir::IntegerAttr":$value,
+    "::mlir::IntegerAttr":$degree
+  );
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+
 #endif // POLYNOMIAL_ATTRIBUTES
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index e37bcf76a20f2..93ea6e4e43698 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -17,6 +17,8 @@ include "mlir/IR/PatternBase.td"
 
 defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
 
+def Equal : Constraint<CPred<"$0 == $1">>;
+
 // Get a -1 integer attribute of the same type as the polynomial SSA value's
 // ring coefficient type.
 def getMinusOne
@@ -31,15 +33,15 @@ def SubAsAdd : Pat<
       (Arith_ConstantOp (getMinusOne $g))))>;
 
 def INTTAfterNTT : Pat<
-  (Polynomial_INTTOp (Polynomial_NTTOp $poly)),
+  (Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
   (replaceWithValue $poly),
-  []
+  [(Equal $r1, $r2)]
 >;
 
 def NTTAfterINTT : Pat<
-  (Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
+  (Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
   (replaceWithValue $tensor),
-  []
+  [(Equal $r1, $r2)]
 >;
 
 // NTTs are expensive, and addition in coefficient or NTT domain should be
@@ -47,35 +49,35 @@ def NTTAfterINTT : Pat<
 // ntt(a) + ntt(b) -> ntt(a + b)
 def NTTOfAdd : Pat<
   (Arith_AddIOp
-    (Polynomial_NTTOp $p1),
-    (Polynomial_NTTOp $p2),
+    (Polynomial_NTTOp $p1, $r1),
+    (Polynomial_NTTOp $p2, $r2),
     $overflow),
-  (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2)),
-  []
+  (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2), $r1),
+  [(Equal $r1, $r2)]
 >;
 // intt(a) + intt(b) -> intt(a + b)
 def INTTOfAdd : Pat<
   (Polynomial_AddOp
-    (Polynomial_INTTOp $t1),
-    (Polynomial_INTTOp $t2)),
-  (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow)),
-  []
+    (Polynomial_INTTOp $t1, $r1),
+    (Polynomial_INTTOp $t2, $r2)),
+  (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow), $r1),
+  [(Equal $r1, $r2)]
 >;
 // repeated for sub
 def NTTOfSub : Pat<
   (Arith_SubIOp
-    (Polynomial_NTTOp $p1),
-    (Polynomial_NTTOp $p2),
+    (Polynomial_NTTOp $p1, $r1),
+    (Polynomial_NTTOp $p2, $r2),
     $overflow),
-  (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2)),
-  []
+  (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2), $r1),
+  [(Equal $r1, $r2)]
 >;
 def INTTOfSub : Pat<
   (Polynomial_SubOp
-    (Polynomial_INTTOp $t1),
-    (Polynomial_INTTOp $t2)),
-  (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow)),
-  []
+    (Polynomial_INTTOp $t1, $r1),
+    (Polynomial_INTTOp $t2, $r2)),
+  (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow), $r1),
+  [(Equal $r1, $r2)]
 >;
 
 #endif  // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 3d302797ce513..a39c3872ad6d5 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -108,14 +108,15 @@ LogicalResult MulScalarOp::verify() {
 }
 
 /// Test if a value is a primitive nth root of unity modulo cmod.
-bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
+bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
                                const APInt &cmod) {
   // Root bitwidth may be 1 less then cmod.
   APInt r = APInt(root).zext(cmod.getBitWidth());
   assert(r.ule(cmod) && "root must be less than cmod");
+  unsigned upperBound = n.getZExtValue();
 
   APInt a = r;
-  for (size_t k = 1; k < n; k++) {
+  for (size_t k = 1; k < upperBound; k++) {
     if (a.isOne())
       return false;
     a = (a * r).urem(cmod);
@@ -126,7 +127,8 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
 /// Verify that the types involved in an NTT or INTT operation are
 /// compatible.
 static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
-                                 RankedTensorType tensorType) {
+                                 RankedTensorType tensorType,
+                                 std::optional<PrimitiveRootAttr> root) {
   Attribute encoding = tensorType.getEncoding();
   if (!encoding) {
     return op->emitOpError()
@@ -157,33 +159,29 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
     return diag;
   }
 
-  if (!ring.getPrimitiveRoot()) {
-    return op->emitOpError()
-           << "ring type " << ring << " does not provide a primitive root "
-           << "of unity, which is required to express an NTT";
-  }
-
-  if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
-                                 ring.getCoefficientModulus().getValue())) {
-    return op->emitOpError()
-           << "ring type " << ring << " has a primitiveRoot attribute '"
-           << ring.getPrimitiveRoot()
-           << "' that is not a primitive root of the coefficient ring";
+  if (root.has_value()) {
+    APInt rootValue = root.value().getValue().getValue();
+    APInt rootDegree = root.value().getDegree().getValue();
+    APInt cmod = ring.getCoefficientModulus().getValue();
+    if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
+      return op->emitOpError()
+             << "provided root " << rootValue.getZExtValue() << " is not a primitive root "
+             << "of unity mod " << cmod.getZExtValue() << ", with the specified degree "
+             << rootDegree.getZExtValue();
+    }
   }
 
   return success();
 }
 
 LogicalResult NTTOp::verify() {
-  auto ring = getInput().getType().getRing();
-  auto tensorType = getOutput().getType();
-  return verifyNTTOp(this->getOperation(), ring, tensorType);
+  return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
+                     getOutput().getType(), getRoot());
 }
 
 LogicalResult INTTOp::verify() {
-  auto tensorType = getInput().getType();
-  auto ring = getOutput().getType().getRing();
-  return verifyNTTOp(this->getOperation(), ring, tensorType);
+  return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
+                     getInput().getType(), getRoot());
 }
 
 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {

>From 0e368428c0d72007047cbc4f1c17e7d3f97e4e65 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 23 May 2024 11:29:56 -0700
Subject: [PATCH 2/6] update tests

---
 .../Dialect/Polynomial/canonicalization.mlir  | 35 +++++++++---------
 mlir/test/Dialect/Polynomial/ops.mlir         |  8 ++---
 mlir/test/Dialect/Polynomial/ops_errors.mlir  | 36 ++++++-------------
 3 files changed, 33 insertions(+), 46 deletions(-)

diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index 489d9ec2720d6..354b76e3d9669 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt -canonicalize %s | FileCheck %s
 #ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
+#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
 !ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 !tensor_ty = tensor<8xi32, #ntt_ring>
 
@@ -10,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty
   // CHECK-NOT: polynomial.ntt
   // CHECK-NOT: polynomial.intt
   // CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]]  : [[T]]
-  %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
-  %p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty
+  %t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
   %p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
   // CHECK: return %[[RESULT]] : [[T]]
   return %p2 : !ntt_poly_ty
@@ -23,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
   // CHECK-NOT: polynomial.intt
   // CHECK-NOT: polynomial.ntt
   // CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
-  %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
-  %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
+  %p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
   %t2 = arith.addi %t1, %t1 : !tensor_ty
   // CHECK: return %[[RESULT]] : [[T]]
   return %t2 : !tensor_ty
@@ -51,10 +52,10 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
 func.func @test_canonicalize_fold_add_through_ntt(
     %poly0 : !ntt_poly_ty,
     %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
-  %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
-  %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+  %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
   %a_plus_b = arith.addi %0, %1 : !tensor_ty
-  %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
   return %out : !ntt_poly_ty
 }
 
@@ -65,10 +66,10 @@ func.func @test_canonicalize_fold_add_through_ntt(
 func.func @test_canonicalize_fold_add_through_intt(
     %tensor0 : !tensor_ty,
     %tensor1 : !tensor_ty) -> !tensor_ty {
-  %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
-  %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+  %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
   %a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
-  %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+  %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
   return %out : !tensor_ty
 }
 
@@ -80,10 +81,10 @@ func.func @test_canonicalize_fold_add_through_intt(
 func.func @test_canonicalize_fold_sub_through_ntt(
     %poly0 : !ntt_poly_ty,
     %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
-  %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
-  %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+  %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
   %a_plus_b = arith.subi %0, %1 : !tensor_ty
-  %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
   return %out : !ntt_poly_ty
 }
 
@@ -94,9 +95,9 @@ func.func @test_canonicalize_fold_sub_through_ntt(
 func.func @test_canonicalize_fold_sub_through_intt(
     %tensor0 : !tensor_ty,
     %tensor1 : !tensor_ty) -> !tensor_ty {
-  %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
-  %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+  %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+  %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
   %a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
-  %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+  %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
   return %out : !tensor_ty
 }
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index 4716e37ff8852..8c134ab789d60 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -11,11 +11,11 @@
 #one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
 
 #ideal = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 #ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
 !ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 
 module {
@@ -91,12 +91,12 @@ module {
   }
 
   func.func @test_ntt(%0 : !ntt_poly_ty) {
-    %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+    %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
     return
   }
 
   func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
-    %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
+    %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
     return
   }
 }
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index af8e4aa5da862..f22b14897e98a 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -55,28 +55,28 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_ntt
 // CHECK-NOT: polynomial.ntt
 func.func @test_invalid_ntt(%0 : !poly_ty) {
   // expected-error at below {{expects a ring encoding to be provided to the tensor}}
-  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
+  %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32>
   return
 }
 
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_ntt
 // CHECK-NOT: polynomial.ntt
 func.func @test_invalid_ntt(%0 : !poly_ty) {
   // expected-error at below {{tensor encoding is not a ring attribute}}
-  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
+  %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32, #my_poly>
   return
 }
 
@@ -84,21 +84,21 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
 #ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
 // CHECK-NOT: polynomial.intt
 func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
   // expected-error at below {{not equivalent to the polynomial ring}}
-  %1 = polynomial.intt %0 : tensor<1024xi32, #ring1> -> !poly_ty
+  %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1024xi32, #ring1> -> !poly_ty
   return
 }
 
 // -----
 
 #my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
@@ -106,21 +106,7 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
 func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
   // expected-error at below {{does not match output type}}
   // expected-note at below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
-  %1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
-  return
-}
-
-// -----
-
-#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-!poly_ty = !polynomial.polynomial<ring=#ring>
-
-// CHECK-NOT: @test_invalid_ntt
-// CHECK-NOT: polynomial.ntt
-func.func @test_invalid_ntt(%0 : !poly_ty) {
-  // expected-error at below {{does not provide a primitive root of unity, which is required to express an NTT}}
-  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
+  %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1025xi32, #ring> -> !poly_ty
   return
 }
 
@@ -128,13 +114,13 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
 
 #my_poly = #polynomial.int_polynomial<-1 + x**8>
 // A valid root is 31
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=32:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
 !poly_ty = !polynomial.polynomial<ring=#ring>
 
 // CHECK-NOT: @test_invalid_intt
 // CHECK-NOT: polynomial.intt
 func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
-  // expected-error at below {{has a primitiveRoot attribute '32 : i16' that is not a primitive root of the coefficient ring}}
-  %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
+  // expected-error at below {{provided root 32 is not a primitive root of unity mod 256, with the specified degree 8}}
+  %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=32:i16, degree=8:index>} : tensor<8xi32, #ring> -> !poly_ty
   return
 }

>From 3ea18c40f76820315727aaa8743a3a84504bf956 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 23 May 2024 11:42:39 -0700
Subject: [PATCH 3/6] add test for root equality

---
 mlir/test/Dialect/Polynomial/canonicalization.mlir | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index 354b76e3d9669..b79938627e415 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -101,3 +101,17 @@ func.func @test_canonicalize_fold_sub_through_intt(
   %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
   return %out : !tensor_ty
 }
+
+
+// CHECK-LABEL: test_canonicalize_do_not_fold_different_roots
+// CHECK: arith.addi
+func.func @test_canonicalize_do_not_fold_different_roots(
+    %poly0 : !ntt_poly_ty,
+    %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+  %0 = polynomial.ntt %poly0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
+  %1 = polynomial.ntt %poly1 {root=#polynomial.primitive_root<value=33:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
+  %a_plus_b = arith.addi %0, %1 : !tensor_ty
+  %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
+  return %out : !ntt_poly_ty
+}
+

>From 5d4619e5a088f3cf6e675b17d41bf162dcc8252d Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 23 May 2024 12:33:55 -0700
Subject: [PATCH 4/6] clang format

---
 mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index a39c3872ad6d5..76e837404c6b5 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -165,9 +165,10 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
     APInt cmod = ring.getCoefficientModulus().getValue();
     if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
       return op->emitOpError()
-             << "provided root " << rootValue.getZExtValue() << " is not a primitive root "
-             << "of unity mod " << cmod.getZExtValue() << ", with the specified degree "
-             << rootDegree.getZExtValue();
+             << "provided root " << rootValue.getZExtValue()
+             << " is not a primitive root "
+             << "of unity mod " << cmod.getZExtValue()
+             << ", with the specified degree " << rootDegree.getZExtValue();
     }
   }
 

>From 30f51bf7455b2ed65f5391bcbeaae89925695e06 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 23 May 2024 14:27:33 -0700
Subject: [PATCH 5/6] [mlir][polynomial] verify from_tensor coeff type

Use the coefficient type to verify if a tensor fits in a polynomial
ring. Downstream we had originally not specified the coefficient type
and just used the implied bit width of the coefficient modulus.  Now
that we specify the type, this is simpler.
---
 mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 5 +----
 mlir/test/Dialect/Polynomial/ops_errors.mlir     | 2 +-
 2 files changed, 2 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 76e837404c6b5..3719979177215 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -47,11 +47,8 @@ LogicalResult FromTensorOp::verify() {
     return diag;
   }
 
-  APInt coefficientModulus = ring.getCoefficientModulus().getValue();
-  unsigned cmodBitWidth = coefficientModulus.ceilLogBase2();
   unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
-
-  if (inputBitWidth > cmodBitWidth) {
+  if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) {
     InFlightDiagnostic diag = emitOpError()
                               << "input tensor element type "
                               << getInput().getType().getElementType()
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index f22b14897e98a..4937e17027afa 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt --split-input-file --verify-diagnostics %s
 
 #my_poly = #polynomial.int_polynomial<1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
+#ring = #polynomial.ring<coefficientType=i16>
 !ty = !polynomial.polynomial<ring=#ring>
 
 func.func @test_from_tensor_too_large_coeffs() {

>From 68a75f6e65fd5a16456511853df72220a0714d19 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Fri, 24 May 2024 21:33:35 -0700
Subject: [PATCH 6/6] [mlir][polynomial] ensure primitive root calculation
 doesn't overflow

---
 mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 14 ++++++++++----
 mlir/test/Dialect/Polynomial/ops.mlir            | 10 ++++++++++
 2 files changed, 20 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 3719979177215..18ed2d175ee29 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/APInt.h"
+#include <iostream>
 
 using namespace mlir;
 using namespace mlir::polynomial;
@@ -107,16 +108,21 @@ LogicalResult MulScalarOp::verify() {
 /// Test if a value is a primitive nth root of unity modulo cmod.
 bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
                                const APInt &cmod) {
+  // The first or subsequent multiplications, may overflow the input bit width,
+  // so scale them up to ensure they do not overflow.
+  unsigned requiredBitWidth =
+      std::max(root.getActiveBits() * 2, cmod.getActiveBits() * 2);
   // Root bitwidth may be 1 less then cmod.
-  APInt r = APInt(root).zext(cmod.getBitWidth());
-  assert(r.ule(cmod) && "root must be less than cmod");
-  unsigned upperBound = n.getZExtValue();
+  APInt r = APInt(root).zext(requiredBitWidth);
+  APInt cmodExt = APInt(cmod).zext(requiredBitWidth);
+  assert(r.ule(cmodExt) && "root must be less than cmod");
+  uint64_t upperBound = n.getZExtValue();
 
   APInt a = r;
   for (size_t k = 1; k < upperBound; k++) {
     if (a.isOne())
       return false;
-    a = (a * r).urem(cmod);
+    a = (a * r).urem(cmodExt);
   }
   return a.isOne();
 }
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index 8c134ab789d60..faeb68a8b2c09 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -18,6 +18,11 @@
 #ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
 !ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 
+#ntt_poly_2 = #polynomial.int_polynomial<1 + x**65536>
+#ntt_ring_2 = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#ntt_poly_2>
+#ntt_ring_2_root = #polynomial.primitive_root<value=283965:i32, degree=131072:i32>
+!ntt_poly_ty_2 = !polynomial.polynomial<ring=#ntt_ring_2>
+
 module {
   func.func @test_multiply() -> !polynomial.polynomial<ring=#ring1> {
     %c0 = arith.constant 0 : index
@@ -95,6 +100,11 @@ module {
     return
   }
 
+  func.func @test_ntt_with_overflowing_root(%0 : !ntt_poly_ty_2) {
+    %1 = polynomial.ntt %0 {root=#ntt_ring_2_root} : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2>
+    return
+  }
+
   func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
     %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
     return



More information about the Mlir-commits mailing list