[Mlir-commits] [mlir] Upstream polynomial.ntt and polynomial.intt (PR #90992)
Jeremy Kun
llvmlistbot at llvm.org
Sat May 4 18:11:07 PDT 2024
https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/90992
>From e6df81ebba5daff2d8ae3437c3a02cb6ef9051ee Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Thu, 2 May 2024 16:41:44 -0700
Subject: [PATCH 01/11] Upstream polynomial.ntt and polynomial.intt
---
.../mlir/Dialect/Polynomial/IR/Polynomial.td | 52 ++++++++++-
.../Polynomial/IR/PolynomialAttributes.cpp | 21 ++++-
.../Dialect/Polynomial/IR/PolynomialOps.cpp | 77 ++++++++++++++++
mlir/test/Dialect/Polynomial/ops.mlir | 16 +++-
mlir/test/Dialect/Polynomial/ops_errors.mlir | 87 +++++++++++++++++++
5 files changed, 250 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index d3e3ac55677f86..93cf363c316c6f 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -123,9 +123,18 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
let parameters = (ins
"Type": $coefficientType,
OptionalParameter<"IntegerAttr">: $coefficientModulus,
- OptionalParameter<"PolynomialAttr">: $polynomialModulus
+ OptionalParameter<"PolynomialAttr">: $polynomialModulus,
+ OptionalParameter<"IntegerAttr">: $primitiveRoot
);
+ let builders = [
+ AttrBuilder<
+ (ins "::mlir::Type":$coefficientTy,
+ "IntegerAttr":$coefficientModulusAttr,
+ "PolynomialAttr":$polynomialModulusAttr), [{
+ return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, nullptr);
+ }]>
+ ];
let hasCustomAssemblyFormat = 1;
}
@@ -416,4 +425,45 @@ def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
let assemblyFormat = "$input attr-dict `:` type($output)";
}
+def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
+ let summary = "Computes point-value tensor representation of a polynomial.";
+ let description = [{
+ `polynomial.ntt` computes the forward integer Number Theoretic Transform
+ (NTT) on the input polynomial. It returns a tensor containing a point-value
+ representation of the input polynomial. The output tensor has shape equal
+ to the degree of the ring's `polynomialModulus`. The polynomial's RingAttr
+ is embedded as the encoding attribute of the output tensor.
+
+ Given an input polynomial `F(x)` over a ring whose `polynomialModulus` has
+ degree `n`, and a primitive `n`-th root of unity `omega_n`, the output is
+ the list of $n$ evaluations
+
+ `f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
+
+ The choice of primitive root is determined by subsequent lowerings.
+ }];
+ let arguments = (ins Polynomial_PolynomialType:$input);
+ let results = (outs RankedTensorOf<[AnyInteger]>:$output);
+ let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+ let hasVerifier = 1;
+}
+
+def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
+ let summary = "Computes the reverse integer Number Theoretic Transform (NTT).";
+ let description = [{
+ `polynomial.intt` computes the reverse integer Number Theoretic Transform
+ (INTT) on the input tensor. This is the inverse operation of the
+ `polynomial.ntt` operation.
+
+ The input tensor is interpreted as a point-value representation of the
+ output polynomial at powers of a primitive `n`-th root of unity (see
+ `polynomial.ntt`). The ring of the polynomial is taken from the required
+ encoding attribute of the tensor.
+ }];
+ let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
+ let results = (outs Polynomial_PolynomialType:$output);
+ let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+ let hasVerifier = 1;
+}
+
#endif // POLYNOMIAL_OPS
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index f1ec2be72a33ab..2c414acc85ec0c 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -202,11 +202,30 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
polyAttr = attr;
}
+ Polynomial poly = polyAttr.getPolynomial();
+ APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
+ bool hasRoot = succeeded(parser.parseOptionalComma());
+ IntegerAttr rootAttr = nullptr;
+ if (hasRoot) {
+ if (failed(parser.parseKeyword("primitiveRoot")))
+ return {};
+
+ if (failed(parser.parseEqual()))
+ return {};
+
+ auto result = parser.parseInteger(root);
+ if (failed(result)) {
+ parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot");
+ return {};
+ }
+ rootAttr = IntegerAttr::get(coefficientModulusAttr.getType(), root);
+ }
+
if (failed(parser.parseGreater()))
return {};
return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
- polyAttr);
+ polyAttr, rootAttr);
}
} // namespace polynomial
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 8e2bb5f27dc6cc..1d5c7be4b6752a 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -104,3 +104,80 @@ LogicalResult MulScalarOp::verify() {
return success();
}
+
+// Test if a value is a primitive nth root of unity modulo cmod
+bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
+ const APInt &cmod) {
+ // root bitwidth may be 1 less then cmod
+ APInt r = APInt(root).zext(cmod.getBitWidth());
+ assert(r.ule(cmod) && "root must be less than cmod");
+
+ APInt a = r;
+ for (size_t k = 1; k < n; k++) {
+ if (a.isOne())
+ return false;
+ a = (a * r).urem(cmod);
+ }
+ return a.isOne();
+}
+
+static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
+ RankedTensorType tensorType) {
+ auto encoding = tensorType.getEncoding();
+ if (!encoding) {
+ return op->emitOpError()
+ << "a ring encoding was not provided to the tensor";
+ }
+ auto encodedRing = dyn_cast<RingAttr>(encoding);
+ if (!encodedRing) {
+ return op->emitOpError()
+ << "the provided tensor encoding is not a ring attribute";
+ }
+
+ if (encodedRing != ring) {
+ return op->emitOpError()
+ << "encoded ring type " << encodedRing
+ << " is not equivalent to the polynomial ring " << ring;
+ }
+
+ auto polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
+ auto tensorShape = tensorType.getShape();
+ bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
+ if (!compatible) {
+ InFlightDiagnostic diag = op->emitOpError()
+ << "tensor type " << tensorType
+ << " does not match output type " << ring;
+ diag.attachNote() << "the tensor must have shape [d] where d "
+ "is exactly the degree of the polynomialModulus of "
+ "the polynomial type's ring attribute";
+ return diag;
+ }
+
+ if (!ring.getPrimitiveRoot()) {
+ return op->emitOpError()
+ << "ring type " << ring << " does not provide a primitive root "
+ << "of unity, which is required to express an NTT";
+ }
+
+ if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
+ ring.getCoefficientModulus().getValue())) {
+ return op->emitOpError()
+ << "ring type " << ring << " has a primitiveRoot attribute '"
+ << ring.getPrimitiveRoot()
+ << "' that is not a primitive root of the coefficient ring";
+ }
+
+ return success();
+}
+
+LogicalResult NTTOp::verify() {
+ auto ring = getInput().getType().getRing();
+ auto tensorType = getOutput().getType();
+ return verifyNTTOp(this->getOperation(), ring, tensorType);
+}
+
+LogicalResult INTTOp::verify() {
+ auto tensorType = getInput().getType();
+ auto ring = getOutput().getType().getRing();
+ return verifyNTTOp(this->getOperation(), ring, tensorType);
+}
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index ea1b279fa1ff96..a29cfc2e9cc549 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -10,9 +10,13 @@
#one_plus_x_squared = #polynomial.polynomial<1 + x**2>
#ideal = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=18, polynomialModulus=#ideal>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
!poly_ty = !polynomial.polynomial<#ring>
+#ntt_poly = #polynomial.polynomial<-1 + x**8>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
+
module {
func.func @test_multiply() -> !polynomial.polynomial<#ring1> {
%c0 = arith.constant 0 : index
@@ -79,4 +83,14 @@ module {
%1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
return
}
+
+ func.func @test_ntt(%0 : !ntt_poly_ty) {
+ %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+ return
+ }
+
+ func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
+ %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
+ return
+ }
}
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index c34a7de30e5fe5..9029017256be3a 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -51,3 +51,90 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
%poly = polynomial.mul_scalar %arg0, %scalar : !ty, i32
return %poly : !ty
}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+ // expected-error at +1 {{a ring encoding was not provided to the tensor}}
+ %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+ // expected-error at +1 {{tensor encoding is not a ring attribute}}
+ %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
+ // expected-error at +1 {{not equivalent to the polynomial ring}}
+ %1 = polynomial.intt %0 : tensor<1024xi32, #ring1> -> !poly_ty
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
+ // expected-error at below {{does not match output type}}
+ // expected-note at below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
+ %1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+ // expected-error at +1 {{does not provide a primitive root of unity, which is required to express an NTT}}
+ %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**8>
+// A valid root is 31
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=32>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
+ // expected-error at below {{has a primitiveRoot attribute '32 : i16' that is not a primitive root of the coefficient ring}}
+ %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
+ return
+}
>From 2d0fad3b945da7d6b5f1226e7e0a7b635d1f6a4f Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sat, 4 May 2024 17:48:49 -0700
Subject: [PATCH 02/11] use ::mlir:: prefix
---
.../include/mlir/Dialect/Polynomial/IR/Polynomial.td | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index 93cf363c316c6f..ed1f4ce8b7e599 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -79,7 +79,7 @@ def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
#poly = #polynomial.polynomial<x**1024 + 1>
```
}];
- let parameters = (ins "Polynomial":$polynomial);
+ let parameters = (ins "::mlir::polynomial::Polynomial":$polynomial);
let hasCustomAssemblyFormat = 1;
}
@@ -122,16 +122,16 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
let parameters = (ins
"Type": $coefficientType,
- OptionalParameter<"IntegerAttr">: $coefficientModulus,
- OptionalParameter<"PolynomialAttr">: $polynomialModulus,
- OptionalParameter<"IntegerAttr">: $primitiveRoot
+ OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
+ OptionalParameter<"::mlir::polynomial::PolynomialAttr">: $polynomialModulus,
+ OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
);
let builders = [
AttrBuilder<
(ins "::mlir::Type":$coefficientTy,
- "IntegerAttr":$coefficientModulusAttr,
- "PolynomialAttr":$polynomialModulusAttr), [{
+ "::mlir::IntegerAttr":$coefficientModulusAttr,
+ "::mlir::polynomial::PolynomialAttr":$polynomialModulusAttr), [{
return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, nullptr);
}]>
];
>From afedacc74786525dc47324af4bcd6941b1ca67bd Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sat, 4 May 2024 18:01:28 -0700
Subject: [PATCH 03/11] fold condition into if
---
mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index 2c414acc85ec0c..488fd04fbd6cfd 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -204,9 +204,8 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
Polynomial poly = polyAttr.getPolynomial();
APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
- bool hasRoot = succeeded(parser.parseOptionalComma());
IntegerAttr rootAttr = nullptr;
- if (hasRoot) {
+ if (succeeded(parser.parseOptionalComma())) {
if (failed(parser.parseKeyword("primitiveRoot")))
return {};
>From b64988a2f431b62fa9456105a39aa5358f5d9af5 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sat, 4 May 2024 18:02:29 -0700
Subject: [PATCH 04/11] expand auto
---
mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index 488fd04fbd6cfd..45263a5e97e72d 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -212,7 +212,7 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseEqual()))
return {};
- auto result = parser.parseInteger(root);
+ ParseResult result = parser.parseInteger(root);
if (failed(result)) {
parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot");
return {};
>From d3fe6b82b3579cbe51889782444a58669756e910 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <kun.jeremy at gmail.com>
Date: Sat, 4 May 2024 18:03:51 -0700
Subject: [PATCH 05/11] Combine failed checks
Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>
---
mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index 45263a5e97e72d..236bb789663529 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -206,10 +206,8 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
IntegerAttr rootAttr = nullptr;
if (succeeded(parser.parseOptionalComma())) {
- if (failed(parser.parseKeyword("primitiveRoot")))
- return {};
-
- if (failed(parser.parseEqual()))
+ if (failed(parser.parseKeyword("primitiveRoot")) ||
+ failed(parser.parseEqual()))
return {};
ParseResult result = parser.parseInteger(root);
>From 3ea6743cc2986597f33be580a174ae80cf976be7 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <kun.jeremy at gmail.com>
Date: Sat, 4 May 2024 18:04:03 -0700
Subject: [PATCH 06/11] Fix docstring
Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>
---
mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 1d5c7be4b6752a..baa6615298cd64 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -105,7 +105,7 @@ LogicalResult MulScalarOp::verify() {
return success();
}
-// Test if a value is a primitive nth root of unity modulo cmod
+/// Test if a value is a primitive nth root of unity modulo cmod.
bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
const APInt &cmod) {
// root bitwidth may be 1 less then cmod
>From f42cae677f89c493dad7d8f9c9a18bc22aaf9146 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <kun.jeremy at gmail.com>
Date: Sat, 4 May 2024 18:04:17 -0700
Subject: [PATCH 07/11] Fix docstring
Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>
---
mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index baa6615298cd64..eaf9fd32b38ec7 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -108,7 +108,7 @@ LogicalResult MulScalarOp::verify() {
/// Test if a value is a primitive nth root of unity modulo cmod.
bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
const APInt &cmod) {
- // root bitwidth may be 1 less then cmod
+ // Root bitwidth may be 1 less then cmod.
APInt r = APInt(root).zext(cmod.getBitWidth());
assert(r.ule(cmod) && "root must be less than cmod");
>From 445d948f7839edb0f8f4ed4716a08c40f39f4c65 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sat, 4 May 2024 18:05:58 -0700
Subject: [PATCH 08/11] document top level fn
---
mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 2 ++
1 file changed, 2 insertions(+)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index eaf9fd32b38ec7..0626d39b90d830 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -121,6 +121,8 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
return a.isOne();
}
+/// Verify that the types involved in an NTT or INTT operation are
+/// compatible.
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
RankedTensorType tensorType) {
auto encoding = tensorType.getEncoding();
>From e0a06ba2aefbcd328b4040b184e6a9463aebb130 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sat, 4 May 2024 18:06:59 -0700
Subject: [PATCH 09/11] Expand more auto
---
mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 0626d39b90d830..cce3fdb704ace6 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -125,7 +125,7 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
/// compatible.
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
RankedTensorType tensorType) {
- auto encoding = tensorType.getEncoding();
+ Attribute encoding = tensorType.getEncoding();
if (!encoding) {
return op->emitOpError()
<< "a ring encoding was not provided to the tensor";
@@ -142,8 +142,8 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
<< " is not equivalent to the polynomial ring " << ring;
}
- auto polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
- auto tensorShape = tensorType.getShape();
+ unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
+ ArrayRef<int64_t> tensorShape = tensorType.getShape();
bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
if (!compatible) {
InFlightDiagnostic diag = op->emitOpError()
>From 394d604eb1c2fb03f04c95be6f239198b258ab33 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sat, 4 May 2024 18:08:42 -0700
Subject: [PATCH 10/11] correct error message
---
mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 2 +-
mlir/test/Dialect/Polynomial/ops_errors.mlir | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index cce3fdb704ace6..12010de3482377 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -128,7 +128,7 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
Attribute encoding = tensorType.getEncoding();
if (!encoding) {
return op->emitOpError()
- << "a ring encoding was not provided to the tensor";
+ << "expects a ring encoding to be provided to the tensor";
}
auto encodedRing = dyn_cast<RingAttr>(encoding);
if (!encodedRing) {
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index 9029017256be3a..f20595e2d538a8 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -61,7 +61,7 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
func.func @test_invalid_ntt(%0 : !poly_ty) {
- // expected-error at +1 {{a ring encoding was not provided to the tensor}}
+ // expected-error at +1 {{expects a ring encoding to be provided to the tensor}}
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
return
}
>From 4120f2f69ec7a3dd782b9735a98b437279104862 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sat, 4 May 2024 18:10:53 -0700
Subject: [PATCH 11/11] use below everywhere
---
mlir/test/Dialect/Polynomial/ops_errors.mlir | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index f20595e2d538a8..2c20e7bcbf1d69 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -61,7 +61,7 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
func.func @test_invalid_ntt(%0 : !poly_ty) {
- // expected-error at +1 {{expects a ring encoding to be provided to the tensor}}
+ // expected-error at below {{expects a ring encoding to be provided to the tensor}}
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
return
}
@@ -75,7 +75,7 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
func.func @test_invalid_ntt(%0 : !poly_ty) {
- // expected-error at +1 {{tensor encoding is not a ring attribute}}
+ // expected-error at below {{tensor encoding is not a ring attribute}}
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
return
}
@@ -90,7 +90,7 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
- // expected-error at +1 {{not equivalent to the polynomial ring}}
+ // expected-error at below {{not equivalent to the polynomial ring}}
%1 = polynomial.intt %0 : tensor<1024xi32, #ring1> -> !poly_ty
return
}
@@ -119,7 +119,7 @@ func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
func.func @test_invalid_ntt(%0 : !poly_ty) {
- // expected-error at +1 {{does not provide a primitive root of unity, which is required to express an NTT}}
+ // expected-error at below {{does not provide a primitive root of unity, which is required to express an NTT}}
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
return
}
More information about the Mlir-commits
mailing list