[llvm] [NVPTX] A fix for LLVM testing. More limitations were added for a sparsity selector in sparse MMA intrinsics. (PR #154984)
Kirill Vedernikov via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 22 09:51:16 PDT 2025
https://github.com/kvederni created https://github.com/llvm/llvm-project/pull/154984
This PR fixes NVPTX tests in LLVM testing by adding more limitations for a sparsity selector in sparse MMA intrinsics.
>From 7952c8d7316792b52370c0a29b950151614dc5e3 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Fri, 22 Aug 2025 18:46:03 +0200
Subject: [PATCH] [NVPTX] A fix for LLVM testing. More limitations were added
for a sparsity selector in sparse MMA intrinsics.
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 7 ++++++-
llvm/test/CodeGen/NVPTX/wmma.py | 6 ++++++
2 files changed, 12 insertions(+), 1 deletion(-)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index cd7a0bc9c4b48..130fa27e4f870 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -2161,6 +2161,7 @@ class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
// The range [0;num_threads) is for the sparsity selector that indicates the threads
// which contribute metadata.
int num_threads = !if(!or(!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "bf16")),
+ !and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "f16")),
!and(!eq(A.geom, "m16n8k16"), !eq(A.ptx_elt_type, "tf32")),
!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "u8")),
!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "s8")),
@@ -2175,7 +2176,11 @@ class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
!eq(A.ptx_elt_type, "e3m2"),
!eq(A.ptx_elt_type, "e2m3"),
!eq(A.ptx_elt_type, "e2m1"))),
- 1, 4));
+ 1,
+ !if(!and(!eq(A.geom, "m16n8k128"),
+ !or(!eq(A.ptx_elt_type, "s4"),
+ !eq(A.ptx_elt_type, "u4"))),
+ 1, 4)));
let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
Range<ArgIndex<pos>, 0, num_threads>];
}
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index f4f166c4018d0..6d73bce46da7c 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -1135,6 +1135,7 @@ def sp_selector_gen(op):
# (geom, type) -> allowed selector range
range_01 = {
("m16n8k32", "bf16"),
+ ("m16n8k32", "f16"),
("m16n8k16", "tf32"),
("m16n8k32", "u8"),
("m16n8k32", "s8"),
@@ -1154,6 +1155,11 @@ def sp_selector_gen(op):
"e2m1",
]:
return range(1)
+ if op.a.geom == "m16n8k128" and op.a.mma_type.ptx_type in [
+ "u4",
+ "s4",
+ ]:
+ return range(1)
return range(4)
More information about the llvm-commits
mailing list