[Mlir-commits] [mlir] Upstream polynomial.ntt and polynomial.intt (PR #90992)
Jeremy Kun
llvmlistbot at llvm.org
Sat May 4 18:03:59 PDT 2024
https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/90992
>From e6df81ebba5daff2d8ae3437c3a02cb6ef9051ee Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Thu, 2 May 2024 16:41:44 -0700
Subject: [PATCH 1/5] Upstream polynomial.ntt and polynomial.intt
---
.../mlir/Dialect/Polynomial/IR/Polynomial.td | 52 ++++++++++-
.../Polynomial/IR/PolynomialAttributes.cpp | 21 ++++-
.../Dialect/Polynomial/IR/PolynomialOps.cpp | 77 ++++++++++++++++
mlir/test/Dialect/Polynomial/ops.mlir | 16 +++-
mlir/test/Dialect/Polynomial/ops_errors.mlir | 87 +++++++++++++++++++
5 files changed, 250 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index d3e3ac55677f86..93cf363c316c6f 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -123,9 +123,18 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
let parameters = (ins
"Type": $coefficientType,
OptionalParameter<"IntegerAttr">: $coefficientModulus,
- OptionalParameter<"PolynomialAttr">: $polynomialModulus
+ OptionalParameter<"PolynomialAttr">: $polynomialModulus,
+ OptionalParameter<"IntegerAttr">: $primitiveRoot
);
+ let builders = [
+ AttrBuilder<
+ (ins "::mlir::Type":$coefficientTy,
+ "IntegerAttr":$coefficientModulusAttr,
+ "PolynomialAttr":$polynomialModulusAttr), [{
+ return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, nullptr);
+ }]>
+ ];
let hasCustomAssemblyFormat = 1;
}
@@ -416,4 +425,45 @@ def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
let assemblyFormat = "$input attr-dict `:` type($output)";
}
+def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
+ let summary = "Computes point-value tensor representation of a polynomial.";
+ let description = [{
+ `polynomial.ntt` computes the forward integer Number Theoretic Transform
+ (NTT) on the input polynomial. It returns a tensor containing a point-value
+ representation of the input polynomial. The output tensor has shape equal
+ to the degree of the ring's `polynomialModulus`. The polynomial's RingAttr
+ is embedded as the encoding attribute of the output tensor.
+
+ Given an input polynomial `F(x)` over a ring whose `polynomialModulus` has
+ degree `n`, and a primitive `n`-th root of unity `omega_n`, the output is
+ the list of $n$ evaluations
+
+ `f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
+
+ The choice of primitive root is determined by subsequent lowerings.
+ }];
+ let arguments = (ins Polynomial_PolynomialType:$input);
+ let results = (outs RankedTensorOf<[AnyInteger]>:$output);
+ let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+ let hasVerifier = 1;
+}
+
+def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
+ let summary = "Computes the reverse integer Number Theoretic Transform (NTT).";
+ let description = [{
+ `polynomial.intt` computes the reverse integer Number Theoretic Transform
+ (INTT) on the input tensor. This is the inverse operation of the
+ `polynomial.ntt` operation.
+
+ The input tensor is interpreted as a point-value representation of the
+ output polynomial at powers of a primitive `n`-th root of unity (see
+ `polynomial.ntt`). The ring of the polynomial is taken from the required
+ encoding attribute of the tensor.
+ }];
+ let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
+ let results = (outs Polynomial_PolynomialType:$output);
+ let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+ let hasVerifier = 1;
+}
+
#endif // POLYNOMIAL_OPS
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index f1ec2be72a33ab..2c414acc85ec0c 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -202,11 +202,30 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
polyAttr = attr;
}
+ Polynomial poly = polyAttr.getPolynomial();
+ APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
+ bool hasRoot = succeeded(parser.parseOptionalComma());
+ IntegerAttr rootAttr = nullptr;
+ if (hasRoot) {
+ if (failed(parser.parseKeyword("primitiveRoot")))
+ return {};
+
+ if (failed(parser.parseEqual()))
+ return {};
+
+ auto result = parser.parseInteger(root);
+ if (failed(result)) {
+ parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot");
+ return {};
+ }
+ rootAttr = IntegerAttr::get(coefficientModulusAttr.getType(), root);
+ }
+
if (failed(parser.parseGreater()))
return {};
return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
- polyAttr);
+ polyAttr, rootAttr);
}
} // namespace polynomial
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 8e2bb5f27dc6cc..1d5c7be4b6752a 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -104,3 +104,80 @@ LogicalResult MulScalarOp::verify() {
return success();
}
+
+// Test if a value is a primitive nth root of unity modulo cmod
+bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
+ const APInt &cmod) {
+ // root bitwidth may be 1 less then cmod
+ APInt r = APInt(root).zext(cmod.getBitWidth());
+ assert(r.ule(cmod) && "root must be less than cmod");
+
+ APInt a = r;
+ for (size_t k = 1; k < n; k++) {
+ if (a.isOne())
+ return false;
+ a = (a * r).urem(cmod);
+ }
+ return a.isOne();
+}
+
+static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
+ RankedTensorType tensorType) {
+ auto encoding = tensorType.getEncoding();
+ if (!encoding) {
+ return op->emitOpError()
+ << "a ring encoding was not provided to the tensor";
+ }
+ auto encodedRing = dyn_cast<RingAttr>(encoding);
+ if (!encodedRing) {
+ return op->emitOpError()
+ << "the provided tensor encoding is not a ring attribute";
+ }
+
+ if (encodedRing != ring) {
+ return op->emitOpError()
+ << "encoded ring type " << encodedRing
+ << " is not equivalent to the polynomial ring " << ring;
+ }
+
+ auto polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
+ auto tensorShape = tensorType.getShape();
+ bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
+ if (!compatible) {
+ InFlightDiagnostic diag = op->emitOpError()
+ << "tensor type " << tensorType
+ << " does not match output type " << ring;
+ diag.attachNote() << "the tensor must have shape [d] where d "
+ "is exactly the degree of the polynomialModulus of "
+ "the polynomial type's ring attribute";
+ return diag;
+ }
+
+ if (!ring.getPrimitiveRoot()) {
+ return op->emitOpError()
+ << "ring type " << ring << " does not provide a primitive root "
+ << "of unity, which is required to express an NTT";
+ }
+
+ if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
+ ring.getCoefficientModulus().getValue())) {
+ return op->emitOpError()
+ << "ring type " << ring << " has a primitiveRoot attribute '"
+ << ring.getPrimitiveRoot()
+ << "' that is not a primitive root of the coefficient ring";
+ }
+
+ return success();
+}
+
+LogicalResult NTTOp::verify() {
+ auto ring = getInput().getType().getRing();
+ auto tensorType = getOutput().getType();
+ return verifyNTTOp(this->getOperation(), ring, tensorType);
+}
+
+LogicalResult INTTOp::verify() {
+ auto tensorType = getInput().getType();
+ auto ring = getOutput().getType().getRing();
+ return verifyNTTOp(this->getOperation(), ring, tensorType);
+}
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index ea1b279fa1ff96..a29cfc2e9cc549 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -10,9 +10,13 @@
#one_plus_x_squared = #polynomial.polynomial<1 + x**2>
#ideal = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=18, polynomialModulus=#ideal>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
!poly_ty = !polynomial.polynomial<#ring>
+#ntt_poly = #polynomial.polynomial<-1 + x**8>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
+
module {
func.func @test_multiply() -> !polynomial.polynomial<#ring1> {
%c0 = arith.constant 0 : index
@@ -79,4 +83,14 @@ module {
%1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
return
}
+
+ func.func @test_ntt(%0 : !ntt_poly_ty) {
+ %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+ return
+ }
+
+ func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
+ %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
+ return
+ }
}
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index c34a7de30e5fe5..9029017256be3a 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -51,3 +51,90 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
%poly = polynomial.mul_scalar %arg0, %scalar : !ty, i32
return %poly : !ty
}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+ // expected-error at +1 {{a ring encoding was not provided to the tensor}}
+ %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+ // expected-error at +1 {{tensor encoding is not a ring attribute}}
+ %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
+ // expected-error at +1 {{not equivalent to the polynomial ring}}
+ %1 = polynomial.intt %0 : tensor<1024xi32, #ring1> -> !poly_ty
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
+ // expected-error at below {{does not match output type}}
+ // expected-note at below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
+ %1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+ // expected-error at +1 {{does not provide a primitive root of unity, which is required to express an NTT}}
+ %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**8>
+// A valid root is 31
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=32>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
+ // expected-error at below {{has a primitiveRoot attribute '32 : i16' that is not a primitive root of the coefficient ring}}
+ %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
+ return
+}
>From 2d0fad3b945da7d6b5f1226e7e0a7b635d1f6a4f Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sat, 4 May 2024 17:48:49 -0700
Subject: [PATCH 2/5] use ::mlir:: prefix
---
.../include/mlir/Dialect/Polynomial/IR/Polynomial.td | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 93cf363c316c6f..ed1f4ce8b7e599 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -79,7 +79,7 @@ def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
#poly = #polynomial.polynomial<x**1024 + 1>
```
}];
- let parameters = (ins "Polynomial":$polynomial);
+ let parameters = (ins "::mlir::polynomial::Polynomial":$polynomial);
let hasCustomAssemblyFormat = 1;
}
@@ -122,16 +122,16 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
let parameters = (ins
"Type": $coefficientType,
- OptionalParameter<"IntegerAttr">: $coefficientModulus,
- OptionalParameter<"PolynomialAttr">: $polynomialModulus,
- OptionalParameter<"IntegerAttr">: $primitiveRoot
+ OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
+ OptionalParameter<"::mlir::polynomial::PolynomialAttr">: $polynomialModulus,
+ OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
);
let builders = [
AttrBuilder<
(ins "::mlir::Type":$coefficientTy,
- "IntegerAttr":$coefficientModulusAttr,
- "PolynomialAttr":$polynomialModulusAttr), [{
+ "::mlir::IntegerAttr":$coefficientModulusAttr,
+ "::mlir::polynomial::PolynomialAttr":$polynomialModulusAttr), [{
return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, nullptr);
}]>
];
>From afedacc74786525dc47324af4bcd6941b1ca67bd Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sat, 4 May 2024 18:01:28 -0700
Subject: [PATCH 3/5] fold condition into if
---
mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index 2c414acc85ec0c..488fd04fbd6cfd 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -204,9 +204,8 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
Polynomial poly = polyAttr.getPolynomial();
APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
- bool hasRoot = succeeded(parser.parseOptionalComma());
IntegerAttr rootAttr = nullptr;
- if (hasRoot) {
+ if (succeeded(parser.parseOptionalComma())) {
if (failed(parser.parseKeyword("primitiveRoot")))
return {};
>From b64988a2f431b62fa9456105a39aa5358f5d9af5 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sat, 4 May 2024 18:02:29 -0700
Subject: [PATCH 4/5] expand auto
---
mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index 488fd04fbd6cfd..45263a5e97e72d 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -212,7 +212,7 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseEqual()))
return {};
- auto result = parser.parseInteger(root);
+ ParseResult result = parser.parseInteger(root);
if (failed(result)) {
parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot");
return {};
>From d3fe6b82b3579cbe51889782444a58669756e910 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <kun.jeremy at gmail.com>
Date: Sat, 4 May 2024 18:03:51 -0700
Subject: [PATCH 5/5] Combine failed checks
Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>
---
mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index 45263a5e97e72d..236bb789663529 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -206,10 +206,8 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
IntegerAttr rootAttr = nullptr;
if (succeeded(parser.parseOptionalComma())) {
- if (failed(parser.parseKeyword("primitiveRoot")))
- return {};
-
- if (failed(parser.parseEqual()))
+ if (failed(parser.parseKeyword("primitiveRoot")) ||
+ failed(parser.parseEqual()))
return {};
ParseResult result = parser.parseInteger(root);
More information about the Mlir-commits
mailing list