[Mlir-commits] [mlir] [mlir][tosa] Use `LogicalResult` in validation functions (PR #160052)

Luke Hutton llvmlistbot at llvm.org
Wed Sep 24 01:15:37 PDT 2025


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/160052

>From 798879fc9ad3c5e662b367a50ed90481560844de Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Sat, 20 Sep 2025 10:51:47 +0000
Subject: [PATCH] [mlir][tosa] Use `LogicalResult` in validation functions

This commit replaces functions that previously returned `bool`
to indicate validation success or failure with `LogicalResult`.

Change-Id: Iec3b54e3cc5462e981e1e9eb8639608c62a128ed
---
 .../Tosa/Transforms/TosaValidation.cpp        | 718 ++++++++----------
 1 file changed, 332 insertions(+), 386 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index e9fdcbdc15837..b82cd59715614 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -205,148 +205,142 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     constCheckers.emplace_back(checkConstantOperandNegate);
   }
 
-  bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_KERNEL) {
-      op->emitOpError() << "failed level check: " << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckKernel(Operation *op, int32_t v,
+                                 const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_KERNEL)
+      return op->emitOpError() << "failed level check: " << checkDesc;
+    return success();
   }
 
-  bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_STRIDE) {
-      op->emitOpError() << "failed level check: " << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckStride(Operation *op, int32_t v,
+                                 const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_STRIDE)
+      return op->emitOpError() << "failed level check: " << checkDesc;
+    return success();
   }
 
-  bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_SCALE) {
-      op->emitOpError() << "failed level check: " << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckScale(Operation *op, int32_t v,
+                                const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_SCALE)
+      return op->emitOpError() << "failed level check: " << checkDesc;
+    return success();
   }
 
-  bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
-    if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
-      op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
-                        << checkDesc;
-      return false;
-    }
-    return true;
+  LogicalResult levelCheckListSize(Operation *op, int32_t v,
+                                   const StringRef checkDesc) {
+    if (v > tosaLevel.MAX_TENSOR_LIST_SIZE)
+      return op->emitOpError()
+             << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
+    return success();
   }
 
   // Perform the Level Rank check on the tensor type.
-  bool levelCheckRank(Operation *op, const Type typeToCheck,
-                      const StringRef operandOrResult, int32_t highest_rank) {
+  LogicalResult levelCheckRank(Operation *op, const Type typeToCheck,
+                               const StringRef operandOrResult,
+                               int32_t highest_rank) {
     if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
-      if (!type.hasRank()) {
-        op->emitOpError() << "failed level check: unranked tensor";
-        return false;
-      }
-      if (type.getRank() > highest_rank) {
-        op->emitOpError() << "failed level check: " << operandOrResult
-                          << " rank(shape) <= MAX_RANK";
-        return false;
-      }
+      if (!type.hasRank())
+        return op->emitOpError() << "failed level check: unranked tensor";
+      if (type.getRank() > highest_rank)
+        return op->emitOpError() << "failed level check: " << operandOrResult
+                                 << " rank(shape) <= MAX_RANK";
     }
-    return true;
+    return success();
   }
 
   // Perform the Level Rank check on the tensor value.
-  bool levelCheckRank(Operation *op, const Value &v,
-                      const StringRef operandOrResult, int32_t highest_rank) {
+  LogicalResult levelCheckRank(Operation *op, const Value &v,
+                               const StringRef operandOrResult,
+                               int32_t highest_rank) {
     return levelCheckRank(op, v.getType(), operandOrResult, highest_rank);
   }
 
   // Perform the Level tensor size check on the tensor type.
-  bool levelCheckSize(Operation *op, const Type &typeToCheck,
-                      const StringRef operandOrResult);
+  LogicalResult levelCheckSize(Operation *op, const Type &typeToCheck,
+                               const StringRef operandOrResult);
 
   // Perform the Level tensor size check on the tensor value.
-  bool levelCheckSize(Operation *op, const Value &v,
-                      const StringRef operandOrResult) {
+  LogicalResult levelCheckSize(Operation *op, const Value &v,
+                               const StringRef operandOrResult) {
     return levelCheckSize(op, v.getType(), operandOrResult);
   }
 
   // Level check sizes of all operands and results of the operation.
   template <typename T>
-  bool levelCheckSizes(T tosaOp) {
+  LogicalResult levelCheckSizes(T tosaOp) {
     auto op = tosaOp.getOperation();
     for (auto v : op->getOperands()) {
-      if (!levelCheckSize(op, v, "operand"))
-        return false;
+      if (failed(levelCheckSize(op, v, "operand")))
+        return failure();
     }
 
     for (auto v : op->getResults()) {
-      if (!levelCheckSize(op, v, "result"))
-        return false;
+      if (failed(levelCheckSize(op, v, "result")))
+        return failure();
     }
-    return true;
+    return success();
   }
 
   // Level check ranks of all operands, attribute and results of the operation.
   template <typename T>
-  bool levelCheckRanks(T tosaOp) {
+  LogicalResult levelCheckRanks(T tosaOp) {
     auto op = tosaOp.getOperation();
     for (auto v : op->getOperands()) {
-      if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))
-        return false;
+      if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
+        return failure();
     }
 
     for (auto v : op->getResults()) {
-      if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
-        return false;
+      if (failed(levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)))
+        return failure();
     }
-    return true;
+    return success();
   }
 
   // Level check ranks and sizes.
-  bool levelCheckRanksAndSizes(Operation *op);
+  LogicalResult levelCheckRanksAndSizes(Operation *op);
 
   // Pool Op: level check kernel/stride/pad values
   template <typename T>
-  bool levelCheckPool(Operation *op) {
+  LogicalResult levelCheckPool(Operation *op) {
     if (auto poolOp = dyn_cast<T>(op)) {
       for (auto k : poolOp.getKernel()) {
-        if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, k, "kernel <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto s : poolOp.getStride()) {
-        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
-          return false;
+        if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+          return failure();
         }
       }
       for (auto p : poolOp.getPad()) {
-        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+          return failure();
         }
       }
     }
-    return true;
+    return success();
   }
 
   // Conv Op: level check dilation/stride/pad values
   template <typename T>
-  bool levelCheckConv(Operation *op) {
+  LogicalResult levelCheckConv(Operation *op) {
     if (auto convOp = dyn_cast<T>(op)) {
 
       for (auto k : convOp.getDilation()) {
-        if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, k, "dilation <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto p : convOp.getPad()) {
-        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto s : convOp.getStride()) {
-        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
-          return false;
+        if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+          return failure();
         }
       }
       auto dilation = convOp.getDilation();
@@ -356,100 +350,100 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
         if (isa<tosa::Conv2DOp>(op)) {
           assert(shape.size() == 4);
           assert(dilation.size() == 2);
-          if (!levelCheckKernel(op, dilation[0] * shape[1],
-                                "dilation_y * KH <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[1] * shape[2],
-                                "dilation_x * KW <= MAX_KERNEL)"))
-            return false;
+          if (failed(levelCheckKernel(op, dilation[0] * shape[1],
+                                      "dilation_y * KH <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[1] * shape[2],
+                                      "dilation_x * KW <= MAX_KERNEL)")))
+            return failure();
         } else if (isa<tosa::Conv3DOp>(op)) {
           assert(shape.size() == 5);
           assert(dilation.size() == 3);
-          if (!levelCheckKernel(op, dilation[0] * shape[1],
-                                "dilation_d * KD <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[1] * shape[2],
-                                "dilation_y * KH <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[2] * shape[3],
-                                "dilation_x * KW <= MAX_KERNEL)"))
-            return false;
+          if (failed(levelCheckKernel(op, dilation[0] * shape[1],
+                                      "dilation_d * KD <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[1] * shape[2],
+                                      "dilation_y * KH <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[2] * shape[3],
+                                      "dilation_x * KW <= MAX_KERNEL)")))
+            return failure();
         } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
           assert(shape.size() == 4);
           assert(dilation.size() == 2);
-          if (!levelCheckKernel(op, dilation[0] * shape[0],
-                                "dilation_y * KH <= MAX_KERNEL)") ||
-              !levelCheckKernel(op, dilation[1] * shape[1],
-                                "dilation_x * KW <= MAX_KERNEL)"))
-            return false;
+          if (failed(levelCheckKernel(op, dilation[0] * shape[0],
+                                      "dilation_y * KH <= MAX_KERNEL)")) ||
+              failed(levelCheckKernel(op, dilation[1] * shape[1],
+                                      "dilation_x * KW <= MAX_KERNEL)")))
+            return failure();
         }
       }
     }
-    return true;
+    return success();
   }
 
   // FFT op: level check H, W in input shape [N,H,W]
   template <typename T>
-  bool levelCheckFFT(Operation *op) {
+  LogicalResult levelCheckFFT(Operation *op) {
     if (isa<T>(op)) {
       for (auto v : op->getOperands()) {
         if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
           auto shape = type.getShape();
           assert(shape.size() == 3);
-          if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
-              !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
-            return false;
+          if (failed(levelCheckKernel(op, shape[1], "H <= MAX_KERNEL")) ||
+              failed(levelCheckKernel(op, shape[2], "W <= MAX_KERNEL"))) {
+            return failure();
           }
         }
       }
     }
-    return true;
+    return success();
   }
 
   // TransposeConv2d op: level check kH/kW, outpad, and stride
-  bool levelCheckTransposeConv2d(Operation *op) {
+  LogicalResult levelCheckTransposeConv2d(Operation *op) {
     if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
       if (ShapedType filterType =
               dyn_cast<ShapedType>(transpose.getWeight().getType())) {
         auto shape = filterType.getShape();
         assert(shape.size() == 4);
         // level check kernel sizes for kH and KW
-        if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
-            !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL")) ||
+            failed(levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto p : transpose.getOutPad()) {
-        if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
-          return false;
+        if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) {
+          return failure();
         }
       }
       for (auto s : transpose.getStride()) {
-        if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
-          return false;
+        if (failed(levelCheckStride(op, s, "stride <= MAX_STRIDE"))) {
+          return failure();
         }
       }
     }
-    return true;
+    return success();
   }
 
   // Resize op: level check max scales
-  bool levelCheckResize(Operation *op) {
+  LogicalResult levelCheckResize(Operation *op) {
     if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
       SmallVector<int64_t> scale;
       if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
                                      scale)) {
-        return false;
+        return failure();
       }
       const int64_t scaleYN = scale[0];
       const int64_t scaleYD = scale[1];
       const int64_t scaleXN = scale[2];
       const int64_t scaleXD = scale[3];
-      if (!levelCheckScale(op, scaleYN / scaleYD,
-                           "scale_y_n/scale_y_d <= MAX_SCALE") ||
-          !levelCheckScale(op, scaleXN / scaleXD,
-                           "scale_x_n/scale_x_d <= MAX_SCALE")) {
-        return false;
+      if (failed(levelCheckScale(op, scaleYN / scaleYD,
+                                 "scale_y_n/scale_y_d <= MAX_SCALE")) ||
+          failed(levelCheckScale(op, scaleXN / scaleXD,
+                                 "scale_x_n/scale_x_d <= MAX_SCALE"))) {
+        return failure();
       }
     }
-    return true;
+    return success();
   }
 
   // Recursively perform a bottom-up search to determine the maximum nesting
@@ -468,62 +462,65 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     getMaxNestedDepth(op, depth);
   }
 
-  bool levelCheckMaxNesting(Operation *op) {
+  LogicalResult levelCheckMaxNesting(Operation *op) {
     int32_t maxNestedDepth = 0;
     getMaxNestedDepth(op, maxNestedDepth);
 
     if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
       op->emitOpError() << "failed level check: " << maxNestedDepth
                         << " >= MAX_NESTING";
-      return false;
+      return failure();
     }
-    return true;
+    return success();
   }
 
-  bool levelCheckListSize(Operation *op) {
+  LogicalResult levelCheckListSize(Operation *op) {
     if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
       return levelCheckListSize(op, concat.getInput1().size(), "input1");
     }
     if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
-      if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
-          !levelCheckListSize(op, custom.getOutputList().size(),
-                              "output_list")) {
-        return false;
+      if (failed(levelCheckListSize(op, custom.getInputList().size(),
+                                    "input_list")) ||
+          failed(levelCheckListSize(op, custom.getOutputList().size(),
+                                    "output_list"))) {
+        return failure();
       }
     }
     if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
-      if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
-          !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
-        return false;
+      if (failed(
+              levelCheckListSize(op, condIf.getInputList().size(), "inputs")) ||
+          failed(levelCheckListSize(op, condIf.getOutputList().size(),
+                                    "outputs"))) {
+        return failure();
       }
     }
     if (auto w = dyn_cast<tosa::WhileOp>(op)) {
-      if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
-          !levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
-        return false;
+      if (failed(levelCheckListSize(op, w.getInputList().size(), "inputs")) ||
+          failed(levelCheckListSize(op, w.getOutputList().size(), "outputs"))) {
+        return failure();
       }
     }
-    return true;
+    return success();
   }
 
-  bool attributeCheckRescale(Operation *op) {
+  LogicalResult attributeCheckRescale(Operation *op) {
     if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
       if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
           !targetEnv.allows(Extension::doubleround)) {
         op->emitOpError()
             << "failed attribute check: rounding_mode = DOUBLE_ROUND "
             << "requires extension [doubleround]";
-        return false;
+        return failure();
       }
       if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND &&
           !targetEnv.allows(Extension::inexactround)) {
         op->emitOpError()
             << "failed attribute check: rounding_mode = INEXACT_ROUND "
             << "requires extension [inexactround]";
-        return false;
+        return failure();
       }
     }
-    return true;
+    return success();
   }
 
   // configure profile and level values from pass options profileName and
@@ -563,8 +560,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     }
   }
 
-  bool CheckVariable(Operation *op);
-  bool CheckVariableReadOrWrite(Operation *op);
+  LogicalResult CheckVariable(Operation *op);
+  LogicalResult CheckVariableReadOrWrite(Operation *op);
   bool isValidElementType(Type type, const bool allowUnsigned = false);
 
   SmallVector<
@@ -577,62 +574,65 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 };
 
 template <>
-bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
   auto *op = tosaOp.getOperation();
   if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))
     return false;
 
   // rank(output) = rank(input) - 1
-  if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1))
-    return false;
+  if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
+                            tosaLevel.MAX_RANK - 1)))
+    return failure();
 
-  return true;
+  return success();
 }
 
 template <>
-bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
   auto *op = tosaOp.getOperation();
 
   // Only the condition input has rank limitation.
-  if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
-    return false;
+  if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
+                            tosaLevel.MAX_RANK)))
+    return failure();
 
-  return true;
+  return success();
 }
 
 template <>
-bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
+LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
   auto *op = tosaOp.getOperation();
   auto variableType = getVariableType(tosaOp);
-  if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK))
-    return false;
+  if (failed(levelCheckRank(op, variableType, "variable type",
+                            tosaLevel.MAX_RANK)))
+    return failure();
 
-  return true;
+  return success();
 }
 
 template <>
-bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
+LogicalResult TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) {
   auto *op = tosaOp.getOperation();
   auto variableType = getVariableType(tosaOp);
-  if (!levelCheckSize(op, variableType, "variable type"))
-    return false;
+  if (failed(levelCheckSize(op, variableType, "variable type")))
+    return failure();
 
-  return true;
+  return success();
 }
 
-bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
+LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
 #define CHECK_RANKS_AND_SIZES(tosaOp)                                          \
   if (isa<tosa::tosaOp##Op>(op)) {                                             \
-    if (!levelCheckRanks(cast<tosa::tosaOp##Op>(op)))                          \
-      return false;                                                            \
-    if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op)))                          \
-      return false;                                                            \
+    if (failed(levelCheckRanks(cast<tosa::tosaOp##Op>(op))))                   \
+      return failure();                                                        \
+    if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op))))                   \
+      return failure();                                                        \
   }
 
 #define CHECK_SIZES(tosaOp)                                                    \
   if (isa<tosa::tosaOp##Op>(op)) {                                             \
-    if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op)))                          \
-      return false;                                                            \
+    if (failed(levelCheckSizes(cast<tosa::tosaOp##Op>(op))))                   \
+      return failure();                                                        \
   }
 
   // Tensor Operators
@@ -735,24 +735,21 @@ bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
 
 #undef CHECK_RANKS_AND_SIZES
 #undef CHECK_SIZES
-  return true;
+  return success();
 }
 
 // Perform the Level tensor size check on the tensor type.
-bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck,
-                                    const StringRef operandOrResult) {
+LogicalResult TosaValidation::levelCheckSize(Operation *op,
+                                             const Type &typeToCheck,
+                                             const StringRef operandOrResult) {
   if (ShapedType type = dyn_cast<ShapedType>(typeToCheck)) {
-    if (!type.hasRank()) {
-      op->emitOpError() << "failed level check: unranked tensor";
-      return false;
-    }
+    if (!type.hasRank())
+      return op->emitOpError() << "failed level check: unranked tensor";
     auto shape = type.getShape();
     for (auto dim : shape) {
-      if (mlir::ShapedType::isDynamic(dim)) {
-        op->emitOpError() << "failed level check: " << operandOrResult
-                          << " shape dimension cannot be dynamic";
-        return false;
-      }
+      if (mlir::ShapedType::isDynamic(dim))
+        return op->emitOpError() << "failed level check: " << operandOrResult
+                                 << " shape dimension cannot be dynamic";
     }
 
     int64_t element_bits = type.getElementTypeBitWidth();
@@ -765,14 +762,12 @@ bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck,
     // For each tensor, the number of tensor elements multiplied by the
     // element size in bytes must be representable as a tensor_size_t.
     const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
-    if (size > max_size) {
-      op->emitOpError()
-          << "failed level check: " << operandOrResult
-          << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
-      return false;
-    }
+    if (size > max_size)
+      return op->emitOpError()
+             << "failed level check: " << operandOrResult
+             << " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
   }
-  return true;
+  return success();
 }
 
 LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
@@ -782,28 +777,28 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
   }
 
   // check rank and sizes early so later checks can assume shaped operands
-  if (!levelCheckRanksAndSizes(op))
+  if (failed(levelCheckRanksAndSizes(op)))
     return failure();
 
   // additional level checks from spec 0.70
-  if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
-      !levelCheckConv<tosa::Conv2DOp>(op) ||
-      !levelCheckConv<tosa::Conv3DOp>(op) ||
-      !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
-      !levelCheckFFT<tosa::FFT2dOp>(op) ||
-      !levelCheckPool<tosa::MaxPool2dOp>(op) ||
-      !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
-      !levelCheckResize(op)) {
+  if (failed(levelCheckPool<tosa::AvgPool2dOp>(op)) ||
+      failed(levelCheckConv<tosa::Conv2DOp>(op)) ||
+      failed(levelCheckConv<tosa::Conv3DOp>(op)) ||
+      failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
+      failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
+      failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
+      failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
+      failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op))) {
     return failure();
   }
 
   // level check MAX_TENSOR_LIST_SIZE
-  if (!levelCheckListSize(op)) {
+  if (failed(levelCheckListSize(op))) {
     return failure();
   }
 
   if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
-    if (!levelCheckMaxNesting(op)) {
+    if (failed(levelCheckMaxNesting(op))) {
       return failure();
     }
   }
@@ -812,7 +807,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
 }
 
 LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
-  if (!attributeCheckRescale(op))
+  if (failed(attributeCheckRescale(op)))
     return failure();
   return success();
 }
@@ -823,14 +818,12 @@ inline bool CompatibleTypes(const mlir::Type &type,
   return type == declaredType;
 }
 
-bool TosaValidation::CheckVariable(Operation *op) {
+LogicalResult TosaValidation::CheckVariable(Operation *op) {
   if (auto variableOp = dyn_cast<mlir::tosa::VariableOp>(op)) {
     mlir::StringAttr nameAttr = variableOp.getNameAttr();
 
-    if (variablesMap.count(nameAttr)) {
-      op->emitOpError() << "name has already been declared";
-      return false;
-    }
+    if (variablesMap.count(nameAttr))
+      return op->emitOpError() << "name has already been declared";
 
     auto elementType = variableOp.getType();
     DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
@@ -841,51 +834,44 @@ bool TosaValidation::CheckVariable(Operation *op) {
     variablesMap[nameAttr] = variableType;
   }
 
-  return true;
+  return success();
 }
 
-bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
+LogicalResult TosaValidation::CheckVariableReadOrWrite(Operation *op) {
   if (isa<mlir::tosa::VariableReadOp>(op) ||
       isa<mlir::tosa::VariableWriteOp>(op)) {
     mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
-    if (!variablesMap.count(nameAttr)) {
-      op->emitOpError() << "name has not been declared";
-      return false;
-    }
+    if (!variablesMap.count(nameAttr))
+      return op->emitOpError() << "name has not been declared";
 
     auto varType = variablesMap[nameAttr];
 
     for (auto v : op->getOperands()) {
       auto type = v.getType();
-      if (!CompatibleTypes(type, varType)) {
-        op->emitOpError() << "operand type does not equal variable type";
-        return false;
-      }
+      if (!CompatibleTypes(type, varType))
+        return op->emitOpError() << "operand type does not equal variable type";
     }
 
     for (auto v : op->getResults()) {
       auto type = v.getType();
-      if (!CompatibleTypes(type, varType)) {
-        op->emitOpError() << "result type does not equal variable type";
-        return false;
-      }
+      if (!CompatibleTypes(type, varType))
+        return op->emitOpError() << "result type does not equal variable type";
     }
   }
 
-  return true;
+  return success();
 }
 
 LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
-  if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
+  if (failed(CheckVariable(op)) || failed(CheckVariableReadOrWrite(op)))
     return failure();
-  }
   return success();
 }
 
-bool checkErrorIfResize(Operation *op) {
+LogicalResult checkErrorIfResize(Operation *op) {
   auto resize = dyn_cast<tosa::ResizeOp>(op);
   if (!resize)
-    return true;
+    return success();
 
   const Value input = resize.getInput();
   const Value output = resize.getOutput();
@@ -894,10 +880,8 @@ bool checkErrorIfResize(Operation *op) {
   const RankedTensorType outputType =
       llvm::dyn_cast<RankedTensorType>(output.getType());
 
-  if (!inputType || !outputType) {
-    op->emitOpError("expect ranked input/output tensor");
-    return false;
-  }
+  if (!inputType || !outputType)
+    return op->emitOpError("expect ranked input/output tensor");
 
   // Ensure the image size is supported by GPU APIs and that for integer
   // implementations, position * stride does not overflow int32_t.
@@ -906,17 +890,15 @@ bool checkErrorIfResize(Operation *op) {
         outputType.getDimSize(1), outputType.getDimSize(2),
         inputType.getDimSize(1), inputType.getDimSize(2)};
     const int64_t *maxDim = llvm::max_element(sizes);
-    if (maxDim != sizes.end() && *maxDim >= 16384) {
-      op->emitOpError("expect input/output height/width dims to be < 16384, ")
-          << "got [OH, OW, IH, IW] = " << sizes;
-      return false;
-    }
+    if (maxDim != sizes.end() && *maxDim >= 16384)
+      return op->emitOpError(
+                 "expect input/output height/width dims to be < 16384, ")
+             << "got [OH, OW, IH, IW] = " << sizes;
   }
 
   SmallVector<int64_t> scale;
-  if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale)) {
-    return false;
-  }
+  if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale))
+    return failure();
 
   const int64_t scaleYN = scale[0];
   const int64_t scaleYD = scale[1];
@@ -924,57 +906,45 @@ bool checkErrorIfResize(Operation *op) {
   const int64_t scaleXD = scale[3];
 
   // Ensure scale values don't overflow int32 accumulator
-  if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) {
-    op->emitOpError("expect all scale numerator values to be <= (1 << 11), "
-                    "got scale_y_n=")
-        << scaleYN << ", scale_x_n=" << scaleXN;
-    return false;
-  }
+  if (scaleYN > (1 << 11) || scaleXN > (1 << 11))
+    return op->emitOpError(
+               "expect all scale numerator values to be <= (1 << 11), "
+               "got scale_y_n=")
+           << scaleYN << ", scale_x_n=" << scaleXN;
 
-  if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
-    op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
-        << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
-    return false;
-  }
+  if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN)
+    return op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
+           << scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
 
   SmallVector<int64_t> offset;
   SmallVector<int64_t> border;
   if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
-      !tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border)) {
-    return false;
-  }
+      !tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border))
+    return failure();
 
   const int64_t offsetY = offset[0];
   const int64_t offsetX = offset[1];
   // Set a consistent lower limit of 1/16 downscale to simplify
   // implementations
-  if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
-    op->emitOpError(
-        "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
-        << offsetY << "/" << scaleYN;
-    return false;
-  }
-  if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
-    op->emitOpError(
-        "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
-        << offsetX << "/" << scaleXN;
-    return false;
-  }
+  if (offsetY < -scaleYN || offsetY >= 16 * scaleYN)
+    return op->emitOpError(
+               "expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
+           << offsetY << "/" << scaleYN;
+  if (offsetX < -scaleXN || offsetX >= 16 * scaleXN)
+    return op->emitOpError(
+               "expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
+           << offsetX << "/" << scaleXN;
 
   const int64_t borderY = border[0];
   const int64_t borderX = border[1];
-  if (borderY < -16 * scaleYN || borderY >= scaleYN) {
-    op->emitOpError(
-        "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
-        << borderY << "/" << scaleYN;
-    return false;
-  }
-  if (borderX < -16 * scaleXN || borderX >= scaleXN) {
-    op->emitOpError(
-        "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
-        << borderX << "/" << scaleXN;
-    return false;
-  }
+  if (borderY < -16 * scaleYN || borderY >= scaleYN)
+    return op->emitOpError(
+               "expect borderY / scaleYNumerator to be in range [-16, 1), got ")
+           << borderY << "/" << scaleYN;
+  if (borderX < -16 * scaleXN || borderX >= scaleXN)
+    return op->emitOpError(
+               "expect borderX / scaleXNumerator to be in range [-16, 1), got ")
+           << borderX << "/" << scaleXN;
 
   // The following section of code is mostly duplicated with ResizeOp::verify().
   //
@@ -1001,81 +971,72 @@ bool checkErrorIfResize(Operation *op) {
   if (ih != ShapedType::kDynamic) {
     const std::optional<int64_t> calculatedOutHeightMinusOne =
         idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
-    if (!calculatedOutHeightMinusOne.has_value()) {
-      op->emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
-                      "border_y ")
-          << "to be wholly divisible by scale_y_d, got ((" << ih << " - 1) * "
-          << scaleYN << " - " << offsetY << " + " << borderY << ") / "
-          << scaleYD;
-      return false;
-    }
+    if (!calculatedOutHeightMinusOne.has_value())
+      return op->emitOpError(
+                 "expected (input_height - 1) * scale_y_n - offset_y + "
+                 "border_y ")
+             << "to be wholly divisible by scale_y_d, got ((" << ih
+             << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
+             << ") / " << scaleYD;
     const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
-    if (oh != ShapedType::kDynamic && calculatedOutHeight != oh) {
-      op->emitOpError("calculated output height did not match expected: ")
-          << "calculated=" << calculatedOutHeight << ", expected=" << oh;
-      return false;
-    }
+    if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
+      return op->emitOpError(
+                 "calculated output height did not match expected: ")
+             << "calculated=" << calculatedOutHeight << ", expected=" << oh;
   }
 
   if (iw != ShapedType::kDynamic) {
     const std::optional<int64_t> calculatedOutWidthMinusOne =
         idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
-    if (!calculatedOutWidthMinusOne.has_value()) {
-      op->emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
-                      "border_x ")
-          << "to be wholly divisible by scale_x_d, got ((" << iw << " - 1) * "
-          << scaleXN << " - " << offsetX << " + " << borderX << ") / "
-          << scaleXD;
-      return false;
-    }
+    if (!calculatedOutWidthMinusOne.has_value())
+      return op->emitOpError(
+                 "expected (input_width - 1) * scale_x_n - offset_x + "
+                 "border_x ")
+             << "to be wholly divisible by scale_x_d, got ((" << iw
+             << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
+             << ") / " << scaleXD;
     const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
-    if (ow != ShapedType::kDynamic && calculatedOutWidth != ow) {
-      op->emitOpError("calculated output width did not match expected: ")
-          << "calculated=" << calculatedOutWidth << ", expected=" << ow;
-      return false;
-    }
+    if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
+      return op->emitOpError("calculated output width did not match expected: ")
+             << "calculated=" << calculatedOutWidth << ", expected=" << ow;
   }
 
-  return true;
+  return success();
 }
 
-bool checkErrorIfMul(Operation *op) {
+LogicalResult checkErrorIfMul(Operation *op) {
   auto mul = dyn_cast<tosa::MulOp>(op);
   if (!mul)
-    return true;
+    return success();
 
   // REQUIRE(0 <= shift && shift <= 63);
   // REQUIRE(is_same<in_t,int32_t>() || shift == 0);
   ElementsAttr shift_elem;
-  if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) {
-    return true;
-  }
+  if (!matchPattern(mul.getShift(), m_Constant(&shift_elem)))
+    return success();
   int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
   auto inputElemType = getElementTypeOrSelf(mul.getInput1());
   if (inputElemType.isInteger(32)) {
     // 0 <= shift <= 63 for int32_t type
-    if (shift < 0 || shift > 63) {
-      op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: "
-                        << shift;
-      return false;
-    }
+    if (shift < 0 || shift > 63)
+      return op->emitOpError()
+             << "requires 0 <= shift && shift <= 63, but got: " << shift;
   } else {
     // shift must be 0 for all other types
-    if (shift != 0) {
-      op->emitOpError() << "requires shift = 0 for all input data types that "
-                           "are not int32_t, but got: "
-                        << shift;
-      return false;
-    }
+    if (shift != 0)
+      return op->emitOpError()
+             << "requires shift = 0 for all input data types that "
+                "are not int32_t, but got: "
+             << shift;
   }
 
-  return true;
+  return success();
 }
 
-bool checkErrorIfTable(Operation *op) {
+LogicalResult checkErrorIfTable(Operation *op) {
   auto table = dyn_cast<tosa::TableOp>(op);
   if (!table)
-    return true;
+    return success();
 
   // REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
   const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
@@ -1084,26 +1045,24 @@ bool checkErrorIfTable(Operation *op) {
   const ShapeAdaptor tableShape(table.getTable().getType());
   if (tableShape.hasStaticShape()) {
     const auto numElements = tableShape.getNumElements();
-    if (numElements != tableSize) {
-      op->emitOpError() << "requires table size of " << tableSize << ", got "
-                        << numElements;
-      return false;
-    }
+    if (numElements != tableSize)
+      return op->emitOpError() << "requires table size of " << tableSize
+                               << ", got " << numElements;
   }
 
-  return true;
+  return success();
 }
 
-bool checkErrorIfRescale(Operation *op) {
+LogicalResult checkErrorIfRescale(Operation *op) {
   auto rescale = dyn_cast<tosa::RescaleOp>(op);
   if (!rescale)
-    return true;
+    return success();
 
   auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
   auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
   if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
       !outputType.getElementType().isInteger())
-    return true;
+    return success();
 
   auto inElemType = inputType.getElementType();
   auto outElemType = outputType.getElementType();
@@ -1117,81 +1076,65 @@ bool checkErrorIfRescale(Operation *op) {
   auto roundingMode = rescale.getRoundingMode();
 
   // ERROR_IF(scale32 && is_same<in_t,i48_t>())
-  if (scale32 && inWidth == 48) {
-    op->emitOpError() << "scale32 is not allowed with 48-bit input.";
-    return false;
-  }
+  if (scale32 && inWidth == 48)
+    return op->emitOpError() << "scale32 is not allowed with 48-bit input.";
 
   // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
-  if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND) {
-    op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
-    return false;
-  }
+  if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND)
+    return op->emitOpError()
+           << "DOUBLE_ROUND is only allowed with scale32=true.";
 
   // ERROR_IF(input_unsigned && output_unsigned)
-  if (inputUnsigned && outputUnsigned) {
-    op->emitOpError() << "input and output cannot be both unsigned.";
-    return false;
-  }
+  if (inputUnsigned && outputUnsigned)
+    return op->emitOpError() << "input and output cannot be both unsigned.";
 
   // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
-  if (outWidth == 32 && inputUnsigned) {
-    op->emitOpError() << "i32 output type is not allowed with unsigned input.";
-    return false;
-  }
+  if (outWidth == 32 && inputUnsigned)
+    return op->emitOpError()
+           << "i32 output type is not allowed with unsigned input.";
 
   // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
-  if (inWidth == 32 && outputUnsigned) {
-    op->emitOpError() << "i32 input type is not allowed with unsigned output.";
-    return false;
-  }
+  if (inWidth == 32 && outputUnsigned)
+    return op->emitOpError()
+           << "i32 input type is not allowed with unsigned output.";
 
   // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
-  if (inWidth == 48 && outputUnsigned) {
-    op->emitOpError() << "i48 input type is not allowed with unsigned output.";
-    return false;
-  }
+  if (inWidth == 48 && outputUnsigned)
+    return op->emitOpError()
+           << "i48 input type is not allowed with unsigned output.";
 
   // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
-  if (inWidth == 48 && inputUnsigned) {
-    op->emitOpError() << "i48 input type cannot be unsigned.";
-    return false;
-  }
+  if (inWidth == 48 && inputUnsigned)
+    return op->emitOpError() << "i48 input type cannot be unsigned.";
 
   // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
-  if (inWidth == 32 && inputUnsigned) {
-    op->emitOpError() << "i32 input type cannot be unsigned.";
-    return false;
-  }
+  if (inWidth == 32 && inputUnsigned)
+    return op->emitOpError() << "i32 input type cannot be unsigned.";
 
   // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
-  if (outWidth == 32 && outputUnsigned) {
-    op->emitOpError() << "i32 output type cannot be unsigned.";
-    return false;
-  }
+  if (outWidth == 32 && outputUnsigned)
+    return op->emitOpError() << "i32 output type cannot be unsigned.";
 
-  return true;
+  return success();
 }
 
-bool checkErrorIfPad(Operation *op) {
+LogicalResult checkErrorIfPad(Operation *op) {
   auto pad = dyn_cast<tosa::PadOp>(op);
   if (!pad)
-    return true;
+    return success();
 
   DenseIntElementsAttr paddingAttr;
   if (!matchPattern(pad.getPadding(), m_Constant(&paddingAttr)))
     // Pad verifier will catch this
-    return true;
+    return success();
 
   for (const APInt &val : paddingAttr.getValues<APInt>()) {
-    if (val.getSExtValue() < 0) {
-      op->emitOpError() << "padding value must all be non-negative, got "
-                        << val.getSExtValue();
-      return false;
-    }
+    if (val.getSExtValue() < 0)
+      return op->emitOpError() << "padding value must all be non-negative, got "
+                               << val.getSExtValue();
   }
 
-  return true;
+  return success();
 }
 
 static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
@@ -1201,7 +1144,7 @@ static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
   });
 }
 
-static bool isRegionIsolatedFromAbove(Region &regionToCheck) {
+static LogicalResult isRegionIsolatedFromAbove(Region &regionToCheck) {
   bool noLiveInValue = true;
   regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *op) {
     if (!isOpIsolatedWithinRegion(op, &regionToCheck)) {
@@ -1210,23 +1153,22 @@ static bool isRegionIsolatedFromAbove(Region &regionToCheck) {
     }
     return WalkResult::advance();
   });
-  return noLiveInValue;
+  return noLiveInValue ? success() : failure();
 }
 
 LogicalResult checkIsolatedRegion(Operation *op, Region &regionToCheck,
                                   StringRef regionName) {
-  if (isRegionIsolatedFromAbove(regionToCheck))
+  if (succeeded(isRegionIsolatedFromAbove(regionToCheck)))
     return success();
-  op->emitOpError()
-      << "is not conformant to the TOSA specification. It requires the '"
-      << regionName << "' region is isolated from above.\n";
-  return failure();
+  return op->emitOpError()
+         << "is not conformant to the TOSA specification. It requires the '"
+         << regionName << "' region is isolated from above.\n";
 }
 
-bool checkErrorIfCondIf(Operation *op) {
+LogicalResult checkErrorIfCondIf(Operation *op) {
   auto ifOp = dyn_cast<tosa::IfOp>(op);
   if (!ifOp)
-    return true;
+    return success();
 
   // Currently the dialect supports declaring cond_if operations that
   // have then/else regions that reference values from outside these
@@ -1257,49 +1199,53 @@ bool checkErrorIfCondIf(Operation *op) {
   //   tosa.yield %arg4
   // }
 
-  return succeeded(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) &&
-         succeeded(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
+  if (failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
+      failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")))
+    return failure();
+  return success();
 }
 
-bool checkErrorIfWhileLoop(Operation *op) {
+LogicalResult checkErrorIfWhileLoop(Operation *op) {
   auto whileOp = dyn_cast<tosa::WhileOp>(op);
   if (!whileOp)
-    return true;
+    return success();
 
-  return succeeded(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) &&
-         succeeded(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body"));
+  if (failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
+      failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body")))
+    return failure();
+  return success();
 }
 
-bool checkErrorIfScatter(Operation *op) {
+LogicalResult checkErrorIfScatter(Operation *op) {
   auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
   if (!scatterOp)
-    return true;
+    return success();
 
   // for constant indices, check that there are no duplicate values
   DenseIntElementsAttr indicesAttr;
   if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
-    return true;
+    return success();
 
   auto const indicesType =
       dyn_cast<ShapedType>(scatterOp.getIndices().getType());
   if (!indicesType || !indicesType.hasRank()) {
     op->emitOpError("expect ranked indices tensor");
-    return false;
+    return failure();
   }
 
   if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
     op->emitOpError("indices values contain duplicates");
-    return false;
+    return failure();
   }
 
-  return true;
+  return success();
 }
 
 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
-  if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
-      !checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
-      !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
-      !checkErrorIfWhileLoop(op) || !checkErrorIfScatter(op))
+  if (failed(checkErrorIfResize(op)) || failed(checkErrorIfMul(op)) ||
+      failed(checkErrorIfTable(op)) || failed(checkErrorIfRescale(op)) ||
+      failed(checkErrorIfPad(op)) || failed(checkErrorIfCondIf(op)) ||
+      failed(checkErrorIfWhileLoop(op)) || failed(checkErrorIfScatter(op)))
     return failure();
   return success();
 }



More information about the Mlir-commits mailing list