[llvm] 533cc9a - [NVPTX] Limit a sparsity selector in sparse MMA intrinsics. (#154984)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 22 12:27:34 PDT 2025
Author: Kirill Vedernikov
Date: 2025-08-23T00:57:31+05:30
New Revision: 533cc9a6ac5011f35002a43121fdbeb45dd9ab66
URL: https://github.com/llvm/llvm-project/commit/533cc9a6ac5011f35002a43121fdbeb45dd9ab66
DIFF: https://github.com/llvm/llvm-project/commit/533cc9a6ac5011f35002a43121fdbeb45dd9ab66.diff
LOG: [NVPTX] Limit a sparsity selector in sparse MMA intrinsics. (#154984)
This PR fixes NVPTX tests in LLVM testing by adding more limitations for a sparsity selector in sparse MMA intrinsics. The previous PR that is merged to llvm:main is [PR150950](https://github.com/llvm/llvm-project/pull/150950). The merge
to llvm:main is d9c6b7b
Added:
Modified:
llvm/include/llvm/IR/IntrinsicsNVVM.td
llvm/test/CodeGen/NVPTX/wmma.py
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index cd7a0bc9c4b48..130fa27e4f870 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -2161,6 +2161,7 @@ class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
// The range [0;num_threads) is for the sparsity selector that indicates the threads
// which contribute metadata.
int num_threads = !if(!or(!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "bf16")),
+ !and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "f16")),
!and(!eq(A.geom, "m16n8k16"), !eq(A.ptx_elt_type, "tf32")),
!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "u8")),
!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "s8")),
@@ -2175,7 +2176,11 @@ class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
!eq(A.ptx_elt_type, "e3m2"),
!eq(A.ptx_elt_type, "e2m3"),
!eq(A.ptx_elt_type, "e2m1"))),
- 1, 4));
+ 1,
+ !if(!and(!eq(A.geom, "m16n8k128"),
+ !or(!eq(A.ptx_elt_type, "s4"),
+ !eq(A.ptx_elt_type, "u4"))),
+ 1, 4)));
let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
Range<ArgIndex<pos>, 0, num_threads>];
}
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index f4f166c4018d0..6d73bce46da7c 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -1135,6 +1135,7 @@ def sp_selector_gen(op):
# (geom, type) -> allowed selector range
range_01 = {
("m16n8k32", "bf16"),
+ ("m16n8k32", "f16"),
("m16n8k16", "tf32"),
("m16n8k32", "u8"),
("m16n8k32", "s8"),
@@ -1154,6 +1155,11 @@ def sp_selector_gen(op):
"e2m1",
]:
return range(1)
+ if op.a.geom == "m16n8k128" and op.a.mma_type.ptx_type in [
+ "u4",
+ "s4",
+ ]:
+ return range(1)
return range(4)
More information about the llvm-commits
mailing list