[Mlir-commits] [mlir] [mlir][nvgpu] Simplify TMA IR generation (PR #87153)

Maksim Levental llvmlistbot at llvm.org
Fri Apr 12 12:10:23 PDT 2024


================
@@ -28,6 +28,43 @@
 DEBUG = False
 
 
+class TmaDescriptorBuilder:
+    """A class that builds a TMA descriptor."""
+
+    def __init__(self, swizzle, l2promo, oob, interleave, tma_box_shape, memref_ty):
+        self.swizzle = swizzle  # mlir.nvgpu.TensorMapSwizzleKind
+        self.l2promo = l2promo  # mlir.nvgpu.TensorMapL2PromoKind
+        self.oob = oob  # mlir.nvgpu.TensorMapOOBKind
+        self.interleave = interleave  # mlir.nvgpu.TensorMapInterleaveKind
+        self.tma_box_shape = tma_box_shape
+        self.memref_ty = memref_ty  # MemRefType
+
+    @property
+    def tensormap_descriptor_ty(self):
+        """Returns a tensormap descriptor type."""
+        memref_str = f"memref<{self.tma_box_shape[0]}x{self.tma_box_shape[1]}x{self.memref_ty.element_type}, 3>"
+        parse_str = f"!nvgpu.tensormap.descriptor<tensor = {memref_str},\
+                                              swizzle = {self.swizzle},\
+                                              l2promo = {self.l2promo},\
+                                              oob = {self.oob},\
+                                              interleave = {self.interleave}>"
+        return ir.Type.parse(parse_str)
----------------
makslevental wrote:

extra :+1: from me on using [mlir_type_subclass](https://github.com/llvm/llvm-project/blob/91f11611337dde9a8e0a5e19240f6bb4671922c6/mlir/include/mlir/Bindings/Python/PybindAdaptors.h#L440) for this (example use [here](https://github.com/llvm/llvm-project/blob/681eacc1b670fd7137d8677fef6fc76c6e37dca9/mlir/lib/Bindings/Python/DialectTransform.cpp#L45))

https://github.com/llvm/llvm-project/pull/87153


More information about the Mlir-commits mailing list