[Mlir-commits] [mlir] Upstream polynomial.ntt and polynomial.intt (PR #90992)

Jeremy Kun llvmlistbot at llvm.org
Fri May 3 11:39:13 PDT 2024


https://github.com/j2kun created https://github.com/llvm/llvm-project/pull/90992

These two ops represent a number-theoretic transform of a polynomial to a tensor of evaluations of the polynomial at a list of powers of primitive roots of the polynomial.

To support this, a new optional attribute is added to the ring attribute to specify the primitive root of unity used for the NTT. A verifier for the op is added to ensure the chosen root is a primitive nth root of unity.

>From e6df81ebba5daff2d8ae3437c3a02cb6ef9051ee Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Thu, 2 May 2024 16:41:44 -0700
Subject: [PATCH] Upstream polynomial.ntt and polynomial.intt

---
 .../mlir/Dialect/Polynomial/IR/Polynomial.td  | 52 ++++++++++-
 .../Polynomial/IR/PolynomialAttributes.cpp    | 21 ++++-
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 77 ++++++++++++++++
 mlir/test/Dialect/Polynomial/ops.mlir         | 16 +++-
 mlir/test/Dialect/Polynomial/ops_errors.mlir  | 87 +++++++++++++++++++
 5 files changed, 250 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index d3e3ac55677f86..93cf363c316c6f 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -123,9 +123,18 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
   let parameters = (ins
     "Type": $coefficientType,
     OptionalParameter<"IntegerAttr">: $coefficientModulus,
-    OptionalParameter<"PolynomialAttr">: $polynomialModulus
+    OptionalParameter<"PolynomialAttr">: $polynomialModulus,
+    OptionalParameter<"IntegerAttr">: $primitiveRoot
   );
 
+  let builders = [
+    AttrBuilder<
+        (ins "::mlir::Type":$coefficientTy,
+             "IntegerAttr":$coefficientModulusAttr,
+             "PolynomialAttr":$polynomialModulusAttr), [{
+      return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, nullptr);
+    }]>
+  ];
   let hasCustomAssemblyFormat = 1;
 }
 
@@ -416,4 +425,45 @@ def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
   let assemblyFormat = "$input attr-dict `:` type($output)";
 }
 
+def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
+  let summary = "Computes point-value tensor representation of a polynomial.";
+  let description = [{
+    `polynomial.ntt` computes the forward integer Number Theoretic Transform
+    (NTT) on the input polynomial. It returns a tensor containing a point-value
+    representation of the input polynomial. The output tensor has shape equal
+    to the degree of the ring's `polynomialModulus`. The polynomial's RingAttr
+    is embedded as the encoding attribute of the output tensor.
+
+    Given an input polynomial `F(x)` over a ring whose `polynomialModulus` has
+    degree `n`, and a primitive `n`-th root of unity `omega_n`, the output is
+    the list of $n$ evaluations
+
+      `f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
+
+    The choice of primitive root is determined by subsequent lowerings.
+  }];
+  let arguments = (ins Polynomial_PolynomialType:$input);
+  let results = (outs RankedTensorOf<[AnyInteger]>:$output);
+  let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+  let hasVerifier = 1;
+}
+
+def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
+  let summary = "Computes the reverse integer Number Theoretic Transform (NTT).";
+  let description = [{
+    `polynomial.intt` computes the reverse integer Number Theoretic Transform
+    (INTT) on the input tensor. This is the inverse operation of the
+    `polynomial.ntt` operation.
+
+    The input tensor is interpreted as a point-value representation of the
+    output polynomial at powers of a primitive `n`-th root of unity (see
+    `polynomial.ntt`). The ring of the polynomial is taken from the required
+    encoding attribute of the tensor.
+  }];
+  let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
+  let results = (outs Polynomial_PolynomialType:$output);
+  let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+  let hasVerifier = 1;
+}
+
 #endif // POLYNOMIAL_OPS
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index f1ec2be72a33ab..2c414acc85ec0c 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -202,11 +202,30 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
     polyAttr = attr;
   }
 
+  Polynomial poly = polyAttr.getPolynomial();
+  APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
+  bool hasRoot = succeeded(parser.parseOptionalComma());
+  IntegerAttr rootAttr = nullptr;
+  if (hasRoot) {
+    if (failed(parser.parseKeyword("primitiveRoot")))
+      return {};
+
+    if (failed(parser.parseEqual()))
+      return {};
+
+    auto result = parser.parseInteger(root);
+    if (failed(result)) {
+      parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot");
+      return {};
+    }
+    rootAttr = IntegerAttr::get(coefficientModulusAttr.getType(), root);
+  }
+
   if (failed(parser.parseGreater()))
     return {};
 
   return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
-                       polyAttr);
+                       polyAttr, rootAttr);
 }
 
 } // namespace polynomial
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 8e2bb5f27dc6cc..1d5c7be4b6752a 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -104,3 +104,80 @@ LogicalResult MulScalarOp::verify() {
 
   return success();
 }
+
+// Test if a value is a primitive nth root of unity modulo cmod
+bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
+                               const APInt &cmod) {
+  // root bitwidth may be 1 less then cmod
+  APInt r = APInt(root).zext(cmod.getBitWidth());
+  assert(r.ule(cmod) && "root must be less than cmod");
+
+  APInt a = r;
+  for (size_t k = 1; k < n; k++) {
+    if (a.isOne())
+      return false;
+    a = (a * r).urem(cmod);
+  }
+  return a.isOne();
+}
+
+static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
+                                 RankedTensorType tensorType) {
+  auto encoding = tensorType.getEncoding();
+  if (!encoding) {
+    return op->emitOpError()
+           << "a ring encoding was not provided to the tensor";
+  }
+  auto encodedRing = dyn_cast<RingAttr>(encoding);
+  if (!encodedRing) {
+    return op->emitOpError()
+           << "the provided tensor encoding is not a ring attribute";
+  }
+
+  if (encodedRing != ring) {
+    return op->emitOpError()
+           << "encoded ring type " << encodedRing
+           << " is not equivalent to the polynomial ring " << ring;
+  }
+
+  auto polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
+  auto tensorShape = tensorType.getShape();
+  bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
+  if (!compatible) {
+    InFlightDiagnostic diag = op->emitOpError()
+                              << "tensor type " << tensorType
+                              << " does not match output type " << ring;
+    diag.attachNote() << "the tensor must have shape [d] where d "
+                         "is exactly the degree of the polynomialModulus of "
+                         "the polynomial type's ring attribute";
+    return diag;
+  }
+
+  if (!ring.getPrimitiveRoot()) {
+    return op->emitOpError()
+           << "ring type " << ring << " does not provide a primitive root "
+           << "of unity, which is required to express an NTT";
+  }
+
+  if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
+                                 ring.getCoefficientModulus().getValue())) {
+    return op->emitOpError()
+           << "ring type " << ring << " has a primitiveRoot attribute '"
+           << ring.getPrimitiveRoot()
+           << "' that is not a primitive root of the coefficient ring";
+  }
+
+  return success();
+}
+
+LogicalResult NTTOp::verify() {
+  auto ring = getInput().getType().getRing();
+  auto tensorType = getOutput().getType();
+  return verifyNTTOp(this->getOperation(), ring, tensorType);
+}
+
+LogicalResult INTTOp::verify() {
+  auto tensorType = getInput().getType();
+  auto ring = getOutput().getType().getRing();
+  return verifyNTTOp(this->getOperation(), ring, tensorType);
+}
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index ea1b279fa1ff96..a29cfc2e9cc549 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -10,9 +10,13 @@
 #one_plus_x_squared = #polynomial.polynomial<1 + x**2>
 
 #ideal = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=18, polynomialModulus=#ideal>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
 !poly_ty = !polynomial.polynomial<#ring>
 
+#ntt_poly = #polynomial.polynomial<-1 + x**8>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
+
 module {
   func.func @test_multiply() -> !polynomial.polynomial<#ring1> {
     %c0 = arith.constant 0 : index
@@ -79,4 +83,14 @@ module {
     %1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
     return
   }
+
+  func.func @test_ntt(%0 : !ntt_poly_ty) {
+    %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+    return
+  }
+
+  func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
+    %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
+    return
+  }
 }
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index c34a7de30e5fe5..9029017256be3a 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -51,3 +51,90 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
   %poly = polynomial.mul_scalar %arg0, %scalar : !ty, i32
   return %poly : !ty
 }
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+  // expected-error at +1 {{a ring encoding was not provided to the tensor}}
+  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
+  return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+  // expected-error at +1 {{tensor encoding is not a ring attribute}}
+  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
+  return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
+  // expected-error at +1 {{not equivalent to the polynomial ring}}
+  %1 = polynomial.intt %0 : tensor<1024xi32, #ring1> -> !poly_ty
+  return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
+  // expected-error at below {{does not match output type}}
+  // expected-note at below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
+  %1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
+  return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+  // expected-error at +1 {{does not provide a primitive root of unity, which is required to express an NTT}}
+  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
+  return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**8>
+// A valid root is 31
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=32>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
+  // expected-error at below {{has a primitiveRoot attribute '32 : i16' that is not a primitive root of the coefficient ring}}
+  %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
+  return
+}



More information about the Mlir-commits mailing list