[Mlir-commits] [mlir] [polynomial] Move primitive root attribute to ntt/intt ops. (PR #93227)
Jeremy Kun
llvmlistbot at llvm.org
Wed May 29 20:17:50 PDT 2024
https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/93227
>From 9978efaa5bd986ecda4d5083dfe630c773c11525 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 23 May 2024 11:20:49 -0700
Subject: [PATCH 1/4] move primitive root attr to ntt/intt ops
---
.../mlir/Dialect/Polynomial/IR/Polynomial.td | 15 +++++--
.../Polynomial/IR/PolynomialAttributes.td | 33 ++++++++++++---
.../IR/PolynomialCanonicalization.td | 42 ++++++++++---------
.../Dialect/Polynomial/IR/PolynomialOps.cpp | 40 +++++++++---------
4 files changed, 79 insertions(+), 51 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index f99cbccd243ec..755396c8b9023 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -277,7 +277,6 @@ def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
Polynomial_TypedIntPolynomialAttr
]>;
-// Not deriving from Polynomial_Op due to need for custom assembly format
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
[Pure, InferTypeOpAdaptor]> {
let summary = "Define a constant polynomial via an attribute.";
@@ -312,9 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
`f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
- The choice of primitive root is determined by subsequent lowerings.
+ The choice of primitive root may be optionally specified.
}];
- let arguments = (ins Polynomial_PolynomialType:$input);
+ let arguments = (ins
+ Polynomial_PolynomialType:$input,
+ OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
+ );
let results = (outs RankedTensorOf<[AnyInteger]>:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
let hasCanonicalizer = 1;
@@ -332,8 +334,13 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
output polynomial at powers of a primitive `n`-th root of unity (see
`polynomial.ntt`). The ring of the polynomial is taken from the required
encoding attribute of the tensor.
+
+ The choice of primitive root may be optionally specified.
}];
- let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
+ let arguments = (
+ ins RankedTensorOf<[AnyInteger]>:$input,
+ OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
+ );
let results = (outs Polynomial_PolynomialType:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
index 655020adf808b..2d3ed60a35fd9 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
@@ -166,24 +166,45 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
let parameters = (ins
"Type": $coefficientType,
OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
- OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
- OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
+ OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus
);
let assemblyFormat = "`<` struct(params) `>`";
let builders = [
AttrBuilderWithInferredContext<
(ins "::mlir::Type":$coefficientTy,
CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
- CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
- CArg<"::mlir::IntegerAttr", "nullptr"> :$primitiveRootAttr), [{
+ CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{
return $_get(
coefficientTy.getContext(),
coefficientTy,
coefficientModulusAttr,
- polynomialModulusAttr,
- primitiveRootAttr);
+ polynomialModulusAttr);
}]>,
];
}
+def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
+ let summary = "an attribute containing an integer and its degree as a root of unity";
+ let description = [{
+ A primitive root attribute stores an integer root `value` and an integer
+ `degree`, corresponding to a primitive root of unity of the given degree in
+ an unspecified ring.
+
+ This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops
+ to specify the root of unity used in lowering the transform.
+
+ Example:
+
+ ```mlir
+ #poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
+ ```
+ }];
+ let parameters = (ins
+ "::mlir::IntegerAttr":$value,
+ "::mlir::IntegerAttr":$degree
+ );
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+
#endif // POLYNOMIAL_ATTRIBUTES
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index e37bcf76a20f2..93ea6e4e43698 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -17,6 +17,8 @@ include "mlir/IR/PatternBase.td"
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
+def Equal : Constraint<CPred<"$0 == $1">>;
+
// Get a -1 integer attribute of the same type as the polynomial SSA value's
// ring coefficient type.
def getMinusOne
@@ -31,15 +33,15 @@ def SubAsAdd : Pat<
(Arith_ConstantOp (getMinusOne $g))))>;
def INTTAfterNTT : Pat<
- (Polynomial_INTTOp (Polynomial_NTTOp $poly)),
+ (Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
(replaceWithValue $poly),
- []
+ [(Equal $r1, $r2)]
>;
def NTTAfterINTT : Pat<
- (Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
+ (Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
(replaceWithValue $tensor),
- []
+ [(Equal $r1, $r2)]
>;
// NTTs are expensive, and addition in coefficient or NTT domain should be
@@ -47,35 +49,35 @@ def NTTAfterINTT : Pat<
// ntt(a) + ntt(b) -> ntt(a + b)
def NTTOfAdd : Pat<
(Arith_AddIOp
- (Polynomial_NTTOp $p1),
- (Polynomial_NTTOp $p2),
+ (Polynomial_NTTOp $p1, $r1),
+ (Polynomial_NTTOp $p2, $r2),
$overflow),
- (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2)),
- []
+ (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2), $r1),
+ [(Equal $r1, $r2)]
>;
// intt(a) + intt(b) -> intt(a + b)
def INTTOfAdd : Pat<
(Polynomial_AddOp
- (Polynomial_INTTOp $t1),
- (Polynomial_INTTOp $t2)),
- (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow)),
- []
+ (Polynomial_INTTOp $t1, $r1),
+ (Polynomial_INTTOp $t2, $r2)),
+ (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow), $r1),
+ [(Equal $r1, $r2)]
>;
// repeated for sub
def NTTOfSub : Pat<
(Arith_SubIOp
- (Polynomial_NTTOp $p1),
- (Polynomial_NTTOp $p2),
+ (Polynomial_NTTOp $p1, $r1),
+ (Polynomial_NTTOp $p2, $r2),
$overflow),
- (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2)),
- []
+ (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2), $r1),
+ [(Equal $r1, $r2)]
>;
def INTTOfSub : Pat<
(Polynomial_SubOp
- (Polynomial_INTTOp $t1),
- (Polynomial_INTTOp $t2)),
- (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow)),
- []
+ (Polynomial_INTTOp $t1, $r1),
+ (Polynomial_INTTOp $t2, $r2)),
+ (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow), $r1),
+ [(Equal $r1, $r2)]
>;
#endif // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 3d302797ce513..a39c3872ad6d5 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -108,14 +108,15 @@ LogicalResult MulScalarOp::verify() {
}
/// Test if a value is a primitive nth root of unity modulo cmod.
-bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
+bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
const APInt &cmod) {
// Root bitwidth may be 1 less then cmod.
APInt r = APInt(root).zext(cmod.getBitWidth());
assert(r.ule(cmod) && "root must be less than cmod");
+ unsigned upperBound = n.getZExtValue();
APInt a = r;
- for (size_t k = 1; k < n; k++) {
+ for (size_t k = 1; k < upperBound; k++) {
if (a.isOne())
return false;
a = (a * r).urem(cmod);
@@ -126,7 +127,8 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
/// Verify that the types involved in an NTT or INTT operation are
/// compatible.
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
- RankedTensorType tensorType) {
+ RankedTensorType tensorType,
+ std::optional<PrimitiveRootAttr> root) {
Attribute encoding = tensorType.getEncoding();
if (!encoding) {
return op->emitOpError()
@@ -157,33 +159,29 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
return diag;
}
- if (!ring.getPrimitiveRoot()) {
- return op->emitOpError()
- << "ring type " << ring << " does not provide a primitive root "
- << "of unity, which is required to express an NTT";
- }
-
- if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
- ring.getCoefficientModulus().getValue())) {
- return op->emitOpError()
- << "ring type " << ring << " has a primitiveRoot attribute '"
- << ring.getPrimitiveRoot()
- << "' that is not a primitive root of the coefficient ring";
+ if (root.has_value()) {
+ APInt rootValue = root.value().getValue().getValue();
+ APInt rootDegree = root.value().getDegree().getValue();
+ APInt cmod = ring.getCoefficientModulus().getValue();
+ if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
+ return op->emitOpError()
+ << "provided root " << rootValue.getZExtValue() << " is not a primitive root "
+ << "of unity mod " << cmod.getZExtValue() << ", with the specified degree "
+ << rootDegree.getZExtValue();
+ }
}
return success();
}
LogicalResult NTTOp::verify() {
- auto ring = getInput().getType().getRing();
- auto tensorType = getOutput().getType();
- return verifyNTTOp(this->getOperation(), ring, tensorType);
+ return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
+ getOutput().getType(), getRoot());
}
LogicalResult INTTOp::verify() {
- auto tensorType = getInput().getType();
- auto ring = getOutput().getType().getRing();
- return verifyNTTOp(this->getOperation(), ring, tensorType);
+ return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
+ getInput().getType(), getRoot());
}
ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
>From a46ef26656ae5e57382eb91eeb6d2843c5770a20 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 23 May 2024 11:29:56 -0700
Subject: [PATCH 2/4] update tests
---
.../Dialect/Polynomial/canonicalization.mlir | 35 +++++++++---------
mlir/test/Dialect/Polynomial/ops.mlir | 8 ++---
mlir/test/Dialect/Polynomial/ops_errors.mlir | 36 ++++++-------------
3 files changed, 33 insertions(+), 46 deletions(-)
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index 489d9ec2720d6..354b76e3d9669 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt -canonicalize %s | FileCheck %s
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
+#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
!tensor_ty = tensor<8xi32, #ntt_ring>
@@ -10,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty
// CHECK-NOT: polynomial.ntt
// CHECK-NOT: polynomial.intt
// CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]] : [[T]]
- %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
- %p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty
+ %t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+ %p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
%p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
// CHECK: return %[[RESULT]] : [[T]]
return %p2 : !ntt_poly_ty
@@ -23,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
// CHECK-NOT: polynomial.intt
// CHECK-NOT: polynomial.ntt
// CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
- %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
- %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
+ %p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+ %t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
%t2 = arith.addi %t1, %t1 : !tensor_ty
// CHECK: return %[[RESULT]] : [[T]]
return %t2 : !tensor_ty
@@ -51,10 +52,10 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
func.func @test_canonicalize_fold_add_through_ntt(
%poly0 : !ntt_poly_ty,
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
- %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
- %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+ %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+ %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
%a_plus_b = arith.addi %0, %1 : !tensor_ty
- %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+ %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
return %out : !ntt_poly_ty
}
@@ -65,10 +66,10 @@ func.func @test_canonicalize_fold_add_through_ntt(
func.func @test_canonicalize_fold_add_through_intt(
%tensor0 : !tensor_ty,
%tensor1 : !tensor_ty) -> !tensor_ty {
- %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
- %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+ %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+ %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
%a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
- %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+ %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
return %out : !tensor_ty
}
@@ -80,10 +81,10 @@ func.func @test_canonicalize_fold_add_through_intt(
func.func @test_canonicalize_fold_sub_through_ntt(
%poly0 : !ntt_poly_ty,
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
- %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
- %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+ %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+ %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
%a_plus_b = arith.subi %0, %1 : !tensor_ty
- %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+ %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
return %out : !ntt_poly_ty
}
@@ -94,9 +95,9 @@ func.func @test_canonicalize_fold_sub_through_ntt(
func.func @test_canonicalize_fold_sub_through_intt(
%tensor0 : !tensor_ty,
%tensor1 : !tensor_ty) -> !tensor_ty {
- %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
- %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+ %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+ %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
%a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
- %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+ %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
return %out : !tensor_ty
}
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index 4716e37ff8852..8c134ab789d60 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -11,11 +11,11 @@
#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
#ideal = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal>
!poly_ty = !polynomial.polynomial<ring=#ring>
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
module {
@@ -91,12 +91,12 @@ module {
}
func.func @test_ntt(%0 : !ntt_poly_ty) {
- %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+ %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
return
}
func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
- %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
+ %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
return
}
}
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index af8e4aa5da862..f22b14897e98a 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -55,28 +55,28 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
// -----
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
func.func @test_invalid_ntt(%0 : !poly_ty) {
// expected-error at below {{expects a ring encoding to be provided to the tensor}}
- %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
+ %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32>
return
}
// -----
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
func.func @test_invalid_ntt(%0 : !poly_ty) {
// expected-error at below {{tensor encoding is not a ring attribute}}
- %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
+ %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32, #my_poly>
return
}
@@ -84,21 +84,21 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly>
!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
// expected-error at below {{not equivalent to the polynomial ring}}
- %1 = polynomial.intt %0 : tensor<1024xi32, #ring1> -> !poly_ty
+ %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1024xi32, #ring1> -> !poly_ty
return
}
// -----
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_intt
@@ -106,21 +106,7 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
// expected-error at below {{does not match output type}}
// expected-note at below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
- %1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
- return
-}
-
-// -----
-
-#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
-!poly_ty = !polynomial.polynomial<ring=#ring>
-
-// CHECK-NOT: @test_invalid_ntt
-// CHECK-NOT: polynomial.ntt
-func.func @test_invalid_ntt(%0 : !poly_ty) {
- // expected-error at below {{does not provide a primitive root of unity, which is required to express an NTT}}
- %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
+ %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1025xi32, #ring> -> !poly_ty
return
}
@@ -128,13 +114,13 @@ func.func @test_invalid_ntt(%0 : !poly_ty) {
#my_poly = #polynomial.int_polynomial<-1 + x**8>
// A valid root is 31
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=32:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_intt
// CHECK-NOT: polynomial.intt
func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
- // expected-error at below {{has a primitiveRoot attribute '32 : i16' that is not a primitive root of the coefficient ring}}
- %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
+ // expected-error at below {{provided root 32 is not a primitive root of unity mod 256, with the specified degree 8}}
+ %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=32:i16, degree=8:index>} : tensor<8xi32, #ring> -> !poly_ty
return
}
>From 77c140a83e8232009c37a6aa3c255218607e056c Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 23 May 2024 11:42:39 -0700
Subject: [PATCH 3/4] add test for root equality
---
mlir/test/Dialect/Polynomial/canonicalization.mlir | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index 354b76e3d9669..b79938627e415 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -101,3 +101,17 @@ func.func @test_canonicalize_fold_sub_through_intt(
%out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
return %out : !tensor_ty
}
+
+
+// CHECK-LABEL: test_canonicalize_do_not_fold_different_roots
+// CHECK: arith.addi
+func.func @test_canonicalize_do_not_fold_different_roots(
+ %poly0 : !ntt_poly_ty,
+ %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+ %0 = polynomial.ntt %poly0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
+ %1 = polynomial.ntt %poly1 {root=#polynomial.primitive_root<value=33:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
+ %a_plus_b = arith.addi %0, %1 : !tensor_ty
+ %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
+ return %out : !ntt_poly_ty
+}
+
>From 5265598a3be501dbe122646935f3a13fba37d880 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 23 May 2024 12:33:55 -0700
Subject: [PATCH 4/4] clang format
---
mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index a39c3872ad6d5..76e837404c6b5 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -165,9 +165,10 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
APInt cmod = ring.getCoefficientModulus().getValue();
if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
return op->emitOpError()
- << "provided root " << rootValue.getZExtValue() << " is not a primitive root "
- << "of unity mod " << cmod.getZExtValue() << ", with the specified degree "
- << rootDegree.getZExtValue();
+ << "provided root " << rootValue.getZExtValue()
+ << " is not a primitive root "
+ << "of unity mod " << cmod.getZExtValue()
+ << ", with the specified degree " << rootDegree.getZExtValue();
}
}
More information about the Mlir-commits
mailing list