[llvm] r328006 - [NVPTX] Make tensor load/store intrinsics overloaded.

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 20 10:18:59 PDT 2018


Author: tra
Date: Tue Mar 20 10:18:59 2018
New Revision: 328006

URL: http://llvm.org/viewvc/llvm-project?rev=328006&view=rev
Log:
[NVPTX] Make tensor load/store intrinsics overloaded.

This way we can support address-space specific variants without explicitly
encoding the space in the name of the intrinsic. Less intrinsics to deal with ->
less boilerplate.

Added a bit of tablegen magic to match/replace an intrinsics with a pointer
argument in particular address space with the space-specific instruction
variant.

Updated tests to use non-default address spaces.

Differential Revision: https://reviews.llvm.org/D43268

Modified:
    llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td
    llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td
    llvm/trunk/test/CodeGen/NVPTX/wmma.py

Modified: llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td?rev=328006&r1=328005&r2=328006&view=diff
==============================================================================
--- llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td (original)
+++ llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td Tue Mar 20 10:18:59 2018
@@ -3884,30 +3884,22 @@ def int_nvvm_match_all_sync_i64p :
 //
 
 // WMMA.LOAD
-class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Space,
-                         string Type, LLVMType regty, int WithStride>
+class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Type,
+                         LLVMType regty, int WithStride>
   : Intrinsic<!if(!eq(Abc#Type,"cf16"),
                   [regty, regty, regty, regty],
                   [regty, regty, regty, regty,
                    regty, regty, regty, regty]),
-              !if(WithStride, [llvm_ptr_ty, llvm_i32_ty], [llvm_ptr_ty]),
-              [], // Properties must be set during instantiation.
+              !if(WithStride, [llvm_anyptr_ty, llvm_i32_ty], [llvm_anyptr_ty]),
+              [IntrReadMem, IntrArgMemOnly, ReadOnly<0>, NoCapture<0>],
               "llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16"
-                #Space
                 #!if(WithStride,".stride","")
                 #"."#Type>;
 
-multiclass NVVM_WMMA_LD_ALST<string Abc, string Layout, string Space,
-                           string Type, LLVMType regty> {
-  def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 1>;
-  def NAME   : NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 0>;
-}
-
-multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout,
-                        string Type, LLVMType regty> {
-  defm _global: NVVM_WMMA_LD_ALST<Abc, Layout, ".global", Type, regty>;
-  defm _shared: NVVM_WMMA_LD_ALST<Abc, Layout, ".shared", Type, regty>;
-  defm NAME:    NVVM_WMMA_LD_ALST<Abc, Layout,        "", Type, regty>;
+multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout, string Type,
+                            LLVMType regty> {
+  def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 1>;
+  def NAME   : NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 0>;
 }
 
 multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
@@ -3915,47 +3907,33 @@ multiclass NVVM_WMMA_LD_AT<string Abc, s
   defm _col: NVVM_WMMA_LD_ALT<Abc, "col", Type, regty>;
 }
 
-// For some reason ReadOnly<N> and NoCapture<N> confuses tblgen if they are
-// passed to Intrinsic<> form inside of a multiclass. Setting them globally
-// outside of the multiclass works.
-let IntrProperties = [IntrReadMem, IntrArgMemOnly,
-                      ReadOnly<0>, NoCapture<0>] in {
-  defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
-  defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
-  defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
-  defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
-}
+defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
+defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
+defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
+defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
 
 // WMMA.STORE.D
-class NVVM_WMMA_STD_LSTS<string Layout, string Space,
-                         string Type, LLVMType regty, int WithStride,
+class NVVM_WMMA_STD_LSTS<string Layout, string Type, LLVMType regty, int WithStride,
                          // This is only used to create a typed empty array we
                          // need to pass to !if below.
                          list<LLVMType>Empty=[]>
   : Intrinsic<[],
               !listconcat(
-                [llvm_ptr_ty],
+                [llvm_anyptr_ty],
                 !if(!eq(Type,"f16"),
                     [regty, regty, regty, regty],
                     [regty, regty, regty, regty,
                      regty, regty, regty, regty]),
                 !if(WithStride, [llvm_i32_ty], Empty)),
-              [], // Properties must be set during instantiation.
+              [IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>],
               "llvm.nvvm.wmma.store.d.sync."#Layout
-                   #".m16n16k16"#Space
+                   #".m16n16k16"
                    #!if(WithStride,".stride","")
                    #"."#Type>;
 
-multiclass NVVM_WMMA_STD_LST<string Layout, string Space,
-                            string Type, LLVMType regty> {
-  def _stride: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 1>;
-  def NAME:    NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 0>;
-}
-
 multiclass NVVM_WMMA_STD_LT<string Layout, string Type, LLVMType regty> {
-  defm _global: NVVM_WMMA_STD_LST<Layout, ".global", Type, regty>;
-  defm _shared: NVVM_WMMA_STD_LST<Layout, ".shared", Type, regty>;
-  defm    NAME: NVVM_WMMA_STD_LST<Layout,        "", Type, regty>;
+  def _stride: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 1>;
+  def NAME:    NVVM_WMMA_STD_LSTS<Layout, Type, regty, 0>;
 }
 
 multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
@@ -3963,11 +3941,8 @@ multiclass NVVM_WMMA_STD_T<string Type,
   defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>;
 }
 
-let IntrProperties = [IntrWriteMem, IntrArgMemOnly,
-                      WriteOnly<0>, NoCapture<0>] in {
-  defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
-  defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
-}
+defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
+defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
 
 // WMMA.MMA
 class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp?rev=328006&r1=328005&r2=328006&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp Tue Mar 20 10:18:59 2018
@@ -3327,26 +3327,10 @@ bool NVPTXTargetLowering::getTgtMemIntri
   case Intrinsic::nvvm_wmma_load_a_f16_row:
   case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
   case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
-  case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
-  case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
-  case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
-  case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
-  case Intrinsic::nvvm_wmma_load_a_f16_col_global:
-  case Intrinsic::nvvm_wmma_load_a_f16_row_global:
-  case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
-  case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
   case Intrinsic::nvvm_wmma_load_b_f16_col:
   case Intrinsic::nvvm_wmma_load_b_f16_row:
   case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
-  case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
-  case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
-  case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
-  case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
-  case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
-  case Intrinsic::nvvm_wmma_load_b_f16_col_global:
-  case Intrinsic::nvvm_wmma_load_b_f16_row_global:
-  case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
-  case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: {
+  case Intrinsic::nvvm_wmma_load_b_f16_row_stride: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v8f16;
     Info.ptrVal = I.getArgOperand(0);
@@ -3359,15 +3343,7 @@ bool NVPTXTargetLowering::getTgtMemIntri
   case Intrinsic::nvvm_wmma_load_c_f16_col:
   case Intrinsic::nvvm_wmma_load_c_f16_row:
   case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
-  case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
-  case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
-  case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
-  case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
-  case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
-  case Intrinsic::nvvm_wmma_load_c_f16_col_global:
-  case Intrinsic::nvvm_wmma_load_c_f16_row_global:
-  case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
-  case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: {
+  case Intrinsic::nvvm_wmma_load_c_f16_row_stride: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v4f16;
     Info.ptrVal = I.getArgOperand(0);
@@ -3380,15 +3356,7 @@ bool NVPTXTargetLowering::getTgtMemIntri
   case Intrinsic::nvvm_wmma_load_c_f32_col:
   case Intrinsic::nvvm_wmma_load_c_f32_row:
   case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
-  case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
-  case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
-  case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
-  case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
-  case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
-  case Intrinsic::nvvm_wmma_load_c_f32_col_global:
-  case Intrinsic::nvvm_wmma_load_c_f32_row_global:
-  case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
-  case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: {
+  case Intrinsic::nvvm_wmma_load_c_f32_row_stride: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v8f32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3401,15 +3369,7 @@ bool NVPTXTargetLowering::getTgtMemIntri
   case Intrinsic::nvvm_wmma_store_d_f16_col:
   case Intrinsic::nvvm_wmma_store_d_f16_row:
   case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
-  case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
-  case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
-  case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
-  case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
-  case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
-  case Intrinsic::nvvm_wmma_store_d_f16_col_global:
-  case Intrinsic::nvvm_wmma_store_d_f16_row_global:
-  case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
-  case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: {
+  case Intrinsic::nvvm_wmma_store_d_f16_row_stride: {
     Info.opc = ISD::INTRINSIC_VOID;
     Info.memVT = MVT::v4f16;
     Info.ptrVal = I.getArgOperand(0);
@@ -3422,15 +3382,7 @@ bool NVPTXTargetLowering::getTgtMemIntri
   case Intrinsic::nvvm_wmma_store_d_f32_col:
   case Intrinsic::nvvm_wmma_store_d_f32_row:
   case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
-  case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
-  case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
-  case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
-  case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
-  case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
-  case Intrinsic::nvvm_wmma_store_d_f32_col_global:
-  case Intrinsic::nvvm_wmma_store_d_f32_row_global:
-  case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
-  case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: {
+  case Intrinsic::nvvm_wmma_store_d_f32_row_stride: {
     Info.opc = ISD::INTRINSIC_VOID;
     Info.memVT = MVT::v8f32;
     Info.ptrVal = I.getArgOperand(0);

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td?rev=328006&r1=328005&r2=328006&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td Tue Mar 20 10:18:59 2018
@@ -7379,13 +7379,16 @@ class WMMA_LOAD_ALSTOS<string Abc, strin
                            string Type, NVPTXRegClass regclass,
                            DAGOperand SrcOp, bit WithStride>
   : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
-  // Intrinsic that matches this instruction.
-  Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_load_"
-                                    # Abc
-                                    # "_" # Type
-                                    # "_" # Layout
-                                    # !subst(".","_",Space)
-                                    # !if(WithStride,"_stride", ""));
+  // Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic
+  // for this function.
+  PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_LOAD_"
+                                       # !subst("a", "A",
+                                         !subst("b", "B",
+                                         !subst("c", "C_" # Type, Abc)))
+                                       # "_" # Layout
+                                       # !subst(".", "_", Space)
+                                       # !if(WithStride,"_stride", "")
+                                       # "_Intr");
   dag OutsR03 = (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3);
   dag OutsR47 = (outs regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7);
   dag Outs = !if(!eq(Abc#Type,"cf16"), OutsR03, !con(OutsR03, OutsR47));
@@ -7410,7 +7413,7 @@ class WMMA_LOAD_ALSTOS<string Abc, strin
                               !subst(imem, ADDRvar,
                               !subst(MEMri64, ADDRri64,
                               !subst(MEMri, ADDRri,
-                              !subst(ins, Intr, tmp)))));
+                              !subst(ins, IntrMatcher, tmp)))));
   // Finally, consatenate both parts together. !con() requires both dags to have
   // the same operator, so we wrap PatArgs in a (set ...) dag.
   let Pattern = [!con(PatOuts, (set PatArgs))];
@@ -7425,20 +7428,52 @@ class WMMA_LOAD_ALSTOS<string Abc, strin
                  #";";
 }
 
-multiclass WMMA_LOAD_ALSTO<string Abc, string Layout, string Space,
-                           string Type, NVPTXRegClass regclass,
-                           DAGOperand SrcOp> {
-  def _stride: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp, 1>;
-  def NAME:    WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp, 0>;
+class WMMA_LOAD_INTR_HELPER<string Abc, string Layout, string Space,
+                           string Type, bit WithStride>
+                           : PatFrag <(ops),(ops)> {
+  // Intrinsic that matches this instruction.
+  Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_load_"
+                                    # Abc
+                                    # "_" # Type
+                                    # "_" # Layout
+                                    # !if(WithStride,"_stride", ""));
+  code match_generic = [{
+   return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
+  }];
+  code match_shared = [{
+   return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
+  }];
+  code match_global = [{
+   return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
+  }];
+
+  let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src));
+  let Fragment = !foreach(tmp, Operands, !subst(ops, Intr, tmp));
+  let PredicateCode = !if(!eq(Space, ".shared"), match_shared,
+                      !if(!eq(Space, ".global"), match_global, match_generic));
+}
+
+multiclass WMMA_LOAD_ALSTS<string Abc, string Layout, string Space,
+                          string Type, NVPTXRegClass regclass, bit WithStride> {
+  def _avar:  WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, imem, WithStride>;
+  def _areg: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, Int32Regs, WithStride>;
+  def _areg64: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, Int64Regs, WithStride>;
+  def _ari: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, MEMri, WithStride>;
+  def _ari64: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, MEMri64, WithStride>;
+}
+
+multiclass WMMA_LOAD_ALSTSh<string Abc, string Layout, string Space,
+                            string Type, NVPTXRegClass regclass, bit WithStride> {
+  // Define a PatFrag that matches appropriate intrinsic that loads from the
+  // given address space.
+  def _Intr : WMMA_LOAD_INTR_HELPER<Abc, Layout, Space, Type, WithStride>;
+  defm NAME:  WMMA_LOAD_ALSTS<Abc, Layout, Space, Type, regclass, WithStride>;
 }
 
 multiclass WMMA_LOAD_ALST<string Abc, string Layout, string Space,
-                          string Type, NVPTXRegClass regclass> {
-  defm _avar:  WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imem>;
-  defm _areg: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, Int32Regs>;
-  defm _areg64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, Int64Regs>;
-  defm _ari: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, MEMri>;
-  defm _ari64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, MEMri64>;
+                           string Type, NVPTXRegClass regclass> {
+  defm _stride: WMMA_LOAD_ALSTSh<Abc, Layout, Space, Type, regclass, 1>;
+  defm NAME:    WMMA_LOAD_ALSTSh<Abc, Layout, Space, Type, regclass, 0>;
 }
 
 multiclass WMMA_LOAD_ALT<string Abc, string Layout,
@@ -7461,15 +7496,16 @@ defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"
 //
 // wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
 //
-class WMMA_STORE_D_LSTOS<string Layout, string Space,
+class WMMA_STORE_D_LSTSO<string Layout, string Space,
                          string Type, NVPTXRegClass regclass,
-                         DAGOperand DstOp, bit WithStride>
+                         bit WithStride, DAGOperand DstOp>
   : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
-  Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_store_d_"
-                                    # Type
-                                    # "_" # Layout
-                                    # !subst(".","_",Space)
-                                    # !if(WithStride,"_stride", ""));
+  PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_STORE_D"
+                                       # "_" # Type
+                                       # "_" # Layout
+                                       # !subst(".", "_", Space)
+                                       # !if(WithStride,"_stride", "")
+                                       # "_Intr");
 
   dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3);
   dag InsR47 = (ins regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7);
@@ -7483,7 +7519,7 @@ class WMMA_STORE_D_LSTOS<string Layout,
                               !subst(imem, ADDRvar,
                               !subst(MEMri64, ADDRri64,
                               !subst(MEMri, ADDRri,
-                              !subst(ins, Intr, tmp)))));
+                              !subst(ins, IntrMatcher, tmp)))));
   let Pattern = [PatArgs];
   let OutOperandList = (outs);
   let InOperandList = Ins;
@@ -7501,20 +7537,56 @@ class WMMA_STORE_D_LSTOS<string Layout,
 
 }
 
-multiclass WMMA_STORE_D_LSTO<string Layout, string Space,
-                             string Type, NVPTXRegClass regclass,
-                             DAGOperand DstOp> {
-  def _stride:  WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp, 1>;
-  def NAME:     WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp, 0>;
+class WMMA_STORE_INTR_HELPER<string Layout, string Space,
+                             string Type, bit WithStride>
+                            : PatFrag <(ops),(ops)> {
+  // Intrinsic that matches this instruction.
+  Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_store_d"
+                                    # "_" # Type
+                                    # "_" # Layout
+                                    # !if(WithStride, "_stride", ""));
+  code match_generic = [{
+   return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
+  }];
+  code match_shared = [{
+   return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
+  }];
+  code match_global = [{
+   return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
+  }];
+
+  dag Args = !if(!eq(Type,"f16"),
+                 (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3),
+                 (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3,
+                                 node:$r4, node:$r5, node:$r6, node:$r7));
+  dag StrideArg = !if(WithStride, (ops node:$ldm), (ops));
+  let Operands = !con(Args, StrideArg);
+  let Fragment = !foreach(tmp, Operands, !subst(ops, Intr, tmp));
+  let PredicateCode = !if(!eq(Space, ".shared"), match_shared,
+                      !if(!eq(Space, ".global"), match_global, match_generic));
+}
+
+multiclass WMMA_STORE_D_LSTS<string Layout, string Space,
+                            string Type, NVPTXRegClass regclass, bit WithStride> {
+  def _avar:   WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, imem>;
+  def _areg:   WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, Int32Regs>;
+  def _areg64: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, Int64Regs>;
+  def _ari:    WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, MEMri>;
+  def _ari64:  WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, MEMri64>;
+}
+
+multiclass WMMA_STORE_D_LSTSh<string Layout, string Space,
+                              string Type, NVPTXRegClass regclass, bit WithStride> {
+  // Define a PatFrag that matches appropriate intrinsic that loads from the
+  // given address space.
+  def _Intr:    WMMA_STORE_INTR_HELPER<Layout, Space, Type, WithStride>;
+  defm NAME:    WMMA_STORE_D_LSTS<Layout, Space, Type, regclass, WithStride>;
 }
 
 multiclass WMMA_STORE_D_LST<string Layout, string Space,
-                            string Type, NVPTXRegClass regclass> {
-  defm _avar:   WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imem>;
-  defm _areg:   WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, Int32Regs>;
-  defm _areg64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, Int64Regs>;
-  defm _ari:    WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, MEMri>;
-  defm _ari64:  WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, MEMri64>;
+                             string Type, NVPTXRegClass regclass > {
+  defm _stride: WMMA_STORE_D_LSTSh<Layout, Space, Type, regclass, 1>;
+  defm NAME:    WMMA_STORE_D_LSTSh<Layout, Space, Type, regclass, 0>;
 }
 
 multiclass WMMA_STORE_D_LT<string Layout,

Modified: llvm/trunk/test/CodeGen/NVPTX/wmma.py
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/NVPTX/wmma.py?rev=328006&r1=328005&r2=328006&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/NVPTX/wmma.py (original)
+++ llvm/trunk/test/CodeGen/NVPTX/wmma.py Tue Mar 20 10:18:59 2018
@@ -15,6 +15,22 @@ def make_wmma_slice_ty(abcd, itype):
 def make_wmma_ld_ret_ty(abc, itype):
   return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))
 
+# returns address space
+def get_aspace(space):
+  space_map = {
+      ".global" : 1,
+      ".shared" : 3,
+      ".const"  : 4,
+      ".local"  : 5,
+      ".param"  : 101,
+      ""        : 0,
+      ".generic": 0
+  }
+  return space_map[space];
+
+def get_pspace(space):
+  return "p%di8" % get_aspace(space);
+
 # Convenient test patterns.
 check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
 check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
@@ -22,28 +38,28 @@ check_f32_8 = "{{%s}}" % ", *".join(["%f
 
 def gen_wmma_load_tests():
   load_template = """
-declare ${ret_ty} @llvm.nvvm.wmma.load.$intrinsic_suffix(i8* %src ${extra_args});
+declare ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args});
 
 ; CHECK-LABEL: .func {{.*}}test_wmma_load_${function_suffix}(
-define ${ret_ty} @test_wmma_load_${function_suffix}(i8* %src ${extra_args}) {
+define ${ret_ty} @test_wmma_load_${function_suffix}(i8 ${as}* %src ${extra_args}) {
 ; CHECK wmma.load.${intrinsic_suffix}
 ; CHECK: {${check_result}}
 ; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
-  %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src ${extra_args});
+  %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args});
   ret ${ret_ty} %v0;
 }
 
 ; CHECK-LABEL: .func{{.*}}test_wmma_load_${function_suffix}_o(
-define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8* %src ${extra_args}) {
+define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8 ${as}* %src ${extra_args}) {
 ; CHECK wmma.load.${intrinsic_suffix}
 ; CHECK: {${check_result}}
 ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
-  %src1 = getelementptr i8, i8* %src, i32 128;
-  %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src1 ${extra_args});
+  %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
+  %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src1 ${extra_args});
   ret ${ret_ty} %v0;
 }
 """
-  suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
+  suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}"
   instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
 
   for abc, layout, space, stride, itype in product(
@@ -58,7 +74,9 @@ define ${ret_ty} @test_wmma_load_${funct
         "layout" : layout,
         "space" : space,
         "stride" : stride,
-        "itype" : itype
+        "itype" : itype,
+        "pspace" : get_pspace(space),
+        "as"     : "addrspace(%d)" % get_aspace(space)
     }
 
     if itype == "f32" and abc != "c":
@@ -89,28 +107,28 @@ def make_wmma_slice_args(itype, abcd, pr
 
 def gen_wmma_store_tests():
   store_template = """
-declare void @llvm.nvvm.wmma.store.$intrinsic_suffix(i8* %src, ${args}${extra_args});
+declare void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args}${extra_args});
 
 ; CHECK-LABEL: .func {{.*}}test_wmma_store_${function_suffix}(
-define void @test_wmma_store_${function_suffix}(i8* %src, ${args}${extra_args}) {
+define void @test_wmma_store_${function_suffix}(i8 ${as}* %src, ${args}${extra_args}) {
 ; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
 ; CHECK: {${check_args}}
 ; CHECK: ${stride_pattern}
-  call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src, ${args} ${extra_args});
+  call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args} ${extra_args});
   ret void
 }
 
 ; CHECK-LABEL: .func{{.*}}test_wmma_store_${function_suffix}_o(
-define void @test_wmma_store_${function_suffix}_o(i8* %src, ${args}${extra_args}) {
+define void @test_wmma_store_${function_suffix}_o(i8 ${as}* %src, ${args}${extra_args}) {
 ; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}+128]
 ; CHECK: ${check_args}
 ; CHECK: ${stride_pattern}
-  %src1 = getelementptr i8, i8* %src, i32 128;
-  call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src1, ${args}${extra_args});
+  %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
+  call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src1, ${args}${extra_args});
   ret void
 }
 """
-  suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
+  suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}"
   instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
 
   for abc, layout, space, stride, itype in product(
@@ -125,7 +143,9 @@ define void @test_wmma_store_${function_
         "layout" : layout,
         "space" : space,
         "stride" : stride,
-        "itype" : itype
+        "itype" : itype,
+        "pspace" : get_pspace(space),
+        "as"     : "addrspace(%d)" % get_aspace(space)
     }
 
     test_params = params




More information about the llvm-commits mailing list