[Mlir-commits] [mlir] 624c9fc - Upstream polynomial.ntt and polynomial.intt (#90992)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun May 5 09:44:40 PDT 2024


Author: Jeremy Kun
Date: 2024-05-05T09:44:35-07:00
New Revision: 624c9fc87fb23c6eb89a28e95fa0cf72f5c39d94

URL: https://github.com/llvm/llvm-project/commit/624c9fc87fb23c6eb89a28e95fa0cf72f5c39d94
DIFF: https://github.com/llvm/llvm-project/commit/624c9fc87fb23c6eb89a28e95fa0cf72f5c39d94.diff

LOG: Upstream polynomial.ntt and polynomial.intt (#90992)

These two ops represent a number-theoretic transform of a polynomial to
a tensor of evaluations of the polynomial at a list of powers of
primitive roots of the polynomial.

To support this, a new optional attribute is added to the ring attribute
to specify the primitive root of unity used for the NTT. A verifier for
the op is added to ensure the chosen root is a primitive nth root of
unity.

---------

Co-authored-by: Jeremy Kun <j2kun at users.noreply.github.com>
Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
    mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
    mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
    mlir/test/Dialect/Polynomial/ops.mlir
    mlir/test/Dialect/Polynomial/ops_errors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index d3e3ac55677f86..ed1f4ce8b7e599 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -79,7 +79,7 @@ def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
     #poly = #polynomial.polynomial<x**1024 + 1>
     ```
   }];
-  let parameters = (ins "Polynomial":$polynomial);
+  let parameters = (ins "::mlir::polynomial::Polynomial":$polynomial);
   let hasCustomAssemblyFormat = 1;
 }
 
@@ -122,10 +122,19 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
 
   let parameters = (ins
     "Type": $coefficientType,
-    OptionalParameter<"IntegerAttr">: $coefficientModulus,
-    OptionalParameter<"PolynomialAttr">: $polynomialModulus
+    OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
+    OptionalParameter<"::mlir::polynomial::PolynomialAttr">: $polynomialModulus,
+    OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
   );
 
+  let builders = [
+    AttrBuilder<
+        (ins "::mlir::Type":$coefficientTy,
+             "::mlir::IntegerAttr":$coefficientModulusAttr,
+             "::mlir::polynomial::PolynomialAttr":$polynomialModulusAttr), [{
+      return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, nullptr);
+    }]>
+  ];
   let hasCustomAssemblyFormat = 1;
 }
 
@@ -416,4 +425,45 @@ def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
   let assemblyFormat = "$input attr-dict `:` type($output)";
 }
 
+def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
+  let summary = "Computes point-value tensor representation of a polynomial.";
+  let description = [{
+    `polynomial.ntt` computes the forward integer Number Theoretic Transform
+    (NTT) on the input polynomial. It returns a tensor containing a point-value
+    representation of the input polynomial. The output tensor has shape equal
+    to the degree of the ring's `polynomialModulus`. The polynomial's RingAttr
+    is embedded as the encoding attribute of the output tensor.
+
+    Given an input polynomial `F(x)` over a ring whose `polynomialModulus` has
+    degree `n`, and a primitive `n`-th root of unity `omega_n`, the output is
+    the list of $n$ evaluations
+
+      `f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
+
+    The choice of primitive root is determined by subsequent lowerings.
+  }];
+  let arguments = (ins Polynomial_PolynomialType:$input);
+  let results = (outs RankedTensorOf<[AnyInteger]>:$output);
+  let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+  let hasVerifier = 1;
+}
+
+def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
+  let summary = "Computes the reverse integer Number Theoretic Transform (NTT).";
+  let description = [{
+    `polynomial.intt` computes the reverse integer Number Theoretic Transform
+    (INTT) on the input tensor. This is the inverse operation of the
+    `polynomial.ntt` operation.
+
+    The input tensor is interpreted as a point-value representation of the
+    output polynomial at powers of a primitive `n`-th root of unity (see
+    `polynomial.ntt`). The ring of the polynomial is taken from the required
+    encoding attribute of the tensor.
+  }];
+  let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
+  let results = (outs Polynomial_PolynomialType:$output);
+  let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+  let hasVerifier = 1;
+}
+
 #endif // POLYNOMIAL_OPS

diff  --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index f1ec2be72a33ab..236bb789663529 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -202,11 +202,27 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
     polyAttr = attr;
   }
 
+  Polynomial poly = polyAttr.getPolynomial();
+  APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
+  IntegerAttr rootAttr = nullptr;
+  if (succeeded(parser.parseOptionalComma())) {
+    if (failed(parser.parseKeyword("primitiveRoot")) ||
+        failed(parser.parseEqual()))
+      return {};
+
+    ParseResult result = parser.parseInteger(root);
+    if (failed(result)) {
+      parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot");
+      return {};
+    }
+    rootAttr = IntegerAttr::get(coefficientModulusAttr.getType(), root);
+  }
+
   if (failed(parser.parseGreater()))
     return {};
 
   return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
-                       polyAttr);
+                       polyAttr, rootAttr);
 }
 
 } // namespace polynomial

diff  --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 8e2bb5f27dc6cc..12010de3482377 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -104,3 +104,82 @@ LogicalResult MulScalarOp::verify() {
 
   return success();
 }
+
+/// Test if a value is a primitive nth root of unity modulo cmod.
+bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
+                               const APInt &cmod) {
+  // Root bitwidth may be 1 less then cmod.
+  APInt r = APInt(root).zext(cmod.getBitWidth());
+  assert(r.ule(cmod) && "root must be less than cmod");
+
+  APInt a = r;
+  for (size_t k = 1; k < n; k++) {
+    if (a.isOne())
+      return false;
+    a = (a * r).urem(cmod);
+  }
+  return a.isOne();
+}
+
+/// Verify that the types involved in an NTT or INTT operation are
+/// compatible.
+static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
+                                 RankedTensorType tensorType) {
+  Attribute encoding = tensorType.getEncoding();
+  if (!encoding) {
+    return op->emitOpError()
+           << "expects a ring encoding to be provided to the tensor";
+  }
+  auto encodedRing = dyn_cast<RingAttr>(encoding);
+  if (!encodedRing) {
+    return op->emitOpError()
+           << "the provided tensor encoding is not a ring attribute";
+  }
+
+  if (encodedRing != ring) {
+    return op->emitOpError()
+           << "encoded ring type " << encodedRing
+           << " is not equivalent to the polynomial ring " << ring;
+  }
+
+  unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
+  ArrayRef<int64_t> tensorShape = tensorType.getShape();
+  bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
+  if (!compatible) {
+    InFlightDiagnostic diag = op->emitOpError()
+                              << "tensor type " << tensorType
+                              << " does not match output type " << ring;
+    diag.attachNote() << "the tensor must have shape [d] where d "
+                         "is exactly the degree of the polynomialModulus of "
+                         "the polynomial type's ring attribute";
+    return diag;
+  }
+
+  if (!ring.getPrimitiveRoot()) {
+    return op->emitOpError()
+           << "ring type " << ring << " does not provide a primitive root "
+           << "of unity, which is required to express an NTT";
+  }
+
+  if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
+                                 ring.getCoefficientModulus().getValue())) {
+    return op->emitOpError()
+           << "ring type " << ring << " has a primitiveRoot attribute '"
+           << ring.getPrimitiveRoot()
+           << "' that is not a primitive root of the coefficient ring";
+  }
+
+  return success();
+}
+
+LogicalResult NTTOp::verify() {
+  auto ring = getInput().getType().getRing();
+  auto tensorType = getOutput().getType();
+  return verifyNTTOp(this->getOperation(), ring, tensorType);
+}
+
+LogicalResult INTTOp::verify() {
+  auto tensorType = getInput().getType();
+  auto ring = getOutput().getType().getRing();
+  return verifyNTTOp(this->getOperation(), ring, tensorType);
+}

diff  --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index ea1b279fa1ff96..a29cfc2e9cc549 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -10,9 +10,13 @@
 #one_plus_x_squared = #polynomial.polynomial<1 + x**2>
 
 #ideal = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=18, polynomialModulus=#ideal>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
 !poly_ty = !polynomial.polynomial<#ring>
 
+#ntt_poly = #polynomial.polynomial<-1 + x**8>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
+
 module {
   func.func @test_multiply() -> !polynomial.polynomial<#ring1> {
     %c0 = arith.constant 0 : index
@@ -79,4 +83,14 @@ module {
     %1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
     return
   }
+
+  func.func @test_ntt(%0 : !ntt_poly_ty) {
+    %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+    return
+  }
+
+  func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
+    %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
+    return
+  }
 }

diff  --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index c34a7de30e5fe5..2c20e7bcbf1d69 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -51,3 +51,90 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
   %poly = polynomial.mul_scalar %arg0, %scalar : !ty, i32
   return %poly : !ty
 }
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+  // expected-error at below {{expects a ring encoding to be provided to the tensor}}
+  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
+  return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+  // expected-error at below {{tensor encoding is not a ring attribute}}
+  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
+  return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
+  // expected-error at below {{not equivalent to the polynomial ring}}
+  %1 = polynomial.intt %0 : tensor<1024xi32, #ring1> -> !poly_ty
+  return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
+  // expected-error at below {{does not match output type}}
+  // expected-note at below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
+  %1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
+  return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+  // expected-error at below {{does not provide a primitive root of unity, which is required to express an NTT}}
+  %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
+  return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**8>
+// A valid root is 31
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=32>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
+  // expected-error at below {{has a primitiveRoot attribute '32 : i16' that is not a primitive root of the coefficient ring}}
+  %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
+  return
+}


        


More information about the Mlir-commits mailing list