[Mlir-commits] [mlir] [mlir][Python] downcast ir.Value to BlockArgument or OpResult (PR #175264)

Maksim Levental llvmlistbot at llvm.org
Fri Jan 9 20:01:42 PST 2026


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/175264

>From 7c2c7f0ac6e721e42445d45f3ec3fff1e1368aa3 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 9 Jan 2026 15:02:44 -0800
Subject: [PATCH 1/2] [mlir][Python] downcast ir.Value to BlockArgument or
 OpResult

---
 mlir/include/mlir/Bindings/Python/IRCore.h    | 13 ++--
 mlir/include/mlir/Bindings/Python/Nanobind.h  |  1 +
 .../mlir/Bindings/Python/NanobindUtils.h      |  2 +-
 mlir/lib/Bindings/Python/IRAffine.cpp         |  2 +-
 mlir/lib/Bindings/Python/IRCore.cpp           | 61 +++++++++----------
 mlir/test/python/dialects/python_test.py      |  2 +-
 mlir/test/python/ir/value.py                  | 22 +++----
 7 files changed, 50 insertions(+), 53 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 330318683c15e..bfc35ca5b9d50 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -36,20 +36,21 @@
 namespace mlir {
 namespace python {
 namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-
+class DefaultingPyLocation;
+class DefaultingPyMlirContext;
 class PyBlock;
+class PyBlockArgument;
 class PyDiagnostic;
 class PyDiagnosticHandler;
 class PyInsertionPoint;
 class PyLocation;
-class DefaultingPyLocation;
 class PyMlirContext;
-class DefaultingPyMlirContext;
 class PyModule;
+class PyOpResult;
 class PyOperation;
 class PyOperationBase;
-class PyType;
 class PySymbolTable;
+class PyType;
 class PyValue;
 
 /// Wrapper for the global LLVM debugging flag.
@@ -1188,7 +1189,9 @@ class MLIR_PYTHON_API_EXPORTED PyValue {
   /// Gets a capsule wrapping the void* within the MlirValue.
   nanobind::object getCapsule();
 
-  nanobind::typed<nanobind::object, PyValue> maybeDownCast();
+  nanobind::typed<nanobind::object,
+                  std::variant<PyBlockArgument, PyOpResult, PyValue>>
+  maybeDownCast();
 
   /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
   /// the underlying MlirValue is still tied to the owning operation.
diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h
index 8dc8a0d063d70..f7cf8fd38981d 100644
--- a/mlir/include/mlir/Bindings/Python/Nanobind.h
+++ b/mlir/include/mlir/Bindings/Python/Nanobind.h
@@ -29,6 +29,7 @@
 #include <nanobind/stl/string.h>
 #include <nanobind/stl/string_view.h>
 #include <nanobind/stl/tuple.h>
+#include <nanobind/stl/variant.h>
 #include <nanobind/stl/vector.h>
 #include <nanobind/typing.h>
 #if defined(__clang__) || defined(__GNUC__)
diff --git a/mlir/include/mlir/Bindings/Python/NanobindUtils.h b/mlir/include/mlir/Bindings/Python/NanobindUtils.h
index aea195fecae82..215daf245b902 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindUtils.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindUtils.h
@@ -277,7 +277,7 @@ class Sliceable {
   /// Returns the element at the given slice index. Supports negative indices
   /// by taking elements in inverse order. Returns a nullptr object if out
   /// of bounds.
-  nanobind::object getItem(intptr_t index) {
+  nanobind::typed<nanobind::object, ElementTy> getItem(intptr_t index) {
     // Negative indices mean we count from the end.
     index = wrapIndex(index);
     if (index < 0) {
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 3c2da03181e6a..2e760e6e6f830 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -614,7 +614,7 @@ void populateIRAffine(nb::module_ &m) {
              return PyAffineExpr(self.getContext(),
                                  mlirAffineExprCompose(self, other));
            })
-      .def("maybe_downcast", &PyAffineExpr::maybeDownCast)
+      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyAffineExpr::maybeDownCast)
       .def(
           "shift_dims",
           [](PyAffineExpr &self, uint32_t numDims, uint32_t shift,
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 19db41fae4fe2..646490bba366b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Bindings/Python/Globals.h"
 #include "mlir/Bindings/Python/IRCore.h"
 #include "mlir/Bindings/Python/NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 // clang-format on
 #include "mlir-c/BuiltinAttributes.h"
@@ -17,10 +18,6 @@
 #include "mlir-c/Diagnostics.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
-#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
-#include "nanobind/nanobind.h"
-#include "nanobind/typing.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
 
@@ -2003,7 +2000,22 @@ nb::object PyValue::getCapsule() {
   return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
 }
 
-nb::typed<nb::object, PyValue> PyValue::maybeDownCast() {
+static PyOperationRef getValueOwnerRef(MlirValue value) {
+  MlirOperation owner;
+  if (mlirValueIsAOpResult(value))
+    owner = mlirOpResultGetOwner(value);
+  else if (mlirValueIsABlockArgument(value))
+    owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
+  else
+    assert(false && "Value must be an block arg or op result.");
+  if (mlirOperationIsNull(owner))
+    throw nb::python_error();
+  MlirContext ctx = mlirOperationGetContext(owner);
+  return PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
+}
+
+nb::typed<nb::object, std::variant<PyBlockArgument, PyOpResult, PyValue>>
+PyValue::maybeDownCast() {
   MlirType type = mlirValueGetType(get());
   MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
   assert(!mlirTypeIDIsNull(mlirTypeID) &&
@@ -2012,26 +2024,23 @@ nb::typed<nb::object, PyValue> PyValue::maybeDownCast() {
       PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
   // nb::rv_policy::move means use std::move to move the return value
   // contents into a new instance that will be owned by Python.
-  nb::object thisObj = nb::cast(this, nb::rv_policy::move);
-  if (!valueCaster)
-    return thisObj;
-  return valueCaster.value()(thisObj);
+  nb::object thisObj;
+  if (mlirValueIsAOpResult(value))
+    thisObj = nb::cast<PyOpResult>(*this, nb::rv_policy::move);
+  else if (mlirValueIsABlockArgument(value))
+    thisObj = nb::cast<PyBlockArgument>(*this, nb::rv_policy::move);
+  else
+    assert(false && "Value must be an block arg or op result.");
+  if (valueCaster)
+    return valueCaster.value()(thisObj);
+  return thisObj;
 }
 
 PyValue PyValue::createFromCapsule(nb::object capsule) {
   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
   if (mlirValueIsNull(value))
     throw nb::python_error();
-  MlirOperation owner;
-  if (mlirValueIsAOpResult(value))
-    owner = mlirOpResultGetOwner(value);
-  if (mlirValueIsABlockArgument(value))
-    owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
-  if (mlirOperationIsNull(owner))
-    throw nb::python_error();
-  MlirContext ctx = mlirOperationGetContext(owner);
-  PyOperationRef ownerRef =
-      PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
+  PyOperationRef ownerRef = getValueOwnerRef(value);
   return PyValue(ownerRef, value);
 }
 
@@ -2279,15 +2288,7 @@ intptr_t PyOpOperandList::getRawNumElements() {
 
 PyValue PyOpOperandList::getRawElement(intptr_t pos) {
   MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
-  MlirOperation owner;
-  if (mlirValueIsAOpResult(operand))
-    owner = mlirOpResultGetOwner(operand);
-  else if (mlirValueIsABlockArgument(operand))
-    owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
-  else
-    assert(false && "Value must be an block arg or op result.");
-  PyOperationRef pyOwner =
-      PyOperation::forOperation(operation->getContext(), owner);
+  PyOperationRef pyOwner = getValueOwnerRef(operand);
   return PyValue(pyOwner, operand);
 }
 
@@ -4679,9 +4680,7 @@ void populateIRCore(nb::module_ &m) {
           "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
       .def(
           MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
-          [](PyValue &self) -> nb::typed<nb::object, PyValue> {
-            return self.maybeDownCast();
-          },
+          [](PyValue &self) { return self.maybeDownCast(); },
           "Downcasts the `Value` to a more specific kind if possible.")
       .def_prop_ro(
           "location",
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 9c0966d2d8798..8814f4d66fdf9 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -862,7 +862,7 @@ def values(lst):
                 ]
                 is Value
             )
-            assert type(variadic_operands.non_variadic) is Value
+            assert type(variadic_operands.non_variadic) is OpResult
 
             # CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
             print(values(variadic_operands.variadic1))
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 45efb880bab44..4cd1c2fa96c27 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -347,13 +347,6 @@ def __init__(self, v):
         def __str__(self):
             return super().__str__().replace(Value.__name__, NOPResult.__name__)
 
-    class NOPValue(Value):
-        def __init__(self, v):
-            super().__init__(v)
-
-        def __str__(self):
-            return super().__str__().replace(Value.__name__, NOPValue.__name__)
-
     class NOPBlockArg(BlockArgument):
         def __init__(self, v):
             super().__init__(v)
@@ -362,14 +355,12 @@ def __str__(self):
             return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
 
     @register_value_caster(IntegerType.static_typeid)
-    def cast_int(v) -> Value:
+    def cast_int(v) -> NOPResult | NOPBlockArg:
         print("in caster", v.__class__.__name__)
         if isinstance(v, OpResult):
             return NOPResult(v)
         if isinstance(v, BlockArgument):
             return NOPBlockArg(v)
-        elif isinstance(v, Value):
-            return NOPValue(v)
 
     ctx = Context()
     ctx.allow_unregistered_dialects = True
@@ -400,12 +391,15 @@ def cast_int(v) -> Value:
             # CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> ()
             print(op1)
 
-            # CHECK: in caster Value
-            # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+            # CHECK: in caster OpResult
+            # CHECK: operand 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
             print("operand 0", op1.operands[0])
-            # CHECK: in caster Value
-            # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+            assert isinstance(op1.operands[0], Value)
+            assert isinstance(op1.operands[0], OpResult)
+            # CHECK: in caster OpResult
+            # CHECK: operand 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
             print("operand 1", op1.operands[1])
+            assert isinstance(op1.operands[1], OpResult)
 
             # CHECK: in caster BlockArgument
             # CHECK: in caster BlockArgument

>From 0ce9e78d83660823ba074584c4bc36dbdbb9b932 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 9 Jan 2026 20:01:29 -0800
Subject: [PATCH 2/2] fix ifs

---
 mlir/lib/Bindings/Python/IRCore.cpp | 98 ++++++++++++-----------------
 1 file changed, 39 insertions(+), 59 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 646490bba366b..6994cb2905589 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -99,11 +99,12 @@ MlirBlock createBlock(const nb::sequence &pyArgTypes,
         python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyLocation::resolve());
   }
 
-  if (argTypes.size() != argLocs.size())
+  if (argTypes.size() != argLocs.size()) {
     throw nb::value_error(("Expected " + Twine(argTypes.size()) +
                            " locations, got: " + Twine(argLocs.size()))
                               .str()
                               .c_str());
+  }
   return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
 }
 
@@ -193,9 +194,8 @@ nb::object PyBlock::getCapsule() {
 
 PyRegion PyRegionIterator::dunderNext() {
   operation->checkValid();
-  if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
+  if (nextIndex >= mlirOperationGetNumRegions(operation->get()))
     throw nb::stop_iteration();
-  }
   MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
   return PyRegion(operation, region);
 }
@@ -243,9 +243,8 @@ PyRegionList PyRegionList::slice(intptr_t startIndex, intptr_t length,
 
 PyBlock PyBlockIterator::dunderNext() {
   operation->checkValid();
-  if (mlirBlockIsNull(next)) {
+  if (mlirBlockIsNull(next))
     throw nb::stop_iteration();
-  }
 
   PyBlock returnBlock(operation, next);
   next = mlirBlockGetNextInRegion(next);
@@ -278,17 +277,14 @@ intptr_t PyBlockList::dunderLen() {
 
 PyBlock PyBlockList::dunderGetItem(intptr_t index) {
   operation->checkValid();
-  if (index < 0) {
+  if (index < 0)
     index += dunderLen();
-  }
-  if (index < 0) {
+  if (index < 0)
     throw nb::index_error("attempt to access out of bounds block");
-  }
   MlirBlock block = mlirRegionGetFirstBlock(region);
   while (!mlirBlockIsNull(block)) {
-    if (index == 0) {
+    if (index == 0)
       return PyBlock(operation, block);
-    }
     block = mlirBlockGetNextInRegion(block);
     index -= 1;
   }
@@ -323,9 +319,8 @@ void PyBlockList::bind(nb::module_ &m) {
 
 nb::typed<nb::object, PyOpView> PyOperationIterator::dunderNext() {
   parentOperation->checkValid();
-  if (mlirOperationIsNull(next)) {
+  if (mlirOperationIsNull(next))
     throw nb::stop_iteration();
-  }
 
   PyOperationRef returnOperation =
       PyOperation::forOperation(parentOperation->getContext(), next);
@@ -360,12 +355,10 @@ intptr_t PyOperationList::dunderLen() {
 
 nb::typed<nb::object, PyOpView> PyOperationList::dunderGetItem(intptr_t index) {
   parentOperation->checkValid();
-  if (index < 0) {
+  if (index < 0)
     index += dunderLen();
-  }
-  if (index < 0) {
+  if (index < 0)
     throw nb::index_error("attempt to access out of bounds operation");
-  }
   MlirOperation childOp = mlirBlockGetFirstOperation(block);
   while (!mlirOperationIsNull(childOp)) {
     if (index == 0) {
@@ -572,9 +565,7 @@ MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
   // Check if the context requested we emit errors instead of capturing them.
   if (self->ctx->emitErrorDiagnostics)
     return mlirLogicalResultFailure();
-
-  if (mlirDiagnosticGetSeverity(diag) !=
-      MlirDiagnosticSeverity::MlirDiagnosticError)
+  if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
     return mlirLogicalResultFailure();
 
   self->errors.emplace_back(PyDiagnostic(diag).getInfo());
@@ -699,8 +690,9 @@ void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
     throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
   auto &tos = stack.back();
   if (tos.frameKind != FrameKind::InsertionPoint &&
-      tos.getInsertionPoint() != &insertionPoint)
+      tos.getInsertionPoint() != &insertionPoint) {
     throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
+  }
   stack.pop_back();
 }
 
@@ -967,9 +959,8 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
   PyOperationRef unownedOperation =
       makeObjectRef<PyOperation>(std::move(contextRef), operation);
   unownedOperation->handle = unownedOperation.getObject();
-  if (parentKeepAlive) {
+  if (parentKeepAlive)
     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
-  }
   return unownedOperation;
 }
 
@@ -1027,9 +1018,8 @@ void PyOperation::setDetached() {
 }
 
 void PyOperation::checkValid() const {
-  if (!valid) {
+  if (!valid)
     throw std::runtime_error("the operation has been invalidated");
-  }
 }
 
 void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
@@ -1095,11 +1085,12 @@ void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
   MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
       operation, config, accum.getCallback(), accum.getUserData());
   mlirBytecodeWriterConfigDestroy(config);
-  if (mlirLogicalResultIsFailure(res))
+  if (mlirLogicalResultIsFailure(res)) {
     throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
                            Twine(*bytecodeVersion))
                               .str()
                               .c_str());
+  }
 }
 
 void PyOperationBase::walk(std::function<PyWalkResult(MlirOperation)> callback,
@@ -1142,11 +1133,10 @@ nb::object PyOperationBase::getAsm(bool binary,
                                    bool useNameLocAsPrefix, bool assumeVerified,
                                    bool skipRegions) {
   nb::object fileObject;
-  if (binary) {
+  if (binary)
     fileObject = nb::module_::import_("io").attr("BytesIO")();
-  } else {
+  else
     fileObject = nb::module_::import_("io").attr("StringIO")();
-  }
   print(/*largeElementsLimit=*/largeElementsLimit,
         /*largeResourceLimit=*/largeResourceLimit,
         /*enableDebugInfo=*/enableDebugInfo,
@@ -1532,9 +1522,8 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList,
             auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
             for (nb::handle segmentItem : segment) {
               resultTypes.push_back(nb::cast<PyType *>(segmentItem));
-              if (!resultTypes.back()) {
+              if (!resultTypes.back())
                 throw nb::type_error("contained a None item");
-              }
             }
             resultSegmentLengths.push_back(nb::len(segment));
           }
@@ -1572,21 +1561,17 @@ MlirValue getUniqueResult(MlirOperation operation) {
 }
 
 static MlirValue getOpResultOrValue(nb::handle operand) {
-  if (operand.is_none()) {
+  if (operand.is_none())
     throw nb::value_error("contained a None item");
-  }
   PyOperationBase *op;
-  if (nb::try_cast<PyOperationBase *>(operand, op)) {
+  if (nb::try_cast<PyOperationBase *>(operand, op))
     return getUniqueResult(op->getOperation());
-  }
   PyOpResultList *opResultList;
-  if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
+  if (nb::try_cast<PyOpResultList *>(operand, opResultList))
     return getUniqueResult(opResultList->getOperation()->get());
-  }
   PyValue *value;
-  if (nb::try_cast<PyValue *>(operand, value)) {
+  if (nb::try_cast<PyValue *>(operand, value))
     return value->get();
-  }
   throw nb::value_error("is not a Value");
 }
 
@@ -1612,9 +1597,8 @@ nb::object PyOpView::buildGeneric(
   // Validate/determine region count.
   int opMinRegionCount = std::get<0>(opRegionSpec);
   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
-  if (!regions) {
+  if (!regions)
     regions = opMinRegionCount;
-  }
   if (*regions < opMinRegionCount) {
     throw nb::value_error(
         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
@@ -1676,7 +1660,6 @@ nb::object PyOpView::buildGeneric(
         auto &operand = std::get<0>(it.value());
         if (!operand.is_none()) {
           try {
-
             operands.push_back(getOpResultOrValue(operand));
           } catch (nb::builtin_exception &err) {
             throw nb::value_error((llvm::Twine("Operand ") +
@@ -1686,7 +1669,6 @@ nb::object PyOpView::buildGeneric(
                                       .str()
                                       .c_str());
           }
-
           operandSegmentLengths.push_back(1);
         } else if (segmentSpec == 0) {
           // Allowed to be optional.
@@ -1733,11 +1715,10 @@ nb::object PyOpView::buildGeneric(
   // Merge operand/result segment lengths into attributes if needed.
   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
     // Dup.
-    if (attributes) {
+    if (attributes)
       attributes = nb::dict(*attributes);
-    } else {
+    else
       attributes = nb::dict();
-    }
     if (attributes->contains("resultSegmentSizes") ||
         attributes->contains("operandSegmentSizes")) {
       throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
@@ -1825,9 +1806,10 @@ PyInsertionPoint::PyInsertionPoint(PyOperationRef beforeOperationRef)
 
 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
   PyOperation &operation = operationBase.getOperation();
-  if (operation.isAttached())
+  if (operation.isAttached()) {
     throw nb::value_error(
         "Attempt to insert operation that is already attached");
+  }
   block.getParentOperation()->checkValid();
   MlirOperation beforeOp = {nullptr};
   if (refOperation) {
@@ -2051,18 +2033,18 @@ PyValue PyValue::createFromCapsule(nb::object capsule) {
 PySymbolTable::PySymbolTable(PyOperationBase &operation)
     : operation(operation.getOperation().getRef()) {
   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
-  if (mlirSymbolTableIsNull(symbolTable)) {
+  if (mlirSymbolTableIsNull(symbolTable))
     throw nb::type_error("Operation is not a Symbol Table.");
-  }
 }
 
 nb::object PySymbolTable::dunderGetItem(const std::string &name) {
   operation->checkValid();
   MlirOperation symbol = mlirSymbolTableLookup(
       symbolTable, mlirStringRefCreate(name.data(), name.length()));
-  if (mlirOperationIsNull(symbol))
+  if (mlirOperationIsNull(symbol)) {
     throw nb::key_error(
         ("Symbol '" + name + "' not in the symbol table.").c_str());
+  }
 
   return PyOperation::forOperation(operation->getContext(), symbol,
                                    operation.getObject())
@@ -2138,9 +2120,10 @@ PyStringAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
 void PySymbolTable::setVisibility(PyOperationBase &symbol,
                                   const std::string &visibility) {
   if (visibility != "public" && visibility != "private" &&
-      visibility != "nested")
+      visibility != "nested") {
     throw nb::value_error(
         "Expected visibility to be 'public', 'private' or 'nested'");
+  }
   PyOperation &operation = symbol.getOperation();
   operation.checkValid();
   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
@@ -2160,9 +2143,9 @@ void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
   fromOperation.checkValid();
   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
-          from.getOperation())))
-
+          from.getOperation()))) {
     throw nb::value_error("Symbol rename failed");
+  }
 }
 
 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
@@ -2383,19 +2366,16 @@ nb::typed<nb::object, PyAttribute>
 PyOpAttributeMap::dunderGetItemNamed(const std::string &name) {
   MlirAttribute attr =
       mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name));
-  if (mlirAttributeIsNull(attr)) {
+  if (mlirAttributeIsNull(attr))
     throw nb::key_error("attempt to access a non-existent attribute");
-  }
   return PyAttribute(operation->getContext(), attr).maybeDownCast();
 }
 
 PyNamedAttribute PyOpAttributeMap::dunderGetItemIndexed(intptr_t index) {
-  if (index < 0) {
+  if (index < 0)
     index += dunderLen();
-  }
-  if (index < 0 || index >= dunderLen()) {
+  if (index < 0 || index >= dunderLen())
     throw nb::index_error("attempt to access out of bounds attribute");
-  }
   MlirNamedAttribute namedAttr =
       mlirOperationGetAttribute(operation->get(), index);
   return PyNamedAttribute(



More information about the Mlir-commits mailing list