[llvm] [NVPTX] Lower stmatrix intrinsics to PTX (PR #148561)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 14 18:56:17 PDT 2025
https://github.com/Pecco-314 updated https://github.com/llvm/llvm-project/pull/148561
>From 8944f9578ade64d14730a4afdd50e81cccbf28b6 Mon Sep 17 00:00:00 2001
From: Gao Yanfeng <gaoyanfeng at linux.alibaba.com>
Date: Mon, 14 Jul 2025 11:06:20 +0800
Subject: [PATCH 1/2] Lower stmatrix intrinsics to PTX
Lower stmatrix intrinsics defined in #148377 to PTX. See [PTX Doc](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix).
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 29 ++++-
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 45 ++++++-
llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py | 14 +++
llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py | 4 +-
llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py | 4 +-
llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py | 4 +-
llvm/test/CodeGen/NVPTX/wmma.py | 125 +++++++++++++++++++
7 files changed, 213 insertions(+), 12 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 14f05250ad6b8..79424386bc8a4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3952,7 +3952,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
- case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
+ case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v2i32;
Info.ptrVal = I.getArgOperand(0);
@@ -3975,6 +3978,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
return true;
}
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
+ Info.opc = ISD::INTRINSIC_VOID;
+ Info.memVT = MVT::i32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOStore;
+ Info.align = Align(4);
+ return true;
+ }
+
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
+ Info.opc = ISD::INTRINSIC_VOID;
+ Info.memVT = MVT::v4i32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOStore;
+ Info.align = Align(16);
+ return true;
+ }
+
case Intrinsic::nvvm_atomic_add_gen_f_cta:
case Intrinsic::nvvm_atomic_add_gen_f_sys:
case Intrinsic::nvvm_atomic_add_gen_i_cta:
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 93827be5c2811..eca6cbabd65b9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4597,7 +4597,14 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
!and(!eq(op, "ldmatrix"),
!eq(ptx_elt_type, "b8x16.b4x16_p64"),
- !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
+ !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],
+
+ !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"),
+ !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>],
+
+ !and(!eq(op, "stmatrix"),
+ !eq(ptx_elt_type, "b8"),
+ !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
// template DAGs for instruction inputs/output.
dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -4878,6 +4885,40 @@ defset list<WMMA_INSTR> LDMATRIXs = {
} // transposed
} // defset
+//
+// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
+//
+class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space>
+ : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>,
+ Requires<Frag.Predicates> {
+ // Build PatFrag that only matches particular address space.
+ dag PFOperands = !con((ops node:$dst), !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names));
+ PatFrag IntrFrag = PatFrag<PFOperands, !foreach(tmp, PFOperands, !subst(ops, Intr, tmp)),
+ !cond(!eq(Space, ".shared"): AS_match.shared,
+ true: AS_match.generic)>;
+ // Build AS-constrained pattern.
+ let IntrinsicPattern = BuildPatternPF<IntrFrag, Args>.ret;
+ let OutOperandList = (outs);
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ let AsmString = "stmatrix.sync.aligned."
+ # Frag.geom
+ # "." # Frag.frag
+ # !if(Transposed, ".trans", "")
+ # Space
+ # "." # Frag.ptx_elt_type
+ # " [$dst], " # Frag.regstring # ";";
+}
+
+// Create all stmatrix variants
+defset list<WMMA_INSTR> STMATRIXs = {
+ foreach transposed = [false, true] in {foreach space = [".shared", ""] in {
+ foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in
+ if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then
+ def : STMATRIX<WMMA_REGINFO<frag, "stmatrix">, transposed, space>;
+ } // space
+ } // transposed
+} // defset
+
// Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
// dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
// the instruction record.
@@ -4888,7 +4929,7 @@ class MMA_PAT<WMMA_INSTR wi>
Requires<wi.Predicates>;
// Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
def : MMA_PAT<mma>;
multiclass MAPA<string suffix, Intrinsic Intr> {
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
new file mode 100644
index 0000000000000..8f502065345c1
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
@@ -0,0 +1,14 @@
+# Check all variants of instructions supported by PTX78 on SM90
+# RUN: %python %s --ptx=78 --gpu-arch=90 --aa > %t-ptx78-sm_90.ll
+# RUN: FileCheck %t-ptx78-sm_90.ll < %t-ptx78-sm_90.ll \
+# RUN: --check-prefixes=PTX78STMATRIX-DAG
+# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
+# RUN: | FileCheck %t-ptx78-sm_90.ll
+# RUN: %if ptxas-12.7 %{ \
+# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
+# RUN: | %ptxas-verify -arch=sm_90 \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
index 6ad0a2a5865c4..5c14a54601ed9 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
@@ -1,9 +1,7 @@
# Check all variants of instructions supported by PTX86 on SM100a
# RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
-# RUN: --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
-# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
# RUN: | FileCheck %t-ptx86-sm_100a.ll
# RUN: %if ptxas-12.7 %{ \
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
index 7d9953484da7d..a77f9adddff9c 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
@@ -1,9 +1,7 @@
# Check all variants of instructions supported by PTX86 on SM101a
# RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
-# RUN: --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
-# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
# RUN: | FileCheck %t-ptx86-sm_101a.ll
# RUN: %if ptxas-12.7 %{ \
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
index 7bddf0b6fbb78..8126e64d6cc85 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
@@ -1,9 +1,7 @@
# Check all variants of instructions supported by PTX86 on SM120a
# RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
-# RUN: --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
-# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
# RUN: | FileCheck %t-ptx86-sm_120a.ll
# RUN: %if ptxas-12.7 %{ \
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 2ee489670e9e4..3888e9b6b1b8d 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -10,6 +10,7 @@
from itertools import product
from string import Template
+
class MMAType:
def __init__(self, ptx_type):
self.ptx_type = ptx_type
@@ -176,6 +177,13 @@ def __init__(self, geom, frag, ptx_elt_type):
"m8n16:x1:b8x16.b4x16_p64": 1,
"m8n16:x2:b8x16.b4x16_p64": 2,
"m8n16:x4:b8x16.b4x16_p64": 4,
+ # stmatrix
+ "m8n8:x1:b16": 1,
+ "m8n8:x2:b16": 2,
+ "m8n8:x4:b16": 4,
+ "m16n8:x1:b8": 1,
+ "m16n8:x2:b8": 2,
+ "m16n8:x4:b8": 4,
}.get(
"%s:%s:%s" % (geom, frag, ptx_elt_type),
{
@@ -241,6 +249,13 @@ def make_ldmatrix_ops(geoms, frags, types):
]
+def make_stmatrix_ops(geoms, frags, types):
+ return [
+ MMAFrag(geom, frag, ptx_type)
+ for (geom, frag, ptx_type) in product(geoms, frags, types)
+ ]
+
+
def get_wmma_ops():
return (
make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
@@ -315,6 +330,12 @@ def get_ldmatrix_ops():
)
+def get_stmatrix_ops():
+ return make_stmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + make_stmatrix_ops(
+ ["m16n8"], ["x1", "x2", "x4"], ["b8"]
+ )
+
+
def is_wmma_geom_supported(geom):
# geometries for FP and ints.
if geom in ["m8n32k16", "m32n8k16"]:
@@ -360,6 +381,14 @@ def is_ldmatrix_geom_supported(geom):
assert False # Unexpected geometry.
+def is_stmatrix_geom_supported(geom):
+ if geom in ["m8n8"]:
+ return ptx_version >= 78 and gpu_arch >= 90
+ elif geom in ["m16n8"]:
+ return ptx_version >= 86 and gpu_arch >= 100 and aa
+ assert False # Unexpected geometry.
+
+
def is_ldmatrix_trans_supported(geom, trans):
if geom in ["m8n8"]:
return True
@@ -369,6 +398,15 @@ def is_ldmatrix_trans_supported(geom, trans):
return trans == ""
assert False # Unexpected geometry.
+
+def is_stmatrix_trans_supported(geom, trans):
+ if geom in ["m8n8"]:
+ return True
+ elif geom in ["m16n8"]:
+ return trans == ".trans"
+ assert False # Unexpected geometry.
+
+
def is_type_supported(ptx_type):
if ptx_type in ["s8", "u8", "s32"]:
return ptx_version >= 63 and gpu_arch >= 72
@@ -463,6 +501,16 @@ def is_ldmatrix_variant_supported(frag, trans):
return frag.frag in ["x1", "x2", "x4"]
+def is_stmatrix_variant_supported(frag, trans):
+ if not (
+ is_type_supported(frag.mma_type.ptx_type)
+ and is_stmatrix_geom_supported(frag.geom)
+ and is_stmatrix_trans_supported(frag.geom, trans)
+ ):
+ return False
+ return frag.frag in ["x1", "x2", "x4"]
+
+
def make_wmma_slice_ty(frag):
return [frag.mma_type.llvm_type] * frag.nregs
@@ -716,6 +764,61 @@ def gen_ldmatrix_tests():
return generated_items
+def gen_stmatrix_tests():
+ stmatrix_template = """
+declare void @${intrinsic}(i8 ${as}* %dst, ${args});
+
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define void @test_${function}(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}]
+; CHECK: {${check_args}}
+ call void @${intrinsic}(i8${as}* %dst, ${args});
+ ret void
+}
+
+; CHECK-LABEL: .func{{.*}}test_${function}_o(
+define void @test_${function}_o(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128],
+; CHECK: {${check_args}}
+ %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128;
+ call void @${intrinsic}(i8 ${as}* %dst1, ${args});
+ ret void
+}
+"""
+ intrinsic_template = (
+ "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
+ )
+ instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
+ )
+ generated_items = []
+
+ for frag, space, trans in product(get_stmatrix_ops(),
+ ["", ".shared"],
+ ["", ".trans"],
+ ):
+ if not is_stmatrix_variant_supported(frag, trans):
+ continue
+
+ params = {
+ "frag": frag.frag,
+ "space": space,"trans": trans,
+ "itype": frag.mma_type.ptx_type,
+ "pspace": get_pspace(space),
+ "as": "addrspace(%d)" % get_aspace(space),
+ "geom": frag.geom,
+ }
+
+ test_params = params
+ test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+ test_params["function"] = test_params["intrinsic"].replace(".", "_")
+ test_params["instruction"] = Template(instruction_template).substitute(params)
+ test_params["args"] = make_wmma_slice_args(frag)
+ test_params["check_args"] = check_pattern(frag)
+
+ print(Template(stmatrix_template).substitute(test_params))
+ generated_items.append((test_params["intrinsic"], test_params["instruction"]))
+
+ return generated_items
def mma_signature(op):
if op.a.mma_type.ptx_type == "f16":
@@ -893,6 +996,7 @@ def gen_check_unsupported_ops(items):
; NOALTFLOAT-NOT: .{{bf16|tf32}}
; NODOUBLE-NOT: .f64
; NOLDMATRIX-NOT: ldmatrix.sync.aligned
+; NOSTMATRIX-NOT: stmatrix.sync.aligned
; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@@ -994,6 +1098,26 @@ def gen_check_unsupported_ops(items):
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16
+
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8
+
; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -1039,6 +1163,7 @@ def gen_tests():
items = gen_wmma_load_tests()
items += gen_wmma_store_tests()
items += gen_ldmatrix_tests()
+ items += gen_stmatrix_tests()
items += gen_wmma_mma_tests()
items += gen_mma_tests()
gen_check_unsupported_ops(items)
>From 10708c25a185b4bb53a8c53d0b4d8c21258cd275 Mon Sep 17 00:00:00 2001
From: Gao Yanfeng <gaoyanfeng at linux.alibaba.com>
Date: Tue, 15 Jul 2025 09:53:18 +0800
Subject: [PATCH 2/2] Format Python files
---
llvm/test/CodeGen/NVPTX/wmma.py | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 3888e9b6b1b8d..2eb3c3dbb4c39 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -764,6 +764,7 @@ def gen_ldmatrix_tests():
return generated_items
+
def gen_stmatrix_tests():
stmatrix_template = """
declare void @${intrinsic}(i8 ${as}* %dst, ${args});
@@ -788,11 +789,13 @@ def gen_stmatrix_tests():
intrinsic_template = (
"llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
)
- instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
+ instruction_template = (
+ "stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
)
generated_items = []
- for frag, space, trans in product(get_stmatrix_ops(),
+ for frag, space, trans in product(
+ get_stmatrix_ops(),
["", ".shared"],
["", ".trans"],
):
@@ -801,7 +804,8 @@ def gen_stmatrix_tests():
params = {
"frag": frag.frag,
- "space": space,"trans": trans,
+ "space": space,
+ "trans": trans,
"itype": frag.mma_type.ptx_type,
"pspace": get_pspace(space),
"as": "addrspace(%d)" % get_aspace(space),
More information about the llvm-commits
mailing list