[Mlir-commits] [mlir] [python] fix enum ambiguity (PR #117918)
Maksim Levental
llvmlistbot at llvm.org
Wed Nov 27 12:23:21 PST 2024
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/117918
>From 67d9107a2fda091ee16de40d3dce68413198589f Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Wed, 27 Nov 2024 15:06:12 -0500
Subject: [PATCH] [python] fix enum ambiguity
---
mlir/cmake/modules/AddMLIRPython.cmake | 4 +-
mlir/python/CMakeLists.txt | 23 +-
mlir/python/mlir/dialects/amdgpu.py | 145 +++
mlir/python/mlir/dialects/arith.py | 291 ++++++
mlir/python/mlir/dialects/bufferization.py | 25 +-
mlir/python/mlir/dialects/gpu/__init__.py | 352 +++++++
mlir/python/mlir/dialects/index.py | 45 +
mlir/python/mlir/dialects/linalg/__init__.py | 150 +++
mlir/python/mlir/dialects/llvm.py | 938 +++++++++++++++++-
mlir/python/mlir/dialects/nvgpu.py | 121 +++
mlir/python/mlir/dialects/nvvm.py | 395 ++++++++
mlir/python/mlir/dialects/sparse_tensor.py | 78 +-
.../mlir/dialects/transform/__init__.py | 52 +-
.../mlir/dialects/transform/structured.py | 43 +-
mlir/python/mlir/dialects/transform/vector.py | 96 ++
mlir/python/mlir/dialects/vector.py | 101 ++
.../mlir-tblgen/EnumPythonBindingGen.cpp | 21 +-
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 2 +-
18 files changed, 2849 insertions(+), 33 deletions(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 7b91f43e2d57fd..d06fc927ea44d4 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -318,7 +318,7 @@ function(declare_mlir_dialect_python_bindings)
set(LLVM_TARGET_DEFINITIONS ${td_file})
endif()
set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
- mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+ mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
list(APPEND _sources ${enum_filename})
endif()
@@ -390,7 +390,7 @@ function(declare_mlir_dialect_extension_python_bindings)
set(LLVM_TARGET_DEFINITIONS ${td_file})
endif()
set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
- mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
+ mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
list(APPEND _sources ${enum_filename})
endif()
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 23187f256455bb..aca9bfa08b8032 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -63,8 +63,7 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/AffineOps.td
SOURCES
dialects/affine.py
- DIALECT_NAME affine
- GEN_ENUM_BINDINGS)
+ DIALECT_NAME affine)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -88,10 +87,7 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/BufferizationOps.td
SOURCES
dialects/bufferization.py
- DIALECT_NAME bufferization
- GEN_ENUM_BINDINGS_TD_FILE
- "../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
-)
+ DIALECT_NAME bufferization)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -178,10 +174,7 @@ declare_mlir_dialect_python_bindings(
SOURCES
dialects/transform/__init__.py
_mlir_libs/_mlir/dialects/transform/__init__.pyi
- DIALECT_NAME transform
- GEN_ENUM_BINDINGS_TD_FILE
- "../../include/mlir/Dialect/Transform/IR/TransformAttrs.td"
-)
+ DIALECT_NAME transform)
declare_mlir_python_sources(
MLIRPythonSources.Dialects.transform.extras
@@ -250,10 +243,7 @@ declare_mlir_dialect_extension_python_bindings(
SOURCES
dialects/transform/structured.py
DIALECT_NAME transform
- EXTENSION_NAME structured_transform
- GEN_ENUM_BINDINGS_TD_FILE
- "../../include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
-)
+ EXTENSION_NAME structured_transform)
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -280,10 +270,7 @@ declare_mlir_dialect_extension_python_bindings(
SOURCES
dialects/transform/vector.py
DIALECT_NAME transform
- EXTENSION_NAME vector_transform
- GEN_ENUM_BINDINGS_TD_FILE
- "../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
-)
+ EXTENSION_NAME vector_transform)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py
index 43d905d0c481cc..39803747f1b1f4 100644
--- a/mlir/python/mlir/dialects/amdgpu.py
+++ b/mlir/python/mlir/dialects/amdgpu.py
@@ -1,6 +1,151 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from enum import IntEnum, IntFlag
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._amdgpu_ops_gen import *
from ._amdgpu_enum_gen import *
+
+
+class DPPPerm(IntEnum):
+ """The possible permutations for a DPP operation"""
+
+ quad_perm = 0
+ row_shl = 1
+ row_shr = 2
+ row_ror = 3
+ wave_shl = 4
+ wave_shr = 5
+ wave_ror = 6
+ wave_rol = 7
+ row_mirror = 8
+ row_half_mirror = 9
+ row_bcast_15 = 10
+ row_bcast_31 = 11
+
+ def __str__(self):
+ if self is DPPPerm.quad_perm:
+ return "quad_perm"
+ if self is DPPPerm.row_shl:
+ return "row_shl"
+ if self is DPPPerm.row_shr:
+ return "row_shr"
+ if self is DPPPerm.row_ror:
+ return "row_ror"
+ if self is DPPPerm.wave_shl:
+ return "wave_shl"
+ if self is DPPPerm.wave_shr:
+ return "wave_shr"
+ if self is DPPPerm.wave_ror:
+ return "wave_ror"
+ if self is DPPPerm.wave_rol:
+ return "wave_rol"
+ if self is DPPPerm.row_mirror:
+ return "row_mirror"
+ if self is DPPPerm.row_half_mirror:
+ return "row_half_mirror"
+ if self is DPPPerm.row_bcast_15:
+ return "row_bcast_15"
+ if self is DPPPerm.row_bcast_31:
+ return "row_bcast_31"
+ raise ValueError("Unknown DPPPerm enum entry.")
+
+
+ at register_attribute_builder("AMDGPU_DPPPerm")
+def _amdgpu_dppperm(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class MFMAPermB(IntEnum):
+ """The possible permutations of the lanes storing B available in an MFMA"""
+
+ none = 0
+ bcast_first_32 = 1
+ bcast_second_32 = 2
+ rotate_16_right = 3
+ bcast_first_16 = 4
+ bcast_second_16 = 5
+ bcast_third_16 = 6
+ bcast_fourth_16 = 7
+
+ def __str__(self):
+ if self is MFMAPermB.none:
+ return "none"
+ if self is MFMAPermB.bcast_first_32:
+ return "bcast_first_32"
+ if self is MFMAPermB.bcast_second_32:
+ return "bcast_second_32"
+ if self is MFMAPermB.rotate_16_right:
+ return "rotate_16_right"
+ if self is MFMAPermB.bcast_first_16:
+ return "bcast_first_16"
+ if self is MFMAPermB.bcast_second_16:
+ return "bcast_second_16"
+ if self is MFMAPermB.bcast_third_16:
+ return "bcast_third_16"
+ if self is MFMAPermB.bcast_fourth_16:
+ return "bcast_fourth_16"
+ raise ValueError("Unknown MFMAPermB enum entry.")
+
+
+ at register_attribute_builder("AMDGPU_MFMAPermB")
+def _amdgpu_mfmapermb(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class sched_barrier_opt_enum(IntFlag):
+ """The possible options for scheduling barriers"""
+
+ none = 0
+ non_mem_non_sideffect = 1
+ valu = 2
+ salu = 4
+ mfma_wmma = 8
+ all_vmem = 16
+ vmem_read = 32
+ vmem_write = 64
+ all_ds = 128
+ ds_read = 256
+ ds_write = 512
+ transcendental = 1024
+
+ def __iter__(self):
+ return iter([case for case in type(self) if (self & case) is case])
+
+ def __len__(self):
+ return bin(self).count("1")
+
+ def __str__(self):
+ if len(self) > 1:
+ return "|".join(map(str, self))
+ if self is sched_barrier_opt_enum.none:
+ return "none"
+ if self is sched_barrier_opt_enum.non_mem_non_sideffect:
+ return "non_mem_non_sideffect"
+ if self is sched_barrier_opt_enum.valu:
+ return "valu"
+ if self is sched_barrier_opt_enum.salu:
+ return "salu"
+ if self is sched_barrier_opt_enum.mfma_wmma:
+ return "mfma_wmma"
+ if self is sched_barrier_opt_enum.all_vmem:
+ return "all_vmem"
+ if self is sched_barrier_opt_enum.vmem_read:
+ return "vmem_read"
+ if self is sched_barrier_opt_enum.vmem_write:
+ return "vmem_write"
+ if self is sched_barrier_opt_enum.all_ds:
+ return "all_ds"
+ if self is sched_barrier_opt_enum.ds_read:
+ return "ds_read"
+ if self is sched_barrier_opt_enum.ds_write:
+ return "ds_write"
+ if self is sched_barrier_opt_enum.transcendental:
+ return "transcendental"
+ raise ValueError("Unknown sched_barrier_opt_enum enum entry.")
+
+
+ at register_attribute_builder("AMDGPU_SchedBarrierOpOpt")
+def _amdgpu_schedbarrieropopt(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 92da5df9bce665..8b04df846adda0 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -1,6 +1,7 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from enum import IntEnum, IntFlag
from ._arith_ops_gen import *
from ._arith_ops_gen import _Dialect
@@ -108,3 +109,293 @@ def constant(
result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
) -> Value:
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
+
+
+class CmpFPredicate(IntEnum):
+ """allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15"""
+
+ AlwaysFalse = 0
+ OEQ = 1
+ OGT = 2
+ OGE = 3
+ OLT = 4
+ OLE = 5
+ ONE = 6
+ ORD = 7
+ UEQ = 8
+ UGT = 9
+ UGE = 10
+ ULT = 11
+ ULE = 12
+ UNE = 13
+ UNO = 14
+ AlwaysTrue = 15
+
+ def __str__(self):
+ if self is CmpFPredicate.AlwaysFalse:
+ return "false"
+ if self is CmpFPredicate.OEQ:
+ return "oeq"
+ if self is CmpFPredicate.OGT:
+ return "ogt"
+ if self is CmpFPredicate.OGE:
+ return "oge"
+ if self is CmpFPredicate.OLT:
+ return "olt"
+ if self is CmpFPredicate.OLE:
+ return "ole"
+ if self is CmpFPredicate.ONE:
+ return "one"
+ if self is CmpFPredicate.ORD:
+ return "ord"
+ if self is CmpFPredicate.UEQ:
+ return "ueq"
+ if self is CmpFPredicate.UGT:
+ return "ugt"
+ if self is CmpFPredicate.UGE:
+ return "uge"
+ if self is CmpFPredicate.ULT:
+ return "ult"
+ if self is CmpFPredicate.ULE:
+ return "ule"
+ if self is CmpFPredicate.UNE:
+ return "une"
+ if self is CmpFPredicate.UNO:
+ return "uno"
+ if self is CmpFPredicate.AlwaysTrue:
+ return "true"
+ raise ValueError("Unknown CmpFPredicate enum entry.")
+
+
+ at register_attribute_builder("Arith_CmpFPredicateAttr")
+def _arith_cmpfpredicateattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class CmpIPredicate(IntEnum):
+ """allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9"""
+
+ eq = 0
+ ne = 1
+ slt = 2
+ sle = 3
+ sgt = 4
+ sge = 5
+ ult = 6
+ ule = 7
+ ugt = 8
+ uge = 9
+
+ def __str__(self):
+ if self is CmpIPredicate.eq:
+ return "eq"
+ if self is CmpIPredicate.ne:
+ return "ne"
+ if self is CmpIPredicate.slt:
+ return "slt"
+ if self is CmpIPredicate.sle:
+ return "sle"
+ if self is CmpIPredicate.sgt:
+ return "sgt"
+ if self is CmpIPredicate.sge:
+ return "sge"
+ if self is CmpIPredicate.ult:
+ return "ult"
+ if self is CmpIPredicate.ule:
+ return "ule"
+ if self is CmpIPredicate.ugt:
+ return "ugt"
+ if self is CmpIPredicate.uge:
+ return "uge"
+ raise ValueError("Unknown CmpIPredicate enum entry.")
+
+
+ at register_attribute_builder("Arith_CmpIPredicateAttr")
+def _arith_cmpipredicateattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class DenormalMode(IntEnum):
+ """denormal mode arith"""
+
+ ieee = 0
+ preserve_sign = 1
+ positive_zero = 2
+
+ def __str__(self):
+ if self is DenormalMode.ieee:
+ return "ieee"
+ if self is DenormalMode.preserve_sign:
+ return "preserve_sign"
+ if self is DenormalMode.positive_zero:
+ return "positive_zero"
+ raise ValueError("Unknown DenormalMode enum entry.")
+
+
+ at register_attribute_builder("Arith_DenormalMode")
+def _arith_denormalmode(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class IntegerOverflowFlags(IntFlag):
+ """Integer overflow arith flags"""
+
+ none = 0
+ nsw = 1
+ nuw = 2
+
+ def __iter__(self):
+ return iter([case for case in type(self) if (self & case) is case])
+
+ def __len__(self):
+ return bin(self).count("1")
+
+ def __str__(self):
+ if len(self) > 1:
+ return ", ".join(map(str, self))
+ if self is IntegerOverflowFlags.none:
+ return "none"
+ if self is IntegerOverflowFlags.nsw:
+ return "nsw"
+ if self is IntegerOverflowFlags.nuw:
+ return "nuw"
+ raise ValueError("Unknown IntegerOverflowFlags enum entry.")
+
+
+ at register_attribute_builder("Arith_IntegerOverflowFlags")
+def _arith_integeroverflowflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class RoundingMode(IntEnum):
+ """Floating point rounding mode"""
+
+ to_nearest_even = 0
+ downward = 1
+ upward = 2
+ toward_zero = 3
+ to_nearest_away = 4
+
+ def __str__(self):
+ if self is RoundingMode.to_nearest_even:
+ return "to_nearest_even"
+ if self is RoundingMode.downward:
+ return "downward"
+ if self is RoundingMode.upward:
+ return "upward"
+ if self is RoundingMode.toward_zero:
+ return "toward_zero"
+ if self is RoundingMode.to_nearest_away:
+ return "to_nearest_away"
+ raise ValueError("Unknown RoundingMode enum entry.")
+
+
+ at register_attribute_builder("Arith_RoundingModeAttr")
+def _arith_roundingmodeattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class AtomicRMWKind(IntEnum):
+ """allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14"""
+
+ addf = 0
+ addi = 1
+ assign = 2
+ maximumf = 3
+ maxs = 4
+ maxu = 5
+ minimumf = 6
+ mins = 7
+ minu = 8
+ mulf = 9
+ muli = 10
+ ori = 11
+ andi = 12
+ maxnumf = 13
+ minnumf = 14
+
+ def __str__(self):
+ if self is AtomicRMWKind.addf:
+ return "addf"
+ if self is AtomicRMWKind.addi:
+ return "addi"
+ if self is AtomicRMWKind.assign:
+ return "assign"
+ if self is AtomicRMWKind.maximumf:
+ return "maximumf"
+ if self is AtomicRMWKind.maxs:
+ return "maxs"
+ if self is AtomicRMWKind.maxu:
+ return "maxu"
+ if self is AtomicRMWKind.minimumf:
+ return "minimumf"
+ if self is AtomicRMWKind.mins:
+ return "mins"
+ if self is AtomicRMWKind.minu:
+ return "minu"
+ if self is AtomicRMWKind.mulf:
+ return "mulf"
+ if self is AtomicRMWKind.muli:
+ return "muli"
+ if self is AtomicRMWKind.ori:
+ return "ori"
+ if self is AtomicRMWKind.andi:
+ return "andi"
+ if self is AtomicRMWKind.maxnumf:
+ return "maxnumf"
+ if self is AtomicRMWKind.minnumf:
+ return "minnumf"
+ raise ValueError("Unknown AtomicRMWKind enum entry.")
+
+
+ at register_attribute_builder("AtomicRMWKindAttr")
+def _atomicrmwkindattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class FastMathFlags(IntFlag):
+ """Floating point fast math flags"""
+
+ none = 0
+ reassoc = 1
+ nnan = 2
+ ninf = 4
+ nsz = 8
+ arcp = 16
+ contract = 32
+ afn = 64
+ fast = 127
+
+ def __iter__(self):
+ return iter([case for case in type(self) if (self & case) is case])
+
+ def __len__(self):
+ return bin(self).count("1")
+
+ def __str__(self):
+ if len(self) > 1:
+ return ",".join(map(str, self))
+ if self is FastMathFlags.none:
+ return "none"
+ if self is FastMathFlags.reassoc:
+ return "reassoc"
+ if self is FastMathFlags.nnan:
+ return "nnan"
+ if self is FastMathFlags.ninf:
+ return "ninf"
+ if self is FastMathFlags.nsz:
+ return "nsz"
+ if self is FastMathFlags.arcp:
+ return "arcp"
+ if self is FastMathFlags.contract:
+ return "contract"
+ if self is FastMathFlags.afn:
+ return "afn"
+ if self is FastMathFlags.fast:
+ return "fast"
+ raise ValueError("Unknown FastMathFlags enum entry.")
+
+
+ at register_attribute_builder("FastMathFlags")
+def _fastmathflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
index 759b6aa24a9ff7..a0eb1b4bb07830 100644
--- a/mlir/python/mlir/dialects/bufferization.py
+++ b/mlir/python/mlir/dialects/bufferization.py
@@ -1,6 +1,29 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from enum import IntEnum
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._bufferization_ops_gen import *
-from ._bufferization_enum_gen import *
+
+
+class LayoutMapOption(IntEnum):
+ """option for map layout"""
+
+ InferLayoutMap = 0
+ IdentityLayoutMap = 1
+ FullyDynamicLayoutMap = 2
+
+ def __str__(self):
+ if self is LayoutMapOption.InferLayoutMap:
+ return "InferLayoutMap"
+ if self is LayoutMapOption.IdentityLayoutMap:
+ return "IdentityLayoutMap"
+ if self is LayoutMapOption.FullyDynamicLayoutMap:
+ return "FullyDynamicLayoutMap"
+ raise ValueError("Unknown LayoutMapOption enum entry.")
+
+
+ at register_attribute_builder("LayoutMapOption")
+def _layoutmapoption(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py
index 4cd80aa8b7ca85..f70d61741ffe7c 100644
--- a/mlir/python/mlir/dialects/gpu/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/__init__.py
@@ -1,7 +1,359 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from enum import IntEnum
+from ...ir import IntegerAttr, IntegerType, register_attribute_builder
from .._gpu_ops_gen import *
from .._gpu_enum_gen import *
from ..._mlir_libs._mlirDialectsGPU import *
+
+
+class AddressSpace(IntEnum):
+ """GPU address space"""
+
+ Global = 1
+ Workgroup = 2
+ Private = 3
+
+ def __str__(self):
+ if self is AddressSpace.Global:
+ return "global"
+ if self is AddressSpace.Workgroup:
+ return "workgroup"
+ if self is AddressSpace.Private:
+ return "private"
+ raise ValueError("Unknown AddressSpace enum entry.")
+
+
+ at register_attribute_builder("GPU_AddressSpaceEnum")
+def _gpu_addressspaceenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class AllReduceOperation(IntEnum):
+ """built-in reduction operations supported by gpu.allreduce."""
+
+ ADD = 0
+ MUL = 1
+ MINUI = 2
+ MINSI = 3
+ MINNUMF = 4
+ MAXUI = 5
+ MAXSI = 6
+ MAXNUMF = 7
+ AND = 8
+ OR = 9
+ XOR = 10
+ MINIMUMF = 11
+ MAXIMUMF = 12
+
+ def __str__(self):
+ if self is AllReduceOperation.ADD:
+ return "add"
+ if self is AllReduceOperation.MUL:
+ return "mul"
+ if self is AllReduceOperation.MINUI:
+ return "minui"
+ if self is AllReduceOperation.MINSI:
+ return "minsi"
+ if self is AllReduceOperation.MINNUMF:
+ return "minnumf"
+ if self is AllReduceOperation.MAXUI:
+ return "maxui"
+ if self is AllReduceOperation.MAXSI:
+ return "maxsi"
+ if self is AllReduceOperation.MAXNUMF:
+ return "maxnumf"
+ if self is AllReduceOperation.AND:
+ return "and"
+ if self is AllReduceOperation.OR:
+ return "or"
+ if self is AllReduceOperation.XOR:
+ return "xor"
+ if self is AllReduceOperation.MINIMUMF:
+ return "minimumf"
+ if self is AllReduceOperation.MAXIMUMF:
+ return "maximumf"
+ raise ValueError("Unknown AllReduceOperation enum entry.")
+
+
+ at register_attribute_builder("GPU_AllReduceOperation")
+def _gpu_allreduceoperation(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class CompilationTarget(IntEnum):
+ """GPU compilation format"""
+
+ Offload = 1
+ Assembly = 2
+ Binary = 3
+ Fatbin = 4
+
+ def __str__(self):
+ if self is CompilationTarget.Offload:
+ return "offload"
+ if self is CompilationTarget.Assembly:
+ return "assembly"
+ if self is CompilationTarget.Binary:
+ return "bin"
+ if self is CompilationTarget.Fatbin:
+ return "fatbin"
+ raise ValueError("Unknown CompilationTarget enum entry.")
+
+
+ at register_attribute_builder("GPU_CompilationTargetEnum")
+def _gpu_compilationtargetenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class Dimension(IntEnum):
+ """a dimension, either 'x', 'y', or 'z'"""
+
+ x = 0
+ y = 1
+ z = 2
+
+ def __str__(self):
+ if self is Dimension.x:
+ return "x"
+ if self is Dimension.y:
+ return "y"
+ if self is Dimension.z:
+ return "z"
+ raise ValueError("Unknown Dimension enum entry.")
+
+
+ at register_attribute_builder("GPU_Dimension")
+def _gpu_dimension(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class Prune2To4SpMatFlag(IntEnum):
+ """pruning strategy for 2:4 sparse matrix"""
+
+ NONE = 0
+ PRUNE_ONLY = 1
+ PRUNE_AND_CHECK = 2
+
+ def __str__(self):
+ if self is Prune2To4SpMatFlag.NONE:
+ return "NONE"
+ if self is Prune2To4SpMatFlag.PRUNE_ONLY:
+ return "PRUNE_ONLY"
+ if self is Prune2To4SpMatFlag.PRUNE_AND_CHECK:
+ return "PRUNE_AND_CHECK"
+ raise ValueError("Unknown Prune2To4SpMatFlag enum entry.")
+
+
+ at register_attribute_builder("GPU_Prune2To4SpMatFlag")
+def _gpu_prune2to4spmatflag(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class ShuffleMode(IntEnum):
+ """Indexing modes supported by gpu.shuffle."""
+
+ XOR = 0
+ UP = 2
+ DOWN = 1
+ IDX = 3
+
+ def __str__(self):
+ if self is ShuffleMode.XOR:
+ return "xor"
+ if self is ShuffleMode.UP:
+ return "up"
+ if self is ShuffleMode.DOWN:
+ return "down"
+ if self is ShuffleMode.IDX:
+ return "idx"
+ raise ValueError("Unknown ShuffleMode enum entry.")
+
+
+ at register_attribute_builder("GPU_ShuffleMode")
+def _gpu_shufflemode(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class SpGEMMWorkEstimationOrComputeKind(IntEnum):
+ """choose whether spgemm_work_estimation_or_compute does work estimation or compute"""
+
+ WORK_ESTIMATION = 0
+ COMPUTE = 1
+
+ def __str__(self):
+ if self is SpGEMMWorkEstimationOrComputeKind.WORK_ESTIMATION:
+ return "WORK_ESTIMATION"
+ if self is SpGEMMWorkEstimationOrComputeKind.COMPUTE:
+ return "COMPUTE"
+ raise ValueError("Unknown SpGEMMWorkEstimationOrComputeKind enum entry.")
+
+
+ at register_attribute_builder("GPU_SpGEMMWorkEstimationOrComputeKind")
+def _gpu_spgemmworkestimationorcomputekind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class TransposeMode(IntEnum):
+ """transpose mode of sparse matrix supported by sparse tensor ops"""
+
+ NON_TRANSPOSE = 0
+ TRANSPOSE = 1
+ CONJUGATE_TRANSPOSE = 2
+
+ def __str__(self):
+ if self is TransposeMode.NON_TRANSPOSE:
+ return "NON_TRANSPOSE"
+ if self is TransposeMode.TRANSPOSE:
+ return "TRANSPOSE"
+ if self is TransposeMode.CONJUGATE_TRANSPOSE:
+ return "CONJUGATE_TRANSPOSE"
+ raise ValueError("Unknown TransposeMode enum entry.")
+
+
+ at register_attribute_builder("GPU_TransposeMode")
+def _gpu_transposemode(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class MMAElementwiseOp(IntEnum):
+ """elementwise operation to apply to mma matrix"""
+
+ ADDF = 0
+ MULF = 1
+ SUBF = 2
+ MAXF = 3
+ MINF = 4
+ DIVF = 5
+ ADDI = 6
+ MULI = 7
+ SUBI = 8
+ DIVS = 9
+ DIVU = 10
+ NEGATEF = 11
+ NEGATES = 12
+ EXTF = 13
+
+ def __str__(self):
+ if self is MMAElementwiseOp.ADDF:
+ return "addf"
+ if self is MMAElementwiseOp.MULF:
+ return "mulf"
+ if self is MMAElementwiseOp.SUBF:
+ return "subf"
+ if self is MMAElementwiseOp.MAXF:
+ return "maxf"
+ if self is MMAElementwiseOp.MINF:
+ return "minf"
+ if self is MMAElementwiseOp.DIVF:
+ return "divf"
+ if self is MMAElementwiseOp.ADDI:
+ return "addi"
+ if self is MMAElementwiseOp.MULI:
+ return "muli"
+ if self is MMAElementwiseOp.SUBI:
+ return "subi"
+ if self is MMAElementwiseOp.DIVS:
+ return "divs"
+ if self is MMAElementwiseOp.DIVU:
+ return "divu"
+ if self is MMAElementwiseOp.NEGATEF:
+ return "negatef"
+ if self is MMAElementwiseOp.NEGATES:
+ return "negates"
+ if self is MMAElementwiseOp.EXTF:
+ return "extf"
+ raise ValueError("Unknown MMAElementwiseOp enum entry.")
+
+
+ at register_attribute_builder("MMAElementWise")
+def _mmaelementwise(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class MappingId(IntEnum):
+ """Mapping ids for loop mapping"""
+
+ DimX = 0
+ DimY = 1
+ DimZ = 2
+ LinearDim0 = 3
+ LinearDim1 = 4
+ LinearDim2 = 5
+ LinearDim3 = 6
+ LinearDim4 = 7
+ LinearDim5 = 8
+ LinearDim6 = 9
+ LinearDim7 = 10
+ LinearDim8 = 11
+ LinearDim9 = 12
+
+ def __str__(self):
+ if self is MappingId.DimX:
+ return "x"
+ if self is MappingId.DimY:
+ return "y"
+ if self is MappingId.DimZ:
+ return "z"
+ if self is MappingId.LinearDim0:
+ return "linear_dim_0"
+ if self is MappingId.LinearDim1:
+ return "linear_dim_1"
+ if self is MappingId.LinearDim2:
+ return "linear_dim_2"
+ if self is MappingId.LinearDim3:
+ return "linear_dim_3"
+ if self is MappingId.LinearDim4:
+ return "linear_dim_4"
+ if self is MappingId.LinearDim5:
+ return "linear_dim_5"
+ if self is MappingId.LinearDim6:
+ return "linear_dim_6"
+ if self is MappingId.LinearDim7:
+ return "linear_dim_7"
+ if self is MappingId.LinearDim8:
+ return "linear_dim_8"
+ if self is MappingId.LinearDim9:
+ return "linear_dim_9"
+ raise ValueError("Unknown MappingId enum entry.")
+
+
+ at register_attribute_builder("MappingIdEnum")
+def _mappingidenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class Processor(IntEnum):
+ """processor for loop mapping"""
+
+ BlockX = 0
+ BlockY = 1
+ BlockZ = 2
+ ThreadX = 3
+ ThreadY = 4
+ ThreadZ = 5
+ Sequential = 6
+
+ def __str__(self):
+ if self is Processor.BlockX:
+ return "block_x"
+ if self is Processor.BlockY:
+ return "block_y"
+ if self is Processor.BlockZ:
+ return "block_z"
+ if self is Processor.ThreadX:
+ return "thread_x"
+ if self is Processor.ThreadY:
+ return "thread_y"
+ if self is Processor.ThreadZ:
+ return "thread_z"
+ if self is Processor.Sequential:
+ return "sequential"
+ raise ValueError("Unknown Processor enum entry.")
+
+
+ at register_attribute_builder("ProcessorEnum")
+def _processorenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/index.py b/mlir/python/mlir/dialects/index.py
index 73708c7d71a8c8..91bc3b0c91b137 100644
--- a/mlir/python/mlir/dialects/index.py
+++ b/mlir/python/mlir/dialects/index.py
@@ -1,6 +1,51 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from enum import IntEnum
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._index_ops_gen import *
from ._index_enum_gen import *
+
+
+class IndexCmpPredicate(IntEnum):
+ """index comparison predicate kind"""
+
+ EQ = 0
+ NE = 1
+ SLT = 2
+ SLE = 3
+ SGT = 4
+ SGE = 5
+ ULT = 6
+ ULE = 7
+ UGT = 8
+ UGE = 9
+
+ def __str__(self):
+ if self is IndexCmpPredicate.EQ:
+ return "eq"
+ if self is IndexCmpPredicate.NE:
+ return "ne"
+ if self is IndexCmpPredicate.SLT:
+ return "slt"
+ if self is IndexCmpPredicate.SLE:
+ return "sle"
+ if self is IndexCmpPredicate.SGT:
+ return "sgt"
+ if self is IndexCmpPredicate.SGE:
+ return "sge"
+ if self is IndexCmpPredicate.ULT:
+ return "ult"
+ if self is IndexCmpPredicate.ULE:
+ return "ule"
+ if self is IndexCmpPredicate.UGT:
+ return "ugt"
+ if self is IndexCmpPredicate.UGE:
+ return "uge"
+ raise ValueError("Unknown IndexCmpPredicate enum entry.")
+
+
+ at register_attribute_builder("IndexCmpPredicate")
+def _indexcmppredicate(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 8fb1227ee80ff5..88e688535762fb 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -1,6 +1,7 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from enum import IntEnum
# Re-export the objects provided by pybind.
from ..._mlir_libs._mlirDialectsLinalg import *
@@ -102,3 +103,152 @@ def broadcast(
)
fill_builtin_region(op.operation)
return op
+
+
+class BinaryFn(IntEnum):
+ """allowed 32-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9"""
+
+ add = 0
+ sub = 1
+ mul = 2
+ div = 3
+ div_unsigned = 4
+ max_signed = 5
+ min_signed = 6
+ max_unsigned = 7
+ min_unsigned = 8
+ powf = 9
+
+ def __str__(self):
+ if self is BinaryFn.add:
+ return "add"
+ if self is BinaryFn.sub:
+ return "sub"
+ if self is BinaryFn.mul:
+ return "mul"
+ if self is BinaryFn.div:
+ return "div"
+ if self is BinaryFn.div_unsigned:
+ return "div_unsigned"
+ if self is BinaryFn.max_signed:
+ return "max_signed"
+ if self is BinaryFn.min_signed:
+ return "min_signed"
+ if self is BinaryFn.max_unsigned:
+ return "max_unsigned"
+ if self is BinaryFn.min_unsigned:
+ return "min_unsigned"
+ if self is BinaryFn.powf:
+ return "powf"
+ raise ValueError("Unknown BinaryFn enum entry.")
+
+
+ at register_attribute_builder("BinaryFn")
+def _binaryfn(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class IteratorType(IntEnum):
+ """Iterator type"""
+
+ parallel = 0
+ reduction = 1
+
+ def __str__(self):
+ if self is IteratorType.parallel:
+ return "parallel"
+ if self is IteratorType.reduction:
+ return "reduction"
+ raise ValueError("Unknown IteratorType enum entry.")
+
+
+ at register_attribute_builder("IteratorType")
+def _iteratortype(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class TernaryFn(IntEnum):
+ """allowed 32-bit signless integer cases: 0"""
+
+ select = 0
+
+ def __str__(self):
+ if self is TernaryFn.select:
+ return "select"
+ raise ValueError("Unknown TernaryFn enum entry.")
+
+
+ at register_attribute_builder("TernaryFn")
+def _ternaryfn(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class TypeFn(IntEnum):
+ """allowed 32-bit signless integer cases: 0, 1"""
+
+ cast_signed = 0
+ cast_unsigned = 1
+
+ def __str__(self):
+ if self is TypeFn.cast_signed:
+ return "cast_signed"
+ if self is TypeFn.cast_unsigned:
+ return "cast_unsigned"
+ raise ValueError("Unknown TypeFn enum entry.")
+
+
+ at register_attribute_builder("TypeFn")
+def _typefn(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class UnaryFn(IntEnum):
+ """allowed 32-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12"""
+
+ exp = 0
+ log = 1
+ abs = 2
+ ceil = 3
+ floor = 4
+ negf = 5
+ reciprocal = 6
+ round = 7
+ sqrt = 8
+ rsqrt = 9
+ square = 10
+ tanh = 11
+ erf = 12
+
+ def __str__(self):
+ if self is UnaryFn.exp:
+ return "exp"
+ if self is UnaryFn.log:
+ return "log"
+ if self is UnaryFn.abs:
+ return "abs"
+ if self is UnaryFn.ceil:
+ return "ceil"
+ if self is UnaryFn.floor:
+ return "floor"
+ if self is UnaryFn.negf:
+ return "negf"
+ if self is UnaryFn.reciprocal:
+ return "reciprocal"
+ if self is UnaryFn.round:
+ return "round"
+ if self is UnaryFn.sqrt:
+ return "sqrt"
+ if self is UnaryFn.rsqrt:
+ return "rsqrt"
+ if self is UnaryFn.square:
+ return "square"
+ if self is UnaryFn.tanh:
+ return "tanh"
+ if self is UnaryFn.erf:
+ return "erf"
+ raise ValueError("Unknown UnaryFn enum entry.")
+
+
+ at register_attribute_builder("UnaryFn")
+def _unaryfn(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py
index 941a584966dcde..2a763c22f7c334 100644
--- a/mlir/python/mlir/dialects/llvm.py
+++ b/mlir/python/mlir/dialects/llvm.py
@@ -1,11 +1,12 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from enum import IntEnum, IntFlag, auto
from ._llvm_ops_gen import *
from ._llvm_enum_gen import *
from .._mlir_libs._mlirDialectsLLVM import *
-from ..ir import Value
+from ..ir import Value, IntegerAttr, IntegerType, register_attribute_builder
from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
@@ -13,3 +14,938 @@ def mlir_constant(value, *, loc=None, ip=None) -> Value:
return _get_op_result_or_op_results(
ConstantOp(res=value.type, value=value, loc=loc, ip=ip)
)
+
+
+class AsmDialect(IntEnum):
+ """ATT (0) or Intel (1) asm dialect"""
+
+ AD_ATT = 0
+ AD_Intel = 1
+
+ def __str__(self):
+ if self is AsmDialect.AD_ATT:
+ return "att"
+ if self is AsmDialect.AD_Intel:
+ return "intel"
+ raise ValueError("Unknown AsmDialect enum entry.")
+
+
+ at register_attribute_builder("AsmATTOrIntel")
+def _asmattorintel(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class AtomicBinOp(IntEnum):
+ """llvm.atomicrmw binary operations"""
+
+ xchg = 0
+ add = 1
+ sub = 2
+ _and = 3
+ nand = 4
+ _or = 5
+ _xor = 6
+ max = 7
+ min = 8
+ umax = 9
+ umin = 10
+ fadd = 11
+ fsub = 12
+ fmax = 13
+ fmin = 14
+ uinc_wrap = 15
+ udec_wrap = 16
+ usub_cond = 17
+ usub_sat = 18
+
+ def __str__(self):
+ if self is AtomicBinOp.xchg:
+ return "xchg"
+ if self is AtomicBinOp.add:
+ return "add"
+ if self is AtomicBinOp.sub:
+ return "sub"
+ if self is AtomicBinOp._and:
+ return "_and"
+ if self is AtomicBinOp.nand:
+ return "nand"
+ if self is AtomicBinOp._or:
+ return "_or"
+ if self is AtomicBinOp._xor:
+ return "_xor"
+ if self is AtomicBinOp.max:
+ return "max"
+ if self is AtomicBinOp.min:
+ return "min"
+ if self is AtomicBinOp.umax:
+ return "umax"
+ if self is AtomicBinOp.umin:
+ return "umin"
+ if self is AtomicBinOp.fadd:
+ return "fadd"
+ if self is AtomicBinOp.fsub:
+ return "fsub"
+ if self is AtomicBinOp.fmax:
+ return "fmax"
+ if self is AtomicBinOp.fmin:
+ return "fmin"
+ if self is AtomicBinOp.uinc_wrap:
+ return "uinc_wrap"
+ if self is AtomicBinOp.udec_wrap:
+ return "udec_wrap"
+ if self is AtomicBinOp.usub_cond:
+ return "usub_cond"
+ if self is AtomicBinOp.usub_sat:
+ return "usub_sat"
+ raise ValueError("Unknown AtomicBinOp enum entry.")
+
+
+ at register_attribute_builder("AtomicBinOp")
+def _atomicbinop(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class AtomicOrdering(IntEnum):
+ """Atomic ordering for LLVM's memory model"""
+
+ not_atomic = 0
+ unordered = 1
+ monotonic = 2
+ acquire = 4
+ release = 5
+ acq_rel = 6
+ seq_cst = 7
+
+ def __str__(self):
+ if self is AtomicOrdering.not_atomic:
+ return "not_atomic"
+ if self is AtomicOrdering.unordered:
+ return "unordered"
+ if self is AtomicOrdering.monotonic:
+ return "monotonic"
+ if self is AtomicOrdering.acquire:
+ return "acquire"
+ if self is AtomicOrdering.release:
+ return "release"
+ if self is AtomicOrdering.acq_rel:
+ return "acq_rel"
+ if self is AtomicOrdering.seq_cst:
+ return "seq_cst"
+ raise ValueError("Unknown AtomicOrdering enum entry.")
+
+
+ at register_attribute_builder("AtomicOrdering")
+def _atomicordering(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class CConv(IntEnum):
+ """Calling Conventions"""
+
+ C = 0
+ Fast = 8
+ Cold = 9
+ GHC = 10
+ HiPE = 11
+ AnyReg = 13
+ PreserveMost = 14
+ PreserveAll = 15
+ Swift = 16
+ CXX_FAST_TLS = 17
+ Tail = 18
+ CFGuard_Check = 19
+ SwiftTail = 20
+ X86_StdCall = 64
+ X86_FastCall = 65
+ ARM_APCS = 66
+ ARM_AAPCS = 67
+ ARM_AAPCS_VFP = 68
+ MSP430_INTR = 69
+ X86_ThisCall = 70
+ PTX_Kernel = 71
+ PTX_Device = 72
+ SPIR_FUNC = 75
+ SPIR_KERNEL = 76
+ Intel_OCL_BI = 77
+ X86_64_SysV = 78
+ Win64 = 79
+ X86_VectorCall = 80
+ DUMMY_HHVM = 81
+ DUMMY_HHVM_C = 82
+ X86_INTR = 83
+ AVR_INTR = 84
+ AVR_BUILTIN = 86
+ AMDGPU_VS = 87
+ AMDGPU_GS = 88
+ AMDGPU_CS = 90
+ AMDGPU_KERNEL = 91
+ X86_RegCall = 92
+ AMDGPU_HS = 93
+ MSP430_BUILTIN = 94
+ AMDGPU_LS = 95
+ AMDGPU_ES = 96
+ AArch64_VectorCall = 97
+ AArch64_SVE_VectorCall = 98
+ WASM_EmscriptenInvoke = 99
+ AMDGPU_Gfx = 100
+ M68k_INTR = 101
+
+ def __str__(self):
+ if self is CConv.C:
+ return "ccc"
+ if self is CConv.Fast:
+ return "fastcc"
+ if self is CConv.Cold:
+ return "coldcc"
+ if self is CConv.GHC:
+ return "cc_10"
+ if self is CConv.HiPE:
+ return "cc_11"
+ if self is CConv.AnyReg:
+ return "anyregcc"
+ if self is CConv.PreserveMost:
+ return "preserve_mostcc"
+ if self is CConv.PreserveAll:
+ return "preserve_allcc"
+ if self is CConv.Swift:
+ return "swiftcc"
+ if self is CConv.CXX_FAST_TLS:
+ return "cxx_fast_tlscc"
+ if self is CConv.Tail:
+ return "tailcc"
+ if self is CConv.CFGuard_Check:
+ return "cfguard_checkcc"
+ if self is CConv.SwiftTail:
+ return "swifttailcc"
+ if self is CConv.X86_StdCall:
+ return "x86_stdcallcc"
+ if self is CConv.X86_FastCall:
+ return "x86_fastcallcc"
+ if self is CConv.ARM_APCS:
+ return "arm_apcscc"
+ if self is CConv.ARM_AAPCS:
+ return "arm_aapcscc"
+ if self is CConv.ARM_AAPCS_VFP:
+ return "arm_aapcs_vfpcc"
+ if self is CConv.MSP430_INTR:
+ return "msp430_intrcc"
+ if self is CConv.X86_ThisCall:
+ return "x86_thiscallcc"
+ if self is CConv.PTX_Kernel:
+ return "ptx_kernelcc"
+ if self is CConv.PTX_Device:
+ return "ptx_devicecc"
+ if self is CConv.SPIR_FUNC:
+ return "spir_funccc"
+ if self is CConv.SPIR_KERNEL:
+ return "spir_kernelcc"
+ if self is CConv.Intel_OCL_BI:
+ return "intel_ocl_bicc"
+ if self is CConv.X86_64_SysV:
+ return "x86_64_sysvcc"
+ if self is CConv.Win64:
+ return "win64cc"
+ if self is CConv.X86_VectorCall:
+ return "x86_vectorcallcc"
+ if self is CConv.DUMMY_HHVM:
+ return "hhvmcc"
+ if self is CConv.DUMMY_HHVM_C:
+ return "hhvm_ccc"
+ if self is CConv.X86_INTR:
+ return "x86_intrcc"
+ if self is CConv.AVR_INTR:
+ return "avr_intrcc"
+ if self is CConv.AVR_BUILTIN:
+ return "avr_builtincc"
+ if self is CConv.AMDGPU_VS:
+ return "amdgpu_vscc"
+ if self is CConv.AMDGPU_GS:
+ return "amdgpu_gscc"
+ if self is CConv.AMDGPU_CS:
+ return "amdgpu_cscc"
+ if self is CConv.AMDGPU_KERNEL:
+ return "amdgpu_kernelcc"
+ if self is CConv.X86_RegCall:
+ return "x86_regcallcc"
+ if self is CConv.AMDGPU_HS:
+ return "amdgpu_hscc"
+ if self is CConv.MSP430_BUILTIN:
+ return "msp430_builtincc"
+ if self is CConv.AMDGPU_LS:
+ return "amdgpu_lscc"
+ if self is CConv.AMDGPU_ES:
+ return "amdgpu_escc"
+ if self is CConv.AArch64_VectorCall:
+ return "aarch64_vectorcallcc"
+ if self is CConv.AArch64_SVE_VectorCall:
+ return "aarch64_sve_vectorcallcc"
+ if self is CConv.WASM_EmscriptenInvoke:
+ return "wasm_emscripten_invokecc"
+ if self is CConv.AMDGPU_Gfx:
+ return "amdgpu_gfxcc"
+ if self is CConv.M68k_INTR:
+ return "m68k_intrcc"
+ raise ValueError("Unknown CConv enum entry.")
+
+
+ at register_attribute_builder("CConvEnum")
+def _cconvenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class Comdat(IntEnum):
+ """LLVM Comdat Types"""
+
+ Any = 0
+ ExactMatch = 1
+ Largest = 2
+ NoDeduplicate = 3
+ SameSize = 4
+
+ def __str__(self):
+ if self is Comdat.Any:
+ return "any"
+ if self is Comdat.ExactMatch:
+ return "exactmatch"
+ if self is Comdat.Largest:
+ return "largest"
+ if self is Comdat.NoDeduplicate:
+ return "nodeduplicate"
+ if self is Comdat.SameSize:
+ return "samesize"
+ raise ValueError("Unknown Comdat enum entry.")
+
+
+ at register_attribute_builder("Comdat")
+def _comdat(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class DIFlags(IntFlag):
+ """LLVM DI flags"""
+
+ Zero = 0
+ Bit0 = 1
+ Bit1 = 2
+ Private = 1
+ Protected = 2
+ Public = 3
+ FwdDecl = 4
+ AppleBlock = 8
+ ReservedBit4 = 16
+ Virtual = 32
+ Artificial = 64
+ Explicit = 128
+ Prototyped = 256
+ ObjcClassComplete = 512
+ ObjectPointer = 1024
+ Vector = 2048
+ StaticMember = 4096
+ LValueReference = 8192
+ RValueReference = 16384
+ ExportSymbols = 32768
+ SingleInheritance = 65536
+ MultipleInheritance = 65536
+ VirtualInheritance = 65536
+ IntroducedVirtual = 262144
+ BitField = 524288
+ NoReturn = 1048576
+ TypePassByValue = 4194304
+ TypePassByReference = 8388608
+ EnumClass = 16777216
+ Thunk = 33554432
+ NonTrivial = 67108864
+ BigEndian = 134217728
+ LittleEndian = 268435456
+ AllCallsDescribed = 536870912
+
+ def __iter__(self):
+ return iter([case for case in type(self) if (self & case) is case])
+
+ def __len__(self):
+ return bin(self).count("1")
+
+ def __str__(self):
+ if len(self) > 1:
+ return "|".join(map(str, self))
+ if self is DIFlags.Zero:
+ return "Zero"
+ if self is DIFlags.Bit0:
+ return "Bit0"
+ if self is DIFlags.Bit1:
+ return "Bit1"
+ if self is DIFlags.Private:
+ return "Private"
+ if self is DIFlags.Protected:
+ return "Protected"
+ if self is DIFlags.Public:
+ return "Public"
+ if self is DIFlags.FwdDecl:
+ return "FwdDecl"
+ if self is DIFlags.AppleBlock:
+ return "AppleBlock"
+ if self is DIFlags.ReservedBit4:
+ return "ReservedBit4"
+ if self is DIFlags.Virtual:
+ return "Virtual"
+ if self is DIFlags.Artificial:
+ return "Artificial"
+ if self is DIFlags.Explicit:
+ return "Explicit"
+ if self is DIFlags.Prototyped:
+ return "Prototyped"
+ if self is DIFlags.ObjcClassComplete:
+ return "ObjcClassComplete"
+ if self is DIFlags.ObjectPointer:
+ return "ObjectPointer"
+ if self is DIFlags.Vector:
+ return "Vector"
+ if self is DIFlags.StaticMember:
+ return "StaticMember"
+ if self is DIFlags.LValueReference:
+ return "LValueReference"
+ if self is DIFlags.RValueReference:
+ return "RValueReference"
+ if self is DIFlags.ExportSymbols:
+ return "ExportSymbols"
+ if self is DIFlags.SingleInheritance:
+ return "SingleInheritance"
+ if self is DIFlags.MultipleInheritance:
+ return "MultipleInheritance"
+ if self is DIFlags.VirtualInheritance:
+ return "VirtualInheritance"
+ if self is DIFlags.IntroducedVirtual:
+ return "IntroducedVirtual"
+ if self is DIFlags.BitField:
+ return "BitField"
+ if self is DIFlags.NoReturn:
+ return "NoReturn"
+ if self is DIFlags.TypePassByValue:
+ return "TypePassByValue"
+ if self is DIFlags.TypePassByReference:
+ return "TypePassByReference"
+ if self is DIFlags.EnumClass:
+ return "EnumClass"
+ if self is DIFlags.Thunk:
+ return "Thunk"
+ if self is DIFlags.NonTrivial:
+ return "NonTrivial"
+ if self is DIFlags.BigEndian:
+ return "BigEndian"
+ if self is DIFlags.LittleEndian:
+ return "LittleEndian"
+ if self is DIFlags.AllCallsDescribed:
+ return "AllCallsDescribed"
+ raise ValueError("Unknown DIFlags enum entry.")
+
+
+ at register_attribute_builder("DIFlags")
+def _diflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class DISubprogramFlags(IntFlag):
+ """LLVM DISubprogram flags"""
+
+ Virtual = 1
+ PureVirtual = 2
+ LocalToUnit = 4
+ Definition = 8
+ Optimized = 16
+ Pure = 32
+ Elemental = 64
+ Recursive = 128
+ MainSubprogram = 256
+ Deleted = 512
+ ObjCDirect = 2048
+
+ def __iter__(self):
+ return iter([case for case in type(self) if (self & case) is case])
+
+ def __len__(self):
+ return bin(self).count("1")
+
+ def __str__(self):
+ if len(self) > 1:
+ return "|".join(map(str, self))
+ if self is DISubprogramFlags.Virtual:
+ return "Virtual"
+ if self is DISubprogramFlags.PureVirtual:
+ return "PureVirtual"
+ if self is DISubprogramFlags.LocalToUnit:
+ return "LocalToUnit"
+ if self is DISubprogramFlags.Definition:
+ return "Definition"
+ if self is DISubprogramFlags.Optimized:
+ return "Optimized"
+ if self is DISubprogramFlags.Pure:
+ return "Pure"
+ if self is DISubprogramFlags.Elemental:
+ return "Elemental"
+ if self is DISubprogramFlags.Recursive:
+ return "Recursive"
+ if self is DISubprogramFlags.MainSubprogram:
+ return "MainSubprogram"
+ if self is DISubprogramFlags.Deleted:
+ return "Deleted"
+ if self is DISubprogramFlags.ObjCDirect:
+ return "ObjCDirect"
+ raise ValueError("Unknown DISubprogramFlags enum entry.")
+
+
+ at register_attribute_builder("DISubprogramFlags")
+def _disubprogramflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class FCmpPredicate(IntEnum):
+ """llvm.fcmp comparison predicate"""
+
+ _false = 0
+ oeq = 1
+ ogt = 2
+ oge = 3
+ olt = 4
+ ole = 5
+ one = 6
+ ord = 7
+ ueq = 8
+ ugt = 9
+ uge = 10
+ ult = 11
+ ule = 12
+ une = 13
+ uno = 14
+ _true = 15
+
+ def __str__(self):
+ if self is FCmpPredicate._false:
+ return "_false"
+ if self is FCmpPredicate.oeq:
+ return "oeq"
+ if self is FCmpPredicate.ogt:
+ return "ogt"
+ if self is FCmpPredicate.oge:
+ return "oge"
+ if self is FCmpPredicate.olt:
+ return "olt"
+ if self is FCmpPredicate.ole:
+ return "ole"
+ if self is FCmpPredicate.one:
+ return "one"
+ if self is FCmpPredicate.ord:
+ return "ord"
+ if self is FCmpPredicate.ueq:
+ return "ueq"
+ if self is FCmpPredicate.ugt:
+ return "ugt"
+ if self is FCmpPredicate.uge:
+ return "uge"
+ if self is FCmpPredicate.ult:
+ return "ult"
+ if self is FCmpPredicate.ule:
+ return "ule"
+ if self is FCmpPredicate.une:
+ return "une"
+ if self is FCmpPredicate.uno:
+ return "uno"
+ if self is FCmpPredicate._true:
+ return "_true"
+ raise ValueError("Unknown FCmpPredicate enum entry.")
+
+
+ at register_attribute_builder("FCmpPredicate")
+def _fcmppredicate(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class FPExceptionBehavior(IntEnum):
+ """LLVM Exception Behavior"""
+
+ Ignore = 0
+ MayTrap = 1
+ Strict = 2
+
+ def __str__(self):
+ if self is FPExceptionBehavior.Ignore:
+ return "ignore"
+ if self is FPExceptionBehavior.MayTrap:
+ return "maytrap"
+ if self is FPExceptionBehavior.Strict:
+ return "strict"
+ raise ValueError("Unknown FPExceptionBehavior enum entry.")
+
+
+ at register_attribute_builder("FPExceptionBehaviorAttr")
+def _fpexceptionbehaviorattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class FastmathFlags(IntFlag):
+ """LLVM fastmath flags"""
+
+ none = 0
+ nnan = 1
+ ninf = 2
+ nsz = 4
+ arcp = 8
+ contract = 16
+ afn = 32
+ reassoc = 64
+ fast = 127
+
+ def __iter__(self):
+ return iter([case for case in type(self) if (self & case) is case])
+
+ def __len__(self):
+ return bin(self).count("1")
+
+ def __str__(self):
+ if len(self) > 1:
+ return ", ".join(map(str, self))
+ if self is FastmathFlags.none:
+ return "none"
+ if self is FastmathFlags.nnan:
+ return "nnan"
+ if self is FastmathFlags.ninf:
+ return "ninf"
+ if self is FastmathFlags.nsz:
+ return "nsz"
+ if self is FastmathFlags.arcp:
+ return "arcp"
+ if self is FastmathFlags.contract:
+ return "contract"
+ if self is FastmathFlags.afn:
+ return "afn"
+ if self is FastmathFlags.reassoc:
+ return "reassoc"
+ if self is FastmathFlags.fast:
+ return "fast"
+ raise ValueError("Unknown FastmathFlags enum entry.")
+
+
+ at register_attribute_builder("FastmathFlags")
+def _fastmathflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class FramePointerKind(IntEnum):
+ """LLVM FramePointerKind"""
+
+ None_ = 0
+ NonLeaf = 1
+ All = 2
+ Reserved = 3
+
+ def __str__(self):
+ if self is FramePointerKind.None_:
+ return "none"
+ if self is FramePointerKind.NonLeaf:
+ return "non-leaf"
+ if self is FramePointerKind.All:
+ return "all"
+ if self is FramePointerKind.Reserved:
+ return "reserved"
+ raise ValueError("Unknown FramePointerKind enum entry.")
+
+
+ at register_attribute_builder("FramePointerKindEnum")
+def _framepointerkindenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class ICmpPredicate(IntEnum):
+ """lvm.icmp comparison predicate"""
+
+ eq = 0
+ ne = 1
+ slt = 2
+ sle = 3
+ sgt = 4
+ sge = 5
+ ult = 6
+ ule = 7
+ ugt = 8
+ uge = 9
+
+ def __str__(self):
+ if self is ICmpPredicate.eq:
+ return "eq"
+ if self is ICmpPredicate.ne:
+ return "ne"
+ if self is ICmpPredicate.slt:
+ return "slt"
+ if self is ICmpPredicate.sle:
+ return "sle"
+ if self is ICmpPredicate.sgt:
+ return "sgt"
+ if self is ICmpPredicate.sge:
+ return "sge"
+ if self is ICmpPredicate.ult:
+ return "ult"
+ if self is ICmpPredicate.ule:
+ return "ule"
+ if self is ICmpPredicate.ugt:
+ return "ugt"
+ if self is ICmpPredicate.uge:
+ return "uge"
+ raise ValueError("Unknown ICmpPredicate enum entry.")
+
+
+ at register_attribute_builder("ICmpPredicate")
+def _icmppredicate(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class IntegerOverflowFlags(IntFlag):
+ """LLVM integer overflow flags"""
+
+ none = 0
+ nsw = 1
+ nuw = 2
+
+ def __iter__(self):
+ return iter([case for case in type(self) if (self & case) is case])
+
+ def __len__(self):
+ return bin(self).count("1")
+
+ def __str__(self):
+ if len(self) > 1:
+ return ", ".join(map(str, self))
+ if self is IntegerOverflowFlags.none:
+ return "none"
+ if self is IntegerOverflowFlags.nsw:
+ return "nsw"
+ if self is IntegerOverflowFlags.nuw:
+ return "nuw"
+ raise ValueError("Unknown IntegerOverflowFlags enum entry.")
+
+
+ at register_attribute_builder("IntegerOverflowFlags")
+def _integeroverflowflags(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class DIEmissionKind(IntEnum):
+ """LLVM debug emission kind"""
+
+ None_ = 0
+ Full = 1
+ LineTablesOnly = 2
+ DebugDirectivesOnly = 3
+
+ def __str__(self):
+ if self is DIEmissionKind.None_:
+ return "None"
+ if self is DIEmissionKind.Full:
+ return "Full"
+ if self is DIEmissionKind.LineTablesOnly:
+ return "LineTablesOnly"
+ if self is DIEmissionKind.DebugDirectivesOnly:
+ return "DebugDirectivesOnly"
+ raise ValueError("Unknown DIEmissionKind enum entry.")
+
+
+ at register_attribute_builder("LLVM_DIEmissionKind")
+def _llvm_diemissionkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class DINameTableKind(IntEnum):
+ """LLVM debug name table kind"""
+
+ Default = 0
+ GNU = 1
+ None_ = 2
+ Apple = 3
+
+ def __str__(self):
+ if self is DINameTableKind.Default:
+ return "Default"
+ if self is DINameTableKind.GNU:
+ return "GNU"
+ if self is DINameTableKind.None_:
+ return "None"
+ if self is DINameTableKind.Apple:
+ return "Apple"
+ raise ValueError("Unknown DINameTableKind enum entry.")
+
+
+ at register_attribute_builder("LLVM_DINameTableKind")
+def _llvm_dinametablekind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class Linkage(IntEnum):
+ """LLVM linkage types"""
+
+ Private = 0
+ Internal = 1
+ AvailableExternally = 2
+ Linkonce = 3
+ Weak = 4
+ Common = 5
+ Appending = 6
+ ExternWeak = 7
+ LinkonceODR = 8
+ WeakODR = 9
+ External = 10
+
+ def __str__(self):
+ if self is Linkage.Private:
+ return "private"
+ if self is Linkage.Internal:
+ return "internal"
+ if self is Linkage.AvailableExternally:
+ return "available_externally"
+ if self is Linkage.Linkonce:
+ return "linkonce"
+ if self is Linkage.Weak:
+ return "weak"
+ if self is Linkage.Common:
+ return "common"
+ if self is Linkage.Appending:
+ return "appending"
+ if self is Linkage.ExternWeak:
+ return "extern_weak"
+ if self is Linkage.LinkonceODR:
+ return "linkonce_odr"
+ if self is Linkage.WeakODR:
+ return "weak_odr"
+ if self is Linkage.External:
+ return "external"
+ raise ValueError("Unknown Linkage enum entry.")
+
+
+ at register_attribute_builder("LinkageEnum")
+def _linkageenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class ModRefInfo(IntEnum):
+ """LLVM ModRefInfo"""
+
+ NoModRef = 0
+ Ref = 1
+ Mod = 2
+ ModRef = 3
+
+ def __str__(self):
+ if self is ModRefInfo.NoModRef:
+ return "none"
+ if self is ModRefInfo.Ref:
+ return "read"
+ if self is ModRefInfo.Mod:
+ return "write"
+ if self is ModRefInfo.ModRef:
+ return "readwrite"
+ raise ValueError("Unknown ModRefInfo enum entry.")
+
+
+ at register_attribute_builder("ModRefInfoEnum")
+def _modrefinfoenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class RoundingMode(IntEnum):
+ """LLVM Rounding Mode"""
+
+ TowardZero = 0
+ NearestTiesToEven = 1
+ TowardPositive = 2
+ TowardNegative = 3
+ NearestTiesToAway = 4
+ Dynamic = 7
+ Invalid = auto()
+
+ def __str__(self):
+ if self is RoundingMode.TowardZero:
+ return "towardzero"
+ if self is RoundingMode.NearestTiesToEven:
+ return "tonearest"
+ if self is RoundingMode.TowardPositive:
+ return "upward"
+ if self is RoundingMode.TowardNegative:
+ return "downward"
+ if self is RoundingMode.NearestTiesToAway:
+ return "tonearestaway"
+ if self is RoundingMode.Dynamic:
+ return "dynamic"
+ if self is RoundingMode.Invalid:
+ return "invalid"
+ raise ValueError("Unknown RoundingMode enum entry.")
+
+
+ at register_attribute_builder("RoundingModeAttr")
+def _roundingmodeattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class TailCallKind(IntEnum):
+ """Tail Call Kind"""
+
+ None_ = 0
+ NoTail = 3
+ MustTail = 2
+ Tail = 1
+
+ def __str__(self):
+ if self is TailCallKind.None_:
+ return "none"
+ if self is TailCallKind.NoTail:
+ return "notail"
+ if self is TailCallKind.MustTail:
+ return "musttail"
+ if self is TailCallKind.Tail:
+ return "tail"
+ raise ValueError("Unknown TailCallKind enum entry.")
+
+
+ at register_attribute_builder("TailCallKindEnum")
+def _tailcallkindenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class UnnamedAddr(IntEnum):
+ """LLVM GlobalValue UnnamedAddr"""
+
+ None_ = 0
+ Local = 1
+ Global = 2
+
+ def __str__(self):
+ if self is UnnamedAddr.None_:
+ return ""
+ if self is UnnamedAddr.Local:
+ return "local_unnamed_addr"
+ if self is UnnamedAddr.Global:
+ return "unnamed_addr"
+ raise ValueError("Unknown UnnamedAddr enum entry.")
+
+
+ at register_attribute_builder("UnnamedAddr")
+def _unnamedaddr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
+
+
+class Visibility(IntEnum):
+ """LLVM GlobalValue Visibility"""
+
+ Default = 0
+ Hidden = 1
+ Protected = 2
+
+ def __str__(self):
+ if self is Visibility.Default:
+ return ""
+ if self is Visibility.Hidden:
+ return "hidden"
+ if self is Visibility.Protected:
+ return "protected"
+ raise ValueError("Unknown Visibility enum entry.")
+
+
+ at register_attribute_builder("Visibility")
+def _visibility(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(64, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/nvgpu.py b/mlir/python/mlir/dialects/nvgpu.py
index d6a54f2772f40d..3b5c1c9f61c09a 100644
--- a/mlir/python/mlir/dialects/nvgpu.py
+++ b/mlir/python/mlir/dialects/nvgpu.py
@@ -1,7 +1,128 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from enum import IntEnum
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._nvgpu_ops_gen import *
from ._nvgpu_enum_gen import *
from .._mlir_libs._mlirDialectsNVGPU import *
+
+
+class RcpRoundingMode(IntEnum):
+ """Rounding mode of rcp"""
+
+ APPROX = 0
+ RN = 1
+ RZ = 2
+ RM = 3
+ RP = 4
+
+ def __str__(self):
+ if self is RcpRoundingMode.APPROX:
+ return "approx"
+ if self is RcpRoundingMode.RN:
+ return "rn"
+ if self is RcpRoundingMode.RZ:
+ return "rz"
+ if self is RcpRoundingMode.RM:
+ return "rm"
+ if self is RcpRoundingMode.RP:
+ return "rp"
+ raise ValueError("Unknown RcpRoundingMode enum entry.")
+
+
+ at register_attribute_builder("RcpRoundingMode")
+def _rcproundingmode(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class TensorMapInterleaveKind(IntEnum):
+ """Tensor map interleave layout type"""
+
+ INTERLEAVE_NONE = 0
+ INTERLEAVE_16B = 1
+ INTERLEAVE_32B = 2
+
+ def __str__(self):
+ if self is TensorMapInterleaveKind.INTERLEAVE_NONE:
+ return "none"
+ if self is TensorMapInterleaveKind.INTERLEAVE_16B:
+ return "interleave_16b"
+ if self is TensorMapInterleaveKind.INTERLEAVE_32B:
+ return "interleave_32b"
+ raise ValueError("Unknown TensorMapInterleaveKind enum entry.")
+
+
+ at register_attribute_builder("TensorMapInterleaveKind")
+def _tensormapinterleavekind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class TensorMapL2PromoKind(IntEnum):
+ """Tensor map L2 promotion type"""
+
+ L2PROMO_NONE = 0
+ L2PROMO_64B = 1
+ L2PROMO_128B = 2
+ L2PROMO_256B = 3
+
+ def __str__(self):
+ if self is TensorMapL2PromoKind.L2PROMO_NONE:
+ return "none"
+ if self is TensorMapL2PromoKind.L2PROMO_64B:
+ return "l2promo_64b"
+ if self is TensorMapL2PromoKind.L2PROMO_128B:
+ return "l2promo_128b"
+ if self is TensorMapL2PromoKind.L2PROMO_256B:
+ return "l2promo_256b"
+ raise ValueError("Unknown TensorMapL2PromoKind enum entry.")
+
+
+ at register_attribute_builder("TensorMapL2PromoKind")
+def _tensormapl2promokind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class TensorMapOOBKind(IntEnum):
+ """Tensor map out-of-bounds fill type"""
+
+ OOB_ZERO = 0
+ OOB_NAN = 1
+
+ def __str__(self):
+ if self is TensorMapOOBKind.OOB_ZERO:
+ return "zero"
+ if self is TensorMapOOBKind.OOB_NAN:
+ return "nan"
+ raise ValueError("Unknown TensorMapOOBKind enum entry.")
+
+
+ at register_attribute_builder("TensorMapOOBKind")
+def _tensormapoobkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class TensorMapSwizzleKind(IntEnum):
+ """Tensor map swizzling mode of shared memory banks"""
+
+ SWIZZLE_NONE = 0
+ SWIZZLE_32B = 1
+ SWIZZLE_64B = 2
+ SWIZZLE_128B = 3
+
+ def __str__(self):
+ if self is TensorMapSwizzleKind.SWIZZLE_NONE:
+ return "none"
+ if self is TensorMapSwizzleKind.SWIZZLE_32B:
+ return "swizzle_32b"
+ if self is TensorMapSwizzleKind.SWIZZLE_64B:
+ return "swizzle_64b"
+ if self is TensorMapSwizzleKind.SWIZZLE_128B:
+ return "swizzle_128b"
+ raise ValueError("Unknown TensorMapSwizzleKind enum entry.")
+
+
+ at register_attribute_builder("TensorMapSwizzleKind")
+def _tensormapswizzlekind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/nvvm.py b/mlir/python/mlir/dialects/nvvm.py
index 9477de39c9ead7..640d82ae9d753b 100644
--- a/mlir/python/mlir/dialects/nvvm.py
+++ b/mlir/python/mlir/dialects/nvvm.py
@@ -1,6 +1,401 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from enum import IntEnum, auto
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._nvvm_ops_gen import *
from ._nvvm_enum_gen import *
+
+
+class LoadCacheModifierKind(IntEnum):
+ """NVVM load cache modifier kind"""
+
+ CA = 0
+ CG = 1
+ CS = 2
+ LU = 3
+ CV = 4
+
+ def __str__(self):
+ if self is LoadCacheModifierKind.CA:
+ return "ca"
+ if self is LoadCacheModifierKind.CG:
+ return "cg"
+ if self is LoadCacheModifierKind.CS:
+ return "cs"
+ if self is LoadCacheModifierKind.LU:
+ return "lu"
+ if self is LoadCacheModifierKind.CV:
+ return "cv"
+ raise ValueError("Unknown LoadCacheModifierKind enum entry.")
+
+
+ at register_attribute_builder("LoadCacheModifierKind")
+def _loadcachemodifierkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class MMAB1Op(IntEnum):
+ """MMA binary operations"""
+
+ none = 0
+ xor_popc = 1
+ and_popc = 2
+
+ def __str__(self):
+ if self is MMAB1Op.none:
+ return "none"
+ if self is MMAB1Op.xor_popc:
+ return "xor_popc"
+ if self is MMAB1Op.and_popc:
+ return "and_popc"
+ raise ValueError("Unknown MMAB1Op enum entry.")
+
+
+ at register_attribute_builder("MMAB1Op")
+def _mmab1op(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class MMAFrag(IntEnum):
+ """NVVM MMA frag type"""
+
+ a = 0
+ b = 1
+ c = 2
+
+ def __str__(self):
+ if self is MMAFrag.a:
+ return "a"
+ if self is MMAFrag.b:
+ return "b"
+ if self is MMAFrag.c:
+ return "c"
+ raise ValueError("Unknown MMAFrag enum entry.")
+
+
+ at register_attribute_builder("MMAFrag")
+def _mmafrag(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class MMAIntOverflow(IntEnum):
+ """MMA overflow options"""
+
+ satfinite = 1
+ wrapped = 0
+
+ def __str__(self):
+ if self is MMAIntOverflow.satfinite:
+ return "satfinite"
+ if self is MMAIntOverflow.wrapped:
+ return "wrapped"
+ raise ValueError("Unknown MMAIntOverflow enum entry.")
+
+
+ at register_attribute_builder("MMAIntOverflow")
+def _mmaintoverflow(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class MMALayout(IntEnum):
+ """NVVM MMA layout"""
+
+ row = 0
+ col = 1
+
+ def __str__(self):
+ if self is MMALayout.row:
+ return "row"
+ if self is MMALayout.col:
+ return "col"
+ raise ValueError("Unknown MMALayout enum entry.")
+
+
+ at register_attribute_builder("MMALayout")
+def _mmalayout(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class MMATypes(IntEnum):
+ """NVVM MMA types"""
+
+ f16 = 0
+ f32 = 1
+ tf32 = 2
+ bf16 = 9
+ s8 = 4
+ u8 = 3
+ s32 = 5
+ s4 = 8
+ u4 = 7
+ b1 = 6
+ f64 = 10
+
+ def __str__(self):
+ if self is MMATypes.f16:
+ return "f16"
+ if self is MMATypes.f32:
+ return "f32"
+ if self is MMATypes.tf32:
+ return "tf32"
+ if self is MMATypes.bf16:
+ return "bf16"
+ if self is MMATypes.s8:
+ return "s8"
+ if self is MMATypes.u8:
+ return "u8"
+ if self is MMATypes.s32:
+ return "s32"
+ if self is MMATypes.s4:
+ return "s4"
+ if self is MMATypes.u4:
+ return "u4"
+ if self is MMATypes.b1:
+ return "b1"
+ if self is MMATypes.f64:
+ return "f64"
+ raise ValueError("Unknown MMATypes enum entry.")
+
+
+ at register_attribute_builder("MMATypes")
+def _mmatypes(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class MemScopeKind(IntEnum):
+ """NVVM Memory Scope kind"""
+
+ CTA = 0
+ CLUSTER = 1
+ GPU = 2
+ SYS = 3
+
+ def __str__(self):
+ if self is MemScopeKind.CTA:
+ return "cta"
+ if self is MemScopeKind.CLUSTER:
+ return "cluster"
+ if self is MemScopeKind.GPU:
+ return "gpu"
+ if self is MemScopeKind.SYS:
+ return "sys"
+ raise ValueError("Unknown MemScopeKind enum entry.")
+
+
+ at register_attribute_builder("MemScopeKind")
+def _memscopekind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class ProxyKind(IntEnum):
+ """Proxy kind"""
+
+ alias = 0
+ async_ = 1
+ async_global = 2
+ async_shared = 3
+ TENSORMAP = 4
+ GENERIC = 5
+
+ def __str__(self):
+ if self is ProxyKind.alias:
+ return "alias"
+ if self is ProxyKind.async_:
+ return "async"
+ if self is ProxyKind.async_global:
+ return "async.global"
+ if self is ProxyKind.async_shared:
+ return "async.shared"
+ if self is ProxyKind.TENSORMAP:
+ return "tensormap"
+ if self is ProxyKind.GENERIC:
+ return "generic"
+ raise ValueError("Unknown ProxyKind enum entry.")
+
+
+ at register_attribute_builder("ProxyKind")
+def _proxykind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class ReduxKind(IntEnum):
+ """NVVM redux kind"""
+
+ ADD = 1
+ AND = 2
+ MAX = 3
+ MIN = 4
+ OR = 5
+ UMAX = 6
+ UMIN = 7
+ XOR = 8
+
+ def __str__(self):
+ if self is ReduxKind.ADD:
+ return "add"
+ if self is ReduxKind.AND:
+ return "and"
+ if self is ReduxKind.MAX:
+ return "max"
+ if self is ReduxKind.MIN:
+ return "min"
+ if self is ReduxKind.OR:
+ return "or"
+ if self is ReduxKind.UMAX:
+ return "umax"
+ if self is ReduxKind.UMIN:
+ return "umin"
+ if self is ReduxKind.XOR:
+ return "xor"
+ raise ValueError("Unknown ReduxKind enum entry.")
+
+
+ at register_attribute_builder("ReduxKind")
+def _reduxkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class SetMaxRegisterAction(IntEnum):
+ """NVVM set max register action"""
+
+ decrease = 1
+ increase = 0
+
+ def __str__(self):
+ if self is SetMaxRegisterAction.decrease:
+ return "decrease"
+ if self is SetMaxRegisterAction.increase:
+ return "increase"
+ raise ValueError("Unknown SetMaxRegisterAction enum entry.")
+
+
+ at register_attribute_builder("SetMaxRegisterAction")
+def _setmaxregisteraction(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class SharedSpace(IntEnum):
+ """Shared memory space"""
+
+ shared_cta = 0
+ shared_cluster = 1
+
+ def __str__(self):
+ if self is SharedSpace.shared_cta:
+ return "cta"
+ if self is SharedSpace.shared_cluster:
+ return "cluster"
+ raise ValueError("Unknown SharedSpace enum entry.")
+
+
+ at register_attribute_builder("SharedSpace")
+def _sharedspace(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class ShflKind(IntEnum):
+ """NVVM shuffle kind"""
+
+ bfly = 0
+ up = 1
+ down = 2
+ idx = 3
+
+ def __str__(self):
+ if self is ShflKind.bfly:
+ return "bfly"
+ if self is ShflKind.up:
+ return "up"
+ if self is ShflKind.down:
+ return "down"
+ if self is ShflKind.idx:
+ return "idx"
+ raise ValueError("Unknown ShflKind enum entry.")
+
+
+ at register_attribute_builder("ShflKind")
+def _shflkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class WGMMAScaleIn(IntEnum):
+ """WGMMA overflow options"""
+
+ one = 1
+ neg = auto()
+
+ def __str__(self):
+ if self is WGMMAScaleIn.one:
+ return "one"
+ if self is WGMMAScaleIn.neg:
+ return "neg"
+ raise ValueError("Unknown WGMMAScaleIn enum entry.")
+
+
+ at register_attribute_builder("WGMMAScaleIn")
+def _wgmmascalein(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class WGMMAScaleOut(IntEnum):
+ """WGMMA input predicate"""
+
+ zero = 0
+ one = 1
+
+ def __str__(self):
+ if self is WGMMAScaleOut.zero:
+ return "zero"
+ if self is WGMMAScaleOut.one:
+ return "one"
+ raise ValueError("Unknown WGMMAScaleOut enum entry.")
+
+
+ at register_attribute_builder("WGMMAScaleOut")
+def _wgmmascaleout(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class WGMMATypes(IntEnum):
+ """NVVM WGMMA types"""
+
+ f16 = 0
+ tf32 = 1
+ u8 = 2
+ s8 = 3
+ b1 = 4
+ bf16 = 5
+ e4m3 = 6
+ e5m2 = 7
+ f32 = 8
+ s32 = 9
+
+ def __str__(self):
+ if self is WGMMATypes.f16:
+ return "f16"
+ if self is WGMMATypes.tf32:
+ return "tf32"
+ if self is WGMMATypes.u8:
+ return "u8"
+ if self is WGMMATypes.s8:
+ return "s8"
+ if self is WGMMATypes.b1:
+ return "b1"
+ if self is WGMMATypes.bf16:
+ return "bf16"
+ if self is WGMMATypes.e4m3:
+ return "e4m3"
+ if self is WGMMATypes.e5m2:
+ return "e5m2"
+ if self is WGMMATypes.f32:
+ return "f32"
+ if self is WGMMATypes.s32:
+ return "s32"
+ raise ValueError("Unknown WGMMATypes enum entry.")
+
+
+ at register_attribute_builder("WGMMATypes")
+def _wgmmatypes(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py
index 209ecc95fa8fc8..6a92e4e08cc52b 100644
--- a/mlir/python/mlir/dialects/sparse_tensor.py
+++ b/mlir/python/mlir/dialects/sparse_tensor.py
@@ -1,8 +1,84 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from enum import IntEnum
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._sparse_tensor_ops_gen import *
from ._sparse_tensor_enum_gen import *
from .._mlir_libs._mlirDialectsSparseTensor import *
-from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses
+
+
+class CrdTransDirectionKind(IntEnum):
+ """sparse tensor coordinate translation direction"""
+
+ dim2lvl = 0
+ lvl2dim = 1
+
+ def __str__(self):
+ if self is CrdTransDirectionKind.dim2lvl:
+ return "dim_to_lvl"
+ if self is CrdTransDirectionKind.lvl2dim:
+ return "lvl_to_dim"
+ raise ValueError("Unknown CrdTransDirectionKind enum entry.")
+
+
+ at register_attribute_builder("SparseTensorCrdTransDirectionEnum")
+def _sparsetensorcrdtransdirectionenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class SparseTensorSortKind(IntEnum):
+ """sparse tensor sort algorithm"""
+
+ HybridQuickSort = 0
+ InsertionSortStable = 1
+ QuickSort = 2
+ HeapSort = 3
+
+ def __str__(self):
+ if self is SparseTensorSortKind.HybridQuickSort:
+ return "hybrid_quick_sort"
+ if self is SparseTensorSortKind.InsertionSortStable:
+ return "insertion_sort_stable"
+ if self is SparseTensorSortKind.QuickSort:
+ return "quick_sort"
+ if self is SparseTensorSortKind.HeapSort:
+ return "heap_sort"
+ raise ValueError("Unknown SparseTensorSortKind enum entry.")
+
+
+ at register_attribute_builder("SparseTensorSortKindEnum")
+def _sparsetensorsortkindenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class StorageSpecifierKind(IntEnum):
+ """sparse tensor storage specifier kind"""
+
+ LvlSize = 0
+ PosMemSize = 1
+ CrdMemSize = 2
+ ValMemSize = 3
+ DimOffset = 4
+ DimStride = 5
+
+ def __str__(self):
+ if self is StorageSpecifierKind.LvlSize:
+ return "lvl_sz"
+ if self is StorageSpecifierKind.PosMemSize:
+ return "pos_mem_sz"
+ if self is StorageSpecifierKind.CrdMemSize:
+ return "crd_mem_sz"
+ if self is StorageSpecifierKind.ValMemSize:
+ return "val_mem_sz"
+ if self is StorageSpecifierKind.DimOffset:
+ return "dim_offset"
+ if self is StorageSpecifierKind.DimStride:
+ return "dim_stride"
+ raise ValueError("Unknown StorageSpecifierKind enum entry.")
+
+
+ at register_attribute_builder("SparseTensorStorageSpecifierKindEnum")
+def _sparsetensorstoragespecifierkindenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 5b158ec6b65fdd..fbe5dcd03403f8 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -1,8 +1,8 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from enum import IntEnum
-from .._transform_enum_gen import *
from .._transform_ops_gen import *
from .._transform_ops_gen import _Dialect
from ..._mlir_libs._mlirDialectsTransform import *
@@ -219,3 +219,53 @@ def __init__(
def any_op_t() -> AnyOpTypeT:
return AnyOpTypeT(AnyOpType.get())
+
+
+class FailurePropagationMode(IntEnum):
+ """Silenceable error propagation policy"""
+
+ Propagate = 1
+ Suppress = 2
+
+ def __str__(self):
+ if self is FailurePropagationMode.Propagate:
+ return "propagate"
+ if self is FailurePropagationMode.Suppress:
+ return "suppress"
+ raise ValueError("Unknown FailurePropagationMode enum entry.")
+
+
+ at register_attribute_builder("FailurePropagationMode")
+def _failurepropagationmode(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class MatchCmpIPredicate(IntEnum):
+ """allowed 32-bit signless integer cases: 0, 1, 2, 3, 4, 5"""
+
+ eq = 0
+ ne = 1
+ lt = 2
+ le = 3
+ gt = 4
+ ge = 5
+
+ def __str__(self):
+ if self is MatchCmpIPredicate.eq:
+ return "eq"
+ if self is MatchCmpIPredicate.ne:
+ return "ne"
+ if self is MatchCmpIPredicate.lt:
+ return "lt"
+ if self is MatchCmpIPredicate.le:
+ return "le"
+ if self is MatchCmpIPredicate.gt:
+ return "gt"
+ if self is MatchCmpIPredicate.ge:
+ return "ge"
+ raise ValueError("Unknown MatchCmpIPredicate enum entry.")
+
+
+ at register_attribute_builder("MatchCmpIPredicateAttr")
+def _matchcmpipredicateattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 9121aa8e40237b..cd2dc6e2b114ab 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -1,10 +1,10 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from enum import IntEnum
from .._structured_transform_ops_gen import *
from .._structured_transform_ops_gen import _Dialect
-from .._structured_transform_enum_gen import *
try:
from ...ir import *
@@ -648,3 +648,44 @@ def __init__(
loc=loc,
ip=ip,
)
+
+
+class MatchInterfaceEnum(IntEnum):
+ """An interface to match"""
+
+ LinalgOp = 0
+ TilingInterface = 1
+ LoopLikeInterface = 2
+
+ def __str__(self):
+ if self is MatchInterfaceEnum.LinalgOp:
+ return "LinalgOp"
+ if self is MatchInterfaceEnum.TilingInterface:
+ return "TilingInterface"
+ if self is MatchInterfaceEnum.LoopLikeInterface:
+ return "LoopLikeInterface"
+ raise ValueError("Unknown MatchInterfaceEnum enum entry.")
+
+
+ at register_attribute_builder("MatchInterfaceEnum")
+def _matchinterfaceenum(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class TransposeMatmulInput(IntEnum):
+ """Input to transpose when converting matmul ops to transposed variants"""
+
+ lhs = 0
+ rhs = 1
+
+ def __str__(self):
+ if self is TransposeMatmulInput.lhs:
+ return "lhs"
+ if self is TransposeMatmulInput.rhs:
+ return "rhs"
+ raise ValueError("Unknown TransposeMatmulInput enum entry.")
+
+
+ at register_attribute_builder("TransposeMatmulInput")
+def _transposematmulinput(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/transform/vector.py b/mlir/python/mlir/dialects/transform/vector.py
index af2435cb26cc4b..35c7d847506923 100644
--- a/mlir/python/mlir/dialects/transform/vector.py
+++ b/mlir/python/mlir/dialects/transform/vector.py
@@ -1,6 +1,102 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from enum import IntEnum
+from ...ir import IntegerAttr, IntegerType, register_attribute_builder
from .._vector_transform_enum_gen import *
from .._vector_transform_ops_gen import *
+
+
+class VectorContractLowering(IntEnum):
+ """control the lowering of `vector.contract` operations."""
+
+ Dot = 0
+ Matmul = 1
+ OuterProduct = 2
+ ParallelArith = 3
+
+ def __str__(self):
+ if self is VectorContractLowering.Dot:
+ return "dot"
+ if self is VectorContractLowering.Matmul:
+ return "matmulintrinsics"
+ if self is VectorContractLowering.OuterProduct:
+ return "outerproduct"
+ if self is VectorContractLowering.ParallelArith:
+ return "parallelarith"
+ raise ValueError("Unknown VectorContractLowering enum entry.")
+
+
+ at register_attribute_builder("VectorContractLoweringAttr")
+def _vectorcontractloweringattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class VectorMultiReductionLowering(IntEnum):
+ """control the lowering of `vector.multi_reduction`."""
+
+ InnerParallel = 0
+ InnerReduction = 1
+
+ def __str__(self):
+ if self is VectorMultiReductionLowering.InnerParallel:
+ return "innerparallel"
+ if self is VectorMultiReductionLowering.InnerReduction:
+ return "innerreduction"
+ raise ValueError("Unknown VectorMultiReductionLowering enum entry.")
+
+
+ at register_attribute_builder("VectorMultiReductionLoweringAttr")
+def _vectormultireductionloweringattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class VectorTransferSplit(IntEnum):
+ """control the splitting of `vector.transfer` operations into in-bounds and out-of-bounds variants."""
+
+ None_ = 0
+ VectorTransfer = 1
+ LinalgCopy = 2
+ ForceInBounds = 3
+
+ def __str__(self):
+ if self is VectorTransferSplit.None_:
+ return "none"
+ if self is VectorTransferSplit.VectorTransfer:
+ return "vector-transfer"
+ if self is VectorTransferSplit.LinalgCopy:
+ return "linalg-copy"
+ if self is VectorTransferSplit.ForceInBounds:
+ return "force-in-bounds"
+ raise ValueError("Unknown VectorTransferSplit enum entry.")
+
+
+ at register_attribute_builder("VectorTransferSplitAttr")
+def _vectortransfersplitattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class VectorTransposeLowering(IntEnum):
+ """control the lowering of `vector.transpose` operations."""
+
+ EltWise = 0
+ Flat = 1
+ Shuffle1D = 2
+ Shuffle16x16 = 3
+
+ def __str__(self):
+ if self is VectorTransposeLowering.EltWise:
+ return "eltwise"
+ if self is VectorTransposeLowering.Flat:
+ return "flat_transpose"
+ if self is VectorTransposeLowering.Shuffle1D:
+ return "shuffle_1d"
+ if self is VectorTransposeLowering.Shuffle16x16:
+ return "shuffle_16x16"
+ raise ValueError("Unknown VectorTransposeLowering enum entry.")
+
+
+ at register_attribute_builder("VectorTransposeLoweringAttr")
+def _vectortransposeloweringattr(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/python/mlir/dialects/vector.py b/mlir/python/mlir/dialects/vector.py
index 7384e9a5aeef29..603cce6b9daf4f 100644
--- a/mlir/python/mlir/dialects/vector.py
+++ b/mlir/python/mlir/dialects/vector.py
@@ -1,6 +1,107 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from enum import IntEnum
+from ..ir import IntegerAttr, IntegerType, register_attribute_builder
from ._vector_ops_gen import *
from ._vector_enum_gen import *
+
+
+class CombiningKind(IntEnum):
+ """Kind of combining function for contractions and reductions"""
+
+ ADD = 0
+ MUL = 1
+ MINUI = 2
+ MINSI = 3
+ MINNUMF = 4
+ MAXUI = 5
+ MAXSI = 6
+ MAXNUMF = 7
+ AND = 8
+ OR = 9
+ XOR = 10
+ MAXIMUMF = 12
+ MINIMUMF = 11
+
+ def __str__(self):
+ if self is CombiningKind.ADD:
+ return "add"
+ if self is CombiningKind.MUL:
+ return "mul"
+ if self is CombiningKind.MINUI:
+ return "minui"
+ if self is CombiningKind.MINSI:
+ return "minsi"
+ if self is CombiningKind.MINNUMF:
+ return "minnumf"
+ if self is CombiningKind.MAXUI:
+ return "maxui"
+ if self is CombiningKind.MAXSI:
+ return "maxsi"
+ if self is CombiningKind.MAXNUMF:
+ return "maxnumf"
+ if self is CombiningKind.AND:
+ return "and"
+ if self is CombiningKind.OR:
+ return "or"
+ if self is CombiningKind.XOR:
+ return "xor"
+ if self is CombiningKind.MAXIMUMF:
+ return "maximumf"
+ if self is CombiningKind.MINIMUMF:
+ return "minimumf"
+ raise ValueError("Unknown CombiningKind enum entry.")
+
+
+ at register_attribute_builder("CombiningKind")
+def _combiningkind(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class PrintPunctuation(IntEnum):
+ """Punctuation for separating vectors or vector elements"""
+
+ NoPunctuation = 0
+ NewLine = 1
+ Comma = 2
+ Open = 3
+ Close = 4
+
+ def __str__(self):
+ if self is PrintPunctuation.NoPunctuation:
+ return "no_punctuation"
+ if self is PrintPunctuation.NewLine:
+ return "newline"
+ if self is PrintPunctuation.Comma:
+ return "comma"
+ if self is PrintPunctuation.Open:
+ return "open"
+ if self is PrintPunctuation.Close:
+ return "close"
+ raise ValueError("Unknown PrintPunctuation enum entry.")
+
+
+ at register_attribute_builder("PrintPunctuation")
+def _printpunctuation(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
+
+
+class IteratorType(IntEnum):
+ """Iterator type"""
+
+ parallel = 0
+ reduction = 1
+
+ def __str__(self):
+ if self is IteratorType.parallel:
+ return "parallel"
+ if self is IteratorType.reduction:
+ return "reduction"
+ raise ValueError("Unknown IteratorType enum entry.")
+
+
+ at register_attribute_builder("Vector_IteratorType")
+def _vector_iteratortype(x, context):
+ return IntegerAttr.get(IntegerType.get_signless(32, context=context), int(x))
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index 3f660ae151c749..3628b924b93167 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -16,6 +16,8 @@
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/GenInfo.h"
+
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Record.h"
@@ -29,13 +31,14 @@ using llvm::RecordKeeper;
constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
-from enum import IntEnum, auto, IntFlag
from ._ods_common import _cext as _ods_cext
from ..ir import register_attribute_builder
_ods_ir = _ods_cext.ir
)Py";
+extern llvm::cl::opt<std::string> clDialectName;
+
/// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
static std::string makePythonEnumCaseName(StringRef name) {
if (isPythonReserved(name.str()))
@@ -134,15 +137,20 @@ static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
/// `false` on success, `true` on failure.
static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
os << fileHeader;
- for (const Record *it :
- records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
- EnumAttr enumAttr(*it);
- emitEnumClass(enumAttr, os);
- emitAttributeBuilder(enumAttr, os);
+ if (clDialectName.empty()) {
+ for (const Record *it :
+ records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
+ EnumAttr enumAttr(*it);
+ emitEnumClass(enumAttr, os);
+ emitAttributeBuilder(enumAttr, os);
+ }
}
for (const Record *it :
records.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
AttrOrTypeDef attr(&*it);
+ StringRef dialect = attr.getDialect().getName();
+ if (!clDialectName.empty() && dialect != clDialectName)
+ continue;
if (!attr.getMnemonic()) {
llvm::errs() << "enum case " << attr
<< " needs mnemonic for python enum bindings generation";
@@ -150,7 +158,6 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
}
StringRef mnemonic = attr.getMnemonic().value();
std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
- StringRef dialect = attr.getDialect().getName();
if (assemblyFormat == "`<` $value `>`") {
emitDialectEnumAttributeBuilder(
attr.getName(),
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 5019b69d91127e..fb6e87ee325b58 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -283,7 +283,7 @@ def {0}({2}) -> {4}:
static llvm::cl::OptionCategory
clOpPythonBindingCat("Options for -gen-python-op-bindings");
-static llvm::cl::opt<std::string>
+llvm::cl::opt<std::string>
clDialectName("bind-dialect",
llvm::cl::desc("The dialect to run the generator for"),
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
More information about the Mlir-commits
mailing list