[Mlir-commits] [mlir] 9223306 - [mlir][python bindings] generate all the enums

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 23 13:05:42 PDT 2023


Author: max
Date: 2023-08-23T15:03:55-05:00
New Revision: 92233062c17590d3157bdc6db430fcdfc54312fe

URL: https://github.com/llvm/llvm-project/commit/92233062c17590d3157bdc6db430fcdfc54312fe
DIFF: https://github.com/llvm/llvm-project/commit/92233062c17590d3157bdc6db430fcdfc54312fe.diff

LOG: [mlir][python bindings] generate all the enums

This PR implements python enum bindings for *all* the enums - this includes `I*Attrs` (including positional/bit) and `Dialect/EnumAttr`.

There are a few parts to this:

1. CMake: a small addition to `declare_mlir_dialect_python_bindings` and `declare_mlir_dialect_extension_python_bindings` to generate the enum, a boolean arg `GEN_ENUM_BINDINGS` to make it opt-in (even though it works for basically all of the dialects), and an optional `GEN_ENUM_BINDINGS_TD_FILE` for handling corner cases.
2. EnumPythonBindingGen.cpp: there are two weedy aspects here that took investigation:
    1. If an enum attribute is not a `Dialect/EnumAttr` then the `EnumAttrInfo` record is canonical, as far as both the cases of the enum **and the `AttrDefName`**. On the otherhand, if an enum is a `Dialect/EnumAttr` then the `EnumAttr` record has the correct `AttrDefName` ("load bearing", i.e., populates `ods.ir.AttributeBuilder('<NAME>')`) but its `enum` field contains the cases, which is an instance of `EnumAttrInfo`. The solution is to generate an one enum class for both `Dialect/EnumAttr` and "independent" `EnumAttrInfo` but to make that class interopable with two builder registrations that both do the right thing (see next sub-bullet).
    2. Because we don't have a good connection to cpp `EnumAttr`, i.e., only the `enum class` getters are exposed (like `DimensionAttr::get(Dimension value)`), we have to resort to parsing e.g., `Attribute.parse(f'#gpu<dim {x}>')`. This means that the set of supported `assemblyFormat`s (for the enum) is fixed at compile of MLIR (currently 2, the only 2 I saw). There might be some things that could be done here but they would require quite a bit more C API work to support generically (e.g., casting ints to enum cases and binding all the getters or going generically through the `symbolize*` methods, like `symbolizeDimension(uint32_t)` or `symbolizeDimension(StringRef)`).

A few small changes:

1. In addition, since this patch registers default builders for attributes where people might've had their own builders already written, I added a `replace` param to `AttributeBuilder.insert` (`False` by default).
2. `makePythonEnumCaseName` can't handle all the different ways in which people write their enum cases, e.g., `llvm.CConv.Intel_OCL_BI`, which gets turned into `INTEL_O_C_L_B_I` (because `llvm::convertToSnakeFromCamelCase` doesn't look for runs of caps). So I dropped it. On the otherhand regularization does need to done because some enums have `None` as a case (and others might have other python keywords).
3. I turned on `llvm` dialect generation here in order to test `nvvm.WGMMAScaleIn`, which is an enum with [[ https://github.com/llvm/llvm-project/blob/d7e26b56207cbd8995296c5bb7c11ce676b649da/mlir/include/mlir/IR/EnumAttr.td#L22-L25 | no explicit discriminator ]] for the `neg` case.

Note, dialects that didn't get a `GEN_ENUM_BINDINGS` don't have any enums to generate.

Let me know if I should add more tests (the three trivial ones I added exercise both the supported `assemblyFormat`s and `replace=True`).

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D157934

Added: 
    mlir/python/mlir/dialects/LLVMOps.td
    mlir/python/mlir/dialects/llvm.py
    mlir/test/python/dialects/llvm.py

Modified: 
    mlir/cmake/modules/AddMLIRPython.cmake
    mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
    mlir/lib/Bindings/Python/Globals.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.cpp
    mlir/python/CMakeLists.txt
    mlir/python/mlir/dialects/amdgpu.py
    mlir/python/mlir/dialects/arith.py
    mlir/python/mlir/dialects/bufferization.py
    mlir/python/mlir/dialects/gpu/__init__.py
    mlir/python/mlir/dialects/linalg/__init__.py
    mlir/python/mlir/dialects/nvgpu.py
    mlir/python/mlir/dialects/nvvm.py
    mlir/python/mlir/dialects/sparse_tensor.py
    mlir/python/mlir/dialects/transform/bufferization.py
    mlir/python/mlir/dialects/transform/structured.py
    mlir/python/mlir/dialects/vector.py
    mlir/python/mlir/ir.py
    mlir/test/mlir-tblgen/enums-python-bindings.td
    mlir/test/python/dialects/arith_dialect.py
    mlir/test/python/dialects/gpu.py
    mlir/test/python/dialects/nvvm.py
    mlir/test/python/dialects/transform.py
    mlir/test/python/dialects/transform_bufferization_ext.py
    mlir/test/python/dialects/transform_gpu_ext.py
    mlir/test/python/dialects/transform_loop_ext.py
    mlir/test/python/dialects/transform_memref_ext.py
    mlir/test/python/dialects/transform_structured_ext.py
    mlir/test/python/dialects/transform_tensor_ext.py
    mlir/test/python/dialects/transform_vector_ext.py
    mlir/test/python/dialects/vector.py
    mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
    mlir/tools/mlir-tblgen/OpGenHelpers.cpp
    mlir/tools/mlir-tblgen/OpGenHelpers.h
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 5349bb8302ab94..012380603a4c45 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -272,6 +272,11 @@ endfunction()
 #   SOURCES: Same as declare_mlir_python_sources().
 #   SOURCES_GLOB: Same as declare_mlir_python_sources().
 #   DEPENDS: Additional dependency targets.
+#   GEN_ENUM_BINDINGS: Generate enum bindings.
+#   GEN_ENUM_BINDINGS_TD_FILE: Optional Tablegen file to generate enums for (relative to ROOT_DIR).
+#     This file is where the *EnumAttrs are defined, not where the *Enums are defined.
+#     **WARNING**: This arg will shortly be removed when the just-below TODO is satisfied. Use at your
+#     risk.
 #
 # TODO: Right now `TD_FILE` can't be the actual dialect tablegen file, since we
 #       use its path to determine where to place the generated python file. If
@@ -279,9 +284,9 @@ endfunction()
 #       need for the separate "wrapper" .td files
 function(declare_mlir_dialect_python_bindings)
   cmake_parse_arguments(ARG
-    ""
+    "GEN_ENUM_BINDINGS"
     "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME"
-    "SOURCES;SOURCES_GLOB;DEPENDS"
+    "SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE"
     ${ARGN})
   # Sources.
   set(_dialect_target "${ARG_ADD_TO_PARENT}.${ARG_DIALECT_NAME}")
@@ -306,11 +311,22 @@ function(declare_mlir_dialect_python_bindings)
     )
     add_public_tablegen_target(${tblgen_target})
 
+    set(_sources ${dialect_filename})
+    if(ARG_GEN_ENUM_BINDINGS OR ARG_GEN_ENUM_BINDINGS_TD_FILE)
+      if(ARG_GEN_ENUM_BINDINGS_TD_FILE)
+        set(td_file "${ARG_ROOT_DIR}/${ARG_GEN_ENUM_BINDINGS_TD_FILE}")
+        set(LLVM_TARGET_DEFINITIONS ${td_file})
+      endif()
+      set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
+      mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+      list(APPEND _sources ${enum_filename})
+    endif()
+
     # Generated.
     declare_mlir_python_sources("${_dialect_target}.ops_gen"
       ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
       ADD_TO_PARENT "${_dialect_target}"
-      SOURCES "${dialect_filename}"
+      SOURCES ${_sources}
     )
   endif()
 endfunction()
@@ -331,11 +347,16 @@ endfunction()
 #   SOURCES: Same as declare_mlir_python_sources().
 #   SOURCES_GLOB: Same as declare_mlir_python_sources().
 #   DEPENDS: Additional dependency targets.
+#   GEN_ENUM_BINDINGS: Generate enum bindings.
+#   GEN_ENUM_BINDINGS_TD_FILE: Optional Tablegen file to generate enums for (relative to ROOT_DIR).
+#     This file is where the *Attrs are defined, not where the *Enums are defined.
+#     **WARNING**: This arg will shortly be removed when the TODO for
+#     declare_mlir_dialect_python_bindings is satisfied. Use at your risk.
 function(declare_mlir_dialect_extension_python_bindings)
   cmake_parse_arguments(ARG
-    ""
+    "GEN_ENUM_BINDINGS"
     "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;EXTENSION_NAME"
-    "SOURCES;SOURCES_GLOB;DEPENDS"
+    "SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE"
     ${ARGN})
   # Source files.
   set(_extension_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}")
@@ -362,10 +383,21 @@ function(declare_mlir_dialect_extension_python_bindings)
       add_dependencies(${tblgen_target} ${ARG_DEPENDS})
     endif()
 
+    set(_sources ${output_filename})
+    if(ARG_GEN_ENUM_BINDINGS OR ARG_GEN_ENUM_BINDINGS_TD_FILE)
+      if(ARG_GEN_ENUM_BINDINGS_TD_FILE)
+        set(td_file "${ARG_ROOT_DIR}/${ARG_GEN_ENUM_BINDINGS_TD_FILE}")
+        set(LLVM_TARGET_DEFINITIONS ${td_file})
+      endif()
+      set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
+      mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+      list(APPEND _sources ${enum_filename})
+    endif()
+
     declare_mlir_python_sources("${_extension_target}.ops_gen"
       ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
       ADD_TO_PARENT "${_extension_target}"
-      SOURCES "${output_filename}"
+      SOURCES ${_sources}
     )
   endif()
 endfunction()

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index a778f305b6c1b0..59f909aed8f61a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -1,4 +1,4 @@
-//===- LinalgBase.td - Linalg dialect base support ---------*- tablegen -*-===//
+//===- LinalgEnums.td - Linalg dialect base support ---------*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.

diff  --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 0fc7614ccad52c..97cd70089a2e96 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -58,10 +58,11 @@ class PyGlobals {
   void loadDialectModule(llvm::StringRef dialectNamespace);
 
   /// Adds a user-friendly Attribute builder.
-  /// Raises an exception if the mapping already exists.
+  /// Raises an exception if the mapping already exists and replace == false.
   /// This is intended to be called by implementation code.
   void registerAttributeBuilder(const std::string &attributeKind,
-                                pybind11::function pyFunc);
+                                pybind11::function pyFunc,
+                                bool replace = false);
 
   /// Adds a user-friendly type caster. Raises an exception if the mapping
   /// already exists and replace == false. This is intended to be called by

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index e1b8d296a7d1e5..b06937bc285e20 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -242,19 +242,23 @@ struct PyAttrBuilderMap {
   static py::function dundeGetItemNamed(const std::string &attributeKind) {
     auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
     if (!builder)
-      throw py::key_error();
+      throw py::key_error(attributeKind);
     return *builder;
   }
   static void dundeSetItemNamed(const std::string &attributeKind,
-                                py::function func) {
-    PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func));
+                                py::function func, bool replace) {
+    PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
+                                              replace);
   }
 
   static void bind(py::module &m) {
     py::class_<PyAttrBuilderMap>(m, "AttrBuilder", py::module_local())
         .def_static("contains", &PyAttrBuilderMap::dunderContains)
         .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
-        .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed);
+        .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed,
+                    "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
+                    "Register an attribute builder for building MLIR "
+                    "attributes from python values.");
   }
 };
 

diff  --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index d9a66bce0fecb1..2cc66277abee0f 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -63,11 +63,13 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
 }
 
 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
-                                         py::function pyFunc) {
+                                         py::function pyFunc, bool replace) {
   py::object &found = attributeBuilderMap[attributeKind];
-  if (found) {
+  if (found && !found.is_none() && !replace) {
     throw std::runtime_error((llvm::Twine("Attribute builder for '") +
-                              attributeKind + "' is already registered")
+                              attributeKind +
+                              "' is already registered with func: " +
+                              py::str(found).operator std::string())
                                  .str());
   }
   found = std::move(pyFunc);

diff  --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 05d09eaf7b8b25..225da778cf3b3a 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -52,7 +52,8 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/AMDGPUOps.td
   SOURCES
     dialects/amdgpu.py
-  DIALECT_NAME amdgpu)
+  DIALECT_NAME amdgpu
+  GEN_ENUM_BINDINGS)
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -68,7 +69,10 @@ declare_mlir_dialect_python_bindings(
   SOURCES
     dialects/bufferization.py
     dialects/_bufferization_ops_ext.py
-  DIALECT_NAME bufferization)
+  DIALECT_NAME bufferization
+  GEN_ENUM_BINDINGS_TD_FILE
+    "../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
+)
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -109,7 +113,8 @@ declare_mlir_dialect_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/GPUOps.td
   SOURCES_GLOB dialects/gpu/*.py
-  DIALECT_NAME gpu)
+  DIALECT_NAME gpu
+  GEN_ENUM_BINDINGS)
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -120,7 +125,17 @@ declare_mlir_dialect_python_bindings(
   SOURCES_GLOB
     dialects/linalg/*.py
   DIALECT_NAME linalg
-  DEPENDS LinalgOdsGen)
+  DEPENDS LinalgOdsGen
+  GEN_ENUM_BINDINGS)
+
+declare_mlir_dialect_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/LLVMOps.td
+  SOURCES
+    dialects/llvm.py
+  DIALECT_NAME llvm
+  GEN_ENUM_BINDINGS)
 
 declare_mlir_dialect_extension_python_bindings(
 ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -140,16 +155,10 @@ declare_mlir_dialect_python_bindings(
     dialects/_transform_ops_ext.py
     dialects/transform/__init__.py
     _mlir_libs/_mlir/dialects/transform/__init__.pyi
-  DIALECT_NAME transform)
-
-set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/TransformOps.td")
-mlir_tablegen("dialects/_transform_enum_gen.py" -gen-python-enum-bindings)
-add_public_tablegen_target(MLIRTransformDialectPyEnumGen)
-declare_mlir_python_sources(
-  MLIRPythonSources.Dialects.transform.enum_gen
-  ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
-  ADD_TO_PARENT MLIRPythonSources.Dialects.transform
-  SOURCES "dialects/_transform_enum_gen.py")
+  DIALECT_NAME transform
+  GEN_ENUM_BINDINGS_TD_FILE
+    "../../include/mlir/Dialect/Transform/IR/TransformAttrs.td"
+)
 
 declare_mlir_dialect_extension_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -161,15 +170,6 @@ declare_mlir_dialect_extension_python_bindings(
   DIALECT_NAME transform
   EXTENSION_NAME bufferization_transform)
 
-set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/BufferizationTransformOps.td")
-mlir_tablegen("dialects/_bufferization_transform_enum_gen.py" -gen-python-enum-bindings)
-add_public_tablegen_target(MLIRBufferizationTransformDialectPyEnumGen)
-declare_mlir_python_sources(
-  MLIRPythonSources.Dialects.bufferization_transform.enum_gen
-  ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
-  ADD_TO_PARENT MLIRPythonSources.Dialects.bufferization_transform
-  SOURCES "dialects/_bufferization_transform_enum_gen.py")
-
 declare_mlir_dialect_extension_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@@ -208,7 +208,10 @@ declare_mlir_dialect_extension_python_bindings(
     dialects/_structured_transform_ops_ext.py
     dialects/transform/structured.py
   DIALECT_NAME transform
-  EXTENSION_NAME structured_transform)
+  EXTENSION_NAME structured_transform
+  GEN_ENUM_BINDINGS_TD_FILE
+    "../../include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
+)
 
 declare_mlir_dialect_extension_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -227,16 +230,10 @@ declare_mlir_dialect_extension_python_bindings(
   SOURCES
     dialects/transform/vector.py
   DIALECT_NAME transform
-  EXTENSION_NAME vector_transform)
-
-set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/VectorTransformOps.td")
-mlir_tablegen("dialects/_vector_transform_enum_gen.py" -gen-python-enum-bindings)
-add_public_tablegen_target(MLIRVectorTransformPyEnumGen)
-declare_mlir_python_sources(
-  MLIRPythonSources.Dialects.vector_transform.enum_gen
-  ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
-  ADD_TO_PARENT MLIRPythonSources.Dialects.vector_transform
-  SOURCES "dialects/_vector_transform_enum_gen.py" )
+  EXTENSION_NAME vector_transform
+  GEN_ENUM_BINDINGS_TD_FILE
+    "../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
+)
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -252,7 +249,8 @@ declare_mlir_dialect_python_bindings(
   SOURCES
     dialects/arith.py
     dialects/_arith_ops_ext.py
-  DIALECT_NAME arith)
+  DIALECT_NAME arith
+  GEN_ENUM_BINDINGS)
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -278,7 +276,8 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/NVGPUOps.td
   SOURCES
     dialects/nvgpu.py
-  DIALECT_NAME nvgpu)
+  DIALECT_NAME nvgpu
+  GEN_ENUM_BINDINGS)
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -286,7 +285,8 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/NVVMOps.td
   SOURCES
     dialects/nvvm.py
-  DIALECT_NAME nvvm)
+  DIALECT_NAME nvvm
+  GEN_ENUM_BINDINGS)
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -300,6 +300,7 @@ declare_mlir_python_sources(
   MLIRPythonSources.Dialects.quant
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  GEN_ENUM_BINDINGS
   SOURCES
     dialects/quant.py
     _mlir_libs/_mlir/dialects/quant.pyi)
@@ -335,7 +336,10 @@ declare_mlir_dialect_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/SparseTensorOps.td
   SOURCES dialects/sparse_tensor.py
-  DIALECT_NAME sparse_tensor)
+  DIALECT_NAME sparse_tensor
+  GEN_ENUM_BINDINGS_TD_FILE
+    "../../include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
+)
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -351,14 +355,16 @@ declare_mlir_dialect_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/TosaOps.td
   SOURCES dialects/tosa.py
-  DIALECT_NAME tosa)
+  DIALECT_NAME tosa
+)
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/VectorOps.td
   SOURCES dialects/vector.py
-  DIALECT_NAME vector)
+  DIALECT_NAME vector
+  GEN_ENUM_BINDINGS)
 
 ################################################################################
 # Python extensions.

diff  --git a/mlir/python/mlir/dialects/LLVMOps.td b/mlir/python/mlir/dialects/LLVMOps.td
new file mode 100644
index 00000000000000..dcf2f4245cf49f
--- /dev/null
+++ b/mlir/python/mlir/dialects/LLVMOps.td
@@ -0,0 +1,14 @@
+//===-- LlvmOps.td - Entry point for llvm bind ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_LLVM_OPS
+#define PYTHON_BINDINGS_LLVM_OPS
+
+include "mlir/Dialect/LLVMIR/LLVMOps.td"
+
+#endif

diff  --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py
index 35283278e8fb03..43d905d0c481cc 100644
--- a/mlir/python/mlir/dialects/amdgpu.py
+++ b/mlir/python/mlir/dialects/amdgpu.py
@@ -3,3 +3,4 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._amdgpu_ops_gen import *
+from ._amdgpu_enum_gen import *

diff  --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 77318b2869fc69..fb13beb63ca66c 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -3,3 +3,4 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._arith_ops_gen import *
+from ._arith_enum_gen import *

diff  --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
index 2121122f12764e..759b6aa24a9ff7 100644
--- a/mlir/python/mlir/dialects/bufferization.py
+++ b/mlir/python/mlir/dialects/bufferization.py
@@ -3,3 +3,4 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._bufferization_ops_gen import *
+from ._bufferization_enum_gen import *

diff  --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py
index 67bf7bd854e15e..033386b0f803b2 100644
--- a/mlir/python/mlir/dialects/gpu/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/__init__.py
@@ -3,3 +3,4 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from .._gpu_ops_gen import *
+from .._gpu_enum_gen import *

diff  --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index eadb8420c06a9f..1353870ec7257a 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -9,6 +9,7 @@
 # definitions following these steps:
 #   DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
 from .._linalg_ops_gen import *
+from .._linalg_enum_gen import *
 
 # These are the ground truth functions defined as:
 # ```

diff  --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py
new file mode 100644
index 00000000000000..77025438c37a4f
--- /dev/null
+++ b/mlir/python/mlir/dialects/llvm.py
@@ -0,0 +1,6 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._llvm_ops_gen import *
+from ._llvm_enum_gen import *

diff  --git a/mlir/python/mlir/dialects/nvgpu.py b/mlir/python/mlir/dialects/nvgpu.py
index afd570cae5300f..2f6993b768ca53 100644
--- a/mlir/python/mlir/dialects/nvgpu.py
+++ b/mlir/python/mlir/dialects/nvgpu.py
@@ -3,3 +3,4 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._nvgpu_ops_gen import *
+from ._nvgpu_enum_gen import *

diff  --git a/mlir/python/mlir/dialects/nvvm.py b/mlir/python/mlir/dialects/nvvm.py
index 87b2a4fd6bf853..9477de39c9ead7 100644
--- a/mlir/python/mlir/dialects/nvvm.py
+++ b/mlir/python/mlir/dialects/nvvm.py
@@ -3,3 +3,4 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._nvvm_ops_gen import *
+from ._nvvm_enum_gen import *

diff  --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py
index 769418e049693d..209ecc95fa8fc8 100644
--- a/mlir/python/mlir/dialects/sparse_tensor.py
+++ b/mlir/python/mlir/dialects/sparse_tensor.py
@@ -3,5 +3,6 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._sparse_tensor_ops_gen import *
+from ._sparse_tensor_enum_gen import *
 from .._mlir_libs._mlirDialectsSparseTensor import *
 from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses

diff  --git a/mlir/python/mlir/dialects/transform/bufferization.py b/mlir/python/mlir/dialects/transform/bufferization.py
index 1891bc0e1d6a1c..eb77b746cf864f 100644
--- a/mlir/python/mlir/dialects/transform/bufferization.py
+++ b/mlir/python/mlir/dialects/transform/bufferization.py
@@ -2,5 +2,4 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from .._bufferization_transform_enum_gen import *
 from .._bufferization_transform_ops_gen import *

diff  --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index b8ee48c42945af..cb3812301dbd4b 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -3,3 +3,4 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from .._structured_transform_ops_gen import *
+from .._structured_transform_enum_gen import *

diff  --git a/mlir/python/mlir/dialects/vector.py b/mlir/python/mlir/dialects/vector.py
index 610c0b204c6be9..7384e9a5aeef29 100644
--- a/mlir/python/mlir/dialects/vector.py
+++ b/mlir/python/mlir/dialects/vector.py
@@ -3,3 +3,4 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._vector_ops_gen import *
+from ._vector_enum_gen import *

diff  --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index e36736f2974f0c..36c49fe6f1d6bd 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -8,9 +8,9 @@
 
 
 # Convenience decorator for registering user-friendly Attribute builders.
-def register_attribute_builder(kind):
+def register_attribute_builder(kind, replace=False):
     def decorator_builder(func):
-        AttrBuilder.insert(kind, func)
+        AttrBuilder.insert(kind, func, replace=replace)
         return func
 
     return decorator_builder

diff  --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td
index 5272eba50f0e7a..1c5567f54a5f4b 100644
--- a/mlir/test/mlir-tblgen/enums-python-bindings.td
+++ b/mlir/test/mlir-tblgen/enums-python-bindings.td
@@ -2,56 +2,108 @@
 
 include "mlir/IR/EnumAttr.td"
 
+def Test_Dialect : Dialect {
+  let name = "TestDialect";
+  let cppNamespace = "::test";
+}
+
 // CHECK: Autogenerated by mlir-tblgen; don't manually edit.
 
-// CHECK: from enum import Enum
+// CHECK: from enum import IntEnum, auto, IntFlag
 // CHECK: from ._ods_common import _cext as _ods_cext
+// CHECK: from ..ir import register_attribute_builder
 // CHECK: _ods_ir = _ods_cext.ir
 
 def One : I32EnumAttrCase<"CaseOne", 1, "one">;
 def Two : I32EnumAttrCase<"CaseTwo", 2, "two">;
+def NegOne : I32EnumAttrCase<"CaseNegOne", -1, "negone">;
 
-def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two]>;
-// CHECK: def _register_attribute_builder(kind):
-// CHECK:     def decorator_builder(func):
-// CHECK:         _ods_ir.AttrBuilder.insert(kind, func)
-// CHECK:         return func
-// CHECK:     return decorator_builder
-
-// CHECK-LABEL: class MyEnum(Enum):
+def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two, NegOne]>;
+// CHECK-LABEL: class MyEnum(IntEnum):
 // CHECK:     """An example 32-bit enum"""
 
-// CHECK:     CASE_ONE = 1
-// CHECK:     CASE_TWO = 2
+// CHECK:     CaseOne = 1
+// CHECK:     CaseTwo = 2
+// CHECK:     CaseNegOne = auto()
+
+// CHECK:     def __str__(self):
+// CHECK:         if self is MyEnum.CaseOne:
+// CHECK:             return "one"
+// CHECK:         if self is MyEnum.CaseTwo:
+// CHECK:             return "two"
+// CHECK:         if self is MyEnum.CaseNegOne:
+// CHECK:             return "negone"
+// CHECK:         raise ValueError("Unknown MyEnum enum entry.")
+
+// CHECK: @register_attribute_builder("MyEnum")
+// CHECK: def _myenum(x, context):
+// CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
 
-// CHECK:     def _as_int(self):
-// CHECK:         if self is MyEnum.CASE_ONE:
-// CHECK:             return 1
-// CHECK:         if self is MyEnum.CASE_TWO:
-// CHECK:             return 2
-// CHECK:         assert False, "Unknown MyEnum enum entry."
+def TestMyEnum_Attr : EnumAttr<Test_Dialect, MyEnum, "enum">;
 
 def One64 : I64EnumAttrCase<"CaseOne64", 1, "one">;
 def Two64 : I64EnumAttrCase<"CaseTwo64", 2, "two">;
 
 def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>;
-// CHECK: @_register_attribute_builder("MyEnum")
-// CHECK: def _my_enum(x, context):
-// CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), x._as_int())
-
-// CHECK-LABEL: class MyEnum64(Enum):
+// CHECK-LABEL: class MyEnum64(IntEnum):
 // CHECK:     """An example 64-bit enum"""
 
-// CHECK:     CASE_ONE64 = 1
-// CHECK:     CASE_TWO64 = 2
+// CHECK:     CaseOne64 = 1
+// CHECK:     CaseTwo64 = 2
+
+// CHECK:     def __str__(self):
+// CHECK:         if self is MyEnum64.CaseOne64:
+// CHECK:             return "one"
+// CHECK:         if self is MyEnum64.CaseTwo64:
+// CHECK:             return "two"
+// CHECK:         raise ValueError("Unknown MyEnum64 enum entry.")
+
+// CHECK: @register_attribute_builder("MyEnum64")
+// CHECK: def _myenum64(x, context):
+// CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
+
+def TestBitEnum
+    : I32BitEnumAttr<"TestBitEnum", "", [
+        I32BitEnumAttrCaseBit<"User", 0, "user">,
+        I32BitEnumAttrCaseBit<"Group", 1, "group">,
+        I32BitEnumAttrCaseBit<"Other", 2, "other">,
+      ]> {
+  let genSpecializedAttr = 0;
+  let separator = " | ";
+}
+
+def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
+
+// CHECK-LABEL: class TestBitEnum(IntFlag):
+
+// CHECK:     User = 1
+// CHECK:     Group = 2
+// CHECK:     Other = 4
+
+// CHECK:     def __iter__(self):
+// CHECK:         return iter([case for case in type(self) if (self & case) is case])
+// CHECK:     def __len__(self):
+// CHECK:         return bin(self).count("1")
+
+// CHECK:     def __str__(self):
+// CHECK:         if len(self) > 1:
+// CHECK:             return " | ".join(map(str, self))
+// CHECK:         if self is TestBitEnum.User:
+// CHECK:             return "user"
+// CHECK:         if self is TestBitEnum.Group:
+// CHECK:             return "group"
+// CHECK:         if self is TestBitEnum.Other:
+// CHECK:             return "other"
+// CHECK:         raise ValueError("Unknown TestBitEnum enum entry.")
+
+// CHECK: @register_attribute_builder("TestBitEnum")
+// CHECK: def _testbitenum(x, context):
+// CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
 
-// CHECK:     def _as_int(self):
-// CHECK:         if self is MyEnum64.CASE_ONE64:
-// CHECK:             return 1
-// CHECK:         if self is MyEnum64.CASE_TWO64:
-// CHECK:             return 2
-// CHECK:         assert False, "Unknown MyEnum64 enum entry."
+// CHECK: @register_attribute_builder("TestBitEnum_Attr")
+// CHECK: def _testbitenum_attr(x, context):
+// CHECK:     return _ods_ir.Attribute.parse(f'#TestDialect<testbitenum {str(x)}>', context=context)
 
-// CHECK: @_register_attribute_builder("MyEnum64")
-// CHECK: def _my_enum64(x, context):
-// CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), x._as_int())
+// CHECK: @register_attribute_builder("TestMyEnum_Attr")
+// CHECK: def _testmyenum_attr(x, context):
+// CHECK:     return _ods_ir.Attribute.parse(f'#TestDialect<enum {str(x)}>', context=context)

diff  --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 8e9613d0524663..f4a793aee4aa14 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -19,3 +19,17 @@ def testConstantOps():
             arith.ConstantOp(value=42.42, result=F32Type.get())
         # CHECK:         %cst = arith.constant 4.242000e+01 : f32
         print(module)
+
+
+# CHECK-LABEL: TEST: testFastMathFlags
+ at run
+def testFastMathFlags():
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            a = arith.ConstantOp(value=42.42, result=F32Type.get())
+            r = arith.AddFOp(
+                a, a, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf
+            )
+            # CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32
+            print(r)

diff  --git a/mlir/test/python/dialects/gpu.py b/mlir/test/python/dialects/gpu.py
index 7eefaed711c2c2..0293e8f276be6b 100644
--- a/mlir/test/python/dialects/gpu.py
+++ b/mlir/test/python/dialects/gpu.py
@@ -1,22 +1,32 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 from mlir.ir import *
-import mlir.dialects.gpu
+import mlir.dialects.gpu as gpu
 import mlir.dialects.gpu.passes
 from mlir.passmanager import *
 
 
 def run(f):
     print("\nTEST:", f.__name__)
-    f()
+    with Context(), Location.unknown():
+        f()
+    return f
 
 
+# CHECK-LABEL: testGPUPass
+#       CHECK: SUCCESS
+ at run
 def testGPUPass():
-    with Context() as context:
-        PassManager.parse("any(gpu-kernel-outlining)")
+    PassManager.parse("any(gpu-kernel-outlining)")
     print("SUCCESS")
 
 
-# CHECK-LABEL: testGPUPass
-#       CHECK: SUCCESS
-run(testGPUPass)
+# CHECK-LABEL: testMMAElementWiseAttr
+ at run
+def testMMAElementWiseAttr():
+    module = Module.create()
+    with InsertionPoint(module.body):
+        gpu.BlockDimOp(gpu.Dimension.y)
+    # CHECK: %0 = gpu.block_dim  y
+    print(module)
+    pass

diff  --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py
new file mode 100644
index 00000000000000..2d207ae14eecd2
--- /dev/null
+++ b/mlir/test/python/dialects/llvm.py
@@ -0,0 +1,25 @@
+# RUN: %PYTHON %s | FileCheck %s
+# This is just a smoke test that the dialect is functional.
+
+from mlir.ir import *
+from mlir.dialects import llvm
+
+
+def constructAndPrintInModule(f):
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f()
+        print(module)
+    return f
+
+
+# CHECK-LABEL: testSmoke
+ at constructAndPrintInModule
+def testSmoke():
+    mat64f32_t = Type.parse(
+        "!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>"
+    )
+    result = llvm.UndefOp(mat64f32_t)
+    # CHECK: %0 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>

diff  --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py
index 7d68a151345ca2..36aaaea79b1866 100644
--- a/mlir/test/python/dialects/nvvm.py
+++ b/mlir/test/python/dialects/nvvm.py
@@ -3,6 +3,8 @@
 
 from mlir.ir import *
 from mlir.dialects import nvvm
+from mlir.dialects import llvm
+from mlir.dialects import func
 
 
 def constructAndPrintInModule(f):
@@ -18,5 +20,30 @@ def constructAndPrintInModule(f):
 # CHECK-LABEL: testSmoke
 @constructAndPrintInModule
 def testSmoke():
-    # CHECK: nvvm.cp.async.wait.group 5
-    nvvm.CpAsyncWaitGroupOp(5)
+    i64 = IntegerType.get_signless(64)
+    mat64f32_t = Type.parse(
+        "!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>"
+    )
+    shape_attr = Attribute.parse("#nvvm.shape<m = 64, n = 32, k = 16>")
+    # CHECK-LABEL: func @wgmma_f32_f16_f16(%arg0: i64, %arg1: i64)
+    @func.FuncOp.from_py_func(i64, i64)
+    def wgmma_f32_f16_f16(desc_a, desc_b):
+        # CHECK: nvvm.cp.async.wait.group 5
+        nvvm.CpAsyncWaitGroupOp(5)
+        # CHECK: %0 = llvm.mlir.undef : [[MAT_T:.*]]
+        result = llvm.UndefOp(mat64f32_t)
+        # CHECK: %1 = nvvm.wgmma.mma_async %arg0, %arg1, <m = 64, n = 32, k = 16>, D[%0, <zero>], A[<f16>, <neg>, <col>], B[<f16>, <neg>, <col>] : [[MAT_T]] -> [[MAT_T]]
+        result1 = nvvm.WgmmaMmaAsyncOp(
+            results_=mat64f32_t,
+            inouts=result,
+            descriptorA=desc_a,
+            descriptorB=desc_b,
+            shape=shape_attr,
+            typeA=nvvm.WGMMATypes.f16,
+            typeB=nvvm.WGMMATypes.f16,
+            scaleD=nvvm.WGMMAScaleOut.zero,
+            scaleA=nvvm.WGMMAScaleIn.neg,
+            scaleB=nvvm.WGMMAScaleIn.neg,
+            layoutA=nvvm.MMALayout.col,
+            layoutB=nvvm.MMALayout.col,
+        )

diff  --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 9881a929c8dc56..5df125694256a4 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -36,7 +36,7 @@ def testTypes():
 @run
 def testSequenceOp():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE,
+        transform.FailurePropagationMode.Propagate,
         [transform.AnyOpType.get()],
         transform.AnyOpType.get(),
     )
@@ -52,15 +52,15 @@ def testSequenceOp():
 @run
 def testNestedSequenceOp():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         nested = transform.SequenceOp(
-            transform.FailurePropagationMode.PROPAGATE, [], sequence.bodyTarget
+            transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
         )
         with InsertionPoint(nested.body):
             doubly_nested = transform.SequenceOp(
-                transform.FailurePropagationMode.PROPAGATE,
+                transform.FailurePropagationMode.Propagate,
                 [transform.AnyOpType.get()],
                 nested.bodyTarget,
             )
@@ -84,7 +84,7 @@ def testNestedSequenceOp():
 @run
 def testSequenceOpWithExtras():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE,
+        transform.FailurePropagationMode.Propagate,
         [],
         transform.AnyOpType.get(),
         [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
@@ -99,14 +99,14 @@ def testSequenceOpWithExtras():
 @run
 def testNestedSequenceOpWithExtras():
   sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE,
+        transform.FailurePropagationMode.Propagate,
         [],
         transform.AnyOpType.get(),
         [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
     )
   with InsertionPoint(sequence.body):
     nested = transform.SequenceOp(
-            transform.FailurePropagationMode.PROPAGATE,
+            transform.FailurePropagationMode.Propagate,
             [],
             sequence.bodyTarget,
             sequence.bodyExtraArgs,
@@ -125,7 +125,7 @@ def testTransformPDLOps():
   withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
   with InsertionPoint(withPdl.body):
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE,
+        transform.FailurePropagationMode.Propagate,
         [transform.AnyOpType.get()],
         withPdl.bodyTarget,
     )
@@ -148,7 +148,7 @@ def testTransformPDLOps():
 @run
 def testGetParentOp():
   sequence = transform.SequenceOp(
-      transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+      transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
   )
   with InsertionPoint(sequence.body):
     transform.GetParentOp(
@@ -164,7 +164,7 @@ def testGetParentOp():
 @run
 def testMergeHandlesOp():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         transform.MergeHandlesOp([sequence.bodyTarget])
@@ -178,7 +178,7 @@ def testMergeHandlesOp():
 @run
 def testApplyPatternsOpCompact():
   sequence = transform.SequenceOp(
-      transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+      transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
   )
   with InsertionPoint(sequence.body):
     with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
@@ -193,7 +193,7 @@ def testApplyPatternsOpCompact():
 @run
 def testApplyPatternsOpWithType():
   sequence = transform.SequenceOp(
-      transform.FailurePropagationMode.PROPAGATE, [],
+      transform.FailurePropagationMode.Propagate, [],
       transform.OperationType.get('test.dummy')
   )
   with InsertionPoint(sequence.body):
@@ -211,7 +211,7 @@ def testReplicateOp():
     with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
     with InsertionPoint(with_pdl.body):
         sequence = transform.SequenceOp(
-            transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget
+            transform.FailurePropagationMode.Propagate, [], with_pdl.bodyTarget
         )
         with InsertionPoint(sequence.body):
             m1 = transform_pdl.PDLMatchOp(

diff  --git a/mlir/test/python/dialects/transform_bufferization_ext.py b/mlir/test/python/dialects/transform_bufferization_ext.py
index fad256513cbd1d..733bd3a2cab6fe 100644
--- a/mlir/test/python/dialects/transform_bufferization_ext.py
+++ b/mlir/test/python/dialects/transform_bufferization_ext.py
@@ -3,6 +3,7 @@
 from mlir.ir import *
 from mlir.dialects import transform
 from mlir.dialects.transform import bufferization
+from mlir.dialects.bufferization import LayoutMapOption
 
 
 def run(f):
@@ -18,7 +19,7 @@ def run(f):
 @run
 def testEmptyTensorToAllocTensorOpCompact():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE,
+        transform.FailurePropagationMode.Propagate,
         [],
         transform.OperationType.get("tensor.empty"),
     )
@@ -33,7 +34,7 @@ def testEmptyTensorToAllocTensorOpCompact():
 @run
 def testEmptyTensorToAllocTensorOpTyped():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE,
+        transform.FailurePropagationMode.Propagate,
         [],
         transform.OperationType.get("tensor.empty"),
     )
@@ -51,7 +52,7 @@ def testEmptyTensorToAllocTensorOpTyped():
 @run
 def testOneShotBufferizeOpCompact():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         bufferization.OneShotBufferizeOp(sequence.bodyTarget)
@@ -64,7 +65,7 @@ def testOneShotBufferizeOpCompact():
 @run
 def testOneShotBufferizeOpTyped():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         bufferization.OneShotBufferizeOp(
@@ -80,7 +81,7 @@ def testOneShotBufferizeOpTyped():
 @run
 def testOneShotBufferizeOpAttributes():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         bufferization.OneShotBufferizeOp(
@@ -89,7 +90,7 @@ def testOneShotBufferizeOpAttributes():
             allow_unknown_ops=True,
             bufferize_function_boundaries=True,
             create_deallocs=False,
-            function_boundary_type_conversion=bufferization.LayoutMapOption.IDENTITY_LAYOUT_MAP,
+            function_boundary_type_conversion=LayoutMapOption.IdentityLayoutMap,
             memcpy_op="linalg.copy",
             print_conflicts=True,
             test_analysis_only=True,

diff  --git a/mlir/test/python/dialects/transform_gpu_ext.py b/mlir/test/python/dialects/transform_gpu_ext.py
index 630a224fe6fea1..db2899592609c5 100644
--- a/mlir/test/python/dialects/transform_gpu_ext.py
+++ b/mlir/test/python/dialects/transform_gpu_ext.py
@@ -10,7 +10,7 @@ def run(f):
         module = Module.create()
         with InsertionPoint(module.body):
             sequence = transform.SequenceOp(
-                transform.FailurePropagationMode.PROPAGATE,
+                transform.FailurePropagationMode.Propagate,
                 [],
                 transform.AnyOpType.get(),
             )

diff  --git a/mlir/test/python/dialects/transform_loop_ext.py b/mlir/test/python/dialects/transform_loop_ext.py
index 28a022a400fe61..95339b2363e67c 100644
--- a/mlir/test/python/dialects/transform_loop_ext.py
+++ b/mlir/test/python/dialects/transform_loop_ext.py
@@ -19,7 +19,7 @@ def run(f):
 @run
 def getParentLoop():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         loop.GetParentForOp(
@@ -34,7 +34,7 @@ def getParentLoop():
 @run
 def loopOutline():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE,
+        transform.FailurePropagationMode.Propagate,
         [],
         transform.OperationType.get("scf.for"),
     )
@@ -54,7 +54,7 @@ def loopOutline():
 @run
 def loopPeel():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE,
+        transform.FailurePropagationMode.Propagate,
         [],
         transform.OperationType.get("scf.for"),
     )
@@ -68,7 +68,7 @@ def loopPeel():
 @run
 def loopPipeline():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE,
+        transform.FailurePropagationMode.Propagate,
         [],
         transform.OperationType.get("scf.for"),
     )
@@ -86,7 +86,7 @@ def loopPipeline():
 @run
 def loopUnroll():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE,
+        transform.FailurePropagationMode.Propagate,
         [],
         transform.OperationType.get("scf.for"),
     )

diff  --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py
index f130fbd829a997..f89005cb2f86d1 100644
--- a/mlir/test/python/dialects/transform_memref_ext.py
+++ b/mlir/test/python/dialects/transform_memref_ext.py
@@ -19,7 +19,7 @@ def run(f):
 @run
 def testMemRefMultiBufferOpCompact():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE,
+        transform.FailurePropagationMode.Propagate,
         [],
         transform.OperationType.get("memref.alloc"),
     )
@@ -35,7 +35,7 @@ def testMemRefMultiBufferOpCompact():
 @run
 def testMemRefMultiBufferOpTyped():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE,
+        transform.FailurePropagationMode.Propagate,
         [],
         transform.OperationType.get("memref.alloc"),
     )
@@ -53,7 +53,7 @@ def testMemRefMultiBufferOpTyped():
 @run
 def testMemRefMultiBufferOpAttributes():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE,
+        transform.FailurePropagationMode.Propagate,
         [],
         transform.OperationType.get("memref.alloc"),
     )

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 2e3198b03d1d74..8cb16e7f3ebde1 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -21,7 +21,7 @@ def run(f):
 @run
 def testBufferizeToAllocationOpCompact():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         structured.BufferizeToAllocationOp(sequence.bodyTarget)
@@ -34,7 +34,7 @@ def testBufferizeToAllocationOpCompact():
 @run
 def testBufferizeToAllocationOpArgs():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         structured.BufferizeToAllocationOp(
@@ -57,7 +57,7 @@ def testBufferizeToAllocationOpArgs():
 @run
 def testDecompose():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         structured.DecomposeOp(sequence.bodyTarget)
@@ -70,7 +70,7 @@ def testDecompose():
 @run
 def testFuseIntoContainingOpTypes():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
@@ -92,7 +92,7 @@ def testFuseIntoContainingOpTypes():
 @run
 def testFuseIntoContainingOpCompact():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
@@ -109,7 +109,7 @@ def testFuseIntoContainingOpCompact():
 @run
 def testGeneralize():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         structured.GeneralizeOp(sequence.bodyTarget)
@@ -122,7 +122,7 @@ def testGeneralize():
 @run
 def testInterchange():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         structured.InterchangeOp(sequence.bodyTarget, iterator_interchange=[1, 0])
@@ -136,7 +136,7 @@ def testInterchange():
 @run
 def testMapCopyToThreadsOpCompact():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         structured.MapCopyToThreadsOp(
@@ -153,7 +153,7 @@ def testMapCopyToThreadsOpCompact():
 @run
 def testMapCopyToThreadsOpTypes():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         structured.MapCopyToThreadsOp(
@@ -174,7 +174,7 @@ def testMapCopyToThreadsOpTypes():
 @run
 def testMatchOpNamesString():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         structured.MatchOp.match_op_names(sequence.bodyTarget, "test.dummy")
@@ -188,7 +188,7 @@ def testMatchOpNamesString():
 @run
 def testMatchOpNamesList():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
@@ -202,7 +202,7 @@ def testMatchOpNamesList():
 @run
 def testMaskedVectorizeStatic():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         structured.MaskedVectorizeOp(sequence.bodyTarget, [16, 4])
@@ -216,7 +216,7 @@ def testMaskedVectorizeStatic():
 @run
 def testMaskedVectorizeArray():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         sizes = Attribute.parse("[16, 4]")
@@ -231,7 +231,7 @@ def testMaskedVectorizeArray():
 @run
 def testMaskedVectorizeMixed():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
@@ -248,7 +248,7 @@ def testMaskedVectorizeMixed():
 @run
 def testMaskedVectorizeScalable():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
@@ -265,7 +265,7 @@ def testMaskedVectorizeScalable():
 @run
 def testMaskedVectorizeArgs():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         structured.MaskedVectorizeOp(
@@ -281,7 +281,7 @@ def testMaskedVectorizeArgs():
 @run
 def testMatchOpNamesTyped():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         structured.MatchOp.match_op_names(
@@ -299,7 +299,7 @@ def testMatchOpNamesTyped():
 @run
 def testMultitileSizes():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         structured.MultiTileSizesOp(
@@ -316,7 +316,7 @@ def testMultitileSizes():
 @run
 def testPad():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         structured.PadOp(
@@ -343,7 +343,7 @@ def testPad():
 @run
 def testScalarize():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         structured.ScalarizeOp(sequence.bodyTarget)
@@ -355,7 +355,7 @@ def testScalarize():
 @run
 def testSplit():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
@@ -369,7 +369,7 @@ def testSplit():
 @run
 def testTileCompact():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
@@ -383,7 +383,7 @@ def testTileCompact():
 @run
 def testTileAttributes():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     attr = DenseI64ArrayAttr.get([4, 8])
     ichange = DenseI64ArrayAttr.get([0, 1])
@@ -399,7 +399,7 @@ def testTileAttributes():
 @run
 def testTileZero():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         structured.TileOp(
@@ -417,7 +417,7 @@ def testTileDynamic():
     with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get())
     with InsertionPoint(with_pdl.body):
         sequence = transform.SequenceOp(
-            transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget
+            transform.FailurePropagationMode.Propagate, [], with_pdl.bodyTarget
         )
         with InsertionPoint(sequence.body):
             m1 = transform_pdl.PDLMatchOp(
@@ -437,7 +437,7 @@ def testTileDynamic():
 @run
 def testTileExplicitLoopTypeSingle():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         structured.TileOp(
@@ -452,7 +452,7 @@ def testTileExplicitLoopTypeSingle():
 @run
 def testTileExplicitLoopTypeAll():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     types = [
         transform.OperationType.get(x)
@@ -470,7 +470,7 @@ def testTileExplicitLoopTypeAll():
 @run
 def testTileToForallCompact():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE,
+        transform.FailurePropagationMode.Propagate,
         [],
         transform.OperationType.get("linalg.matmul"),
     )
@@ -486,7 +486,7 @@ def testTileToForallCompact():
 @run
 def testTileToForallLoopsAndTileOpTypes():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         structured.TileToForallOp(
@@ -505,7 +505,7 @@ def testTileToForallLoopsAndTileOpTypes():
 @run
 def testTileToForallTileSizes():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         structured.TileToForallOp(sequence.bodyTarget, tile_sizes=[2, 3, 4])
@@ -518,7 +518,7 @@ def testTileToForallTileSizes():
 @run
 def testTileToForallMixedDynamic():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
@@ -532,7 +532,7 @@ def testTileToForallMixedDynamic():
 @run
 def testTileToForallPackedDynamic():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
@@ -546,7 +546,7 @@ def testTileToForallPackedDynamic():
 @run
 def testTileToForallMapping():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
         mapping = Attribute.parse("[ #gpu.thread<y>, #gpu.thread<x> ]")
@@ -562,7 +562,7 @@ def testTileToForallMapping():
 @run
 def testVectorize():
     sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()
+        transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
     )
     with InsertionPoint(sequence.body):
         structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True)
@@ -571,3 +571,53 @@ def testVectorize():
     # CHECK: transform.sequence
     # CHECK: = transform.structured.vectorize
     # CHECK: {vectorize_padding}
+
+
+ at run
+def testMatchInterfaceEnum():
+    names = ArrayAttr.get([StringAttr.get("test.dummy")])
+    result_type = transform.AnyOpType.get()
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        fused = structured.MatchOp.__base__(
+            result_type,
+            sequence.bodyTarget,
+            ops=names,
+            interface=structured.MatchInterfaceEnum.LinalgOp,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMatchInterfaceEnum
+    # CHECK: transform.sequence
+    # CHECK: = transform.structured.match
+    # CHECK: interface{LinalgOp}
+
+
+ at run
+def testMatchInterfaceEnumReplaceAttributeBuilder():
+    @register_attribute_builder("MatchInterfaceEnum", replace=True)
+    def match_interface_enum(x, context):
+        if x == "LinalgOp":
+            y = 0
+        elif x == "TilingInterface":
+            y = 1
+        return IntegerAttr.get(IntegerType.get_signless(32, context=context), y)
+
+    names = ArrayAttr.get([StringAttr.get("test.dummy")])
+    result_type = transform.AnyOpType.get()
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        fused = structured.MatchOp.__base__(
+            result_type,
+            sequence.bodyTarget,
+            ops=names,
+            interface="TilingInterface",
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMatchInterfaceEnumReplaceAttributeBuilder
+    # CHECK: transform.sequence
+    # CHECK: = transform.structured.match
+    # CHECK: interface{TilingInterface}

diff  --git a/mlir/test/python/dialects/transform_tensor_ext.py b/mlir/test/python/dialects/transform_tensor_ext.py
index 601d551ede5e58..a2e7aa242b9da6 100644
--- a/mlir/test/python/dialects/transform_tensor_ext.py
+++ b/mlir/test/python/dialects/transform_tensor_ext.py
@@ -11,7 +11,7 @@ def run(f):
         module = Module.create()
         with InsertionPoint(module.body):
             sequence = transform.SequenceOp(
-                transform.FailurePropagationMode.PROPAGATE,
+                transform.FailurePropagationMode.Propagate,
                 [],
                 transform.AnyOpType.get(),
             )

diff  --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index a0808e549454bb..1a0a9e1d6ecbde 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -10,7 +10,7 @@ def run_apply_patterns(f):
         module = Module.create()
         with InsertionPoint(module.body):
             sequence = transform.SequenceOp(
-                transform.FailurePropagationMode.PROPAGATE,
+                transform.FailurePropagationMode.Propagate,
                 [],
                 transform.AnyOpType.get(),
             )
@@ -72,12 +72,12 @@ def enum_configurable_patterns():
     # CHECK: transform.apply_patterns.vector.lower_contraction
     # CHECK-SAME: lowering_strategy = matmulintrinsics
     vector.ApplyLowerContractionPatternsOp(
-        lowering_strategy=vector.VectorContractLowering.MATMUL
+        lowering_strategy=vector.VectorContractLowering.Matmul
     )
     # CHECK: transform.apply_patterns.vector.lower_contraction
     # CHECK-SAME: lowering_strategy = parallelarith
     vector.ApplyLowerContractionPatternsOp(
-        lowering_strategy=vector.VectorContractLowering.PARALLEL_ARITH
+        lowering_strategy=vector.VectorContractLowering.ParallelArith
     )
 
     # CHECK: transform.apply_patterns.vector.lower_multi_reduction
@@ -85,12 +85,12 @@ def enum_configurable_patterns():
     # CHECK: transform.apply_patterns.vector.lower_multi_reduction
     # This is the default mode, not printed.
     vector.ApplyLowerMultiReductionPatternsOp(
-        lowering_strategy=vector.VectorMultiReductionLowering.INNER_PARALLEL
+        lowering_strategy=vector.VectorMultiReductionLowering.InnerParallel
     )
     # CHECK: transform.apply_patterns.vector.lower_multi_reduction
     # CHECK-SAME: lowering_strategy = innerreduction
     vector.ApplyLowerMultiReductionPatternsOp(
-        lowering_strategy=vector.VectorMultiReductionLowering.INNER_REDUCTION
+        lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
     )
 
     # CHECK: transform.apply_patterns.vector.lower_transpose
@@ -101,31 +101,31 @@ def enum_configurable_patterns():
     # CHECK-SAME: lowering_strategy = eltwise
     # CHECK-SAME: avx2_lowering_strategy = false
     vector.ApplyLowerTransposePatternsOp(
-        lowering_strategy=vector.VectorTransposeLowering.ELT_WISE
+        lowering_strategy=vector.VectorTransposeLowering.EltWise
     )
     # CHECK: transform.apply_patterns.vector.lower_transpose
     # CHECK-SAME: lowering_strategy = flat_transpose
     # CHECK-SAME: avx2_lowering_strategy = false
     vector.ApplyLowerTransposePatternsOp(
-        lowering_strategy=vector.VectorTransposeLowering.FLAT
+        lowering_strategy=vector.VectorTransposeLowering.Flat
     )
     # CHECK: transform.apply_patterns.vector.lower_transpose
     # CHECK-SAME: lowering_strategy = shuffle_1d
     # CHECK-SAME: avx2_lowering_strategy = false
     vector.ApplyLowerTransposePatternsOp(
-        lowering_strategy=vector.VectorTransposeLowering.SHUFFLE1_D
+        lowering_strategy=vector.VectorTransposeLowering.Shuffle1D
     )
     # CHECK: transform.apply_patterns.vector.lower_transpose
     # CHECK-SAME: lowering_strategy = shuffle_16x16
     # CHECK-SAME: avx2_lowering_strategy = false
     vector.ApplyLowerTransposePatternsOp(
-        lowering_strategy=vector.VectorTransposeLowering.SHUFFLE16X16
+        lowering_strategy=vector.VectorTransposeLowering.Shuffle16x16
     )
     # CHECK: transform.apply_patterns.vector.lower_transpose
     # CHECK-SAME: lowering_strategy = flat_transpose
     # CHECK-SAME: avx2_lowering_strategy = true
     vector.ApplyLowerTransposePatternsOp(
-        lowering_strategy=vector.VectorTransposeLowering.FLAT,
+        lowering_strategy=vector.VectorTransposeLowering.Flat,
         avx2_lowering_strategy=True,
     )
 
@@ -134,20 +134,20 @@ def enum_configurable_patterns():
     # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
     # CHECK-SAME: split_transfer_strategy = none
     vector.ApplySplitTransferFullPartialPatternsOp(
-        split_transfer_strategy=vector.VectorTransferSplit.NONE
+        split_transfer_strategy=vector.VectorTransferSplit.None_
     )
     # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
     # CHECK-SAME: split_transfer_strategy = "vector-transfer"
     vector.ApplySplitTransferFullPartialPatternsOp(
-        split_transfer_strategy=vector.VectorTransferSplit.VECTOR_TRANSFER
+        split_transfer_strategy=vector.VectorTransferSplit.VectorTransfer
     )
     # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
     # This is the default mode, not printed.
     vector.ApplySplitTransferFullPartialPatternsOp(
-        split_transfer_strategy=vector.VectorTransferSplit.LINALG_COPY
+        split_transfer_strategy=vector.VectorTransferSplit.LinalgCopy
     )
     # CHECK: transform.apply_patterns.vector.split_transfer_full_partial
     # CHECK-SAME: split_transfer_strategy = "force-in-bounds"
     vector.ApplySplitTransferFullPartialPatternsOp(
-        split_transfer_strategy=vector.VectorTransferSplit.FORCE_IN_BOUNDS
+        split_transfer_strategy=vector.VectorTransferSplit.ForceInBounds
     )

diff  --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py
index 36896cd4dc98d8..dafb2bfde8982d 100644
--- a/mlir/test/python/dialects/vector.py
+++ b/mlir/test/python/dialects/vector.py
@@ -64,3 +64,21 @@ def testTransferReadOp():
     # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]]
     # CHECK-NOT: %[[MASK]]
     print(module)
+
+
+# CHECK-LABEL: TEST: testBitEnumCombiningKind
+ at run
+def testBitEnumCombiningKind():
+    module = Module.create()
+    with InsertionPoint(module.body):
+        f32 = F32Type.get()
+        vector_type = VectorType.get([16], f32)
+
+        @func.FuncOp.from_py_func(vector_type)
+        def reduction(arg):
+            v = vector.ReductionOp(f32, vector.CombiningKind.ADD, arg)
+            return v
+
+    # CHECK: func.func @reduction(%[[VEC:.*]]: vector<16xf32>) -> f32 {
+    # CHECK: %0 = vector.reduction <add>, %[[VEC]] : vector<16xf32> into f32
+    print(module)

diff  --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index 9748e33e2ebe8a..f4ced0803772ed 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -10,10 +10,12 @@
 // generate the corresponding Python binding classes.
 //
 //===----------------------------------------------------------------------===//
+#include "OpGenHelpers.h"
 
+#include "mlir/TableGen/AttrOrTypeDef.h"
 #include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Dialect.h"
 #include "mlir/TableGen/GenInfo.h"
-#include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/TableGen/Record.h"
 
@@ -24,48 +26,61 @@ using namespace mlir::tblgen;
 constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
 
-from enum import Enum
+from enum import IntEnum, auto, IntFlag
 from ._ods_common import _cext as _ods_cext
+from ..ir import register_attribute_builder
 _ods_ir = _ods_cext.ir
 
-# Convenience decorator for registering user-friendly Attribute builders.
-def _register_attribute_builder(kind):
-    def decorator_builder(func):
-        _ods_ir.AttrBuilder.insert(kind, func)
-        return func
-
-    return decorator_builder
-
 )Py";
 
 /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
 static std::string makePythonEnumCaseName(StringRef name) {
-  return StringRef(llvm::convertToSnakeFromCamelCase(name)).upper();
+  if (isPythonReserved(name.str()))
+    return (name + "_").str();
+  return name.str();
 }
 
 /// Emits the Python class for the given enum.
-static void emitEnumClass(StringRef enumName, StringRef description,
-                          ArrayRef<EnumAttrCase> cases, raw_ostream &os) {
-  os << llvm::formatv("class {0}(Enum):\n", enumName);
-  if (!description.empty())
-    os << llvm::formatv("    \"\"\"{0}\"\"\"\n", description);
+static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
+  os << llvm::formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
+                      enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
+  if (!enumAttr.getSummary().empty())
+    os << llvm::formatv("    \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
   os << "\n";
 
-  for (const EnumAttrCase &enumCase : cases) {
-    os << llvm::formatv("    {0} = {1}\n",
-                        makePythonEnumCaseName(enumCase.getSymbol()),
-                        enumCase.getValue());
+  for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
+    os << llvm::formatv(
+        "    {0} = {1}\n", makePythonEnumCaseName(enumCase.getSymbol()),
+        enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
+                                 : "auto()");
   }
 
   os << "\n";
-  os << llvm::formatv("    def _as_int(self):\n");
-  for (const EnumAttrCase &enumCase : cases) {
-    os << llvm::formatv("        if self is {0}.{1}:\n", enumName,
+
+  if (enumAttr.isBitEnum()) {
+    os << llvm::formatv("    def __iter__(self):\n"
+                        "        return iter([case for case in type(self) if "
+                        "(self & case) is case])\n");
+    os << llvm::formatv("    def __len__(self):\n"
+                        "        return bin(self).count(\"1\")\n");
+    os << "\n";
+  }
+
+  os << llvm::formatv("    def __str__(self):\n");
+  if (enumAttr.isBitEnum())
+    os << llvm::formatv("        if len(self) > 1:\n"
+                        "            return \"{0}\".join(map(str, self))\n",
+                        enumAttr.getDef().getValueAsString("separator"));
+  for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
+    os << llvm::formatv("        if self is {0}.{1}:\n",
+                        enumAttr.getEnumClassName(),
                         makePythonEnumCaseName(enumCase.getSymbol()));
-    os << llvm::formatv("            return {0}\n", enumCase.getValue());
+    os << llvm::formatv("            return \"{0}\"\n", enumCase.getStr());
   }
-  os << llvm::formatv("        assert False, \"Unknown {0} enum entry.\"\n\n\n",
-                      enumName);
+  os << llvm::formatv(
+      "        raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
+      enumAttr.getEnumClassName());
+  os << "\n";
 }
 
 /// Attempts to extract the bitwidth B from string "uintB_t" describing the
@@ -90,36 +105,68 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
     return true;
   }
 
-  os << llvm::formatv("@_register_attribute_builder(\"{0}\")\n",
+  os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
                       enumAttr.getAttrDefName());
-  os << llvm::formatv(
-      "def _{0}(x, context):\n",
-      llvm::convertToSnakeFromCamelCase(enumAttr.getAttrDefName()));
+  os << llvm::formatv("def _{0}(x, context):\n",
+                      enumAttr.getAttrDefName().lower());
   os << llvm::formatv(
       "    return "
       "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
-      "context=context), x._as_int())\n\n",
+      "context=context), int(x))\n\n",
       bitwidth);
   return false;
 }
 
+/// Emits an attribute builder for the given dialect enum attribute to support
+/// automatic conversion between enum values and attributes in Python. Returns
+/// `false` on success, `true` on failure.
+static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
+                                            StringRef formatString,
+                                            raw_ostream &os) {
+  os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
+  os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
+  os << llvm::formatv("    return "
+                      "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
+                      formatString);
+  return false;
+}
+
 /// Emits Python bindings for all enums in the record keeper. Returns
 /// `false` on success, `true` on failure.
 static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
                             raw_ostream &os) {
   os << fileHeader;
-  std::vector<llvm::Record *> defs =
-      recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
-  for (const llvm::Record *def : defs) {
-    EnumAttr enumAttr(*def);
-    if (enumAttr.isBitEnum()) {
-      llvm::errs() << "bit enums not supported\n";
+  for (auto &it :
+       recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
+    EnumAttr enumAttr(*it);
+    emitEnumClass(enumAttr, os);
+    emitAttributeBuilder(enumAttr, os);
+  }
+  for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
+    AttrOrTypeDef attr(&*it);
+    if (!attr.getMnemonic()) {
+      llvm::errs() << "enum case " << attr
+                   << " needs mnemonic for python enum bindings generation";
+      return true;
+    }
+    StringRef mnemonic = attr.getMnemonic().value();
+    std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
+    StringRef dialect = attr.getDialect().getName();
+    if (assemblyFormat == "`<` $value `>`") {
+      emitDialectEnumAttributeBuilder(
+          attr.getName(),
+          llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
+    } else if (assemblyFormat == "$value") {
+      emitDialectEnumAttributeBuilder(
+          attr.getName(),
+          llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
+    } else {
+      llvm::errs()
+          << "unsupported assembly format for python enum bindings generation";
       return true;
     }
-    emitEnumClass(enumAttr.getEnumClassName(), enumAttr.getSummary(),
-                  enumAttr.getAllCases(), os);
-    emitAttributeBuilder(enumAttr, os);
   }
+
   return false;
 }
 

diff  --git a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
index b08f3fb7768c9c..7fd34df8460d39 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "OpGenHelpers.h"
+#include "llvm/ADT/StringSet.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/Regex.h"
@@ -63,3 +64,19 @@ mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) {
 
   return defs;
 }
+
+bool mlir::tblgen::isPythonReserved(StringRef str) {
+  static llvm::StringSet<> reserved({
+      "False",  "None",   "True",    "and",      "as",       "assert", "async",
+      "await",  "break",  "class",   "continue", "def",      "del",    "elif",
+      "else",   "except", "finally", "for",      "from",     "global", "if",
+      "import", "in",     "is",      "lambda",   "nonlocal", "not",    "or",
+      "pass",   "raise",  "return",  "try",      "while",    "with",   "yield",
+  });
+  // These aren't Python keywords but builtin functions that shouldn't/can't be
+  // shadowed.
+  reserved.insert("callable");
+  reserved.insert("issubclass");
+  reserved.insert("type");
+  return reserved.contains(str);
+}
\ No newline at end of file

diff  --git a/mlir/tools/mlir-tblgen/OpGenHelpers.h b/mlir/tools/mlir-tblgen/OpGenHelpers.h
index 166d24725ac6fc..3dcff14d1221ee 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.h
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.h
@@ -24,6 +24,10 @@ namespace tblgen {
 std::vector<llvm::Record *>
 getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
 
+/// Checks whether `str` is a Python keyword or would shadow builtin function.
+/// Regenerate using python -c"print(set(sorted(__import__('keyword').kwlist)))"
+bool isPythonReserved(llvm::StringRef str);
+
 } // namespace tblgen
 } // namespace mlir
 

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 7c7b991fb7b07a..0b5df7ab70dddb 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -11,6 +11,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "OpGenHelpers.h"
+
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/StringSet.h"
@@ -278,18 +280,6 @@ static llvm::cl::opt<std::string> clDialectExtensionName(
 
 using AttributeClasses = DenseMap<StringRef, StringRef>;
 
-/// Checks whether `str` is a Python keyword or would shadow builtin function.
-static bool isPythonReserved(StringRef str) {
-  static llvm::StringSet<> reserved(
-      {"and",      "as",    "assert", "break",      "callable", "class",
-       "continue", "def",   "del",    "elif",       "else",     "except",
-       "finally",  "for",   "from",   "global",     "if",       "import",
-       "in",       "is",    "lambda", "nonlocal",   "not",      "or",
-       "pass",     "raise", "return", "issubclass", "try",      "type",
-       "while",    "with",  "yield"});
-  return reserved.contains(str);
-}
-
 /// Checks whether `str` would shadow a generated variable or attribute
 /// part of the OpView API.
 static bool isODSReserved(StringRef str) {


        


More information about the Mlir-commits mailing list