[Mlir-commits] [mlir] [mlir][polynomial] verify from_tensor coeff type (PR #93243)
Jeremy Kun
llvmlistbot at llvm.org
Thu May 30 10:57:57 PDT 2024
https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/93243
>From 10615775033226a738ff2f2eec0ae1a77e8df1e6 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 23 May 2024 14:27:33 -0700
Subject: [PATCH 1/2] [mlir][polynomial] verify from_tensor coeff type
Use the coefficient type to verify if a tensor fits in a polynomial
ring. Downstream we had originally not specified the coefficient type
and just used the implied bit width of the coefficient modulus. Now
that we specify the type, this is simpler.
---
mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 5 +----
mlir/test/Dialect/Polynomial/ops_errors.mlir | 2 +-
2 files changed, 2 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 76e837404c6b5..3719979177215 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -47,11 +47,8 @@ LogicalResult FromTensorOp::verify() {
return diag;
}
- APInt coefficientModulus = ring.getCoefficientModulus().getValue();
- unsigned cmodBitWidth = coefficientModulus.ceilLogBase2();
unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
-
- if (inputBitWidth > cmodBitWidth) {
+ if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) {
InFlightDiagnostic diag = emitOpError()
<< "input tensor element type "
<< getInput().getType().getElementType()
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index f22b14897e98a..4937e17027afa 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt --split-input-file --verify-diagnostics %s
#my_poly = #polynomial.int_polynomial<1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
+#ring = #polynomial.ring<coefficientType=i16>
!ty = !polynomial.polynomial<ring=#ring>
func.func @test_from_tensor_too_large_coeffs() {
>From 3cf04e4cc22debffa4e83e0e7427dd46c64ff54c Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Thu, 30 May 2024 10:57:44 -0700
Subject: [PATCH 2/2] make verifiers resilient to missing optional attrs
---
.../Dialect/Polynomial/IR/PolynomialOps.cpp | 67 ++++++++++---------
1 file changed, 37 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 3719979177215..ea83db4fdde9b 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -34,17 +34,21 @@ void FromTensorOp::build(OpBuilder &builder, OperationState &result,
LogicalResult FromTensorOp::verify() {
ArrayRef<int64_t> tensorShape = getInput().getType().getShape();
RingAttr ring = getOutput().getType().getRing();
- unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
- bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
- if (!compatible) {
- InFlightDiagnostic diag = emitOpError()
- << "input type " << getInput().getType()
- << " does not match output type "
- << getOutput().getType();
- diag.attachNote() << "the input type must be a tensor of shape [d] where d "
- "is at most the degree of the polynomialModulus of "
- "the output type's ring attribute";
- return diag;
+ IntPolynomialAttr polyMod = ring.getPolynomialModulus();
+ if (polyMod) {
+ unsigned polyDegree = polyMod.getPolynomial().getDegree();
+ bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
+ if (!compatible) {
+ InFlightDiagnostic diag = emitOpError()
+ << "input type " << getInput().getType()
+ << " does not match output type "
+ << getOutput().getType();
+ diag.attachNote()
+ << "the input type must be a tensor of shape [d] where d "
+ "is at most the degree of the polynomialModulus of "
+ "the output type's ring attribute";
+ return diag;
+ }
}
unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
@@ -64,24 +68,27 @@ LogicalResult FromTensorOp::verify() {
LogicalResult ToTensorOp::verify() {
ArrayRef<int64_t> tensorShape = getOutput().getType().getShape();
- unsigned polyDegree = getInput()
- .getType()
- .getRing()
- .getPolynomialModulus()
- .getPolynomial()
- .getDegree();
- bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
+ IntPolynomialAttr polyMod =
+ getInput().getType().getRing().getPolynomialModulus();
+ if (polyMod) {
+ unsigned polyDegree = polyMod.getPolynomial().getDegree();
+ bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
- if (compatible)
- return success();
+ if (compatible)
+ return success();
- InFlightDiagnostic diag =
- emitOpError() << "input type " << getInput().getType()
- << " does not match output type " << getOutput().getType();
- diag.attachNote() << "the output type must be a tensor of shape [d] where d "
- "is at most the degree of the polynomialModulus of "
- "the input type's ring attribute";
- return diag;
+ InFlightDiagnostic diag = emitOpError()
+ << "input type " << getInput().getType()
+ << " does not match output type "
+ << getOutput().getType();
+ diag.attachNote()
+ << "the output type must be a tensor of shape [d] where d "
+ "is at most the degree of the polynomialModulus of "
+ "the input type's ring attribute";
+ return diag;
+ }
+
+ return success();
}
LogicalResult MulScalarOp::verify() {
@@ -163,9 +170,9 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
return op->emitOpError()
<< "provided root " << rootValue.getZExtValue()
- << " is not a primitive root "
- << "of unity mod " << cmod.getZExtValue()
- << ", with the specified degree " << rootDegree.getZExtValue();
+ << " is not a primitive root " << "of unity mod "
+ << cmod.getZExtValue() << ", with the specified degree "
+ << rootDegree.getZExtValue();
}
}
More information about the Mlir-commits
mailing list