[Mlir-commits] [mlir] [mlir][polynomial] ensure primitive root calculation doesn't overflow (PR #93368)

Jeremy Kun llvmlistbot at llvm.org
Thu May 30 07:24:23 PDT 2024


https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/93368

>From af5567b40170bc9d669a07d9d83b84ce290ed9d4 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 23 May 2024 14:27:33 -0700
Subject: [PATCH 1/2] [mlir][polynomial] verify from_tensor coeff type

Use the coefficient type to verify if a tensor fits in a polynomial
ring. Downstream we had originally not specified the coefficient type
and just used the implied bit width of the coefficient modulus.  Now
that we specify the type, this is simpler.
---
 mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 5 +----
 mlir/test/Dialect/Polynomial/ops_errors.mlir     | 2 +-
 2 files changed, 2 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 76e837404c6b5..3719979177215 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -47,11 +47,8 @@ LogicalResult FromTensorOp::verify() {
     return diag;
   }
 
-  APInt coefficientModulus = ring.getCoefficientModulus().getValue();
-  unsigned cmodBitWidth = coefficientModulus.ceilLogBase2();
   unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
-
-  if (inputBitWidth > cmodBitWidth) {
+  if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) {
     InFlightDiagnostic diag = emitOpError()
                               << "input tensor element type "
                               << getInput().getType().getElementType()
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index f22b14897e98a..4937e17027afa 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt --split-input-file --verify-diagnostics %s
 
 #my_poly = #polynomial.int_polynomial<1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
+#ring = #polynomial.ring<coefficientType=i16>
 !ty = !polynomial.polynomial<ring=#ring>
 
 func.func @test_from_tensor_too_large_coeffs() {

>From a277dedf5abe59233eca27f246d0d43f872df187 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Fri, 24 May 2024 21:33:35 -0700
Subject: [PATCH 2/2] [mlir][polynomial] ensure primitive root calculation
 doesn't overflow

---
 mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 13 +++++++++----
 mlir/test/Dialect/Polynomial/ops.mlir            | 10 ++++++++++
 2 files changed, 19 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 3719979177215..451db8168c60f 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -107,16 +107,21 @@ LogicalResult MulScalarOp::verify() {
 /// Test if a value is a primitive nth root of unity modulo cmod.
 bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
                                const APInt &cmod) {
+  // The first or subsequent multiplications, may overflow the input bit width,
+  // so scale them up to ensure they do not overflow.
+  unsigned requiredBitWidth =
+      std::max(root.getActiveBits() * 2, cmod.getActiveBits() * 2);
   // Root bitwidth may be 1 less then cmod.
-  APInt r = APInt(root).zext(cmod.getBitWidth());
-  assert(r.ule(cmod) && "root must be less than cmod");
-  unsigned upperBound = n.getZExtValue();
+  APInt r = APInt(root).zextOrTrunc(requiredBitWidth);
+  APInt cmodExt = APInt(cmod).zextOrTrunc(requiredBitWidth);
+  assert(r.ule(cmodExt) && "root must be less than cmod");
+  uint64_t upperBound = n.getZExtValue();
 
   APInt a = r;
   for (size_t k = 1; k < upperBound; k++) {
     if (a.isOne())
       return false;
-    a = (a * r).urem(cmod);
+    a = (a * r).urem(cmodExt);
   }
   return a.isOne();
 }
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index 8c134ab789d60..faeb68a8b2c09 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -18,6 +18,11 @@
 #ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
 !ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
 
+#ntt_poly_2 = #polynomial.int_polynomial<1 + x**65536>
+#ntt_ring_2 = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#ntt_poly_2>
+#ntt_ring_2_root = #polynomial.primitive_root<value=283965:i32, degree=131072:i32>
+!ntt_poly_ty_2 = !polynomial.polynomial<ring=#ntt_ring_2>
+
 module {
   func.func @test_multiply() -> !polynomial.polynomial<ring=#ring1> {
     %c0 = arith.constant 0 : index
@@ -95,6 +100,11 @@ module {
     return
   }
 
+  func.func @test_ntt_with_overflowing_root(%0 : !ntt_poly_ty_2) {
+    %1 = polynomial.ntt %0 {root=#ntt_ring_2_root} : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2>
+    return
+  }
+
   func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
     %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
     return



More information about the Mlir-commits mailing list