[Mlir-commits] [mlir] [mlir][polynomial] ensure primitive root calculation doesn't overflow (PR #93368)

Jeremy Kun llvmlistbot at llvm.org
Thu May 30 11:35:40 PDT 2024

https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/93368

>From 2fc35e1f7d8b445beb9e77a748a19df6a3ad19e5 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Fri, 24 May 2024 21:33:35 -0700
Subject: [PATCH] [mlir][polynomial] ensure primitive root calculation doesn't

 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 20 +++++++++++--------
 mlir/test/Dialect/Polynomial/ops.mlir         | 10 ++++++++++
 2 files changed, 22 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 3117721a94152..0df09bfcc3981 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -114,16 +114,20 @@ LogicalResult MulScalarOp::verify() {
 /// Test if a value is a primitive nth root of unity modulo cmod.
 bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
                                const APInt &cmod) {
-  // Root bitwidth may be 1 less then cmod.
-  APInt r = APInt(root).zext(cmod.getBitWidth());
-  assert(r.ule(cmod) && "root must be less than cmod");
-  unsigned upperBound = n.getZExtValue();
+  // The first or subsequent multiplications, may overflow the input bit width,
+  // so scale them up to ensure they do not overflow.
+  unsigned requiredBitWidth =
+      std::max(root.getActiveBits() * 2, cmod.getActiveBits() * 2);
+  APInt r = APInt(root).zextOrTrunc(requiredBitWidth);
+  APInt cmodExt = APInt(cmod).zextOrTrunc(requiredBitWidth);
+  assert(r.ule(cmodExt) && "root must be less than cmod");
+  uint64_t upperBound = n.getZExtValue();
   APInt a = r;
   for (size_t k = 1; k < upperBound; k++) {
     if (a.isOne())
       return false;
-    a = (a * r).urem(cmod);
+    a = (a * r).urem(cmodExt);
   return a.isOne();
@@ -170,9 +174,9 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
     if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
       return op->emitOpError()
              << "provided root " << rootValue.getZExtValue()
-             << " is not a primitive root "
-             << "of unity mod " << cmod.getZExtValue()
-             << ", with the specified degree " << rootDegree.getZExtValue();
+             << " is not a primitive root " << "of unity mod "
+             << cmod.getZExtValue() << ", with the specified degree "
+             << rootDegree.getZExtValue();
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index 8c134ab789d60..faeb68a8b2c09 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -18,6 +18,11 @@
 #ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
 !ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
+#ntt_poly_2 = #polynomial.int_polynomial<1 + x**65536>
+#ntt_ring_2 = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#ntt_poly_2>
+#ntt_ring_2_root = #polynomial.primitive_root<value=283965:i32, degree=131072:i32>
+!ntt_poly_ty_2 = !polynomial.polynomial<ring=#ntt_ring_2>
 module {
   func.func @test_multiply() -> !polynomial.polynomial<ring=#ring1> {
     %c0 = arith.constant 0 : index
@@ -95,6 +100,11 @@ module {
+  func.func @test_ntt_with_overflowing_root(%0 : !ntt_poly_ty_2) {
+    %1 = polynomial.ntt %0 {root=#ntt_ring_2_root} : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2>
+    return
+  }
   func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
     %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty

More information about the Mlir-commits mailing list