[Mlir-commits] [mlir] 692ae54 - [mlir][polynomial] verify from_tensor coeff type (#93243)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 30 11:03:40 PDT 2024


Author: Jeremy Kun
Date: 2024-05-30T11:03:36-07:00
New Revision: 692ae5443b1778e138527ef55d799a4b535a36f9

URL: https://github.com/llvm/llvm-project/commit/692ae5443b1778e138527ef55d799a4b535a36f9
DIFF: https://github.com/llvm/llvm-project/commit/692ae5443b1778e138527ef55d799a4b535a36f9.diff

LOG: [mlir][polynomial] verify from_tensor coeff type  (#93243)

Rebased over https://github.com/llvm/llvm-project/pull/93227

---------

Co-authored-by: Jeremy Kun <j2kun at users.noreply.github.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
    mlir/test/Dialect/Polynomial/ops_errors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 76e837404c6b5..3117721a94152 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -34,24 +34,25 @@ void FromTensorOp::build(OpBuilder &builder, OperationState &result,
 LogicalResult FromTensorOp::verify() {
   ArrayRef<int64_t> tensorShape = getInput().getType().getShape();
   RingAttr ring = getOutput().getType().getRing();
-  unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
-  bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
-  if (!compatible) {
-    InFlightDiagnostic diag = emitOpError()
-                              << "input type " << getInput().getType()
-                              << " does not match output type "
-                              << getOutput().getType();
-    diag.attachNote() << "the input type must be a tensor of shape [d] where d "
-                         "is at most the degree of the polynomialModulus of "
-                         "the output type's ring attribute";
-    return diag;
+  IntPolynomialAttr polyMod = ring.getPolynomialModulus();
+  if (polyMod) {
+    unsigned polyDegree = polyMod.getPolynomial().getDegree();
+    bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
+    if (!compatible) {
+      InFlightDiagnostic diag = emitOpError()
+                                << "input type " << getInput().getType()
+                                << " does not match output type "
+                                << getOutput().getType();
+      diag.attachNote()
+          << "the input type must be a tensor of shape [d] where d "
+             "is at most the degree of the polynomialModulus of "
+             "the output type's ring attribute";
+      return diag;
+    }
   }
 
-  APInt coefficientModulus = ring.getCoefficientModulus().getValue();
-  unsigned cmodBitWidth = coefficientModulus.ceilLogBase2();
   unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
-
-  if (inputBitWidth > cmodBitWidth) {
+  if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) {
     InFlightDiagnostic diag = emitOpError()
                               << "input tensor element type "
                               << getInput().getType().getElementType()
@@ -67,24 +68,27 @@ LogicalResult FromTensorOp::verify() {
 
 LogicalResult ToTensorOp::verify() {
   ArrayRef<int64_t> tensorShape = getOutput().getType().getShape();
-  unsigned polyDegree = getInput()
-                            .getType()
-                            .getRing()
-                            .getPolynomialModulus()
-                            .getPolynomial()
-                            .getDegree();
-  bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
+  IntPolynomialAttr polyMod =
+      getInput().getType().getRing().getPolynomialModulus();
+  if (polyMod) {
+    unsigned polyDegree = polyMod.getPolynomial().getDegree();
+    bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
 
-  if (compatible)
-    return success();
+    if (compatible)
+      return success();
+
+    InFlightDiagnostic diag = emitOpError()
+                              << "input type " << getInput().getType()
+                              << " does not match output type "
+                              << getOutput().getType();
+    diag.attachNote()
+        << "the output type must be a tensor of shape [d] where d "
+           "is at most the degree of the polynomialModulus of "
+           "the input type's ring attribute";
+    return diag;
+  }
 
-  InFlightDiagnostic diag =
-      emitOpError() << "input type " << getInput().getType()
-                    << " does not match output type " << getOutput().getType();
-  diag.attachNote() << "the output type must be a tensor of shape [d] where d "
-                       "is at most the degree of the polynomialModulus of "
-                       "the input type's ring attribute";
-  return diag;
+  return success();
 }
 
 LogicalResult MulScalarOp::verify() {

diff  --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index f22b14897e98a..4937e17027afa 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt --split-input-file --verify-diagnostics %s
 
 #my_poly = #polynomial.int_polynomial<1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
+#ring = #polynomial.ring<coefficientType=i16>
 !ty = !polynomial.polynomial<ring=#ring>
 
 func.func @test_from_tensor_too_large_coeffs() {


        


More information about the Mlir-commits mailing list