[Mlir-commits] [llvm] [mlir] [NVPTX] Added more MMA intrinsics for F8F6F4 and FP64 types. (PR #156040)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 29 07:55:22 PDT 2025


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {darker}-->


:warning: Python code formatter, darker found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
darker --check --diff -r origin/main...HEAD llvm/test/CodeGen/NVPTX/wmma.py
``````````

:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:

</details>

<details>
<summary>
View the diff from darker here.
</summary>

``````````diff
--- wmma.py	2025-08-29 14:44:59.000000 +0000
+++ wmma.py	2025-08-29 14:54:56.005225 +0000
@@ -573,30 +573,34 @@
     ):
         return False
 
     if (
         op.a.geom in ["m16n8k16", "m16n8k32"]
-        and any(x in ["e4m3", "e5m2"] for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type))
+        and any(
+            x in ["e4m3", "e5m2"]
+            for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
+        )
         and ptx_version < 87
     ):
         return False
 
     if kind != "" and (ptx_version < 87 or gpu_arch < 120 or not aa):
         return False
 
+    if kind != "" and (
+        op.a.geom != "m16n8k32"
+        or op.a.mma_type.ptx_type not in ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]
+    ):
+        return False
+
     if (
-        kind != ""
-        and (
-            op.a.geom != "m16n8k32"
-            or op.a.mma_type.ptx_type not in ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]
-        )
-    ):
-        return False
-
-    if (kind == ""
+        kind == ""
         and op.a.geom in ["m16n8k16", "m16n8k32"]
-        and any(x in ["e3m2", "e2m3", "e2m1"] for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type))
+        and any(
+            x in ["e3m2", "e2m3", "e2m1"]
+            for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
+        )
     ):
         return False
 
     # Require row/col layout for all MMA except m8n8k4 on FP16
     if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
@@ -1084,16 +1088,12 @@
 
     return generated_items
 
 
 def gen_mma_tests():
-    mma_intrinsic_template = (
-        "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${kind}${satf}.${intrinsic_signature}"
-    )
-    mma_instruction_template = (
-        "mma.sync${aligned}.${geom}.${alayout}.${blayout}${kind}${satf}.${ptx_signature}${b1op}"
-    )
+    mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${kind}${satf}.${intrinsic_signature}"
+    mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${kind}${satf}.${ptx_signature}${b1op}"
 
     generated_items = []
 
     for op, alayout, blayout, kind, satf in product(
         get_mma_ops(),

``````````

</details>


https://github.com/llvm/llvm-project/pull/156040


More information about the Mlir-commits mailing list