[Mlir-commits] [mlir] [mlir][polynomial] ensure primitive root calculation doesn't overflow (PR #93368)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 24 21:35:21 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jeremy Kun (j2kun)
<details>
<summary>Changes</summary>
Rebased over https://github.com/llvm/llvm-project/pull/93243
---
Patch is 24.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93368.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td (+11-4)
- (modified) mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td (+27-6)
- (modified) mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td (+22-20)
- (modified) mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp (+30-28)
- (modified) mlir/test/Dialect/Polynomial/canonicalization.mlir (+32-17)
- (modified) mlir/test/Dialect/Polynomial/ops.mlir (+14-4)
- (modified) mlir/test/Dialect/Polynomial/ops_errors.mlir (+12-26)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index f99cbccd243ec..755396c8b9023 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -277,7 +277,6 @@ def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
Polynomial_TypedIntPolynomialAttr
]>;
-// Not deriving from Polynomial_Op due to need for custom assembly format
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
[Pure, InferTypeOpAdaptor]> {
let summary = "Define a constant polynomial via an attribute.";
@@ -312,9 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
`f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
- The choice of primitive root is determined by subsequent lowerings.
+ The choice of primitive root may be optionally specified.
}];
- let arguments = (ins Polynomial_PolynomialType:$input);
+ let arguments = (ins
+ Polynomial_PolynomialType:$input,
+ OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
+ );
let results = (outs RankedTensorOf<[AnyInteger]>:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
let hasCanonicalizer = 1;
@@ -332,8 +334,13 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
output polynomial at powers of a primitive `n`-th root of unity (see
`polynomial.ntt`). The ring of the polynomial is taken from the required
encoding attribute of the tensor.
+
+ The choice of primitive root may be optionally specified.
}];
- let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
+ let arguments = (
+ ins RankedTensorOf<[AnyInteger]>:$input,
+ OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
+ );
let results = (outs Polynomial_PolynomialType:$output);
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
index 655020adf808b..2d3ed60a35fd9 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td
@@ -166,24 +166,45 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
let parameters = (ins
"Type": $coefficientType,
OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
- OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
- OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
+ OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus
);
let assemblyFormat = "`<` struct(params) `>`";
let builders = [
AttrBuilderWithInferredContext<
(ins "::mlir::Type":$coefficientTy,
CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
- CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
- CArg<"::mlir::IntegerAttr", "nullptr"> :$primitiveRootAttr), [{
+ CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{
return $_get(
coefficientTy.getContext(),
coefficientTy,
coefficientModulusAttr,
- polynomialModulusAttr,
- primitiveRootAttr);
+ polynomialModulusAttr);
}]>,
];
}
+def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
+ let summary = "an attribute containing an integer and its degree as a root of unity";
+ let description = [{
+ A primitive root attribute stores an integer root `value` and an integer
+ `degree`, corresponding to a primitive root of unity of the given degree in
+ an unspecified ring.
+
+ This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops
+ to specify the root of unity used in lowering the transform.
+
+ Example:
+
+ ```mlir
+ #poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
+ ```
+ }];
+ let parameters = (ins
+ "::mlir::IntegerAttr":$value,
+ "::mlir::IntegerAttr":$degree
+ );
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+
#endif // POLYNOMIAL_ATTRIBUTES
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
index e37bcf76a20f2..93ea6e4e43698 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td
@@ -17,6 +17,8 @@ include "mlir/IR/PatternBase.td"
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
+def Equal : Constraint<CPred<"$0 == $1">>;
+
// Get a -1 integer attribute of the same type as the polynomial SSA value's
// ring coefficient type.
def getMinusOne
@@ -31,15 +33,15 @@ def SubAsAdd : Pat<
(Arith_ConstantOp (getMinusOne $g))))>;
def INTTAfterNTT : Pat<
- (Polynomial_INTTOp (Polynomial_NTTOp $poly)),
+ (Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
(replaceWithValue $poly),
- []
+ [(Equal $r1, $r2)]
>;
def NTTAfterINTT : Pat<
- (Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
+ (Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
(replaceWithValue $tensor),
- []
+ [(Equal $r1, $r2)]
>;
// NTTs are expensive, and addition in coefficient or NTT domain should be
@@ -47,35 +49,35 @@ def NTTAfterINTT : Pat<
// ntt(a) + ntt(b) -> ntt(a + b)
def NTTOfAdd : Pat<
(Arith_AddIOp
- (Polynomial_NTTOp $p1),
- (Polynomial_NTTOp $p2),
+ (Polynomial_NTTOp $p1, $r1),
+ (Polynomial_NTTOp $p2, $r2),
$overflow),
- (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2)),
- []
+ (Polynomial_NTTOp (Polynomial_AddOp $p1, $p2), $r1),
+ [(Equal $r1, $r2)]
>;
// intt(a) + intt(b) -> intt(a + b)
def INTTOfAdd : Pat<
(Polynomial_AddOp
- (Polynomial_INTTOp $t1),
- (Polynomial_INTTOp $t2)),
- (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow)),
- []
+ (Polynomial_INTTOp $t1, $r1),
+ (Polynomial_INTTOp $t2, $r2)),
+ (Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow), $r1),
+ [(Equal $r1, $r2)]
>;
// repeated for sub
def NTTOfSub : Pat<
(Arith_SubIOp
- (Polynomial_NTTOp $p1),
- (Polynomial_NTTOp $p2),
+ (Polynomial_NTTOp $p1, $r1),
+ (Polynomial_NTTOp $p2, $r2),
$overflow),
- (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2)),
- []
+ (Polynomial_NTTOp (Polynomial_SubOp $p1, $p2), $r1),
+ [(Equal $r1, $r2)]
>;
def INTTOfSub : Pat<
(Polynomial_SubOp
- (Polynomial_INTTOp $t1),
- (Polynomial_INTTOp $t2)),
- (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow)),
- []
+ (Polynomial_INTTOp $t1, $r1),
+ (Polynomial_INTTOp $t2, $r2)),
+ (Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow), $r1),
+ [(Equal $r1, $r2)]
>;
#endif // POLYNOMIAL_CANONICALIZATION
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 3d302797ce513..b3bb15b6a50ee 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APInt.h"
+#include <iostream>
using namespace mlir;
using namespace mlir::polynomial;
@@ -47,11 +48,8 @@ LogicalResult FromTensorOp::verify() {
return diag;
}
- APInt coefficientModulus = ring.getCoefficientModulus().getValue();
- unsigned cmodBitWidth = coefficientModulus.ceilLogBase2();
unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
-
- if (inputBitWidth > cmodBitWidth) {
+ if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) {
InFlightDiagnostic diag = emitOpError()
<< "input tensor element type "
<< getInput().getType().getElementType()
@@ -108,17 +106,23 @@ LogicalResult MulScalarOp::verify() {
}
/// Test if a value is a primitive nth root of unity modulo cmod.
-bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
+bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
const APInt &cmod) {
+ // The first or subsequent multiplications, may overflow the input bit width,
+ // so scale them up to ensure they do not overflow.
+ unsigned requiredBitWidth =
+ std::max(root.getActiveBits() * 2, cmod.getBitWidth() * 2);
// Root bitwidth may be 1 less then cmod.
- APInt r = APInt(root).zext(cmod.getBitWidth());
- assert(r.ule(cmod) && "root must be less than cmod");
+ APInt r = APInt(root).zext(requiredBitWidth);
+ APInt cmodExt = APInt(cmod).zext(requiredBitWidth);
+ assert(r.ule(cmodExt) && "root must be less than cmod");
+ uint64_t upperBound = n.getZExtValue();
APInt a = r;
- for (size_t k = 1; k < n; k++) {
+ for (size_t k = 1; k < upperBound; k++) {
if (a.isOne())
return false;
- a = (a * r).urem(cmod);
+ a = (a * r).urem(cmodExt);
}
return a.isOne();
}
@@ -126,7 +130,8 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
/// Verify that the types involved in an NTT or INTT operation are
/// compatible.
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
- RankedTensorType tensorType) {
+ RankedTensorType tensorType,
+ std::optional<PrimitiveRootAttr> root) {
Attribute encoding = tensorType.getEncoding();
if (!encoding) {
return op->emitOpError()
@@ -157,33 +162,30 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
return diag;
}
- if (!ring.getPrimitiveRoot()) {
- return op->emitOpError()
- << "ring type " << ring << " does not provide a primitive root "
- << "of unity, which is required to express an NTT";
- }
-
- if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
- ring.getCoefficientModulus().getValue())) {
- return op->emitOpError()
- << "ring type " << ring << " has a primitiveRoot attribute '"
- << ring.getPrimitiveRoot()
- << "' that is not a primitive root of the coefficient ring";
+ if (root.has_value()) {
+ APInt rootValue = root.value().getValue().getValue();
+ APInt rootDegree = root.value().getDegree().getValue();
+ APInt cmod = ring.getCoefficientModulus().getValue();
+ if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
+ return op->emitOpError()
+ << "provided root " << rootValue.getZExtValue()
+ << " is not a primitive root "
+ << "of unity mod " << cmod.getZExtValue()
+ << ", with the specified degree " << rootDegree.getZExtValue();
+ }
}
return success();
}
LogicalResult NTTOp::verify() {
- auto ring = getInput().getType().getRing();
- auto tensorType = getOutput().getType();
- return verifyNTTOp(this->getOperation(), ring, tensorType);
+ return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
+ getOutput().getType(), getRoot());
}
LogicalResult INTTOp::verify() {
- auto tensorType = getInput().getType();
- auto ring = getOutput().getType().getRing();
- return verifyNTTOp(this->getOperation(), ring, tensorType);
+ return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
+ getInput().getType(), getRoot());
}
ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
diff --git a/mlir/test/Dialect/Polynomial/canonicalization.mlir b/mlir/test/Dialect/Polynomial/canonicalization.mlir
index 489d9ec2720d6..b79938627e415 100644
--- a/mlir/test/Dialect/Polynomial/canonicalization.mlir
+++ b/mlir/test/Dialect/Polynomial/canonicalization.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt -canonicalize %s | FileCheck %s
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
+#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
!tensor_ty = tensor<8xi32, #ntt_ring>
@@ -10,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty
// CHECK-NOT: polynomial.ntt
// CHECK-NOT: polynomial.intt
// CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]] : [[T]]
- %t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
- %p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty
+ %t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+ %p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
%p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
// CHECK: return %[[RESULT]] : [[T]]
return %p2 : !ntt_poly_ty
@@ -23,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
// CHECK-NOT: polynomial.intt
// CHECK-NOT: polynomial.ntt
// CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
- %p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
- %t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
+ %p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+ %t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
%t2 = arith.addi %t1, %t1 : !tensor_ty
// CHECK: return %[[RESULT]] : [[T]]
return %t2 : !tensor_ty
@@ -51,10 +52,10 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
func.func @test_canonicalize_fold_add_through_ntt(
%poly0 : !ntt_poly_ty,
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
- %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
- %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+ %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+ %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
%a_plus_b = arith.addi %0, %1 : !tensor_ty
- %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+ %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
return %out : !ntt_poly_ty
}
@@ -65,10 +66,10 @@ func.func @test_canonicalize_fold_add_through_ntt(
func.func @test_canonicalize_fold_add_through_intt(
%tensor0 : !tensor_ty,
%tensor1 : !tensor_ty) -> !tensor_ty {
- %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
- %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+ %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+ %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
%a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
- %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+ %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
return %out : !tensor_ty
}
@@ -80,10 +81,10 @@ func.func @test_canonicalize_fold_add_through_intt(
func.func @test_canonicalize_fold_sub_through_ntt(
%poly0 : !ntt_poly_ty,
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
- %0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
- %1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
+ %0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
+ %1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
%a_plus_b = arith.subi %0, %1 : !tensor_ty
- %out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
+ %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
return %out : !ntt_poly_ty
}
@@ -94,9 +95,23 @@ func.func @test_canonicalize_fold_sub_through_ntt(
func.func @test_canonicalize_fold_sub_through_intt(
%tensor0 : !tensor_ty,
%tensor1 : !tensor_ty) -> !tensor_ty {
- %0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
- %1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
+ %0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
+ %1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
%a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
- %out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
+ %out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
return %out : !tensor_ty
}
+
+
+// CHECK-LABEL: test_canonicalize_do_not_fold_different_roots
+// CHECK: arith.addi
+func.func @test_canonicalize_do_not_fold_different_roots(
+ %poly0 : !ntt_poly_ty,
+ %poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
+ %0 = polynomial.ntt %poly0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
+ %1 = polynomial.ntt %poly1 {root=#polynomial.primitive_root<value=33:i32, degree=8:index>} : !ntt_poly_ty -> !tensor_ty
+ %a_plus_b = arith.addi %0, %1 : !tensor_ty
+ %out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
+ return %out : !ntt_poly_ty
+}
+
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index 4716e37ff8852..282c631805a4e 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -11,13 +11,18 @@
#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
#ideal = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal>
!poly_ty = !polynomial.polynomial<ring=#ring>
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
-#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
+#ntt_poly_2 = #polynomial.int_polynomial<1 + x**65536>
+#ntt_ring_2 = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#ntt_poly_2>
+#ntt_ring_2_root = #polynomial.primitive_root<value=283965:i32, degree=131072:i32>
+!ntt_poly_ty_2 = !polynomial.polynomial<ring=#ring>
+
module {
func.func @test_multiply() -> !polynomial.polynomial<ring=#ring1> {
%c0 = arith.constant 0 : index
@@ -91,12 +96,17 @@ module {
}
func.func @test_ntt(%0 : !ntt_poly_ty) {
- %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+ %1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+ return
+ }
+
+ func.func @test_ntt_with_overflowing_root(%0 : !ntt_poly_ty_2) {
+ %1 = polynomial.ntt %0 {root=#ntt_ring_2_root} : !ntt_poly_ty_2 -> tensor<8xi32, #ntt_ring_2>
return
}
func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
- %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
+ %1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
return
}
}
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index af8e4aa5da862..4937e17027afa 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt --split-input-file --verify-diagnostics %s
#my_poly = #polynomial.int_polynomial<1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
+#ring = #polynomial.ring<coefficientType=i16>
!ty = !polynomial.polynomial<ring=#ring>
func.func @test_from_tensor_too_large_coeffs() {
@@ -55,28 +55,28 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
// -----
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
!poly_ty = !polynomial.polynomial<ring=#ring>
// CHECK-NOT: @test_invalid_ntt
// CHECK-NOT: polynomial.ntt
func.func @test_invalid_ntt(%0 : !poly_ty) {
// expected-error at below...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/93368
More information about the Mlir-commits
mailing list