[Mlir-commits] [mlir] [mlir][polynomial] verify from_tensor coeff type (PR #93243)

Jeremy Kun llvmlistbot at llvm.org
Thu May 30 11:02:13 PDT 2024


https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/93243

>From 10615775033226a738ff2f2eec0ae1a77e8df1e6 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Thu, 23 May 2024 14:27:33 -0700
Subject: [PATCH 1/2] [mlir][polynomial] verify from_tensor coeff type

Use the coefficient type to verify if a tensor fits in a polynomial
ring. Downstream we had originally not specified the coefficient type
and just used the implied bit width of the coefficient modulus.  Now
that we specify the type, this is simpler.
---
 mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | 5 +----
 mlir/test/Dialect/Polynomial/ops_errors.mlir     | 2 +-
 2 files changed, 2 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 76e837404c6b5..3719979177215 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -47,11 +47,8 @@ LogicalResult FromTensorOp::verify() {
     return diag;
   }
 
-  APInt coefficientModulus = ring.getCoefficientModulus().getValue();
-  unsigned cmodBitWidth = coefficientModulus.ceilLogBase2();
   unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
-
-  if (inputBitWidth > cmodBitWidth) {
+  if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) {
     InFlightDiagnostic diag = emitOpError()
                               << "input tensor element type "
                               << getInput().getType().getElementType()
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index f22b14897e98a..4937e17027afa 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt --split-input-file --verify-diagnostics %s
 
 #my_poly = #polynomial.int_polynomial<1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i32, polynomialModulus=#my_poly>
+#ring = #polynomial.ring<coefficientType=i16>
 !ty = !polynomial.polynomial<ring=#ring>
 
 func.func @test_from_tensor_too_large_coeffs() {

>From 24d2b298032fa4cb0e8e846e717fe740884091bc Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Thu, 30 May 2024 10:57:44 -0700
Subject: [PATCH 2/2] make verifiers resilient to missing optional attrs

---
 .../Dialect/Polynomial/IR/PolynomialOps.cpp   | 61 +++++++++++--------
 1 file changed, 34 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 3719979177215..3117721a94152 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -34,17 +34,21 @@ void FromTensorOp::build(OpBuilder &builder, OperationState &result,
 LogicalResult FromTensorOp::verify() {
   ArrayRef<int64_t> tensorShape = getInput().getType().getShape();
   RingAttr ring = getOutput().getType().getRing();
-  unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
-  bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
-  if (!compatible) {
-    InFlightDiagnostic diag = emitOpError()
-                              << "input type " << getInput().getType()
-                              << " does not match output type "
-                              << getOutput().getType();
-    diag.attachNote() << "the input type must be a tensor of shape [d] where d "
-                         "is at most the degree of the polynomialModulus of "
-                         "the output type's ring attribute";
-    return diag;
+  IntPolynomialAttr polyMod = ring.getPolynomialModulus();
+  if (polyMod) {
+    unsigned polyDegree = polyMod.getPolynomial().getDegree();
+    bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
+    if (!compatible) {
+      InFlightDiagnostic diag = emitOpError()
+                                << "input type " << getInput().getType()
+                                << " does not match output type "
+                                << getOutput().getType();
+      diag.attachNote()
+          << "the input type must be a tensor of shape [d] where d "
+             "is at most the degree of the polynomialModulus of "
+             "the output type's ring attribute";
+      return diag;
+    }
   }
 
   unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
@@ -64,24 +68,27 @@ LogicalResult FromTensorOp::verify() {
 
 LogicalResult ToTensorOp::verify() {
   ArrayRef<int64_t> tensorShape = getOutput().getType().getShape();
-  unsigned polyDegree = getInput()
-                            .getType()
-                            .getRing()
-                            .getPolynomialModulus()
-                            .getPolynomial()
-                            .getDegree();
-  bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
+  IntPolynomialAttr polyMod =
+      getInput().getType().getRing().getPolynomialModulus();
+  if (polyMod) {
+    unsigned polyDegree = polyMod.getPolynomial().getDegree();
+    bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
 
-  if (compatible)
-    return success();
+    if (compatible)
+      return success();
 
-  InFlightDiagnostic diag =
-      emitOpError() << "input type " << getInput().getType()
-                    << " does not match output type " << getOutput().getType();
-  diag.attachNote() << "the output type must be a tensor of shape [d] where d "
-                       "is at most the degree of the polynomialModulus of "
-                       "the input type's ring attribute";
-  return diag;
+    InFlightDiagnostic diag = emitOpError()
+                              << "input type " << getInput().getType()
+                              << " does not match output type "
+                              << getOutput().getType();
+    diag.attachNote()
+        << "the output type must be a tensor of shape [d] where d "
+           "is at most the degree of the polynomialModulus of "
+           "the input type's ring attribute";
+    return diag;
+  }
+
+  return success();
 }
 
 LogicalResult MulScalarOp::verify() {



More information about the Mlir-commits mailing list