[Mlir-commits] [mlir] [mlir][tosa] Use `LogicalResult` in	validation functions (PR #160052)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Mon Sep 22 01:34:04 PDT 2025
    
    
  
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
This commit replaces functions that previously returned `bool` to indicate validation success or failure with `LogicalResult`.
Note: this PR also contains the contents of https://github.com/llvm/llvm-project/pull/159754, so shouldn't be merged before https://github.com/llvm/llvm-project/pull/159754.
---
Patch is 51.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160052.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+335-388) 
- (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (-33) 
- (added) mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir (+34) 
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 790bbf77877bc..6ea4e7736f78c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -205,148 +205,142 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     constCheckers.emplace_back(checkConstantOperandNegate);
   }
 
-  bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_KERNEL) {
-      op->emitOpError() << "failed level check: " << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckKernel(Operation *op, int32_t v,
+                                 const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_KERNEL)
+      return op->emitOpError() << "failed level check: " << checkDesc;
+    return success();
   }
 
-  bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_STRIDE) {
-      op->emitOpError() << "failed level check: " << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckStride(Operation *op, int32_t v,
+                                 const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_STRIDE)
+      return op->emitOpError() << "failed level check: " << checkDesc;
+    return success();
   }
 
-  bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_SCALE) {
-      op->emitOpError() << "failed level check: " << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckScale(Operation *op, int32_t v,
+                                const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_SCALE)
+      return op->emitOpError() << "failed level check: " << checkDesc;
+    return success();
   }
 
-  bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
-      op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
-                        << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckListSize(Operation *op, int32_t v,
+                                   const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_TENSOR_LIST_SIZE)
+      return op->emitOpError()
+             << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
+    return success();
   }
 
   // Perform the Level Rank check on the tensor type.
-  bool levelCheckRank(Operation *op, const Type typeToCheck,
-                      const StringRef operandOrResult, int32_t highest_rank) {
+  LogicalResult levelCheckRank(Operation *op, const Type typeToCheck,
+                               const StringRef operandOrResult,
+                               int32_t highest_rank) {
     if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
-      if (!type.hasRank()) {
-        op->emitOpError() << "failed level check: unranked tensor";
-        return false;
-      }
-      if (type.getRank() > highest_rank) {
-        op->emitOpError() << "failed level check: " << operandOrResult
-                          << " rank(shape) <= MAX_RANK";
-        return false;
-      }
+      if (!type.hasRank())
+        return op->emitOpError() << "failed level check: unranked tensor";
+      if (type.getRank() > highest_rank)
+        return op->emitOpError() << "failed level check: " << operandOrResult
+                                 << " rank(shape) <= MAX_RANK";
     }
-    return true;
+    return success();
   }
 
   // Perform the Level Rank check on the tensor value.
-  bool levelCheckRank(Operation *op, const Value &v,
-                      const StringRef operandOrResult, int32_t highest_rank) {
+  LogicalResult levelCheckRank(Operation *op, const Value &v,
+                               const StringRef operandOrResult,
+                               int32_t highest_rank) {
     return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
   }
 
   // Perform the Level tensor size check on the tensor type.
-  bool levelCheckSize(Operation *op, const Type &typeToCheck,
-                      const StringRef operandOrResult);
+  LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck,
+                               const StringRef operandOrResult);
 
   // Perform the Level tensor size check on the tensor value.
-  bool levelCheckSize(Operation *op, const Value &v,
-                      const StringRef operandOrResult) {
+  LogicalResult levelCheckSize(Operation *op, const Value &v,
+                               const StringRef operandOrResult) {
     return levelCheckSize(op, v.getType(), operandOrResult);
   }
 
   // Level check sizes of all operands and results of the operation.
   template <typename T>
-  bool levelCheckSizes(T tosaOp) {
+  LogicalResult levelCheckSizes(T tosaOp) {
     auto op = tosaOp.getOperation();
     for (auto v : op->getOperands()) {
-      if (!levelCheckSize(op, v, "operand"))
-        return false;
+      if (failed(levelCheckSize(op, v, "operand")))
+        return failure();
     }
 
     for (auto v : op->getResults()) {
-      if (!levelCheckSize(op, v, "result"))
-        return false;
+      if (failed(levelCheckSize(op, v, "result")))
+        return failure();
     }
-    return true;
+    return success();
   }
 
   // Level check ranks of all operands, attribute and results of the operation.
   template <typename T>
-  bool levelCheckRanks(T tosaOp) {
+  LogicalResult levelCheckRanks(T tosaOp) {
     auto op = tosaOp.getOperation();
     for (auto v : op->getOperands()) {
-      if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))
-        return false;
+      if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
+        return failure();
     }
 
     for (auto v : op->getResults()) {
-      if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
-        return false;
+      if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)))
+        return failure();
     }
-    return true;
+    return success();
   }
 
   // Level check ranks and sizes.
-  bool levelCheckRanksAndSizes(Operation *op);
+  LogicalResult levelCheckRanksAndSizes(Operation *op);
 
   // Pool Op: level check kernel/stride/pad values
   template <typename T>
-  bool levelCheckPool(Operation *op) {
+  LogicalResult levelCheckPool(Operation *op) {
     if (auto poolOp = dyn_cast<T>(op)) {
       for (auto k : poolOp.getKernel()) {
-        if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, k, "kernel <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto s : poolOp.getStride()) {
-        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
-          return false;
+        if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+          return failure();
         }
       }
       for (auto p : poolOp.getPad()) {
-        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+          return failure();
         }
       }
     }
-    return true;
+    return success();
   }
 
   // Conv Op: level check dilation/stride/pad values
   template <typename T>
-  bool levelCheckConv(Operation *op) {
+  LogicalResult levelCheckConv(Operation *op) {
     if (auto convOp = dyn_cast<T>(op)) {
 
       for (auto k : convOp.getDilation()) {
-        if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, k, "dilation <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto p : convOp.getPad()) {
-        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto s : convOp.getStride()) {
-        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
-          return false;
+        if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+          return failure();
         }
       }
       auto dilation = convOp.getDilation();
@@ -356,100 +350,100 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
         if (isa<tosa::Conv2DOp>(op)) {
           assert(shape.size() == 4);
           assert(dilation.size() == 2);
-          if (!levelCheckKernel(op, dilation[0] * shape[1],
-                                "dilation_y * KH <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[1] * shape[2],
-                                "dilation_x * KW <= MAX_KERNEL)"))
-            return false;
+          if (failed(levelCheckKernel(op, dilation[0] * shape[1],
+                                      "dilation_y * KH <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[1] * shape[2],
+                                      "dilation_x * KW <= MAX_KERNEL)")))
+            return failure();
         } else if (isa<tosa::Conv3DOp>(op)) {
           assert(shape.size() == 5);
           assert(dilation.size() == 3);
-          if (!levelCheckKernel(op, dilation[0] * shape[1],
-                                "dilation_d * KD <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[1] * shape[2],
-                                "dilation_y * KH <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[2] * shape[3],
-                                "dilation_x * KW <= MAX_KERNEL)"))
-            return false;
+          if (failed(levelCheckKernel(op, dilation[0] * shape[1],
+                                      "dilation_d * KD <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[1] * shape[2],
+                                      "dilation_y * KH <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[2] * shape[3],
+                                      "dilation_x * KW <= MAX_KERNEL)")))
+            return failure();
         } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
           assert(shape.size() == 4);
           assert(dilation.size() == 2);
-          if (!levelCheckKernel(op, dilation[0] * shape[0],
-                                "dilation_y * KH <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[1] * shape[1],
-                                "dilation_x * KW <= MAX_KERNEL)"))
-            return false;
+          if (failed(levelCheckKernel(op, dilation[0] * shape[0],
+                                      "dilation_y * KH <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[1] * shape[1],
+                                      "dilation_x * KW <= MAX_KERNEL)")))
+            return failure();
         }
       }
     }
-    return true;
+    return success();
   }
 
   // FFT op: level check H, W in input shape [N,H,W]
   template <typename T>
-  bool levelCheckFFT(Operation *op) {
+  LogicalResult levelCheckFFT(Operation *op) {
     if (isa<T>(op)) {
       for (auto v : op->getOperands()) {
         if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
           auto shape = type.getShape();
           assert(shape.size() == 3);
-          if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
-              !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
-            return false;
+          if (failed(levelCheckKernel(op, shape[1], "H <= MAX_KERNEL")) ||
+              failed(levelCheckKernel(op, shape[2], "W <= MAX_KERNEL"))) {
+            return failure();
           }
         }
       }
     }
-    return true;
+    return success();
   }
 
   // TransposeConv2d op: level check kH/kW, outpad, and stride
-  bool levelCheckTransposeConv2d(Operation *op) {
+  LogicalResult levelCheckTransposeConv2d(Operation *op) {
     if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
       if (ShapedType filterType =
               dyn_cast<ShapedType>(transpose.getWeight().getType())) {
         auto shape = filterType.getShape();
         assert(shape.size() == 4);
         // level check kernel sizes for kH and KW
-        if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
-            !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL")) ||
+            failed(levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto p : transpose.getOutPad()) {
-        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto s : transpose.getStride()) {
-        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
-          return false;
+        if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+          return failure();
         }
       }
     }
-    return true;
+    return success();
   }
 
   // Resize op: level check max scales
-  bool levelCheckResize(Operation *op) {
+  LogicalResult levelCheckResize(Operation *op) {
     if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
       SmallVector<int64_t> scale;
       if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
                                      scale)) {
-        return false;
+        return failure();
       }
       const int64_t scaleYN = scale[0];
       const int64_t scaleYD = scale[1];
       const int64_t scaleXN = scale[2];
       const int64_t scaleXD = scale[3];
-      if (!levelCheckScale(op, scaleYN / scaleYD,
-                           "scale_y_n/scale_y_d <= MAX_SCALE") ||
-          !levelCheckScale(op, scaleXN / scaleXD,
-                           "scale_x_n/scale_x_d <= MAX_SCALE")) {
-        return false;
+      if (failed(levelCheckScale(op, scaleYN / scaleYD,
+                                 "scale_y_n/scale_y_d <= MAX_SCALE")) ||
+          failed(levelCheckScale(op, scaleXN / scaleXD,
+                                 "scale_x_n/scale_x_d <= MAX_SCALE"))) {
+        return failure();
       }
     }
-    return true;
+    return success();
   }
 
   // Recursively perform a bottom-up search to determine the maximum nesting
@@ -468,62 +462,65 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     getMaxNestedDepth(op, depth);
   }
 
-  bool levelCheckMaxNesting(Operation *op) {
+  LogicalResult levelCheckMaxNesting(Operation *op) {
     int32_t maxNestedDepth = 0;
     getMaxNestedDepth(op, maxNestedDepth);
 
     if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
       op->emitOpError() << "failed level check: " << maxNestedDepth
                         << " >= MAX_NESTING";
-      return false;
+      return failure();
     }
-    return true;
+    return success();
   }
 
-  bool levelCheckListSize(Operation *op) {
+  LogicalResult levelCheckListSize(Operation *op) {
     if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
       return levelCheckListSize(op, concat.getInput1().size(), "input1");
     }
     if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
-      if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
-          !levelCheckListSize(op, custom.getOutputList().size(),
-                              "output_list")) {
-        return false;
+      if (failed(levelCheckListSize(op, custom.getInputList().size(),
+                                    "input_list")) ||
+          failed(levelCheckListSize(op, custom.getOutputList().size(),
+                                    "output_list"))) {
+        return failure();
       }
     }
     if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
-      if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
-          !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
-        return false;
+      if (failed(
+              levelCheckListSize(op, condIf.getInputList().size(), "inputs")) ||
+          failed(levelCheckListSize(op, condIf.getOutputList().size(),
+                                    "outputs"))) {
+        return failure();
       }
     }
     if (auto w = dyn_cast<tosa::WhileOp>(op)) {
-      if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
-          !levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
-        return false;
+      if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) ||
+          failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) {
+        return failure();
       }
     }
-    return true;
+    return success();
   }
 
-  bool attributeCheckRescale(Operation *op) {
+  LogicalResult attributeCheckRescale(Operation *op) {
     if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
       if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
           !targetEnv.allows(Extension::doubleround)) {
         op->emitOpError()
             << "failed attribute check: rounding_mode = DOUBLE_ROUND "
             << "requires extension [doubleround]";
-        return false;
+        return failure();
       }
       if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
           !targetEnv.allows(Extension::inexactround)) {
         op->emitOpError()
             << "failed attribute check: rounding_mode = INEXACT_ROUND "
             << "requires extension [inexactround]";
-        return false;
+        return failure();
       }
     }
-    return true;
+    return success();
   }
 
   // configure profile and level values from pass options profileName and
@@ -563,8 +560,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     }
   }
 
-  bool CheckVariable(Operation *op);
-  bool CheckVariableReadOrWrite(Operation *op);
+  LogicalResult CheckVariable(Operation *op);
+  LogicalResult CheckVariableReadOrWrite(Operation *op);
   bool isValidElementType(Type type, const bool allowUnsigned = false);
 
   SmallVector<
@@ -577,62 +574,66 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 };
 
 template <>
-bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
   auto op = tosaOp.getOperation();
-  if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))
-    return false;
+  if (failed(
+          levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK)))
+    return failure();
 
   // rank(output) = rank(input) - 1
-  if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1))
-    return false;
+  if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
+                            tosaLevel.MAX_RANK - 1)))
+    return failure();
 
-  return true;
+  return success();
 }
 
 template <>
-bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
   auto op = tosaOp.getOperation();
 
   // Only the condition input has rank limitation.
-  if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
-    return false;
+  if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
+                            tosaLevel.MAX_RANK)))
+    return failure();
 
-  return true;
+  return success();
 }
 
 template <>
-bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
   auto op = tosaOp.getOperation();
   auto variableType = getVariableType(tosaOp);
-  if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
-    return false;
+  if (failed(levelCheckRank(op, variableType, "variable type",
+                            tosaLevel.MAX_RANK)))
+    return failure();
 
-  return true;
+  return success();
 }
 
 template <>
-bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
+LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
   auto op = tosaOp.getOperation();
   auto variableType = getVariableType(tosaOp);
-  if (!levelCheckSize(op, variableType, "variable type"))
-    return false;
+  if (failed(levelCheckSize(op, variableType, "variable type")))
+    ret...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/160052
    
    
More information about the Mlir-commits
mailing list