[Mlir-commits] [mlir] [python] fix enum collision (PR #117918)
Maksim Levental
llvmlistbot at llvm.org
Wed Nov 27 14:55:07 PST 2024
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/117918
>From 5d3785efd8e89a0f98150ca97a70b271d30522c7 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Wed, 27 Nov 2024 17:50:37 -0500
Subject: [PATCH] [python] fix enum ambiguity
---
mlir/cmake/modules/AddMLIRPython.cmake | 4 +-
mlir/python/CMakeLists.txt | 3 +-
mlir/python/mlir/dialects/_ods_common.py | 1 +
mlir/python/mlir/dialects/amdgpu.py | 16 +++
mlir/python/mlir/dialects/arith.py | 35 ++++++
mlir/python/mlir/dialects/bufferization.py | 6 +
mlir/python/mlir/dialects/gpu/__init__.py | 56 +++++++++
mlir/python/mlir/dialects/index.py | 6 +
mlir/python/mlir/dialects/linalg/__init__.py | 25 ++++
.../linalg/opdsl/ops/core_named_ops.py | 116 ++++++++++--------
mlir/python/mlir/dialects/llvm.py | 107 +++++++++++++++-
mlir/python/mlir/dialects/nvgpu.py | 26 ++++
mlir/python/mlir/dialects/nvvm.py | 76 ++++++++++++
mlir/python/mlir/dialects/sparse_tensor.py | 11 ++
.../mlir/dialects/transform/__init__.py | 10 ++
.../dialects/transform/extras/__init__.py | 1 +
.../mlir/dialects/transform/structured.py | 10 ++
mlir/python/mlir/dialects/transform/vector.py | 21 ++++
mlir/python/mlir/dialects/vector.py | 16 +++
mlir/python/mlir/ir.py | 108 ++++++++--------
.../test/mlir-tblgen/enums-python-bindings.td | 16 +--
mlir/test/mlir-tblgen/op-python-bindings.td | 4 +-
mlir/test/python/dialects/index_dialect.py | 2 +-
.../dialects/transform_structured_ext.py | 2 +-
.../mlir-tblgen/EnumPythonBindingGen.cpp | 23 ++--
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 9 +-
26 files changed, 573 insertions(+), 137 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 7b91f43e2d57fd8..d06fc927ea44d49 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -318,7 +318,7 @@ function(declare_mlir_dialect_python_bindings)
set(LLVM_TARGET_DEFINITIONS ${td_file})
endif()
set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
- mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+ mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
list(APPEND _sources ${enum_filename})
endif()
@@ -390,7 +390,7 @@ function(declare_mlir_dialect_extension_python_bindings)
set(LLVM_TARGET_DEFINITIONS ${td_file})
endif()
set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
- mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+ mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
list(APPEND _sources ${enum_filename})
endif()
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 23187f256455bba..9949743b9bf09cd 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -63,8 +63,7 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/AffineOps.td
SOURCES
dialects/affine.py
- DIALECT_NAME affine
- GEN_ENUM_BINDINGS)
+ DIALECT_NAME affine)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index d40d936cdc83d67..22fc588e180b1d7 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -143,6 +143,7 @@ def get_op_result_or_op_results(
else op
)
+
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
ResultValueT = _Union[ResultValueTypeTuple]
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py
index 43d905d0c481cce..9b8beaa5571a260 100644
--- a/mlir/python/mlir/dialects/amdgpu.py
+++ b/mlir/python/mlir/dialects/amdgpu.py
@@ -2,5 +2,21 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._amdgpu_ops_gen import *
from ._amdgpu_enum_gen import *
+
+
+ at register_attribute_builder("builtin.AMDGPU_DPPPerm")
+def _amdgpu_dppperm(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.AMDGPU_MFMAPermB")
+def _amdgpu_mfmapermb(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.AMDGPU_SchedBarrierOpOpt")
+def _amdgpu_schedbarrieropopt(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 92da5df9bce6650..32ba832260d1296 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -108,3 +108,38 @@ def constant(
result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
) -> Value:
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
+
+
+ at register_attribute_builder("builtin.Arith_CmpFPredicateAttr")
+def _arith_cmpfpredicateattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.Arith_CmpIPredicateAttr")
+def _arith_cmpipredicateattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.Arith_DenormalMode")
+def _arith_denormalmode(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.Arith_IntegerOverflowFlags")
+def _arith_integeroverflowflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.Arith_RoundingModeAttr")
+def _arith_roundingmodeattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.AtomicRMWKindAttr")
+def _atomicrmwkindattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.FastMathFlags")
+def _fastmathflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
index 759b6aa24a9ff73..6ad76c729ed2dcc 100644
--- a/mlir/python/mlir/dialects/bufferization.py
+++ b/mlir/python/mlir/dialects/bufferization.py
@@ -2,5 +2,11 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._bufferization_ops_gen import *
from ._bufferization_enum_gen import *
+
+
+ at register_attribute_builder("builtin.LayoutMapOption")
+def _layoutmapoption(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py
index 4cd80aa8b7ca85e..e0bb07c5dad8bec 100644
--- a/mlir/python/mlir/dialects/gpu/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/__init__.py
@@ -2,6 +2,62 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ...ir import IntegerAttr, IntegerType, register_attribute_builder
from .._gpu_ops_gen import *
from .._gpu_enum_gen import *
from ..._mlir_libs._mlirDialectsGPU import *
+
+
+ at register_attribute_builder("builtin.GPU_AddressSpaceEnum")
+def _gpu_addressspaceenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.GPU_AllReduceOperation")
+def _gpu_allreduceoperation(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.GPU_CompilationTargetEnum")
+def _gpu_compilationtargetenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.GPU_Dimension")
+def _gpu_dimension(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.GPU_Prune2To4SpMatFlag")
+def _gpu_prune2to4spmatflag(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.GPU_ShuffleMode")
+def _gpu_shufflemode(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.GPU_SpGEMMWorkEstimationOrComputeKind")
+def _gpu_spgemmworkestimationorcomputekind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.GPU_TransposeMode")
+def _gpu_transposemode(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MMAElementWise")
+def _mmaelementwise(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MappingIdEnum")
+def _mappingidenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.ProcessorEnum")
+def _processorenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/index.py b/mlir/python/mlir/dialects/index.py
index 73708c7d71a8c83..f00c397965c97cd 100644
--- a/mlir/python/mlir/dialects/index.py
+++ b/mlir/python/mlir/dialects/index.py
@@ -2,5 +2,11 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._index_ops_gen import *
from ._index_enum_gen import *
+
+
+ at register_attribute_builder("builtin.IndexCmpPredicate")
+def _indexcmppredicate(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 8fb1227ee80ff50..4fe9cc40ee910a2 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -102,3 +102,28 @@ def broadcast(
)
fill_builtin_region(op.operation)
return op
+
+
+ at register_attribute_builder("builtin.BinaryFn")
+def _binaryfn(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.IteratorType")
+def _iteratortype(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TernaryFn")
+def _ternaryfn(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TypeFn")
+def _typefn(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.UnaryFn")
+def _unaryfn(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index c95cd5eecfffca3..72966f12c87ea67 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -888,6 +888,7 @@ def conv_2d_nchw_fchw_q(
- TypeFn.cast_signed(U, IZp)
) * (TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) - TypeFn.cast_signed(U, KZp))
+
@linalg_structured_op
def conv_2d_nchw_fchw(
I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
@@ -1082,16 +1083,19 @@ def conv_3d_ndhwc_dhwcf(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
- O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed(
- U,
- I[
- D.n,
- D.od * S.SD + D.kd * S.DD,
- D.oh * S.SH + D.kh * S.DH,
- D.ow * S.SW + D.kw * S.DW,
- D.c,
- ],
- ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f])
+ O[D.n, D.od, D.oh, D.ow, D.f] += (
+ TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.od * S.SD + D.kd * S.DD,
+ D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW,
+ D.c,
+ ],
+ )
+ * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f])
+ )
@linalg_structured_op
@@ -1159,16 +1163,19 @@ def conv_3d_ncdhw_fcdhw(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
- O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed(
- U,
- I[
- D.n,
- D.c,
- D.od * S.SD + D.kd * S.DD,
- D.oh * S.SH + D.kh * S.DH,
- D.ow * S.SW + D.kw * S.DW,
- ],
- ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw])
+ O[D.n, D.f, D.od, D.oh, D.ow] += (
+ TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.c,
+ D.od * S.SD + D.kd * S.DD,
+ D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW,
+ ],
+ )
+ * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw])
+ )
@linalg_structured_op
@@ -1368,16 +1375,19 @@ def depthwise_conv_3d_ndhwc_dhwc(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
- O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed(
- U,
- I[
- D.n,
- D.od * S.SD + D.kd * S.DD,
- D.oh * S.SH + D.kh * S.DH,
- D.ow * S.SW + D.kw * S.DW,
- D.ic,
- ],
- ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic])
+ O[D.n, D.od, D.oh, D.ow, D.ic] += (
+ TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.od * S.SD + D.kd * S.DD,
+ D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW,
+ D.ic,
+ ],
+ )
+ * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic])
+ )
@linalg_structured_op
@@ -1403,16 +1413,19 @@ def depthwise_conv_3d_ncdhw_cdhw(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic)
- O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed(
- U,
- I[
- D.n,
- D.ic,
- D.od * S.SD + D.kd * S.DD,
- D.oh * S.SH + D.kh * S.DH,
- D.ow * S.SW + D.kw * S.DW,
- ],
- ) * TypeFn.cast_signed(U, K[D.ic, D.kd, D.kh, D.kw])
+ O[D.n, D.ic, D.od, D.oh, D.ow] += (
+ TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.ic,
+ D.od * S.SD + D.kd * S.DD,
+ D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW,
+ ],
+ )
+ * TypeFn.cast_signed(U, K[D.ic, D.kd, D.kh, D.kw])
+ )
@linalg_structured_op
@@ -1437,16 +1450,19 @@ def depthwise_conv_3d_ndhwc_dhwcm(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic)
- O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed(
- U,
- I[
- D.n,
- D.od * S.SD + D.kd * S.DD,
- D.oh * S.SH + D.kh * S.DH,
- D.ow * S.SW + D.kw * S.DW,
- D.ic,
- ],
- ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic, D.cm])
+ O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += (
+ TypeFn.cast_signed(
+ U,
+ I[
+ D.n,
+ D.od * S.SD + D.kd * S.DD,
+ D.oh * S.SH + D.kh * S.DH,
+ D.ow * S.SW + D.kw * S.DW,
+ D.ic,
+ ],
+ )
+ * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic, D.cm])
+ )
@linalg_structured_op
diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py
index 941a584966dcde9..456df39ae650ffe 100644
--- a/mlir/python/mlir/dialects/llvm.py
+++ b/mlir/python/mlir/dialects/llvm.py
@@ -5,7 +5,7 @@
from ._llvm_ops_gen import *
from ._llvm_enum_gen import *
from .._mlir_libs._mlirDialectsLLVM import *
-from ..ir import Value
+from ..ir import Value, IntegerAttr, IntegerType, register_attribute_builder
from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
@@ -13,3 +13,108 @@ def mlir_constant(value, *, loc=None, ip=None) -> Value:
return _get_op_result_or_op_results(
ConstantOp(res=value.type, value=value, loc=loc, ip=ip)
)
+
+
+ at register_attribute_builder("builtin.AsmATTOrIntel")
+def _asmattorintel(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.AtomicBinOp")
+def _atomicbinop(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.AtomicOrdering")
+def _atomicordering(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.CConvEnum")
+def _cconvenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.Comdat")
+def _comdat(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.DIFlags")
+def _diflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.DISubprogramFlags")
+def _disubprogramflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.FCmpPredicate")
+def _fcmppredicate(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.FPExceptionBehaviorAttr")
+def _fpexceptionbehaviorattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.FastmathFlags")
+def _fastmathflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.FramePointerKindEnum")
+def _framepointerkindenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.ICmpPredicate")
+def _icmppredicate(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.IntegerOverflowFlags")
+def _integeroverflowflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.LLVM_DIEmissionKind")
+def _llvm_diemissionkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.LLVM_DINameTableKind")
+def _llvm_dinametablekind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.LinkageEnum")
+def _linkageenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.ModRefInfoEnum")
+def _modrefinfoenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.RoundingModeAttr")
+def _roundingmodeattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TailCallKindEnum")
+def _tailcallkindenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.UnnamedAddr")
+def _unnamedaddr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.Visibility")
+def _visibility(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/nvgpu.py b/mlir/python/mlir/dialects/nvgpu.py
index d6a54f2772f40da..eea132adb0484e5 100644
--- a/mlir/python/mlir/dialects/nvgpu.py
+++ b/mlir/python/mlir/dialects/nvgpu.py
@@ -2,6 +2,32 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._nvgpu_ops_gen import *
from ._nvgpu_enum_gen import *
from .._mlir_libs._mlirDialectsNVGPU import *
+
+
+ at register_attribute_builder("builtin.RcpRoundingMode")
+def _rcproundingmode(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TensorMapInterleaveKind")
+def _tensormapinterleavekind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TensorMapL2PromoKind")
+def _tensormapl2promokind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TensorMapOOBKind")
+def _tensormapoobkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TensorMapSwizzleKind")
+def _tensormapswizzlekind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/nvvm.py b/mlir/python/mlir/dialects/nvvm.py
index 9477de39c9ead73..21bf24cb73fdabf 100644
--- a/mlir/python/mlir/dialects/nvvm.py
+++ b/mlir/python/mlir/dialects/nvvm.py
@@ -2,5 +2,81 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._nvvm_ops_gen import *
from ._nvvm_enum_gen import *
+
+
+ at register_attribute_builder("builtin.LoadCacheModifierKind")
+def _loadcachemodifierkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MMAB1Op")
+def _mmab1op(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MMAFrag")
+def _mmafrag(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MMAIntOverflow")
+def _mmaintoverflow(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MMALayout")
+def _mmalayout(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MMATypes")
+def _mmatypes(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MemScopeKind")
+def _memscopekind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.ProxyKind")
+def _proxykind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.ReduxKind")
+def _reduxkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.SetMaxRegisterAction")
+def _setmaxregisteraction(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.SharedSpace")
+def _sharedspace(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.ShflKind")
+def _shflkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.WGMMAScaleIn")
+def _wgmmascalein(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.WGMMAScaleOut")
+def _wgmmascaleout(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.WGMMATypes")
+def _wgmmatypes(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py
index 209ecc95fa8fc88..afaf076ca9b49f9 100644
--- a/mlir/python/mlir/dialects/sparse_tensor.py
+++ b/mlir/python/mlir/dialects/sparse_tensor.py
@@ -2,7 +2,18 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._sparse_tensor_ops_gen import *
from ._sparse_tensor_enum_gen import *
from .._mlir_libs._mlirDialectsSparseTensor import *
from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses
+
+
+ at register_attribute_builder("builtin.SparseTensorSortKindEnum")
+def _sparsetensorsortkindenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.SparseTensorStorageSpecifierKindEnum")
+def _sparsetensorstoragespecifierkindenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 5b158ec6b65fdd9..5116b7c4309cbec 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -219,3 +219,13 @@ def __init__(
def any_op_t() -> AnyOpTypeT:
return AnyOpTypeT(AnyOpType.get())
+
+
+ at register_attribute_builder("builtin.FailurePropagationMode")
+def _failurepropagationmode(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MatchCmpIPredicateAttr")
+def _matchcmpipredicateattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8d045cad7a4a36f..3d7ffe1374890d7 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -43,6 +43,7 @@ def __init__(
self.parent = parent
self.children = children if children is not None else []
+
@ir.register_value_caster(AnyOpType.get_static_typeid())
@ir.register_value_caster(OperationType.get_static_typeid())
class OpHandle(Handle):
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 9121aa8e40237be..7f8cfedee5cb808 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -648,3 +648,13 @@ def __init__(
loc=loc,
ip=ip,
)
+
+
+ at register_attribute_builder("builtin.MatchInterfaceEnum")
+def _matchinterfaceenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TransposeMatmulInput")
+def _transposematmulinput(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/transform/vector.py b/mlir/python/mlir/dialects/transform/vector.py
index af2435cb26cc4b5..fd366e2970afba5 100644
--- a/mlir/python/mlir/dialects/transform/vector.py
+++ b/mlir/python/mlir/dialects/transform/vector.py
@@ -2,5 +2,26 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ...ir import IntegerAttr, IntegerType, register_attribute_builder
from .._vector_transform_enum_gen import *
from .._vector_transform_ops_gen import *
+
+
+ at register_attribute_builder("builtin.VectorContractLoweringAttr")
+def _vectorcontractloweringattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.VectorMultiReductionLoweringAttr")
+def _vectormultireductionloweringattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.VectorTransferSplitAttr")
+def _vectortransfersplitattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.VectorTransposeLoweringAttr")
+def _vectortransposeloweringattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/vector.py b/mlir/python/mlir/dialects/vector.py
index 7384e9a5aeef298..a0babf872918d04 100644
--- a/mlir/python/mlir/dialects/vector.py
+++ b/mlir/python/mlir/dialects/vector.py
@@ -2,5 +2,21 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._vector_ops_gen import *
from ._vector_enum_gen import *
+
+
+ at register_attribute_builder("builtin.CombiningKind")
+def _combiningkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.PrintPunctuation")
+def _printpunctuation(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.Vector_IteratorType")
+def _vector_iteratortype(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 9a6ce462047ad2d..fa3b2178f2acfa8 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -17,127 +17,127 @@ def decorator_builder(func):
return decorator_builder
- at register_attribute_builder("AffineMapAttr")
+ at register_attribute_builder("builtin.AffineMapAttr")
def _affineMapAttr(x, context):
return AffineMapAttr.get(x)
- at register_attribute_builder("IntegerSetAttr")
+ at register_attribute_builder("builtin.IntegerSetAttr")
def _integerSetAttr(x, context):
return IntegerSetAttr.get(x)
- at register_attribute_builder("BoolAttr")
+ at register_attribute_builder("builtin.BoolAttr")
def _boolAttr(x, context):
return BoolAttr.get(x, context=context)
- at register_attribute_builder("DictionaryAttr")
+ at register_attribute_builder("builtin.DictionaryAttr")
def _dictAttr(x, context):
return DictAttr.get(x, context=context)
- at register_attribute_builder("IndexAttr")
+ at register_attribute_builder("builtin.IndexAttr")
def _indexAttr(x, context):
return IntegerAttr.get(IndexType.get(context=context), x)
- at register_attribute_builder("I1Attr")
+ at register_attribute_builder("builtin.I1Attr")
def _i1Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(1, context=context), x)
- at register_attribute_builder("I8Attr")
+ at register_attribute_builder("builtin.I8Attr")
def _i8Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(8, context=context), x)
- at register_attribute_builder("I16Attr")
+ at register_attribute_builder("builtin.I16Attr")
def _i16Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
- at register_attribute_builder("I32Attr")
+ at register_attribute_builder("builtin.I32Attr")
def _i32Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), x)
- at register_attribute_builder("I64Attr")
+ at register_attribute_builder("builtin.I64Attr")
def _i64Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
- at register_attribute_builder("SI1Attr")
+ at register_attribute_builder("builtin.SI1Attr")
def _si1Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(1, context=context), x)
- at register_attribute_builder("SI8Attr")
+ at register_attribute_builder("builtin.SI8Attr")
def _si8Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(8, context=context), x)
- at register_attribute_builder("SI16Attr")
+ at register_attribute_builder("builtin.SI16Attr")
def _si16Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
- at register_attribute_builder("SI32Attr")
+ at register_attribute_builder("builtin.SI32Attr")
def _si32Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
- at register_attribute_builder("SI64Attr")
+ at register_attribute_builder("builtin.SI64Attr")
def _si64Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(64, context=context), x)
- at register_attribute_builder("UI1Attr")
+ at register_attribute_builder("builtin.UI1Attr")
def _ui1Attr(x, context):
return IntegerAttr.get(IntegerType.get_unsigned(1, context=context), x)
- at register_attribute_builder("UI8Attr")
+ at register_attribute_builder("builtin.UI8Attr")
def _ui8Attr(x, context):
return IntegerAttr.get(IntegerType.get_unsigned(8, context=context), x)
- at register_attribute_builder("UI16Attr")
+ at register_attribute_builder("builtin.UI16Attr")
def _ui16Attr(x, context):
return IntegerAttr.get(IntegerType.get_unsigned(16, context=context), x)
- at register_attribute_builder("UI32Attr")
+ at register_attribute_builder("builtin.UI32Attr")
def _ui32Attr(x, context):
return IntegerAttr.get(IntegerType.get_unsigned(32, context=context), x)
- at register_attribute_builder("UI64Attr")
+ at register_attribute_builder("builtin.UI64Attr")
def _ui64Attr(x, context):
return IntegerAttr.get(IntegerType.get_unsigned(64, context=context), x)
- at register_attribute_builder("F32Attr")
+ at register_attribute_builder("builtin.F32Attr")
def _f32Attr(x, context):
return FloatAttr.get_f32(x, context=context)
- at register_attribute_builder("F64Attr")
+ at register_attribute_builder("builtin.F64Attr")
def _f64Attr(x, context):
return FloatAttr.get_f64(x, context=context)
- at register_attribute_builder("StrAttr")
+ at register_attribute_builder("builtin.StrAttr")
def _stringAttr(x, context):
return StringAttr.get(x, context=context)
- at register_attribute_builder("SymbolNameAttr")
+ at register_attribute_builder("builtin.SymbolNameAttr")
def _symbolNameAttr(x, context):
return StringAttr.get(x, context=context)
- at register_attribute_builder("SymbolRefAttr")
+ at register_attribute_builder("builtin.SymbolRefAttr")
def _symbolRefAttr(x, context):
if isinstance(x, list):
return SymbolRefAttr.get(x, context=context)
@@ -145,12 +145,12 @@ def _symbolRefAttr(x, context):
return FlatSymbolRefAttr.get(x, context=context)
- at register_attribute_builder("FlatSymbolRefAttr")
+ at register_attribute_builder("builtin.FlatSymbolRefAttr")
def _flatSymbolRefAttr(x, context):
return FlatSymbolRefAttr.get(x, context=context)
- at register_attribute_builder("UnitAttr")
+ at register_attribute_builder("builtin.UnitAttr")
def _unitAttr(x, context):
if x:
return UnitAttr.get(context=context)
@@ -158,117 +158,117 @@ def _unitAttr(x, context):
return None
- at register_attribute_builder("ArrayAttr")
+ at register_attribute_builder("builtin.ArrayAttr")
def _arrayAttr(x, context):
return ArrayAttr.get(x, context=context)
- at register_attribute_builder("AffineMapArrayAttr")
+ at register_attribute_builder("builtin.AffineMapArrayAttr")
def _affineMapArrayAttr(x, context):
return ArrayAttr.get([_affineMapAttr(v, context) for v in x])
- at register_attribute_builder("BoolArrayAttr")
+ at register_attribute_builder("builtin.BoolArrayAttr")
def _boolArrayAttr(x, context):
return ArrayAttr.get([_boolAttr(v, context) for v in x])
- at register_attribute_builder("DictArrayAttr")
+ at register_attribute_builder("builtin.DictArrayAttr")
def _dictArrayAttr(x, context):
return ArrayAttr.get([_dictAttr(v, context) for v in x])
- at register_attribute_builder("FlatSymbolRefArrayAttr")
+ at register_attribute_builder("builtin.FlatSymbolRefArrayAttr")
def _flatSymbolRefArrayAttr(x, context):
return ArrayAttr.get([_flatSymbolRefAttr(v, context) for v in x])
- at register_attribute_builder("I32ArrayAttr")
+ at register_attribute_builder("builtin.I32ArrayAttr")
def _i32ArrayAttr(x, context):
return ArrayAttr.get([_i32Attr(v, context) for v in x])
- at register_attribute_builder("I64ArrayAttr")
+ at register_attribute_builder("builtin.I64ArrayAttr")
def _i64ArrayAttr(x, context):
return ArrayAttr.get([_i64Attr(v, context) for v in x])
- at register_attribute_builder("I64SmallVectorArrayAttr")
+ at register_attribute_builder("builtin.I64SmallVectorArrayAttr")
def _i64SmallVectorArrayAttr(x, context):
return _i64ArrayAttr(x, context=context)
- at register_attribute_builder("IndexListArrayAttr")
+ at register_attribute_builder("builtin.IndexListArrayAttr")
def _indexListArrayAttr(x, context):
return ArrayAttr.get([_i64ArrayAttr(v, context) for v in x])
- at register_attribute_builder("F32ArrayAttr")
+ at register_attribute_builder("builtin.F32ArrayAttr")
def _f32ArrayAttr(x, context):
return ArrayAttr.get([_f32Attr(v, context) for v in x])
- at register_attribute_builder("F64ArrayAttr")
+ at register_attribute_builder("builtin.F64ArrayAttr")
def _f64ArrayAttr(x, context):
return ArrayAttr.get([_f64Attr(v, context) for v in x])
- at register_attribute_builder("StrArrayAttr")
+ at register_attribute_builder("builtin.StrArrayAttr")
def _strArrayAttr(x, context):
return ArrayAttr.get([_stringAttr(v, context) for v in x])
- at register_attribute_builder("SymbolRefArrayAttr")
+ at register_attribute_builder("builtin.SymbolRefArrayAttr")
def _symbolRefArrayAttr(x, context):
return ArrayAttr.get([_symbolRefAttr(v, context) for v in x])
- at register_attribute_builder("DenseF32ArrayAttr")
+ at register_attribute_builder("builtin.DenseF32ArrayAttr")
def _denseF32ArrayAttr(x, context):
return DenseF32ArrayAttr.get(x, context=context)
- at register_attribute_builder("DenseF64ArrayAttr")
+ at register_attribute_builder("builtin.DenseF64ArrayAttr")
def _denseF64ArrayAttr(x, context):
return DenseF64ArrayAttr.get(x, context=context)
- at register_attribute_builder("DenseI8ArrayAttr")
+ at register_attribute_builder("builtin.DenseI8ArrayAttr")
def _denseI8ArrayAttr(x, context):
return DenseI8ArrayAttr.get(x, context=context)
- at register_attribute_builder("DenseI16ArrayAttr")
+ at register_attribute_builder("builtin.DenseI16ArrayAttr")
def _denseI16ArrayAttr(x, context):
return DenseI16ArrayAttr.get(x, context=context)
- at register_attribute_builder("DenseI32ArrayAttr")
+ at register_attribute_builder("builtin.DenseI32ArrayAttr")
def _denseI32ArrayAttr(x, context):
return DenseI32ArrayAttr.get(x, context=context)
- at register_attribute_builder("DenseI64ArrayAttr")
+ at register_attribute_builder("builtin.DenseI64ArrayAttr")
def _denseI64ArrayAttr(x, context):
return DenseI64ArrayAttr.get(x, context=context)
- at register_attribute_builder("DenseBoolArrayAttr")
+ at register_attribute_builder("builtin.DenseBoolArrayAttr")
def _denseBoolArrayAttr(x, context):
return DenseBoolArrayAttr.get(x, context=context)
- at register_attribute_builder("TypeAttr")
+ at register_attribute_builder("builtin.TypeAttr")
def _typeAttr(x, context):
return TypeAttr.get(x, context=context)
- at register_attribute_builder("TypeArrayAttr")
+ at register_attribute_builder("builtin.TypeArrayAttr")
def _typeArrayAttr(x, context):
return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context)
- at register_attribute_builder("MemRefTypeAttr")
+ at register_attribute_builder("builtin.MemRefTypeAttr")
def _memref_type_attr(x, context):
return _typeAttr(x, context)
@@ -276,7 +276,7 @@ def _memref_type_attr(x, context):
try:
import numpy as np
- @register_attribute_builder("F64ElementsAttr")
+ @register_attribute_builder("builtin.F64ElementsAttr")
def _f64ElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.float64),
@@ -284,7 +284,7 @@ def _f64ElementsAttr(x, context):
context=context,
)
- @register_attribute_builder("I32ElementsAttr")
+ @register_attribute_builder("builtin.I32ElementsAttr")
def _i32ElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.int32),
@@ -292,7 +292,7 @@ def _i32ElementsAttr(x, context):
context=context,
)
- @register_attribute_builder("I64ElementsAttr")
+ @register_attribute_builder("builtin.I64ElementsAttr")
def _i64ElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.int64),
@@ -300,7 +300,7 @@ def _i64ElementsAttr(x, context):
context=context,
)
- @register_attribute_builder("IndexElementsAttr")
+ @register_attribute_builder("builtin.IndexElementsAttr")
def _indexElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.int64),
diff --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td
index 1c5567f54a5f4b9..ec27dd45903ebb5 100644
--- a/mlir/test/mlir-tblgen/enums-python-bindings.td
+++ b/mlir/test/mlir-tblgen/enums-python-bindings.td
@@ -35,10 +35,6 @@ def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two, NegOne]>
// CHECK: return "negone"
// CHECK: raise ValueError("Unknown MyEnum enum entry.")
-// CHECK: @register_attribute_builder("MyEnum")
-// CHECK: def _myenum(x, context):
-// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
-
def TestMyEnum_Attr : EnumAttr<Test_Dialect, MyEnum, "enum">;
def One64 : I64EnumAttrCase<"CaseOne64", 1, "one">;
@@ -58,10 +54,6 @@ def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>
// CHECK: return "two"
// CHECK: raise ValueError("Unknown MyEnum64 enum entry.")
-// CHECK: @register_attribute_builder("MyEnum64")
-// CHECK: def _myenum64(x, context):
-// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
-
def TestBitEnum
: I32BitEnumAttr<"TestBitEnum", "", [
I32BitEnumAttrCaseBit<"User", 0, "user">,
@@ -96,14 +88,10 @@ def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
// CHECK: return "other"
// CHECK: raise ValueError("Unknown TestBitEnum enum entry.")
-// CHECK: @register_attribute_builder("TestBitEnum")
-// CHECK: def _testbitenum(x, context):
-// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
-
-// CHECK: @register_attribute_builder("TestBitEnum_Attr")
+// CHECK: @register_attribute_builder("TestDialect.TestBitEnum_Attr")
// CHECK: def _testbitenum_attr(x, context):
// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect<testbitenum {str(x)}>', context=context)
-// CHECK: @register_attribute_builder("TestMyEnum_Attr")
+// CHECK: @register_attribute_builder("TestDialect.TestMyEnum_Attr")
// CHECK: def _testmyenum_attr(x, context):
// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect<enum {str(x)}>', context=context)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 632046389e12cff..d8a662038df16ab 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -123,8 +123,8 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: regions = None
// CHECK: attributes["i32attr"] = (i32attr if (
// CHECK-NEXT: isinstance(i32attr, _ods_ir.Attribute) or
- // CHECK-NEXT: not _ods_ir.AttrBuilder.contains('I32Attr')
- // CHECK-NEXT: _ods_ir.AttrBuilder.get('I32Attr')(i32attr, context=_ods_context)
+ // CHECK-NEXT: not _ods_ir.AttrBuilder.contains('builtin.I32Attr')
+ // CHECK-NEXT: _ods_ir.AttrBuilder.get('builtin.I32Attr')(i32attr, context=_ods_context)
// CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = (optionalF32Attr
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
diff --git a/mlir/test/python/dialects/index_dialect.py b/mlir/test/python/dialects/index_dialect.py
index 9db883469792c5c..8da6a262cc441e0 100644
--- a/mlir/test/python/dialects/index_dialect.py
+++ b/mlir/test/python/dialects/index_dialect.py
@@ -94,7 +94,7 @@ def testCeilDivUOp(ctx):
def testCmpOp(ctx):
a = index.ConstantOp(value=42)
b = index.ConstantOp(value=23)
- pred = AttrBuilder.get("IndexCmpPredicateAttr")("slt", context=ctx)
+ pred = AttrBuilder.get("index.IndexCmpPredicateAttr")("slt", context=ctx)
r = index.CmpOp(pred, lhs=a, rhs=b)
# CHECK: %{{.*}} = index.cmp slt(%{{.*}}, %{{.*}})
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index fb4c75b5337928b..63d30b498c8b0cb 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -586,7 +586,7 @@ def testMatchInterfaceEnum(target):
@run
@create_sequence
def testMatchInterfaceEnumReplaceAttributeBuilder(target):
- @register_attribute_builder("MatchInterfaceEnum", replace=True)
+ @register_attribute_builder("builtin.MatchInterfaceEnum", replace=True)
def match_interface_enum(x, context):
if x == "LinalgOp":
y = 0
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index 3f660ae151c7492..9cfe55399745aa6 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -16,6 +16,8 @@
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/GenInfo.h"
+
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Record.h"
@@ -36,6 +38,8 @@ _ods_ir = _ods_cext.ir
)Py";
+extern llvm::cl::opt<std::string> clDialectName;
+
/// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
static std::string makePythonEnumCaseName(StringRef name) {
if (isPythonReserved(name.str()))
@@ -106,7 +110,7 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
return true;
}
- os << formatv("@register_attribute_builder(\"{0}\")\n",
+ os << formatv("@register_attribute_builder(\"builtin.{0}\")\n",
enumAttr.getAttrDefName());
os << formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower());
os << formatv(" return "
@@ -119,10 +123,12 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
/// Emits an attribute builder for the given dialect enum attribute to support
/// automatic conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
-static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
+static bool emitDialectEnumAttributeBuilder(StringRef dialect,
+ StringRef attrDefName,
StringRef formatString,
raw_ostream &os) {
- os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
+ os << formatv("@register_attribute_builder(\"{0}.{1}\")\n", dialect,
+ attrDefName);
os << formatv("def _{0}(x, context):\n", attrDefName.lower());
os << formatv(" return "
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
@@ -138,11 +144,15 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
EnumAttr enumAttr(*it);
emitEnumClass(enumAttr, os);
- emitAttributeBuilder(enumAttr, os);
+ if (clDialectName.empty())
+ emitAttributeBuilder(enumAttr, os);
}
for (const Record *it :
records.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
AttrOrTypeDef attr(&*it);
+ StringRef dialect = attr.getDialect().getName();
+ if (!clDialectName.empty() && dialect != clDialectName)
+ continue;
if (!attr.getMnemonic()) {
llvm::errs() << "enum case " << attr
<< " needs mnemonic for python enum bindings generation";
@@ -150,14 +160,13 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
}
StringRef mnemonic = attr.getMnemonic().value();
std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
- StringRef dialect = attr.getDialect().getName();
if (assemblyFormat == "`<` $value `>`") {
emitDialectEnumAttributeBuilder(
- attr.getName(),
+ dialect, attr.getName(),
formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
} else if (assemblyFormat == "$value") {
emitDialectEnumAttributeBuilder(
- attr.getName(),
+ dialect, attr.getName(),
formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
} else {
llvm::errs()
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 5019b69d91127e8..003749b5d32be96 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -283,7 +283,7 @@ def {0}({2}) -> {4}:
static llvm::cl::OptionCategory
clOpPythonBindingCat("Options for -gen-python-op-bindings");
-static llvm::cl::opt<std::string>
+llvm::cl::opt<std::string>
clDialectName("bind-dialect",
llvm::cl::desc("The dialect to run the generator for"),
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
@@ -672,12 +672,15 @@ populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames,
formatv(initUnitAttributeTemplate, attribute->name, argNames[i]));
continue;
}
-
+ Dialect maybeDialect = attribute->attr.getBaseAttr().getDialect();
+ auto disambigAttrName = formatv(
+ "{0}.{1}", bool(maybeDialect) ? maybeDialect.getName() : "builtin",
+ attribute->attr.getAttrDefName());
builderLines.push_back(formatv(
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
? initOptionalAttributeWithBuilderTemplate
: initAttributeWithBuilderTemplate,
- argNames[i], attribute->name, attribute->attr.getAttrDefName()));
+ argNames[i], attribute->name, disambigAttrName));
}
}
More information about the Mlir-commits
mailing list