[llvm] r330296 - [NVPTX, CUDA] Added support for m8n32k16 and m32n8k16 variants of wmma instructions.
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 18 14:51:48 PDT 2018
Author: tra
Date: Wed Apr 18 14:51:48 2018
New Revision: 330296
URL: http://llvm.org/viewvc/llvm-project?rev=330296&view=rev
Log:
[NVPTX, CUDA] Added support for m8n32k16 and m32n8k16 variants of wmma instructions.
The new instructions were added added for sm_70+ GPUs in CUDA-9.1.
Differential Revision: https://reviews.llvm.org/D45068
Modified:
llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td
llvm/trunk/lib/Target/NVPTX/NVPTX.td
llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td
llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td
llvm/trunk/test/CodeGen/NVPTX/wmma.py
Modified: llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td?rev=330296&r1=330295&r2=330296&view=diff
==============================================================================
--- llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td (original)
+++ llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td Wed Apr 18 14:51:48 2018
@@ -3920,7 +3920,9 @@ multiclass NVVM_WMMA_LD_G<string Geometr
}
multiclass NVVM_WMMA_LD {
+ defm _m32n8k16_load: NVVM_WMMA_LD_G<"m32n8k16">;
defm _m16n16k16_load: NVVM_WMMA_LD_G<"m16n16k16">;
+ defm _m8n32k16_load: NVVM_WMMA_LD_G<"m8n32k16">;
}
defm int_nvvm_wmma: NVVM_WMMA_LD;
@@ -3947,7 +3949,7 @@ class NVVM_WMMA_STD_GLSTS<string Geometr
# !if(WithStride, ".stride", "")
# "." # Type>;
-multiclass NVVM_WMMA_STD_GLT<string Geometry, string Layout,
+multiclass NVVM_WMMA_STD_GLT<string Geometry, string Layout,
string Type, LLVMType regty> {
def _stride: NVVM_WMMA_STD_GLSTS<Geometry, Layout, Type, regty, 1>;
def NAME: NVVM_WMMA_STD_GLSTS<Geometry, Layout, Type, regty, 0>;
@@ -3963,7 +3965,9 @@ multiclass NVVM_WMMA_STD_G<string Geomet
}
multiclass NVVM_WMMA_STD {
+ defm _m32n8k16_store: NVVM_WMMA_STD_G<"m32n8k16">;
defm _m16n16k16_store: NVVM_WMMA_STD_G<"m16n16k16">;
+ defm _m8n32k16_store: NVVM_WMMA_STD_G<"m8n32k16">;
}
defm int_nvvm_wmma: NVVM_WMMA_STD;
@@ -4033,7 +4037,9 @@ multiclass NVVM_WMMA_MMA_G<string Geomet
}
multiclass NVVM_WMMA_MMA {
+ defm _m32n8k16_mma : NVVM_WMMA_MMA_G<"m32n8k16">;
defm _m16n16k16_mma : NVVM_WMMA_MMA_G<"m16n16k16">;
+ defm _m8n32k16_mma : NVVM_WMMA_MMA_G<"m8n32k16">;
}
defm int_nvvm_wmma : NVVM_WMMA_MMA;
Modified: llvm/trunk/lib/Target/NVPTX/NVPTX.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTX.td?rev=330296&r1=330295&r2=330296&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTX.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTX.td Wed Apr 18 14:51:48 2018
@@ -52,6 +52,8 @@ def SM62 : SubtargetFeature<"sm_62", "Sm
"Target SM 6.2">;
def SM70 : SubtargetFeature<"sm_70", "SmVersion", "70",
"Target SM 7.0">;
+def SM72 : SubtargetFeature<"sm_72", "SmVersion", "72",
+ "Target SM 7.2">;
// PTX Versions
def PTX32 : SubtargetFeature<"ptx32", "PTXVersion", "32",
@@ -68,6 +70,8 @@ def PTX50 : SubtargetFeature<"ptx50", "P
"Use PTX version 5.0">;
def PTX60 : SubtargetFeature<"ptx60", "PTXVersion", "60",
"Use PTX version 6.0">;
+def PTX61 : SubtargetFeature<"ptx61", "PTXVersion", "61",
+ "Use PTX version 6.1">;
//===----------------------------------------------------------------------===//
// NVPTX supported processors.
@@ -89,6 +93,7 @@ def : Proc<"sm_60", [SM60, PTX50]>;
def : Proc<"sm_61", [SM61, PTX50]>;
def : Proc<"sm_62", [SM62, PTX50]>;
def : Proc<"sm_70", [SM70, PTX60]>;
+def : Proc<"sm_72", [SM72, PTX61]>;
def NVPTXInstrInfo : InstrInfo {
}
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp?rev=330296&r1=330295&r2=330296&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp Wed Apr 18 14:51:48 2018
@@ -3329,7 +3329,23 @@ bool NVPTXTargetLowering::getTgtMemIntri
case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
- case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride: {
+ case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v8f16;
Info.ptrVal = I.getArgOperand(0);
@@ -3342,7 +3358,15 @@ bool NVPTXTargetLowering::getTgtMemIntri
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
- case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride: {
+ case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v4f16;
Info.ptrVal = I.getArgOperand(0);
@@ -3355,7 +3379,15 @@ bool NVPTXTargetLowering::getTgtMemIntri
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
- case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride: {
+ case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);
@@ -3368,7 +3400,15 @@ bool NVPTXTargetLowering::getTgtMemIntri
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
- case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride: {
+ case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v4f16;
Info.ptrVal = I.getArgOperand(0);
@@ -3381,7 +3421,15 @@ bool NVPTXTargetLowering::getTgtMemIntri
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
- case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride: {
+ case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td?rev=330296&r1=330295&r2=330296&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td Wed Apr 18 14:51:48 2018
@@ -142,6 +142,7 @@ def true : Predicate<"true">;
def hasPTX31 : Predicate<"Subtarget->getPTXVersion() >= 31">;
def hasPTX60 : Predicate<"Subtarget->getPTXVersion() >= 60">;
+def hasPTX61 : Predicate<"Subtarget->getPTXVersion() >= 61">;
def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">;
def hasSM70 : Predicate<"Subtarget->getSmVersion() >= 70">;
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td?rev=330296&r1=330295&r2=330296&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td Wed Apr 18 14:51:48 2018
@@ -7378,7 +7378,11 @@ class EmptyNVPTXInst : NVPTXInst<(outs),
class WMMA_LOAD_GALSTOS<string Geometry, string Abc, string Layout,
string Space, string Type, NVPTXRegClass regclass,
DAGOperand SrcOp, bit WithStride>
- : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
+ : EmptyNVPTXInst,
+ Requires<[!if(!eq(Geometry, "m16n16k16"),
+ hasPTX60,
+ hasPTX61),
+ hasSM70]> {
// Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic
// for this function.
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_"
@@ -7420,10 +7424,10 @@ class WMMA_LOAD_GALSTOS<string Geometry,
let InOperandList = Ins;
let AsmString = "wmma.load."
# Abc
- # ".sync."
- # Layout
- # ".m16n16k16"
- # Space
+ # ".sync"
+ # "." # Layout
+ # "." # Geometry
+ # Space
# "." # Type # " \t"
# !if(!eq(Abc#Type, "cf16"),
"{{$r0, $r1, $r2, $r3}}",
@@ -7512,7 +7516,9 @@ multiclass WMMA_LOAD_G<string Geometry>
defm _load_c_f32: WMMA_LOAD_GAT<Geometry, "c", "f32", Float32Regs>;
}
+defm INT_WMMA_m32n8k16: WMMA_LOAD_G<"m32n8k16">;
defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">;
+defm INT_WMMA_m8n32k16: WMMA_LOAD_G<"m8n32k16">;
//
// wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
@@ -7520,7 +7526,11 @@ defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m1
class WMMA_STORE_D_GLSTSO<string Geometry, string Layout, string Space,
string Type, NVPTXRegClass regclass,
bit WithStride, DAGOperand DstOp>
- : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
+ : EmptyNVPTXInst,
+ Requires<[!if(!eq(Geometry, "m16n16k16"),
+ hasPTX60,
+ hasPTX61),
+ hasSM70]> {
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA"
# "_" # Geometry # "_store_d"
# "_" # Type
@@ -7641,11 +7651,9 @@ multiclass WMMA_STORE_D_G<string Geometr
defm _store_d_f32: WMMA_STORE_D_GT<Geometry, "f32", Float32Regs>;
}
-// multiclass WMMA_STORE_D {
-// defm _m16n16k16: WMMA_STORE_D_G<"m16n16k16">;
-// }
-
+defm INT_WMMA_m32n8k16: WMMA_STORE_D_G<"m32n8k16">;
defm INT_WMMA_m16n16k16: WMMA_STORE_D_G<"m16n16k16">;
+defm INT_WMMA_m8n32k16: WMMA_STORE_D_G<"m8n32k16">;
// WMMA.MMA
class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout,
@@ -7653,7 +7661,11 @@ class WMMA_MMA_GABDCS<string Geometry, s
string CType, NVPTXRegClass c_reg,
NVPTXRegClass ab_reg,
string Satfinite = "">
- : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
+ : EmptyNVPTXInst,
+ Requires<[!if(!eq(Geometry, "m16n16k16"),
+ hasPTX60,
+ hasPTX61),
+ hasSM70]> {
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_"
# Geometry
# "_mma"
@@ -7686,7 +7698,7 @@ class WMMA_MMA_GABDCS<string Geometry, s
let AsmString = "wmma.mma.sync."
# ALayout
# "." # BLayout
- # ".m16n16k16"
+ # "." # Geometry
# "." # DType
# "." # CType
# Satfinite # "\n\t\t"
@@ -7734,4 +7746,6 @@ multiclass WMMA_MMA_G<string Geometry> {
defm _row: WMMA_MMA_GA<Geometry, "row">;
}
+defm INT_WMMA_MMA_m32n8k16 : WMMA_MMA_G<"m32n8k16">;
defm INT_WMMA_MMA_m16n16k16 : WMMA_MMA_G<"m16n16k16">;
+defm INT_WMMA_MMA_m8n32k16 : WMMA_MMA_G<"m8n32k16">;
Modified: llvm/trunk/test/CodeGen/NVPTX/wmma.py
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/NVPTX/wmma.py?rev=330296&r1=330295&r2=330296&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/NVPTX/wmma.py (original)
+++ llvm/trunk/test/CodeGen/NVPTX/wmma.py Wed Apr 18 14:51:48 2018
@@ -2,7 +2,7 @@
# generates correct instructions for them.
# RUN: python %s > %t.ll
-# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 | FileCheck %t.ll
+# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll
from itertools import product
from string import Template
@@ -36,13 +36,15 @@ check_f16_8 = "{{%s}}" % ", *".join(["%h
check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
+known_geoms = ["m16n16k16", "m8n32k16", "m32n8k16"]
+
def gen_wmma_load_tests():
load_template = """
declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
; CHECK-LABEL: .func {{.*}}test_${function}(
define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
-; CHECK ${instruction}
+; CHECK: ${instruction}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
%v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
@@ -51,7 +53,7 @@ define ${ret_ty} @test_${function}(i8 ${
; CHECK-LABEL: .func{{.*}}test_${function}_o(
define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
-; CHECK ${instruction}
+; CHECK: ${instruction}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
%src1 = getelementptr i8, i8 ${as}* %src, i32 128;
@@ -60,9 +62,10 @@ define ${ret_ty} @test_${function}_o(i8
}
"""
intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
- instruction_template = "wmma.load.${abc}.sync.${geom}.${layout}${space}.${itype}"
+ instruction_template = "wmma.load.${abc}.sync.${layout}.${geom}${space}.${itype}"
- for abc, layout, space, stride, itype in product(
+ for geom, abc, layout, space, stride, itype in product(
+ known_geoms,
"abc",
["row","col"],
["",".shared",".global"],
@@ -77,7 +80,7 @@ define ${ret_ty} @test_${function}_o(i8
"itype" : itype,
"pspace" : get_pspace(space),
"as" : "addrspace(%d)" % get_aspace(space),
- "geom" : "m16n16k16",
+ "geom" : geom,
}
if itype == "f32" and abc != "c":
@@ -112,7 +115,7 @@ declare void @${intrinsic}(i8 ${as}* %sr
; CHECK-LABEL: .func {{.*}}test_${function}(
define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
-; CHECK ${instruction} {{.*}}[%rd{{[0-9+]}}
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
; CHECK: {${check_args}}
; CHECK: ${stride_pattern}
call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
@@ -121,7 +124,7 @@ define void @test_${function}(i8 ${as}*
; CHECK-LABEL: .func{{.*}}test_${function}_o(
define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
-; CHECK ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
; CHECK: ${check_args}
; CHECK: ${stride_pattern}
%src1 = getelementptr i8, i8 ${as}* %src, i32 128;
@@ -130,9 +133,10 @@ define void @test_${function}_o(i8 ${as}
}
"""
intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
- instruction_template = "wmma.store.${abc}.sync.${geom}.${layout}${space}.${itype}"
+ instruction_template = "wmma.store.${abc}.sync.${layout}.${geom}${space}.${itype}"
- for abc, layout, space, stride, itype in product(
+ for geom, abc, layout, space, stride, itype in product(
+ known_geoms,
"d",
["row","col"],
["",".shared",".global"],
@@ -147,7 +151,7 @@ define void @test_${function}_o(i8 ${as}
"itype" : itype,
"pspace" : get_pspace(space),
"as" : "addrspace(%d)" % get_aspace(space),
- "geom" : "m16n16k16",
+ "geom" : geom,
}
test_params = params
@@ -174,11 +178,11 @@ declare ${ret_ty} @${intrinsic}(
; CHECK-LABEL: .func {{.*}}test_${function}(
define ${ret_ty} @test_${function}(
${args}) {
-; CHECK ${instruction} {{.*}}[%rd{{[0-9+]}}
-; CHECK ${check_d}
-; CHECK ${check_ab}
-; CHECK ${check_ab}
-; CHECK ${check_c}
+; CHECK: ${instruction}
+; CHECK-NEXT: ${check_d}
+; CHECK-NEXT: ${check_ab}
+; CHECK-NEXT: ${check_ab}
+; CHECK-NEXT: ${check_c}
%r = call ${ret_ty} @${intrinsic}(
${args});
ret ${ret_ty} %r;
@@ -187,7 +191,8 @@ define ${ret_ty} @test_${function}(
intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}"
instruction_template = "wmma.mma.sync.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}"
- for alayout, blayout, ctype, dtype, satf in product(
+ for geom, alayout, blayout, ctype, dtype, satf in product(
+ known_geoms,
["row","col"],
["row","col"],
["f16", "f32"],
@@ -200,7 +205,7 @@ define ${ret_ty} @test_${function}(
"ctype" : ctype,
"dtype" : dtype,
"satf" : satf,
- "geom" : "m16n16k16",
+ "geom" : geom,
}
test_params = params
More information about the llvm-commits
mailing list