[llvm] r330296 - [NVPTX, CUDA] Added support for m8n32k16 and m32n8k16 variants of wmma instructions.

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 18 14:51:48 PDT 2018


Author: tra
Date: Wed Apr 18 14:51:48 2018
New Revision: 330296

URL: http://llvm.org/viewvc/llvm-project?rev=330296&view=rev
Log:
[NVPTX, CUDA] Added support for m8n32k16 and m32n8k16 variants of wmma instructions.

The new instructions were added added for sm_70+ GPUs in CUDA-9.1.

Differential Revision: https://reviews.llvm.org/D45068

Modified:
    llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td
    llvm/trunk/lib/Target/NVPTX/NVPTX.td
    llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td
    llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td
    llvm/trunk/test/CodeGen/NVPTX/wmma.py

Modified: llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td?rev=330296&r1=330295&r2=330296&view=diff
==============================================================================
--- llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td (original)
+++ llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td Wed Apr 18 14:51:48 2018
@@ -3920,7 +3920,9 @@ multiclass NVVM_WMMA_LD_G<string Geometr
 }
 
 multiclass NVVM_WMMA_LD {
+  defm _m32n8k16_load: NVVM_WMMA_LD_G<"m32n8k16">;
   defm _m16n16k16_load: NVVM_WMMA_LD_G<"m16n16k16">;
+  defm _m8n32k16_load: NVVM_WMMA_LD_G<"m8n32k16">;
 }
 
 defm int_nvvm_wmma: NVVM_WMMA_LD;
@@ -3947,7 +3949,7 @@ class NVVM_WMMA_STD_GLSTS<string Geometr
                    # !if(WithStride, ".stride", "")
                    # "." # Type>;
 
-multiclass NVVM_WMMA_STD_GLT<string Geometry, string Layout, 
+multiclass NVVM_WMMA_STD_GLT<string Geometry, string Layout,
                              string Type, LLVMType regty> {
   def _stride: NVVM_WMMA_STD_GLSTS<Geometry, Layout, Type, regty, 1>;
   def NAME:    NVVM_WMMA_STD_GLSTS<Geometry, Layout, Type, regty, 0>;
@@ -3963,7 +3965,9 @@ multiclass NVVM_WMMA_STD_G<string Geomet
 }
 
 multiclass NVVM_WMMA_STD {
+  defm _m32n8k16_store:  NVVM_WMMA_STD_G<"m32n8k16">;
   defm _m16n16k16_store: NVVM_WMMA_STD_G<"m16n16k16">;
+  defm _m8n32k16_store:  NVVM_WMMA_STD_G<"m8n32k16">;
 }
 
 defm int_nvvm_wmma: NVVM_WMMA_STD;
@@ -4033,7 +4037,9 @@ multiclass NVVM_WMMA_MMA_G<string Geomet
 }
 
 multiclass NVVM_WMMA_MMA {
+  defm _m32n8k16_mma : NVVM_WMMA_MMA_G<"m32n8k16">;
   defm _m16n16k16_mma : NVVM_WMMA_MMA_G<"m16n16k16">;
+  defm _m8n32k16_mma : NVVM_WMMA_MMA_G<"m8n32k16">;
 }
 
 defm int_nvvm_wmma : NVVM_WMMA_MMA;

Modified: llvm/trunk/lib/Target/NVPTX/NVPTX.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTX.td?rev=330296&r1=330295&r2=330296&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTX.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTX.td Wed Apr 18 14:51:48 2018
@@ -52,6 +52,8 @@ def SM62 : SubtargetFeature<"sm_62", "Sm
                              "Target SM 6.2">;
 def SM70 : SubtargetFeature<"sm_70", "SmVersion", "70",
                              "Target SM 7.0">;
+def SM72 : SubtargetFeature<"sm_72", "SmVersion", "72",
+                             "Target SM 7.2">;
 
 // PTX Versions
 def PTX32 : SubtargetFeature<"ptx32", "PTXVersion", "32",
@@ -68,6 +70,8 @@ def PTX50 : SubtargetFeature<"ptx50", "P
                              "Use PTX version 5.0">;
 def PTX60 : SubtargetFeature<"ptx60", "PTXVersion", "60",
                              "Use PTX version 6.0">;
+def PTX61 : SubtargetFeature<"ptx61", "PTXVersion", "61",
+                             "Use PTX version 6.1">;
 
 //===----------------------------------------------------------------------===//
 // NVPTX supported processors.
@@ -89,6 +93,7 @@ def : Proc<"sm_60", [SM60, PTX50]>;
 def : Proc<"sm_61", [SM61, PTX50]>;
 def : Proc<"sm_62", [SM62, PTX50]>;
 def : Proc<"sm_70", [SM70, PTX60]>;
+def : Proc<"sm_72", [SM72, PTX61]>;
 
 def NVPTXInstrInfo : InstrInfo {
 }

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp?rev=330296&r1=330295&r2=330296&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp Wed Apr 18 14:51:48 2018
@@ -3329,7 +3329,23 @@ bool NVPTXTargetLowering::getTgtMemIntri
   case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
   case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
   case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
-  case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride: {
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v8f16;
     Info.ptrVal = I.getArgOperand(0);
@@ -3342,7 +3358,15 @@ bool NVPTXTargetLowering::getTgtMemIntri
   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
-  case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride: {
+  case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v4f16;
     Info.ptrVal = I.getArgOperand(0);
@@ -3355,7 +3379,15 @@ bool NVPTXTargetLowering::getTgtMemIntri
   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
-  case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride: {
+  case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v8f32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3368,7 +3400,15 @@ bool NVPTXTargetLowering::getTgtMemIntri
   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
-  case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride: {
+  case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
+  case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
+  case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
+  case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
+  case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
     Info.opc = ISD::INTRINSIC_VOID;
     Info.memVT = MVT::v4f16;
     Info.ptrVal = I.getArgOperand(0);
@@ -3381,7 +3421,15 @@ bool NVPTXTargetLowering::getTgtMemIntri
   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
-  case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride: {
+  case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
+  case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
+  case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
+  case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
+  case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride: {
     Info.opc = ISD::INTRINSIC_VOID;
     Info.memVT = MVT::v8f32;
     Info.ptrVal = I.getArgOperand(0);

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td?rev=330296&r1=330295&r2=330296&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td Wed Apr 18 14:51:48 2018
@@ -142,6 +142,7 @@ def true : Predicate<"true">;
 
 def hasPTX31 : Predicate<"Subtarget->getPTXVersion() >= 31">;
 def hasPTX60 : Predicate<"Subtarget->getPTXVersion() >= 60">;
+def hasPTX61 : Predicate<"Subtarget->getPTXVersion() >= 61">;
 
 def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">;
 def hasSM70 : Predicate<"Subtarget->getSmVersion() >= 70">;

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td?rev=330296&r1=330295&r2=330296&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td Wed Apr 18 14:51:48 2018
@@ -7378,7 +7378,11 @@ class EmptyNVPTXInst : NVPTXInst<(outs),
 class WMMA_LOAD_GALSTOS<string Geometry, string Abc, string Layout,
                         string Space, string Type, NVPTXRegClass regclass,
                         DAGOperand SrcOp, bit WithStride>
-  : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
+  : EmptyNVPTXInst,
+    Requires<[!if(!eq(Geometry, "m16n16k16"),
+                  hasPTX60,
+                  hasPTX61),
+              hasSM70]> {
   // Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic
   // for this function.
   PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_"
@@ -7420,10 +7424,10 @@ class WMMA_LOAD_GALSTOS<string Geometry,
   let InOperandList = Ins;
   let AsmString = "wmma.load."
                   # Abc
-                  # ".sync."
-                  # Layout
-                  # ".m16n16k16"
-                  # Space 
+                  # ".sync"
+                  # "." # Layout
+                  # "." # Geometry
+                  # Space
                   # "." # Type # " \t"
                   # !if(!eq(Abc#Type, "cf16"),
                         "{{$r0, $r1, $r2, $r3}}",
@@ -7512,7 +7516,9 @@ multiclass WMMA_LOAD_G<string Geometry>
   defm _load_c_f32: WMMA_LOAD_GAT<Geometry, "c", "f32", Float32Regs>;
 }
 
+defm INT_WMMA_m32n8k16: WMMA_LOAD_G<"m32n8k16">;
 defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">;
+defm INT_WMMA_m8n32k16: WMMA_LOAD_G<"m8n32k16">;
 
 //
 // wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
@@ -7520,7 +7526,11 @@ defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m1
 class WMMA_STORE_D_GLSTSO<string Geometry, string Layout, string Space,
                           string Type, NVPTXRegClass regclass,
                           bit WithStride, DAGOperand DstOp>
-  : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
+  : EmptyNVPTXInst,
+    Requires<[!if(!eq(Geometry, "m16n16k16"),
+                  hasPTX60,
+                  hasPTX61),
+              hasSM70]> {
   PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA"
                                        # "_" # Geometry # "_store_d"
                                        # "_" # Type
@@ -7641,11 +7651,9 @@ multiclass WMMA_STORE_D_G<string Geometr
   defm _store_d_f32: WMMA_STORE_D_GT<Geometry, "f32", Float32Regs>;
 }
 
-// multiclass WMMA_STORE_D {
-//   defm _m16n16k16: WMMA_STORE_D_G<"m16n16k16">;
-// }
-
+defm INT_WMMA_m32n8k16: WMMA_STORE_D_G<"m32n8k16">;
 defm INT_WMMA_m16n16k16: WMMA_STORE_D_G<"m16n16k16">;
+defm INT_WMMA_m8n32k16: WMMA_STORE_D_G<"m8n32k16">;
 
 // WMMA.MMA
 class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout,
@@ -7653,7 +7661,11 @@ class WMMA_MMA_GABDCS<string Geometry, s
                      string CType, NVPTXRegClass c_reg,
                      NVPTXRegClass ab_reg,
                      string Satfinite = "">
-  : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
+  : EmptyNVPTXInst,
+    Requires<[!if(!eq(Geometry, "m16n16k16"),
+                  hasPTX60,
+                  hasPTX61),
+              hasSM70]> {
   Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_"
                                     # Geometry
                                     # "_mma"
@@ -7686,7 +7698,7 @@ class WMMA_MMA_GABDCS<string Geometry, s
   let AsmString = "wmma.mma.sync."
                   # ALayout
                   # "." # BLayout
-                  # ".m16n16k16"
+                  # "." # Geometry
                   # "." # DType
                   # "." # CType
                   # Satfinite # "\n\t\t"
@@ -7734,4 +7746,6 @@ multiclass WMMA_MMA_G<string Geometry> {
   defm _row: WMMA_MMA_GA<Geometry, "row">;
 }
 
+defm INT_WMMA_MMA_m32n8k16 : WMMA_MMA_G<"m32n8k16">;
 defm INT_WMMA_MMA_m16n16k16 : WMMA_MMA_G<"m16n16k16">;
+defm INT_WMMA_MMA_m8n32k16 : WMMA_MMA_G<"m8n32k16">;

Modified: llvm/trunk/test/CodeGen/NVPTX/wmma.py
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/NVPTX/wmma.py?rev=330296&r1=330295&r2=330296&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/NVPTX/wmma.py (original)
+++ llvm/trunk/test/CodeGen/NVPTX/wmma.py Wed Apr 18 14:51:48 2018
@@ -2,7 +2,7 @@
 # generates correct instructions for them.
 
 # RUN: python %s > %t.ll
-# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 | FileCheck %t.ll
+# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll
 
 from itertools import product
 from string import Template
@@ -36,13 +36,15 @@ check_f16_8 = "{{%s}}" % ", *".join(["%h
 check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
 check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
 
+known_geoms = ["m16n16k16", "m8n32k16", "m32n8k16"]
+
 def gen_wmma_load_tests():
   load_template = """
 declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
 
 ; CHECK-LABEL: .func {{.*}}test_${function}(
 define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
-; CHECK ${instruction}
+; CHECK: ${instruction}
 ; CHECK: {${check_result}}
 ; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
   %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
@@ -51,7 +53,7 @@ define ${ret_ty} @test_${function}(i8 ${
 
 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
 define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
-; CHECK ${instruction}
+; CHECK: ${instruction}
 ; CHECK: {${check_result}}
 ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
   %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
@@ -60,9 +62,10 @@ define ${ret_ty} @test_${function}_o(i8
 }
 """
   intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
-  instruction_template = "wmma.load.${abc}.sync.${geom}.${layout}${space}.${itype}"
+  instruction_template = "wmma.load.${abc}.sync.${layout}.${geom}${space}.${itype}"
 
-  for abc, layout, space, stride, itype in product(
+  for geom, abc, layout, space, stride, itype in product(
+      known_geoms,
       "abc",
       ["row","col"],
       ["",".shared",".global"],
@@ -77,7 +80,7 @@ define ${ret_ty} @test_${function}_o(i8
         "itype" : itype,
         "pspace" : get_pspace(space),
         "as"     : "addrspace(%d)" % get_aspace(space),
-        "geom"   : "m16n16k16",
+        "geom"   : geom,
     }
 
     if itype == "f32" and abc != "c":
@@ -112,7 +115,7 @@ declare void @${intrinsic}(i8 ${as}* %sr
 
 ; CHECK-LABEL: .func {{.*}}test_${function}(
 define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
-; CHECK ${instruction} {{.*}}[%rd{{[0-9+]}}
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
 ; CHECK: {${check_args}}
 ; CHECK: ${stride_pattern}
   call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
@@ -121,7 +124,7 @@ define void @test_${function}(i8 ${as}*
 
 ; CHECK-LABEL: .func{{.*}}test_${function}_o(
 define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
-; CHECK ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
 ; CHECK: ${check_args}
 ; CHECK: ${stride_pattern}
   %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
@@ -130,9 +133,10 @@ define void @test_${function}_o(i8 ${as}
 }
 """
   intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
-  instruction_template = "wmma.store.${abc}.sync.${geom}.${layout}${space}.${itype}"
+  instruction_template = "wmma.store.${abc}.sync.${layout}.${geom}${space}.${itype}"
 
-  for abc, layout, space, stride, itype in product(
+  for geom, abc, layout, space, stride, itype in product(
+      known_geoms,
       "d",
       ["row","col"],
       ["",".shared",".global"],
@@ -147,7 +151,7 @@ define void @test_${function}_o(i8 ${as}
         "itype" : itype,
         "pspace" : get_pspace(space),
         "as"     : "addrspace(%d)" % get_aspace(space),
-        "geom"   : "m16n16k16",
+        "geom"   : geom,
     }
 
     test_params = params
@@ -174,11 +178,11 @@ declare ${ret_ty} @${intrinsic}(
 ; CHECK-LABEL: .func {{.*}}test_${function}(
 define ${ret_ty} @test_${function}(
         ${args}) {
-; CHECK ${instruction} {{.*}}[%rd{{[0-9+]}}
-; CHECK ${check_d}
-; CHECK ${check_ab}
-; CHECK ${check_ab}
-; CHECK ${check_c}
+; CHECK: ${instruction}
+; CHECK-NEXT: ${check_d}
+; CHECK-NEXT: ${check_ab}
+; CHECK-NEXT: ${check_ab}
+; CHECK-NEXT: ${check_c}
   %r = call ${ret_ty} @${intrinsic}(
         ${args});
   ret ${ret_ty} %r;
@@ -187,7 +191,8 @@ define ${ret_ty} @test_${function}(
   intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}"
   instruction_template = "wmma.mma.sync.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}"
 
-  for alayout, blayout, ctype, dtype, satf in product(
+  for geom, alayout, blayout, ctype, dtype, satf in product(
+      known_geoms,
       ["row","col"],
       ["row","col"],
       ["f16", "f32"],
@@ -200,7 +205,7 @@ define ${ret_ty} @test_${function}(
         "ctype" : ctype,
         "dtype" : dtype,
         "satf"  : satf,
-        "geom"  : "m16n16k16",
+        "geom"  : geom,
     }
 
     test_params = params




More information about the llvm-commits mailing list