[Mlir-commits] [mlir] 624c9fc - Upstream polynomial.ntt and polynomial.intt (#90992)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun May 5 09:44:40 PDT 2024
Author: Jeremy Kun
Date: 2024-05-05T09:44:35-07:00
New Revision: 624c9fc87fb23c6eb89a28e95fa0cf72f5c39d94
URL: https://github.com/llvm/llvm-project/commit/624c9fc87fb23c6eb89a28e95fa0cf72f5c39d94
DIFF: https://github.com/llvm/llvm-project/commit/624c9fc87fb23c6eb89a28e95fa0cf72f5c39d94.diff
LOG: Upstream polynomial.ntt and polynomial.intt (#90992)
These two ops represent a number-theoretic transform of a polynomial to
a tensor of evaluations of the polynomial at a list of powers of
primitive roots of the polynomial.
To support this, a new optional attribute is added to the ring attribute
to specify the primitive root of unity used for the NTT. A verifier for
the op is added to ensure the chosen root is a primitive nth root of
unity.
---------
Co-authored-by: Jeremy Kun <j2kun at users.noreply.github.com>
Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>
Added:
Modified:
mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
mlir/test/Dialect/Polynomial/ops.mlir
mlir/test/Dialect/Polynomial/ops_errors.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index d3e3ac55677f86..ed1f4ce8b7e599 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -79,7 +79,7 @@ def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
#poly = #polynomial.polynomial<x**1024 + 1>
```
}];
- let parameters = (ins "Polynomial":$polynomial);
+ let parameters = (ins "::mlir::polynomial::Polynomial":$polynomial);
let hasCustomAssemblyFormat = 1;
}
@@ -122,10 +122,19 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
let parameters = (ins
"Type": $coefficientType,
- OptionalParameter<"IntegerAttr">: $coefficientModulus,
- OptionalParameter<"PolynomialAttr">: $polynomialModulus
+ OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
+ OptionalParameter<"::mlir::polynomial::PolynomialAttr">: $polynomialModulus,
+ OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
);
+ let builders = [
+ AttrBuilder<
+ (ins "::mlir::Type":$coefficientTy,
+ "::mlir::IntegerAttr":$coefficientModulusAttr,
+ "::mlir::polynomial::PolynomialAttr":$polynomialModulusAttr), [{
+ return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, nullptr);
+ }]>
+ ];
let hasCustomAssemblyFormat = 1;
}
@@ -416,4 +425,45 @@ def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
let assemblyFormat = "$input attr-dict `:` type($output)";
}
+def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
+ let summary = "Computes point-value tensor representation of a polynomial.";
+ let description = [{
+ `polynomial.ntt` computes the forward integer Number Theoretic Transform
+ (NTT) on the input polynomial. It returns a tensor containing a point-value
+ representation of the input polynomial. The output tensor has shape equal
+ to the degree of the ring's `polynomialModulus`. The polynomial's RingAttr
+ is embedded as the encoding attribute of the output tensor.
+
+ Given an input polynomial `F(x)` over a ring whose `polynomialModulus` has
+ degree `n`, and a primitive `n`-th root of unity `omega_n`, the output is
+ the list of $n$ evaluations
+
+ `f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
+
+ The choice of primitive root is determined by subsequent lowerings.
+ }];
+ let arguments = (ins Polynomial_PolynomialType:$input);
+ let results = (outs RankedTensorOf<[AnyInteger]>:$output);
+ let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+ let hasVerifier = 1;
+}
+
+def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
+ let summary = "Computes the reverse integer Number Theoretic Transform (NTT).";
+ let description = [{
+ `polynomial.intt` computes the reverse integer Number Theoretic Transform
+ (INTT) on the input tensor. This is the inverse operation of the
+ `polynomial.ntt` operation.
+
+ The input tensor is interpreted as a point-value representation of the
+ output polynomial at powers of a primitive `n`-th root of unity (see
+ `polynomial.ntt`). The ring of the polynomial is taken from the required
+ encoding attribute of the tensor.
+ }];
+ let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
+ let results = (outs Polynomial_PolynomialType:$output);
+ let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+ let hasVerifier = 1;
+}
+
#endif // POLYNOMIAL_OPS
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index f1ec2be72a33ab..236bb789663529 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -202,11 +202,27 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
polyAttr = attr;
}
+ Polynomial poly = polyAttr.getPolynomial();
+ APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
+ IntegerAttr rootAttr = nullptr;
+ if (succeeded(parser.parseOptionalComma())) {
+ if (failed(parser.parseKeyword("primitiveRoot")) ||
+ failed(parser.parseEqual()))
+ return {};
+
+ ParseResult result = parser.parseInteger(root);
+ if (failed(result)) {
+ parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot");
+ return {};
+ }
+ rootAttr = IntegerAttr::get(coefficientModulusAttr.getType(), root);
+ }
+
if (failed(parser.parseGreater()))
return {};
return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
- polyAttr);
+ polyAttr, rootAttr);
}
} // namespace polynomial
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 8e2bb5f27dc6cc..12010de3482377 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -104,3 +104,82 @@ LogicalResult MulScalarOp::verify() {
return success();
}
+
+/// Test if a value is a primitive nth root of unity modulo cmod.
+bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
+ const APInt &cmod) {
+ // Root bitwidth may be 1 less then cmod.
+ APInt r = APInt(root).zext(cmod.getBitWidth());
+ assert(r.ule(cmod) && "root must be less than cmod");
+
+ APInt a = r;
+ for (size_t k = 1; k < n; k++) {
+ if (a.isOne())
+ return false;
+ a = (a * r).urem(cmod);
+ }
+ return a.isOne();
+}
+
+/// Verify that the types involved in an NTT or INTT operation are
+/// compatible.
+static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
+ RankedTensorType tensorType) {
+ Attribute encoding = tensorType.getEncoding();
+ if (!encoding) {
+ return op->emitOpError()
+ << "expects a ring encoding to be provided to the tensor";
+ }
+ auto encodedRing = dyn_cast<RingAttr>(encoding);
+ if (!encodedRing) {
+ return op->emitOpError()
+ << "the provided tensor encoding is not a ring attribute";
+ }
+
+ if (encodedRing != ring) {
+ return op->emitOpError()
+ << "encoded ring type " << encodedRing
+ << " is not equivalent to the polynomial ring " << ring;
+ }
+
+ unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
+ ArrayRef<int64_t> tensorShape = tensorType.getShape();
+ bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
+ if (!compatible) {
+ InFlightDiagnostic diag = op->emitOpError()
+ << "tensor type " << tensorType
+ << " does not match output type " << ring;
+ diag.attachNote() << "the tensor must have shape [d] where d "
+ "is exactly the degree of the polynomialModulus of "
+ "the polynomial type's ring attribute";
+ return diag;
+ }
+
+ if (!ring.getPrimitiveRoot()) {
+ return op->emitOpError()
+ << "ring type " << ring << " does not provide a primitive root "
+ << "of unity, which is required to express an NTT";
+ }
+
+ if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
+ ring.getCoefficientModulus().getValue())) {
+ return op->emitOpError()
+ << "ring type " << ring << " has a primitiveRoot attribute '"
+ << ring.getPrimitiveRoot()
+ << "' that is not a primitive root of the coefficient ring";
+ }
+
+ return success();
+}
+
+LogicalResult NTTOp::verify() {
+ auto ring = getInput().getType().getRing();
+ auto tensorType = getOutput().getType();
+ return verifyNTTOp(this->getOperation(), ring, tensorType);
+}
+
+LogicalResult INTTOp::verify() {
+ auto tensorType = getInput().getType();
+ auto ring = getOutput().getType().getRing();
+ return verifyNTTOp(this->getOperation(), ring, tensorType);
+}
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index ea1b279fa1ff96..a29cfc2e9cc549 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -10,9 +10,13 @@
#one_plus_x_squared = #polynomial.polynomial<1 + x**2>
#ideal = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=18, polynomialModulus=#ideal>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
!poly_ty = !polynomial.polynomial<#ring>
+#ntt_poly = #polynomial.polynomial<-1 + x**8>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
+
module {
func.func @test_multiply() -> !polynomial.polynomial<#ring1> {
%c0 = arith.constant 0 : index
@@ -79,4 +83,14 @@ module {
%1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
return
}
+
+ func.func @test_ntt(%0 : !ntt_poly_ty) {
+ %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+ return
+ }
+
+ func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
+ %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
+ return
+ }
}
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index c34a7de30e5fe5..2c20e7bcbf1d69 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -51,3 +51,90 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
%poly = polynomial.mul_scalar %arg0, %scalar : !ty, i32
return %poly : !ty
}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+ // expected-error at below {{expects a ring encoding to be provided to the tensor}}
+ %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+ // expected-error at below {{tensor encoding is not a ring attribute}}
+ %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
+ // expected-error at below {{not equivalent to the polynomial ring}}
+ %1 = polynomial.intt %0 : tensor<1024xi32, #ring1> -> !poly_ty
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
+ // expected-error at below {{does not match output type}}
+ // expected-note at below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
+ %1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+ // expected-error at below {{does not provide a primitive root of unity, which is required to express an NTT}}
+ %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**8>
+// A valid root is 31
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=32>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
+ // expected-error at below {{has a primitiveRoot attribute '32 : i16' that is not a primitive root of the coefficient ring}}
+ %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
+ return
+}
More information about the Mlir-commits
mailing list