[Mlir-commits] [mlir] [mlir][Python] downcast ir.Value to BlockArgument or OpResult (PR #175264)
Maksim Levental
llvmlistbot at llvm.org
Fri Jan 9 20:01:42 PST 2026
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/175264
>From 7c2c7f0ac6e721e42445d45f3ec3fff1e1368aa3 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 9 Jan 2026 15:02:44 -0800
Subject: [PATCH 1/2] [mlir][Python] downcast ir.Value to BlockArgument or
OpResult
---
mlir/include/mlir/Bindings/Python/IRCore.h | 13 ++--
mlir/include/mlir/Bindings/Python/Nanobind.h | 1 +
.../mlir/Bindings/Python/NanobindUtils.h | 2 +-
mlir/lib/Bindings/Python/IRAffine.cpp | 2 +-
mlir/lib/Bindings/Python/IRCore.cpp | 61 +++++++++----------
mlir/test/python/dialects/python_test.py | 2 +-
mlir/test/python/ir/value.py | 22 +++----
7 files changed, 50 insertions(+), 53 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 330318683c15e..bfc35ca5b9d50 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -36,20 +36,21 @@
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-
+class DefaultingPyLocation;
+class DefaultingPyMlirContext;
class PyBlock;
+class PyBlockArgument;
class PyDiagnostic;
class PyDiagnosticHandler;
class PyInsertionPoint;
class PyLocation;
-class DefaultingPyLocation;
class PyMlirContext;
-class DefaultingPyMlirContext;
class PyModule;
+class PyOpResult;
class PyOperation;
class PyOperationBase;
-class PyType;
class PySymbolTable;
+class PyType;
class PyValue;
/// Wrapper for the global LLVM debugging flag.
@@ -1188,7 +1189,9 @@ class MLIR_PYTHON_API_EXPORTED PyValue {
/// Gets a capsule wrapping the void* within the MlirValue.
nanobind::object getCapsule();
- nanobind::typed<nanobind::object, PyValue> maybeDownCast();
+ nanobind::typed<nanobind::object,
+ std::variant<PyBlockArgument, PyOpResult, PyValue>>
+ maybeDownCast();
/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
/// the underlying MlirValue is still tied to the owning operation.
diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h
index 8dc8a0d063d70..f7cf8fd38981d 100644
--- a/mlir/include/mlir/Bindings/Python/Nanobind.h
+++ b/mlir/include/mlir/Bindings/Python/Nanobind.h
@@ -29,6 +29,7 @@
#include <nanobind/stl/string.h>
#include <nanobind/stl/string_view.h>
#include <nanobind/stl/tuple.h>
+#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <nanobind/typing.h>
#if defined(__clang__) || defined(__GNUC__)
diff --git a/mlir/include/mlir/Bindings/Python/NanobindUtils.h b/mlir/include/mlir/Bindings/Python/NanobindUtils.h
index aea195fecae82..215daf245b902 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindUtils.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindUtils.h
@@ -277,7 +277,7 @@ class Sliceable {
/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Returns a nullptr object if out
/// of bounds.
- nanobind::object getItem(intptr_t index) {
+ nanobind::typed<nanobind::object, ElementTy> getItem(intptr_t index) {
// Negative indices mean we count from the end.
index = wrapIndex(index);
if (index < 0) {
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 3c2da03181e6a..2e760e6e6f830 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -614,7 +614,7 @@ void populateIRAffine(nb::module_ &m) {
return PyAffineExpr(self.getContext(),
mlirAffineExprCompose(self, other));
})
- .def("maybe_downcast", &PyAffineExpr::maybeDownCast)
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyAffineExpr::maybeDownCast)
.def(
"shift_dims",
[](PyAffineExpr &self, uint32_t numDims, uint32_t shift,
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 19db41fae4fe2..646490bba366b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -10,6 +10,7 @@
#include "mlir/Bindings/Python/Globals.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/NanobindUtils.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
// clang-format on
#include "mlir-c/BuiltinAttributes.h"
@@ -17,10 +18,6 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
-#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
-#include "nanobind/nanobind.h"
-#include "nanobind/typing.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@@ -2003,7 +2000,22 @@ nb::object PyValue::getCapsule() {
return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
}
-nb::typed<nb::object, PyValue> PyValue::maybeDownCast() {
+static PyOperationRef getValueOwnerRef(MlirValue value) {
+ MlirOperation owner;
+ if (mlirValueIsAOpResult(value))
+ owner = mlirOpResultGetOwner(value);
+ else if (mlirValueIsABlockArgument(value))
+ owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
+ else
+ assert(false && "Value must be an block arg or op result.");
+ if (mlirOperationIsNull(owner))
+ throw nb::python_error();
+ MlirContext ctx = mlirOperationGetContext(owner);
+ return PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
+}
+
+nb::typed<nb::object, std::variant<PyBlockArgument, PyOpResult, PyValue>>
+PyValue::maybeDownCast() {
MlirType type = mlirValueGetType(get());
MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
assert(!mlirTypeIDIsNull(mlirTypeID) &&
@@ -2012,26 +2024,23 @@ nb::typed<nb::object, PyValue> PyValue::maybeDownCast() {
PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
// nb::rv_policy::move means use std::move to move the return value
// contents into a new instance that will be owned by Python.
- nb::object thisObj = nb::cast(this, nb::rv_policy::move);
- if (!valueCaster)
- return thisObj;
- return valueCaster.value()(thisObj);
+ nb::object thisObj;
+ if (mlirValueIsAOpResult(value))
+ thisObj = nb::cast<PyOpResult>(*this, nb::rv_policy::move);
+ else if (mlirValueIsABlockArgument(value))
+ thisObj = nb::cast<PyBlockArgument>(*this, nb::rv_policy::move);
+ else
+ assert(false && "Value must be an block arg or op result.");
+ if (valueCaster)
+ return valueCaster.value()(thisObj);
+ return thisObj;
}
PyValue PyValue::createFromCapsule(nb::object capsule) {
MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
if (mlirValueIsNull(value))
throw nb::python_error();
- MlirOperation owner;
- if (mlirValueIsAOpResult(value))
- owner = mlirOpResultGetOwner(value);
- if (mlirValueIsABlockArgument(value))
- owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
- if (mlirOperationIsNull(owner))
- throw nb::python_error();
- MlirContext ctx = mlirOperationGetContext(owner);
- PyOperationRef ownerRef =
- PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
+ PyOperationRef ownerRef = getValueOwnerRef(value);
return PyValue(ownerRef, value);
}
@@ -2279,15 +2288,7 @@ intptr_t PyOpOperandList::getRawNumElements() {
PyValue PyOpOperandList::getRawElement(intptr_t pos) {
MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
- MlirOperation owner;
- if (mlirValueIsAOpResult(operand))
- owner = mlirOpResultGetOwner(operand);
- else if (mlirValueIsABlockArgument(operand))
- owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
- else
- assert(false && "Value must be an block arg or op result.");
- PyOperationRef pyOwner =
- PyOperation::forOperation(operation->getContext(), owner);
+ PyOperationRef pyOwner = getValueOwnerRef(operand);
return PyValue(pyOwner, operand);
}
@@ -4679,9 +4680,7 @@ void populateIRCore(nb::module_ &m) {
"with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring)
.def(
MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyValue &self) -> nb::typed<nb::object, PyValue> {
- return self.maybeDownCast();
- },
+ [](PyValue &self) { return self.maybeDownCast(); },
"Downcasts the `Value` to a more specific kind if possible.")
.def_prop_ro(
"location",
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 9c0966d2d8798..8814f4d66fdf9 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -862,7 +862,7 @@ def values(lst):
]
is Value
)
- assert type(variadic_operands.non_variadic) is Value
+ assert type(variadic_operands.non_variadic) is OpResult
# CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
print(values(variadic_operands.variadic1))
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 45efb880bab44..4cd1c2fa96c27 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -347,13 +347,6 @@ def __init__(self, v):
def __str__(self):
return super().__str__().replace(Value.__name__, NOPResult.__name__)
- class NOPValue(Value):
- def __init__(self, v):
- super().__init__(v)
-
- def __str__(self):
- return super().__str__().replace(Value.__name__, NOPValue.__name__)
-
class NOPBlockArg(BlockArgument):
def __init__(self, v):
super().__init__(v)
@@ -362,14 +355,12 @@ def __str__(self):
return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
@register_value_caster(IntegerType.static_typeid)
- def cast_int(v) -> Value:
+ def cast_int(v) -> NOPResult | NOPBlockArg:
print("in caster", v.__class__.__name__)
if isinstance(v, OpResult):
return NOPResult(v)
if isinstance(v, BlockArgument):
return NOPBlockArg(v)
- elif isinstance(v, Value):
- return NOPValue(v)
ctx = Context()
ctx.allow_unregistered_dialects = True
@@ -400,12 +391,15 @@ def cast_int(v) -> Value:
# CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> ()
print(op1)
- # CHECK: in caster Value
- # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+ # CHECK: in caster OpResult
+ # CHECK: operand 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
print("operand 0", op1.operands[0])
- # CHECK: in caster Value
- # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
+ assert isinstance(op1.operands[0], Value)
+ assert isinstance(op1.operands[0], OpResult)
+ # CHECK: in caster OpResult
+ # CHECK: operand 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
print("operand 1", op1.operands[1])
+ assert isinstance(op1.operands[1], OpResult)
# CHECK: in caster BlockArgument
# CHECK: in caster BlockArgument
>From 0ce9e78d83660823ba074584c4bc36dbdbb9b932 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Fri, 9 Jan 2026 20:01:29 -0800
Subject: [PATCH 2/2] fix ifs
---
mlir/lib/Bindings/Python/IRCore.cpp | 98 ++++++++++++-----------------
1 file changed, 39 insertions(+), 59 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 646490bba366b..6994cb2905589 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -99,11 +99,12 @@ MlirBlock createBlock(const nb::sequence &pyArgTypes,
python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyLocation::resolve());
}
- if (argTypes.size() != argLocs.size())
+ if (argTypes.size() != argLocs.size()) {
throw nb::value_error(("Expected " + Twine(argTypes.size()) +
" locations, got: " + Twine(argLocs.size()))
.str()
.c_str());
+ }
return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
}
@@ -193,9 +194,8 @@ nb::object PyBlock::getCapsule() {
PyRegion PyRegionIterator::dunderNext() {
operation->checkValid();
- if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
+ if (nextIndex >= mlirOperationGetNumRegions(operation->get()))
throw nb::stop_iteration();
- }
MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
return PyRegion(operation, region);
}
@@ -243,9 +243,8 @@ PyRegionList PyRegionList::slice(intptr_t startIndex, intptr_t length,
PyBlock PyBlockIterator::dunderNext() {
operation->checkValid();
- if (mlirBlockIsNull(next)) {
+ if (mlirBlockIsNull(next))
throw nb::stop_iteration();
- }
PyBlock returnBlock(operation, next);
next = mlirBlockGetNextInRegion(next);
@@ -278,17 +277,14 @@ intptr_t PyBlockList::dunderLen() {
PyBlock PyBlockList::dunderGetItem(intptr_t index) {
operation->checkValid();
- if (index < 0) {
+ if (index < 0)
index += dunderLen();
- }
- if (index < 0) {
+ if (index < 0)
throw nb::index_error("attempt to access out of bounds block");
- }
MlirBlock block = mlirRegionGetFirstBlock(region);
while (!mlirBlockIsNull(block)) {
- if (index == 0) {
+ if (index == 0)
return PyBlock(operation, block);
- }
block = mlirBlockGetNextInRegion(block);
index -= 1;
}
@@ -323,9 +319,8 @@ void PyBlockList::bind(nb::module_ &m) {
nb::typed<nb::object, PyOpView> PyOperationIterator::dunderNext() {
parentOperation->checkValid();
- if (mlirOperationIsNull(next)) {
+ if (mlirOperationIsNull(next))
throw nb::stop_iteration();
- }
PyOperationRef returnOperation =
PyOperation::forOperation(parentOperation->getContext(), next);
@@ -360,12 +355,10 @@ intptr_t PyOperationList::dunderLen() {
nb::typed<nb::object, PyOpView> PyOperationList::dunderGetItem(intptr_t index) {
parentOperation->checkValid();
- if (index < 0) {
+ if (index < 0)
index += dunderLen();
- }
- if (index < 0) {
+ if (index < 0)
throw nb::index_error("attempt to access out of bounds operation");
- }
MlirOperation childOp = mlirBlockGetFirstOperation(block);
while (!mlirOperationIsNull(childOp)) {
if (index == 0) {
@@ -572,9 +565,7 @@ MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
// Check if the context requested we emit errors instead of capturing them.
if (self->ctx->emitErrorDiagnostics)
return mlirLogicalResultFailure();
-
- if (mlirDiagnosticGetSeverity(diag) !=
- MlirDiagnosticSeverity::MlirDiagnosticError)
+ if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
return mlirLogicalResultFailure();
self->errors.emplace_back(PyDiagnostic(diag).getInfo());
@@ -699,8 +690,9 @@ void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
auto &tos = stack.back();
if (tos.frameKind != FrameKind::InsertionPoint &&
- tos.getInsertionPoint() != &insertionPoint)
+ tos.getInsertionPoint() != &insertionPoint) {
throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
+ }
stack.pop_back();
}
@@ -967,9 +959,8 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
PyOperationRef unownedOperation =
makeObjectRef<PyOperation>(std::move(contextRef), operation);
unownedOperation->handle = unownedOperation.getObject();
- if (parentKeepAlive) {
+ if (parentKeepAlive)
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
- }
return unownedOperation;
}
@@ -1027,9 +1018,8 @@ void PyOperation::setDetached() {
}
void PyOperation::checkValid() const {
- if (!valid) {
+ if (!valid)
throw std::runtime_error("the operation has been invalidated");
- }
}
void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
@@ -1095,11 +1085,12 @@ void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
operation, config, accum.getCallback(), accum.getUserData());
mlirBytecodeWriterConfigDestroy(config);
- if (mlirLogicalResultIsFailure(res))
+ if (mlirLogicalResultIsFailure(res)) {
throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
Twine(*bytecodeVersion))
.str()
.c_str());
+ }
}
void PyOperationBase::walk(std::function<PyWalkResult(MlirOperation)> callback,
@@ -1142,11 +1133,10 @@ nb::object PyOperationBase::getAsm(bool binary,
bool useNameLocAsPrefix, bool assumeVerified,
bool skipRegions) {
nb::object fileObject;
- if (binary) {
+ if (binary)
fileObject = nb::module_::import_("io").attr("BytesIO")();
- } else {
+ else
fileObject = nb::module_::import_("io").attr("StringIO")();
- }
print(/*largeElementsLimit=*/largeElementsLimit,
/*largeResourceLimit=*/largeResourceLimit,
/*enableDebugInfo=*/enableDebugInfo,
@@ -1532,9 +1522,8 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList,
auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
for (nb::handle segmentItem : segment) {
resultTypes.push_back(nb::cast<PyType *>(segmentItem));
- if (!resultTypes.back()) {
+ if (!resultTypes.back())
throw nb::type_error("contained a None item");
- }
}
resultSegmentLengths.push_back(nb::len(segment));
}
@@ -1572,21 +1561,17 @@ MlirValue getUniqueResult(MlirOperation operation) {
}
static MlirValue getOpResultOrValue(nb::handle operand) {
- if (operand.is_none()) {
+ if (operand.is_none())
throw nb::value_error("contained a None item");
- }
PyOperationBase *op;
- if (nb::try_cast<PyOperationBase *>(operand, op)) {
+ if (nb::try_cast<PyOperationBase *>(operand, op))
return getUniqueResult(op->getOperation());
- }
PyOpResultList *opResultList;
- if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
+ if (nb::try_cast<PyOpResultList *>(operand, opResultList))
return getUniqueResult(opResultList->getOperation()->get());
- }
PyValue *value;
- if (nb::try_cast<PyValue *>(operand, value)) {
+ if (nb::try_cast<PyValue *>(operand, value))
return value->get();
- }
throw nb::value_error("is not a Value");
}
@@ -1612,9 +1597,8 @@ nb::object PyOpView::buildGeneric(
// Validate/determine region count.
int opMinRegionCount = std::get<0>(opRegionSpec);
bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
- if (!regions) {
+ if (!regions)
regions = opMinRegionCount;
- }
if (*regions < opMinRegionCount) {
throw nb::value_error(
(llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
@@ -1676,7 +1660,6 @@ nb::object PyOpView::buildGeneric(
auto &operand = std::get<0>(it.value());
if (!operand.is_none()) {
try {
-
operands.push_back(getOpResultOrValue(operand));
} catch (nb::builtin_exception &err) {
throw nb::value_error((llvm::Twine("Operand ") +
@@ -1686,7 +1669,6 @@ nb::object PyOpView::buildGeneric(
.str()
.c_str());
}
-
operandSegmentLengths.push_back(1);
} else if (segmentSpec == 0) {
// Allowed to be optional.
@@ -1733,11 +1715,10 @@ nb::object PyOpView::buildGeneric(
// Merge operand/result segment lengths into attributes if needed.
if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
// Dup.
- if (attributes) {
+ if (attributes)
attributes = nb::dict(*attributes);
- } else {
+ else
attributes = nb::dict();
- }
if (attributes->contains("resultSegmentSizes") ||
attributes->contains("operandSegmentSizes")) {
throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
@@ -1825,9 +1806,10 @@ PyInsertionPoint::PyInsertionPoint(PyOperationRef beforeOperationRef)
void PyInsertionPoint::insert(PyOperationBase &operationBase) {
PyOperation &operation = operationBase.getOperation();
- if (operation.isAttached())
+ if (operation.isAttached()) {
throw nb::value_error(
"Attempt to insert operation that is already attached");
+ }
block.getParentOperation()->checkValid();
MlirOperation beforeOp = {nullptr};
if (refOperation) {
@@ -2051,18 +2033,18 @@ PyValue PyValue::createFromCapsule(nb::object capsule) {
PySymbolTable::PySymbolTable(PyOperationBase &operation)
: operation(operation.getOperation().getRef()) {
symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
- if (mlirSymbolTableIsNull(symbolTable)) {
+ if (mlirSymbolTableIsNull(symbolTable))
throw nb::type_error("Operation is not a Symbol Table.");
- }
}
nb::object PySymbolTable::dunderGetItem(const std::string &name) {
operation->checkValid();
MlirOperation symbol = mlirSymbolTableLookup(
symbolTable, mlirStringRefCreate(name.data(), name.length()));
- if (mlirOperationIsNull(symbol))
+ if (mlirOperationIsNull(symbol)) {
throw nb::key_error(
("Symbol '" + name + "' not in the symbol table.").c_str());
+ }
return PyOperation::forOperation(operation->getContext(), symbol,
operation.getObject())
@@ -2138,9 +2120,10 @@ PyStringAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
void PySymbolTable::setVisibility(PyOperationBase &symbol,
const std::string &visibility) {
if (visibility != "public" && visibility != "private" &&
- visibility != "nested")
+ visibility != "nested") {
throw nb::value_error(
"Expected visibility to be 'public', 'private' or 'nested'");
+ }
PyOperation &operation = symbol.getOperation();
operation.checkValid();
MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
@@ -2160,9 +2143,9 @@ void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
fromOperation.checkValid();
if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
- from.getOperation())))
-
+ from.getOperation()))) {
throw nb::value_error("Symbol rename failed");
+ }
}
void PySymbolTable::walkSymbolTables(PyOperationBase &from,
@@ -2383,19 +2366,16 @@ nb::typed<nb::object, PyAttribute>
PyOpAttributeMap::dunderGetItemNamed(const std::string &name) {
MlirAttribute attr =
mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name));
- if (mlirAttributeIsNull(attr)) {
+ if (mlirAttributeIsNull(attr))
throw nb::key_error("attempt to access a non-existent attribute");
- }
return PyAttribute(operation->getContext(), attr).maybeDownCast();
}
PyNamedAttribute PyOpAttributeMap::dunderGetItemIndexed(intptr_t index) {
- if (index < 0) {
+ if (index < 0)
index += dunderLen();
- }
- if (index < 0 || index >= dunderLen()) {
+ if (index < 0 || index >= dunderLen())
throw nb::index_error("attempt to access out of bounds attribute");
- }
MlirNamedAttribute namedAttr =
mlirOperationGetAttribute(operation->get(), index);
return PyNamedAttribute(
More information about the Mlir-commits
mailing list