[Mlir-commits] [mlir] [python] fix enum collision (PR #117918)
    Maksim Levental 
    llvmlistbot at llvm.org
       
    Wed Nov 27 15:28:47 PST 2024
    
    
  
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/117918
>From 90c0b70dd2b76247f87d0a4a54c37ce472a92acc Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Wed, 27 Nov 2024 17:50:37 -0500
Subject: [PATCH] [python] fix enum ambiguity
---
 mlir/cmake/modules/AddMLIRPython.cmake        |   4 +-
 mlir/python/CMakeLists.txt                    |   3 +-
 mlir/python/mlir/dialects/_ods_common.py      |   1 +
 mlir/python/mlir/dialects/amdgpu.py           |  16 +++
 mlir/python/mlir/dialects/arith.py            |  35 ++++++
 mlir/python/mlir/dialects/bufferization.py    |   6 +
 mlir/python/mlir/dialects/gpu/__init__.py     |  56 +++++++++
 mlir/python/mlir/dialects/index.py            |   6 +
 mlir/python/mlir/dialects/linalg/__init__.py  |  25 ++++
 .../linalg/opdsl/ops/core_named_ops.py        |   1 +
 mlir/python/mlir/dialects/llvm.py             | 107 ++++++++++++++++-
 mlir/python/mlir/dialects/nvgpu.py            |  26 +++++
 mlir/python/mlir/dialects/nvvm.py             |  76 ++++++++++++
 mlir/python/mlir/dialects/sparse_tensor.py    |  16 +++
 .../mlir/dialects/transform/__init__.py       |  10 ++
 .../dialects/transform/extras/__init__.py     |   1 +
 .../mlir/dialects/transform/structured.py     |  10 ++
 mlir/python/mlir/dialects/transform/vector.py |  21 ++++
 mlir/python/mlir/dialects/vector.py           |  16 +++
 mlir/python/mlir/ir.py                        | 108 +++++++++---------
 .../test/mlir-tblgen/enums-python-bindings.td |  16 +--
 mlir/test/mlir-tblgen/op-python-bindings.td   |   4 +-
 mlir/test/python/dialects/index_dialect.py    |   2 +-
 .../dialects/transform_structured_ext.py      |   2 +-
 .../mlir-tblgen/EnumPythonBindingGen.cpp      |  23 ++--
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp |   9 +-
 26 files changed, 513 insertions(+), 87 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 7b91f43e2d57fd..d06fc927ea44d4 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -318,7 +318,7 @@ function(declare_mlir_dialect_python_bindings)
         set(LLVM_TARGET_DEFINITIONS ${td_file})
       endif()
       set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
-      mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+      mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
       list(APPEND _sources ${enum_filename})
     endif()
 
@@ -390,7 +390,7 @@ function(declare_mlir_dialect_extension_python_bindings)
         set(LLVM_TARGET_DEFINITIONS ${td_file})
       endif()
       set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
-      mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+      mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
       list(APPEND _sources ${enum_filename})
     endif()
 
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 23187f256455bb..9949743b9bf09c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -63,8 +63,7 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/AffineOps.td
   SOURCES
     dialects/affine.py
-  DIALECT_NAME affine
-  GEN_ENUM_BINDINGS)
+  DIALECT_NAME affine)
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index d40d936cdc83d6..22fc588e180b1d 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -143,6 +143,7 @@ def get_op_result_or_op_results(
         else op
     )
 
+
 ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
 ResultValueT = _Union[ResultValueTypeTuple]
 VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py
index 43d905d0c481cc..9b8beaa5571a26 100644
--- a/mlir/python/mlir/dialects/amdgpu.py
+++ b/mlir/python/mlir/dialects/amdgpu.py
@@ -2,5 +2,21 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
 from ._amdgpu_ops_gen import *
 from ._amdgpu_enum_gen import *
+
+
+ at register_attribute_builder("builtin.AMDGPU_DPPPerm")
+def _amdgpu_dppperm(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.AMDGPU_MFMAPermB")
+def _amdgpu_mfmapermb(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.AMDGPU_SchedBarrierOpOpt")
+def _amdgpu_schedbarrieropopt(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 92da5df9bce665..32ba832260d129 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -108,3 +108,38 @@ def constant(
     result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
 ) -> Value:
     return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
+
+
+ at register_attribute_builder("builtin.Arith_CmpFPredicateAttr")
+def _arith_cmpfpredicateattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.Arith_CmpIPredicateAttr")
+def _arith_cmpipredicateattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.Arith_DenormalMode")
+def _arith_denormalmode(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.Arith_IntegerOverflowFlags")
+def _arith_integeroverflowflags(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.Arith_RoundingModeAttr")
+def _arith_roundingmodeattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.AtomicRMWKindAttr")
+def _atomicrmwkindattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.FastMathFlags")
+def _fastmathflags(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
index 759b6aa24a9ff7..6ad76c729ed2dc 100644
--- a/mlir/python/mlir/dialects/bufferization.py
+++ b/mlir/python/mlir/dialects/bufferization.py
@@ -2,5 +2,11 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
 from ._bufferization_ops_gen import *
 from ._bufferization_enum_gen import *
+
+
+ at register_attribute_builder("builtin.LayoutMapOption")
+def _layoutmapoption(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py
index 4cd80aa8b7ca85..e0bb07c5dad8be 100644
--- a/mlir/python/mlir/dialects/gpu/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/__init__.py
@@ -2,6 +2,62 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ...ir import IntegerAttr, IntegerType, register_attribute_builder
 from .._gpu_ops_gen import *
 from .._gpu_enum_gen import *
 from ..._mlir_libs._mlirDialectsGPU import *
+
+
+ at register_attribute_builder("builtin.GPU_AddressSpaceEnum")
+def _gpu_addressspaceenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.GPU_AllReduceOperation")
+def _gpu_allreduceoperation(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.GPU_CompilationTargetEnum")
+def _gpu_compilationtargetenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.GPU_Dimension")
+def _gpu_dimension(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.GPU_Prune2To4SpMatFlag")
+def _gpu_prune2to4spmatflag(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.GPU_ShuffleMode")
+def _gpu_shufflemode(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.GPU_SpGEMMWorkEstimationOrComputeKind")
+def _gpu_spgemmworkestimationorcomputekind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.GPU_TransposeMode")
+def _gpu_transposemode(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MMAElementWise")
+def _mmaelementwise(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MappingIdEnum")
+def _mappingidenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.ProcessorEnum")
+def _processorenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/index.py b/mlir/python/mlir/dialects/index.py
index 73708c7d71a8c8..f00c397965c97c 100644
--- a/mlir/python/mlir/dialects/index.py
+++ b/mlir/python/mlir/dialects/index.py
@@ -2,5 +2,11 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
 from ._index_ops_gen import *
 from ._index_enum_gen import *
+
+
+ at register_attribute_builder("builtin.IndexCmpPredicate")
+def _indexcmppredicate(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 8fb1227ee80ff5..4fe9cc40ee910a 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -102,3 +102,28 @@ def broadcast(
     )
     fill_builtin_region(op.operation)
     return op
+
+
+ at register_attribute_builder("builtin.BinaryFn")
+def _binaryfn(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.IteratorType")
+def _iteratortype(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TernaryFn")
+def _ternaryfn(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TypeFn")
+def _typefn(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.UnaryFn")
+def _unaryfn(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index c95cd5eecfffca..f87b25e8416023 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -888,6 +888,7 @@ def conv_2d_nchw_fchw_q(
         - TypeFn.cast_signed(U, IZp)
     ) * (TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) - TypeFn.cast_signed(U, KZp))
 
+
 @linalg_structured_op
 def conv_2d_nchw_fchw(
     I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py
index 941a584966dcde..456df39ae650ff 100644
--- a/mlir/python/mlir/dialects/llvm.py
+++ b/mlir/python/mlir/dialects/llvm.py
@@ -5,7 +5,7 @@
 from ._llvm_ops_gen import *
 from ._llvm_enum_gen import *
 from .._mlir_libs._mlirDialectsLLVM import *
-from ..ir import Value
+from ..ir import Value, IntegerAttr, IntegerType, register_attribute_builder
 from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
 
 
@@ -13,3 +13,108 @@ def mlir_constant(value, *, loc=None, ip=None) -> Value:
     return _get_op_result_or_op_results(
         ConstantOp(res=value.type, value=value, loc=loc, ip=ip)
     )
+
+
+ at register_attribute_builder("builtin.AsmATTOrIntel")
+def _asmattorintel(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.AtomicBinOp")
+def _atomicbinop(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.AtomicOrdering")
+def _atomicordering(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.CConvEnum")
+def _cconvenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.Comdat")
+def _comdat(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.DIFlags")
+def _diflags(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.DISubprogramFlags")
+def _disubprogramflags(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.FCmpPredicate")
+def _fcmppredicate(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.FPExceptionBehaviorAttr")
+def _fpexceptionbehaviorattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.FastmathFlags")
+def _fastmathflags(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.FramePointerKindEnum")
+def _framepointerkindenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.ICmpPredicate")
+def _icmppredicate(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.IntegerOverflowFlags")
+def _integeroverflowflags(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.LLVM_DIEmissionKind")
+def _llvm_diemissionkind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.LLVM_DINameTableKind")
+def _llvm_dinametablekind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.LinkageEnum")
+def _linkageenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.ModRefInfoEnum")
+def _modrefinfoenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.RoundingModeAttr")
+def _roundingmodeattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TailCallKindEnum")
+def _tailcallkindenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.UnnamedAddr")
+def _unnamedaddr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.Visibility")
+def _visibility(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/nvgpu.py b/mlir/python/mlir/dialects/nvgpu.py
index d6a54f2772f40d..eea132adb0484e 100644
--- a/mlir/python/mlir/dialects/nvgpu.py
+++ b/mlir/python/mlir/dialects/nvgpu.py
@@ -2,6 +2,32 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
 from ._nvgpu_ops_gen import *
 from ._nvgpu_enum_gen import *
 from .._mlir_libs._mlirDialectsNVGPU import *
+
+
+ at register_attribute_builder("builtin.RcpRoundingMode")
+def _rcproundingmode(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TensorMapInterleaveKind")
+def _tensormapinterleavekind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TensorMapL2PromoKind")
+def _tensormapl2promokind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TensorMapOOBKind")
+def _tensormapoobkind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TensorMapSwizzleKind")
+def _tensormapswizzlekind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/nvvm.py b/mlir/python/mlir/dialects/nvvm.py
index 9477de39c9ead7..21bf24cb73fdab 100644
--- a/mlir/python/mlir/dialects/nvvm.py
+++ b/mlir/python/mlir/dialects/nvvm.py
@@ -2,5 +2,81 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
 from ._nvvm_ops_gen import *
 from ._nvvm_enum_gen import *
+
+
+ at register_attribute_builder("builtin.LoadCacheModifierKind")
+def _loadcachemodifierkind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MMAB1Op")
+def _mmab1op(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MMAFrag")
+def _mmafrag(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MMAIntOverflow")
+def _mmaintoverflow(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MMALayout")
+def _mmalayout(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MMATypes")
+def _mmatypes(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MemScopeKind")
+def _memscopekind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.ProxyKind")
+def _proxykind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.ReduxKind")
+def _reduxkind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.SetMaxRegisterAction")
+def _setmaxregisteraction(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.SharedSpace")
+def _sharedspace(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.ShflKind")
+def _shflkind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.WGMMAScaleIn")
+def _wgmmascalein(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.WGMMAScaleOut")
+def _wgmmascaleout(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.WGMMATypes")
+def _wgmmatypes(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py
index 209ecc95fa8fc8..8f1b83f9d514fd 100644
--- a/mlir/python/mlir/dialects/sparse_tensor.py
+++ b/mlir/python/mlir/dialects/sparse_tensor.py
@@ -2,7 +2,23 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
 from ._sparse_tensor_ops_gen import *
 from ._sparse_tensor_enum_gen import *
 from .._mlir_libs._mlirDialectsSparseTensor import *
 from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses
+
+
+ at register_attribute_builder("builtin.SparseTensorCrdTransDirectionEnum")
+def _sparsetensorcrdtransdirectionenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.SparseTensorSortKindEnum")
+def _sparsetensorsortkindenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.SparseTensorStorageSpecifierKindEnum")
+def _sparsetensorstoragespecifierkindenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 5b158ec6b65fdd..5116b7c4309cbe 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -219,3 +219,13 @@ def __init__(
 
 def any_op_t() -> AnyOpTypeT:
     return AnyOpTypeT(AnyOpType.get())
+
+
+ at register_attribute_builder("builtin.FailurePropagationMode")
+def _failurepropagationmode(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.MatchCmpIPredicateAttr")
+def _matchcmpipredicateattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py
index 8d045cad7a4a36..3d7ffe1374890d 100644
--- a/mlir/python/mlir/dialects/transform/extras/__init__.py
+++ b/mlir/python/mlir/dialects/transform/extras/__init__.py
@@ -43,6 +43,7 @@ def __init__(
         self.parent = parent
         self.children = children if children is not None else []
 
+
 @ir.register_value_caster(AnyOpType.get_static_typeid())
 @ir.register_value_caster(OperationType.get_static_typeid())
 class OpHandle(Handle):
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 9121aa8e40237b..7f8cfedee5cb80 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -648,3 +648,13 @@ def __init__(
             loc=loc,
             ip=ip,
         )
+
+
+ at register_attribute_builder("builtin.MatchInterfaceEnum")
+def _matchinterfaceenum(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.TransposeMatmulInput")
+def _transposematmulinput(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/transform/vector.py b/mlir/python/mlir/dialects/transform/vector.py
index af2435cb26cc4b..fd366e2970afba 100644
--- a/mlir/python/mlir/dialects/transform/vector.py
+++ b/mlir/python/mlir/dialects/transform/vector.py
@@ -2,5 +2,26 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ...ir import IntegerAttr, IntegerType, register_attribute_builder
 from .._vector_transform_enum_gen import *
 from .._vector_transform_ops_gen import *
+
+
+ at register_attribute_builder("builtin.VectorContractLoweringAttr")
+def _vectorcontractloweringattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.VectorMultiReductionLoweringAttr")
+def _vectormultireductionloweringattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.VectorTransferSplitAttr")
+def _vectortransfersplitattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.VectorTransposeLoweringAttr")
+def _vectortransposeloweringattr(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/vector.py b/mlir/python/mlir/dialects/vector.py
index 7384e9a5aeef29..a0babf872918d0 100644
--- a/mlir/python/mlir/dialects/vector.py
+++ b/mlir/python/mlir/dialects/vector.py
@@ -2,5 +2,21 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
 from ._vector_ops_gen import *
 from ._vector_enum_gen import *
+
+
+ at register_attribute_builder("builtin.CombiningKind")
+def _combiningkind(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.PrintPunctuation")
+def _printpunctuation(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+ at register_attribute_builder("builtin.Vector_IteratorType")
+def _vector_iteratortype(x, context):
+    return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 9a6ce462047ad2..fa3b2178f2acfa 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -17,127 +17,127 @@ def decorator_builder(func):
     return decorator_builder
 
 
- at register_attribute_builder("AffineMapAttr")
+ at register_attribute_builder("builtin.AffineMapAttr")
 def _affineMapAttr(x, context):
     return AffineMapAttr.get(x)
 
 
- at register_attribute_builder("IntegerSetAttr")
+ at register_attribute_builder("builtin.IntegerSetAttr")
 def _integerSetAttr(x, context):
     return IntegerSetAttr.get(x)
 
 
- at register_attribute_builder("BoolAttr")
+ at register_attribute_builder("builtin.BoolAttr")
 def _boolAttr(x, context):
     return BoolAttr.get(x, context=context)
 
 
- at register_attribute_builder("DictionaryAttr")
+ at register_attribute_builder("builtin.DictionaryAttr")
 def _dictAttr(x, context):
     return DictAttr.get(x, context=context)
 
 
- at register_attribute_builder("IndexAttr")
+ at register_attribute_builder("builtin.IndexAttr")
 def _indexAttr(x, context):
     return IntegerAttr.get(IndexType.get(context=context), x)
 
 
- at register_attribute_builder("I1Attr")
+ at register_attribute_builder("builtin.I1Attr")
 def _i1Attr(x, context):
     return IntegerAttr.get(IntegerType.get_signless(1, context=context), x)
 
 
- at register_attribute_builder("I8Attr")
+ at register_attribute_builder("builtin.I8Attr")
 def _i8Attr(x, context):
     return IntegerAttr.get(IntegerType.get_signless(8, context=context), x)
 
 
- at register_attribute_builder("I16Attr")
+ at register_attribute_builder("builtin.I16Attr")
 def _i16Attr(x, context):
     return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
 
 
- at register_attribute_builder("I32Attr")
+ at register_attribute_builder("builtin.I32Attr")
 def _i32Attr(x, context):
     return IntegerAttr.get(IntegerType.get_signless(32, context=context), x)
 
 
- at register_attribute_builder("I64Attr")
+ at register_attribute_builder("builtin.I64Attr")
 def _i64Attr(x, context):
     return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
 
 
- at register_attribute_builder("SI1Attr")
+ at register_attribute_builder("builtin.SI1Attr")
 def _si1Attr(x, context):
     return IntegerAttr.get(IntegerType.get_signed(1, context=context), x)
 
 
- at register_attribute_builder("SI8Attr")
+ at register_attribute_builder("builtin.SI8Attr")
 def _si8Attr(x, context):
     return IntegerAttr.get(IntegerType.get_signed(8, context=context), x)
 
 
- at register_attribute_builder("SI16Attr")
+ at register_attribute_builder("builtin.SI16Attr")
 def _si16Attr(x, context):
     return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
 
 
- at register_attribute_builder("SI32Attr")
+ at register_attribute_builder("builtin.SI32Attr")
 def _si32Attr(x, context):
     return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
 
 
- at register_attribute_builder("SI64Attr")
+ at register_attribute_builder("builtin.SI64Attr")
 def _si64Attr(x, context):
     return IntegerAttr.get(IntegerType.get_signed(64, context=context), x)
 
 
- at register_attribute_builder("UI1Attr")
+ at register_attribute_builder("builtin.UI1Attr")
 def _ui1Attr(x, context):
     return IntegerAttr.get(IntegerType.get_unsigned(1, context=context), x)
 
 
- at register_attribute_builder("UI8Attr")
+ at register_attribute_builder("builtin.UI8Attr")
 def _ui8Attr(x, context):
     return IntegerAttr.get(IntegerType.get_unsigned(8, context=context), x)
 
 
- at register_attribute_builder("UI16Attr")
+ at register_attribute_builder("builtin.UI16Attr")
 def _ui16Attr(x, context):
     return IntegerAttr.get(IntegerType.get_unsigned(16, context=context), x)
 
 
- at register_attribute_builder("UI32Attr")
+ at register_attribute_builder("builtin.UI32Attr")
 def _ui32Attr(x, context):
     return IntegerAttr.get(IntegerType.get_unsigned(32, context=context), x)
 
 
- at register_attribute_builder("UI64Attr")
+ at register_attribute_builder("builtin.UI64Attr")
 def _ui64Attr(x, context):
     return IntegerAttr.get(IntegerType.get_unsigned(64, context=context), x)
 
 
- at register_attribute_builder("F32Attr")
+ at register_attribute_builder("builtin.F32Attr")
 def _f32Attr(x, context):
     return FloatAttr.get_f32(x, context=context)
 
 
- at register_attribute_builder("F64Attr")
+ at register_attribute_builder("builtin.F64Attr")
 def _f64Attr(x, context):
     return FloatAttr.get_f64(x, context=context)
 
 
- at register_attribute_builder("StrAttr")
+ at register_attribute_builder("builtin.StrAttr")
 def _stringAttr(x, context):
     return StringAttr.get(x, context=context)
 
 
- at register_attribute_builder("SymbolNameAttr")
+ at register_attribute_builder("builtin.SymbolNameAttr")
 def _symbolNameAttr(x, context):
     return StringAttr.get(x, context=context)
 
 
- at register_attribute_builder("SymbolRefAttr")
+ at register_attribute_builder("builtin.SymbolRefAttr")
 def _symbolRefAttr(x, context):
     if isinstance(x, list):
         return SymbolRefAttr.get(x, context=context)
@@ -145,12 +145,12 @@ def _symbolRefAttr(x, context):
         return FlatSymbolRefAttr.get(x, context=context)
 
 
- at register_attribute_builder("FlatSymbolRefAttr")
+ at register_attribute_builder("builtin.FlatSymbolRefAttr")
 def _flatSymbolRefAttr(x, context):
     return FlatSymbolRefAttr.get(x, context=context)
 
 
- at register_attribute_builder("UnitAttr")
+ at register_attribute_builder("builtin.UnitAttr")
 def _unitAttr(x, context):
     if x:
         return UnitAttr.get(context=context)
@@ -158,117 +158,117 @@ def _unitAttr(x, context):
         return None
 
 
- at register_attribute_builder("ArrayAttr")
+ at register_attribute_builder("builtin.ArrayAttr")
 def _arrayAttr(x, context):
     return ArrayAttr.get(x, context=context)
 
 
- at register_attribute_builder("AffineMapArrayAttr")
+ at register_attribute_builder("builtin.AffineMapArrayAttr")
 def _affineMapArrayAttr(x, context):
     return ArrayAttr.get([_affineMapAttr(v, context) for v in x])
 
 
- at register_attribute_builder("BoolArrayAttr")
+ at register_attribute_builder("builtin.BoolArrayAttr")
 def _boolArrayAttr(x, context):
     return ArrayAttr.get([_boolAttr(v, context) for v in x])
 
 
- at register_attribute_builder("DictArrayAttr")
+ at register_attribute_builder("builtin.DictArrayAttr")
 def _dictArrayAttr(x, context):
     return ArrayAttr.get([_dictAttr(v, context) for v in x])
 
 
- at register_attribute_builder("FlatSymbolRefArrayAttr")
+ at register_attribute_builder("builtin.FlatSymbolRefArrayAttr")
 def _flatSymbolRefArrayAttr(x, context):
     return ArrayAttr.get([_flatSymbolRefAttr(v, context) for v in x])
 
 
- at register_attribute_builder("I32ArrayAttr")
+ at register_attribute_builder("builtin.I32ArrayAttr")
 def _i32ArrayAttr(x, context):
     return ArrayAttr.get([_i32Attr(v, context) for v in x])
 
 
- at register_attribute_builder("I64ArrayAttr")
+ at register_attribute_builder("builtin.I64ArrayAttr")
 def _i64ArrayAttr(x, context):
     return ArrayAttr.get([_i64Attr(v, context) for v in x])
 
 
- at register_attribute_builder("I64SmallVectorArrayAttr")
+ at register_attribute_builder("builtin.I64SmallVectorArrayAttr")
 def _i64SmallVectorArrayAttr(x, context):
     return _i64ArrayAttr(x, context=context)
 
 
- at register_attribute_builder("IndexListArrayAttr")
+ at register_attribute_builder("builtin.IndexListArrayAttr")
 def _indexListArrayAttr(x, context):
     return ArrayAttr.get([_i64ArrayAttr(v, context) for v in x])
 
 
- at register_attribute_builder("F32ArrayAttr")
+ at register_attribute_builder("builtin.F32ArrayAttr")
 def _f32ArrayAttr(x, context):
     return ArrayAttr.get([_f32Attr(v, context) for v in x])
 
 
- at register_attribute_builder("F64ArrayAttr")
+ at register_attribute_builder("builtin.F64ArrayAttr")
 def _f64ArrayAttr(x, context):
     return ArrayAttr.get([_f64Attr(v, context) for v in x])
 
 
- at register_attribute_builder("StrArrayAttr")
+ at register_attribute_builder("builtin.StrArrayAttr")
 def _strArrayAttr(x, context):
     return ArrayAttr.get([_stringAttr(v, context) for v in x])
 
 
- at register_attribute_builder("SymbolRefArrayAttr")
+ at register_attribute_builder("builtin.SymbolRefArrayAttr")
 def _symbolRefArrayAttr(x, context):
     return ArrayAttr.get([_symbolRefAttr(v, context) for v in x])
 
 
- at register_attribute_builder("DenseF32ArrayAttr")
+ at register_attribute_builder("builtin.DenseF32ArrayAttr")
 def _denseF32ArrayAttr(x, context):
     return DenseF32ArrayAttr.get(x, context=context)
 
 
- at register_attribute_builder("DenseF64ArrayAttr")
+ at register_attribute_builder("builtin.DenseF64ArrayAttr")
 def _denseF64ArrayAttr(x, context):
     return DenseF64ArrayAttr.get(x, context=context)
 
 
- at register_attribute_builder("DenseI8ArrayAttr")
+ at register_attribute_builder("builtin.DenseI8ArrayAttr")
 def _denseI8ArrayAttr(x, context):
     return DenseI8ArrayAttr.get(x, context=context)
 
 
- at register_attribute_builder("DenseI16ArrayAttr")
+ at register_attribute_builder("builtin.DenseI16ArrayAttr")
 def _denseI16ArrayAttr(x, context):
     return DenseI16ArrayAttr.get(x, context=context)
 
 
- at register_attribute_builder("DenseI32ArrayAttr")
+ at register_attribute_builder("builtin.DenseI32ArrayAttr")
 def _denseI32ArrayAttr(x, context):
     return DenseI32ArrayAttr.get(x, context=context)
 
 
- at register_attribute_builder("DenseI64ArrayAttr")
+ at register_attribute_builder("builtin.DenseI64ArrayAttr")
 def _denseI64ArrayAttr(x, context):
     return DenseI64ArrayAttr.get(x, context=context)
 
 
- at register_attribute_builder("DenseBoolArrayAttr")
+ at register_attribute_builder("builtin.DenseBoolArrayAttr")
 def _denseBoolArrayAttr(x, context):
     return DenseBoolArrayAttr.get(x, context=context)
 
 
- at register_attribute_builder("TypeAttr")
+ at register_attribute_builder("builtin.TypeAttr")
 def _typeAttr(x, context):
     return TypeAttr.get(x, context=context)
 
 
- at register_attribute_builder("TypeArrayAttr")
+ at register_attribute_builder("builtin.TypeArrayAttr")
 def _typeArrayAttr(x, context):
     return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context)
 
 
- at register_attribute_builder("MemRefTypeAttr")
+ at register_attribute_builder("builtin.MemRefTypeAttr")
 def _memref_type_attr(x, context):
     return _typeAttr(x, context)
 
@@ -276,7 +276,7 @@ def _memref_type_attr(x, context):
 try:
     import numpy as np
 
-    @register_attribute_builder("F64ElementsAttr")
+    @register_attribute_builder("builtin.F64ElementsAttr")
     def _f64ElementsAttr(x, context):
         return DenseElementsAttr.get(
             np.array(x, dtype=np.float64),
@@ -284,7 +284,7 @@ def _f64ElementsAttr(x, context):
             context=context,
         )
 
-    @register_attribute_builder("I32ElementsAttr")
+    @register_attribute_builder("builtin.I32ElementsAttr")
     def _i32ElementsAttr(x, context):
         return DenseElementsAttr.get(
             np.array(x, dtype=np.int32),
@@ -292,7 +292,7 @@ def _i32ElementsAttr(x, context):
             context=context,
         )
 
-    @register_attribute_builder("I64ElementsAttr")
+    @register_attribute_builder("builtin.I64ElementsAttr")
     def _i64ElementsAttr(x, context):
         return DenseElementsAttr.get(
             np.array(x, dtype=np.int64),
@@ -300,7 +300,7 @@ def _i64ElementsAttr(x, context):
             context=context,
         )
 
-    @register_attribute_builder("IndexElementsAttr")
+    @register_attribute_builder("builtin.IndexElementsAttr")
     def _indexElementsAttr(x, context):
         return DenseElementsAttr.get(
             np.array(x, dtype=np.int64),
diff --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td
index 1c5567f54a5f4b..ec27dd45903ebb 100644
--- a/mlir/test/mlir-tblgen/enums-python-bindings.td
+++ b/mlir/test/mlir-tblgen/enums-python-bindings.td
@@ -35,10 +35,6 @@ def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two, NegOne]>
 // CHECK:             return "negone"
 // CHECK:         raise ValueError("Unknown MyEnum enum entry.")
 
-// CHECK: @register_attribute_builder("MyEnum")
-// CHECK: def _myenum(x, context):
-// CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
-
 def TestMyEnum_Attr : EnumAttr<Test_Dialect, MyEnum, "enum">;
 
 def One64 : I64EnumAttrCase<"CaseOne64", 1, "one">;
@@ -58,10 +54,6 @@ def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>
 // CHECK:             return "two"
 // CHECK:         raise ValueError("Unknown MyEnum64 enum entry.")
 
-// CHECK: @register_attribute_builder("MyEnum64")
-// CHECK: def _myenum64(x, context):
-// CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
-
 def TestBitEnum
     : I32BitEnumAttr<"TestBitEnum", "", [
         I32BitEnumAttrCaseBit<"User", 0, "user">,
@@ -96,14 +88,10 @@ def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
 // CHECK:             return "other"
 // CHECK:         raise ValueError("Unknown TestBitEnum enum entry.")
 
-// CHECK: @register_attribute_builder("TestBitEnum")
-// CHECK: def _testbitenum(x, context):
-// CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
-
-// CHECK: @register_attribute_builder("TestBitEnum_Attr")
+// CHECK: @register_attribute_builder("TestDialect.TestBitEnum_Attr")
 // CHECK: def _testbitenum_attr(x, context):
 // CHECK:     return _ods_ir.Attribute.parse(f'#TestDialect<testbitenum {str(x)}>', context=context)
 
-// CHECK: @register_attribute_builder("TestMyEnum_Attr")
+// CHECK: @register_attribute_builder("TestDialect.TestMyEnum_Attr")
 // CHECK: def _testmyenum_attr(x, context):
 // CHECK:     return _ods_ir.Attribute.parse(f'#TestDialect<enum {str(x)}>', context=context)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 632046389e12cf..d8a662038df16a 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -123,8 +123,8 @@ def AttributedOp : TestOp<"attributed_op"> {
   // CHECK:   regions = None
   // CHECK:   attributes["i32attr"] = (i32attr if (
   // CHECK-NEXT:   isinstance(i32attr, _ods_ir.Attribute) or
-  // CHECK-NEXT:   not _ods_ir.AttrBuilder.contains('I32Attr')
-  // CHECK-NEXT:   _ods_ir.AttrBuilder.get('I32Attr')(i32attr, context=_ods_context)
+  // CHECK-NEXT:   not _ods_ir.AttrBuilder.contains('builtin.I32Attr')
+  // CHECK-NEXT:   _ods_ir.AttrBuilder.get('builtin.I32Attr')(i32attr, context=_ods_context)
   // CHECK:   if optionalF32Attr is not None: attributes["optionalF32Attr"] = (optionalF32Attr
   // CHECK:   if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
   // CHECK:     _ods_get_default_loc_context(loc))
diff --git a/mlir/test/python/dialects/index_dialect.py b/mlir/test/python/dialects/index_dialect.py
index 9db883469792c5..8da6a262cc441e 100644
--- a/mlir/test/python/dialects/index_dialect.py
+++ b/mlir/test/python/dialects/index_dialect.py
@@ -94,7 +94,7 @@ def testCeilDivUOp(ctx):
 def testCmpOp(ctx):
     a = index.ConstantOp(value=42)
     b = index.ConstantOp(value=23)
-    pred = AttrBuilder.get("IndexCmpPredicateAttr")("slt", context=ctx)
+    pred = AttrBuilder.get("index.IndexCmpPredicateAttr")("slt", context=ctx)
     r = index.CmpOp(pred, lhs=a, rhs=b)
     # CHECK: %{{.*}} = index.cmp slt(%{{.*}}, %{{.*}})
 
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index fb4c75b5337928..63d30b498c8b0c 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -586,7 +586,7 @@ def testMatchInterfaceEnum(target):
 @run
 @create_sequence
 def testMatchInterfaceEnumReplaceAttributeBuilder(target):
-    @register_attribute_builder("MatchInterfaceEnum", replace=True)
+    @register_attribute_builder("builtin.MatchInterfaceEnum", replace=True)
     def match_interface_enum(x, context):
         if x == "LinalgOp":
             y = 0
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index 3f660ae151c749..9cfe55399745aa 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -16,6 +16,8 @@
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/Dialect.h"
 #include "mlir/TableGen/GenInfo.h"
+
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/TableGen/Record.h"
 
@@ -36,6 +38,8 @@ _ods_ir = _ods_cext.ir
 
 )Py";
 
+extern llvm::cl::opt<std::string> clDialectName;
+
 /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
 static std::string makePythonEnumCaseName(StringRef name) {
   if (isPythonReserved(name.str()))
@@ -106,7 +110,7 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
     return true;
   }
 
-  os << formatv("@register_attribute_builder(\"{0}\")\n",
+  os << formatv("@register_attribute_builder(\"builtin.{0}\")\n",
                 enumAttr.getAttrDefName());
   os << formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower());
   os << formatv("    return "
@@ -119,10 +123,12 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
 /// Emits an attribute builder for the given dialect enum attribute to support
 /// automatic conversion between enum values and attributes in Python. Returns
 /// `false` on success, `true` on failure.
-static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
+static bool emitDialectEnumAttributeBuilder(StringRef dialect,
+                                            StringRef attrDefName,
                                             StringRef formatString,
                                             raw_ostream &os) {
-  os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
+  os << formatv("@register_attribute_builder(\"{0}.{1}\")\n", dialect,
+                attrDefName);
   os << formatv("def _{0}(x, context):\n", attrDefName.lower());
   os << formatv("    return "
                 "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
@@ -138,11 +144,15 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
        records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
     EnumAttr enumAttr(*it);
     emitEnumClass(enumAttr, os);
-    emitAttributeBuilder(enumAttr, os);
+    if (clDialectName.empty())
+      emitAttributeBuilder(enumAttr, os);
   }
   for (const Record *it :
        records.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
     AttrOrTypeDef attr(&*it);
+    StringRef dialect = attr.getDialect().getName();
+    if (!clDialectName.empty() && dialect != clDialectName)
+      continue;
     if (!attr.getMnemonic()) {
       llvm::errs() << "enum case " << attr
                    << " needs mnemonic for python enum bindings generation";
@@ -150,14 +160,13 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
     }
     StringRef mnemonic = attr.getMnemonic().value();
     std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
-    StringRef dialect = attr.getDialect().getName();
     if (assemblyFormat == "`<` $value `>`") {
       emitDialectEnumAttributeBuilder(
-          attr.getName(),
+          dialect, attr.getName(),
           formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
     } else if (assemblyFormat == "$value") {
       emitDialectEnumAttributeBuilder(
-          attr.getName(),
+          dialect, attr.getName(),
           formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
     } else {
       llvm::errs()
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 5019b69d91127e..003749b5d32be9 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -283,7 +283,7 @@ def {0}({2}) -> {4}:
 static llvm::cl::OptionCategory
     clOpPythonBindingCat("Options for -gen-python-op-bindings");
 
-static llvm::cl::opt<std::string>
+llvm::cl::opt<std::string>
     clDialectName("bind-dialect",
                   llvm::cl::desc("The dialect to run the generator for"),
                   llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
@@ -672,12 +672,15 @@ populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames,
           formatv(initUnitAttributeTemplate, attribute->name, argNames[i]));
       continue;
     }
-
+    Dialect maybeDialect = attribute->attr.getBaseAttr().getDialect();
+    auto disambigAttrName = formatv(
+        "{0}.{1}", bool(maybeDialect) ? maybeDialect.getName() : "builtin",
+        attribute->attr.getAttrDefName());
     builderLines.push_back(formatv(
         attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
             ? initOptionalAttributeWithBuilderTemplate
             : initAttributeWithBuilderTemplate,
-        argNames[i], attribute->name, attribute->attr.getAttrDefName()));
+        argNames[i], attribute->name, disambigAttrName));
   }
 }
 
    
    
More information about the Mlir-commits
mailing list