[Mlir-commits] [mlir] [MLIR][Python] Use ir.Value directly instead of _SubClassValueT (PR #82341)
Sergei Lebedev
llvmlistbot at llvm.org
Tue Feb 20 03:19:26 PST 2024
https://github.com/superbobry updated https://github.com/llvm/llvm-project/pull/82341
>From 44792946675940a0d240c12aa9bce81138042a63 Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Tue, 20 Feb 2024 11:09:08 +0000
Subject: [PATCH] [MLIR][Python] Use ir.Value directly instead of
_SubClassValueT
_SubClassValueT is only useful when it is has >1 usage in a signature.
This was not true for the signatures produced by tblgen.
For example
def call(result, callee, operands_, *, loc=None, ip=None) -> _SubClassValueT:
...
here a type checker does not have enough information to infer a type argument
for _SubClassValueT, and thus effectively treats it as Any.
---
mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi | 2 +-
mlir/python/mlir/dialects/_ods_common.py | 7 -------
mlir/python/mlir/dialects/arith.py | 3 +--
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 18 +++++++-----------
4 files changed, 9 insertions(+), 21 deletions(-)
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
index 3ed1872f1cd5a2..93b978c75540f4 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
@@ -10,4 +10,4 @@ class _Globals:
def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...
def register_dialect(dialect_class: type) -> object: ...
-def register_operation(dialect_class: type) -> object: ...
+def register_operation(dialect_class: type, *, replace: bool = ...) -> object: ...
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 3af3b5ce73bc60..1e7e8244ed4420 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -8,7 +8,6 @@
Sequence as _Sequence,
Tuple as _Tuple,
Type as _Type,
- TypeVar as _TypeVar,
Union as _Union,
)
@@ -143,12 +142,6 @@ def get_op_result_or_op_results(
else op
)
-
-# This is the standard way to indicate subclass/inheritance relationship
-# see the typing.Type doc string.
-_U = _TypeVar("_U", bound=_cext.ir.Value)
-SubClassValueT = _Type[_U]
-
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
ResultValueT = _Union[ResultValueTypeTuple]
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 663a53660a6474..61c6917393f1f9 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -12,7 +12,6 @@
get_default_loc_context as _get_default_loc_context,
_cext as _ods_cext,
get_op_result_or_op_results as _get_op_result_or_op_results,
- SubClassValueT as _SubClassValueT,
)
from typing import Any, List, Union
@@ -81,5 +80,5 @@ def literal_value(self) -> Union[int, float]:
def constant(
result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
-) -> _SubClassValueT:
+) -> Value:
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0770ed562309e7..6c06b86fdf751f 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -31,7 +31,6 @@ constexpr const char *fileHeader = R"Py(
from ._ods_common import _cext as _ods_cext
from ._ods_common import (
- SubClassValueT as _SubClassValueT,
equally_sized_accessor as _ods_equally_sized_accessor,
get_default_loc_context as _ods_get_default_loc_context,
get_op_result_or_op_results as _get_op_result_or_op_results,
@@ -52,8 +51,6 @@ constexpr const char *dialectClassTemplate = R"Py(
@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
DIALECT_NAMESPACE = "{0}"
- pass
-
)Py";
constexpr const char *dialectExtensionTemplate = R"Py(
@@ -1007,14 +1004,13 @@ static void emitValueBuilder(const Operator &op,
});
std::string nameWithoutDialect =
op.getOperationName().substr(op.getOperationName().find('.') + 1);
- os << llvm::formatv(valueBuilderTemplate, sanitizeName(nameWithoutDialect),
- op.getCppClassName(),
- llvm::join(valueBuilderParams, ", "),
- llvm::join(opBuilderArgs, ", "),
- (op.getNumResults() > 1
- ? "_Sequence[_SubClassValueT]"
- : (op.getNumResults() > 0 ? "_SubClassValueT"
- : "_ods_ir.Operation")));
+ os << llvm::formatv(
+ valueBuilderTemplate, sanitizeName(nameWithoutDialect),
+ op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
+ llvm::join(opBuilderArgs, ", "),
+ (op.getNumResults() > 1
+ ? "_Sequence[_ods_ir.Value]"
+ : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
}
/// Emits bindings for a specific Op to the given output stream.
More information about the Mlir-commits
mailing list