[llvm] r327672 - [NVPTX] TblGen-ized lowering of WMMA intrinsics.
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 15 14:40:56 PDT 2018
Author: tra
Date: Thu Mar 15 14:40:56 2018
New Revision: 327672
URL: http://llvm.org/viewvc/llvm-project?rev=327672&view=rev
Log:
[NVPTX] TblGen-ized lowering of WMMA intrinsics.
NFC.
Differential Revision: https://reviews.llvm.org/D43151
Modified:
llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td
llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp?rev=327672&r1=327671&r2=327672&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp Thu Mar 15 14:40:56 2018
@@ -496,318 +496,8 @@ void NVPTXDAGToDAGISel::Select(SDNode *N
SelectCode(N);
}
-// Each instruction has four addressing variants. WMMA_VARIANTS() macro below
-// constructs an array indexed by WmmaVariant which getWmmaLdVariant() uses to
-// look up the intrinsic ID of particular variant.
-enum WmmaVariant {
- WMMA_VARIANT_ARI64,
- WMMA_VARIANT_ARI64_STRIDE,
- WMMA_VARIANT_AVAR,
- WMMA_VARIANT_AVAR_STRIDE,
-};
-
-// clang-format off
-#define WMMA_VARIANTS(base) \
- {{ base##_ari64, base##_ari64_stride, base##_avar, base##_avar_stride }}
-// clang-format on
-
-static unsigned getWmmaLdVariant(WmmaVariant Variant, bool Stride,
- const std::array<unsigned, 4> Variants) {
- if (Stride) {
- if (Variant == WMMA_VARIANT_ARI64)
- Variant = WMMA_VARIANT_ARI64_STRIDE;
- else if (Variant == WMMA_VARIANT_AVAR)
- Variant = WMMA_VARIANT_AVAR_STRIDE;
- }
- return Variants[Variant];
-}
-
-static Optional<unsigned>
-getWmmaLdStOpcode(unsigned IntrinsicID,
- WmmaVariant Variant = WMMA_VARIANT_ARI64) {
- switch (IntrinsicID) {
- default:
- return None;
- //
- // WMMA_LOAD_A f16
- //
- case Intrinsic::nvvm_wmma_load_a_f16_col:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col));
- case Intrinsic::nvvm_wmma_load_a_f16_row:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row));
- case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col));
- case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row));
- case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared));
- case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared));
- case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared));
- case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared));
- case Intrinsic::nvvm_wmma_load_a_f16_col_global:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global));
- case Intrinsic::nvvm_wmma_load_a_f16_row_global:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global));
- case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global));
- case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global));
-
- //
- // WMMA_LOAD_B f16
- //
- case Intrinsic::nvvm_wmma_load_b_f16_col:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col));
- case Intrinsic::nvvm_wmma_load_b_f16_row:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row));
- case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col));
- case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row));
- case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared));
- case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared));
- case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared));
- case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared));
- case Intrinsic::nvvm_wmma_load_b_f16_col_global:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global));
- case Intrinsic::nvvm_wmma_load_b_f16_row_global:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global));
- case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global));
- case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global));
-
- //
- // WMMA_LOAD_C f16
- //
- case Intrinsic::nvvm_wmma_load_c_f16_col:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col));
- case Intrinsic::nvvm_wmma_load_c_f16_row:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row));
- case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col));
- case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row));
- case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared));
- case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared));
- case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared));
- case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared));
- case Intrinsic::nvvm_wmma_load_c_f16_col_global:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global));
- case Intrinsic::nvvm_wmma_load_c_f16_row_global:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global));
- case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global));
- case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global));
-
- //
- // WMMA_LOAD_C f32
- //
- case Intrinsic::nvvm_wmma_load_c_f32_col:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col));
- case Intrinsic::nvvm_wmma_load_c_f32_row:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row));
- case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col));
- case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row));
- case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared));
- case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared));
- case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared));
- case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared));
- case Intrinsic::nvvm_wmma_load_c_f32_col_global:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global));
- case Intrinsic::nvvm_wmma_load_c_f32_row_global:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global));
- case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global));
- case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global));
-
- //
- // WMMA_STORE_D f16
- //
- case Intrinsic::nvvm_wmma_store_d_f16_col:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col));
- case Intrinsic::nvvm_wmma_store_d_f16_row:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row));
- case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col));
- case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row));
- case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared));
- case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared));
- case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared));
- case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared));
- case Intrinsic::nvvm_wmma_store_d_f16_col_global:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global));
- case Intrinsic::nvvm_wmma_store_d_f16_row_global:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global));
- case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global));
- case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global));
-
- //
- // WMMA_STORE_D f32
- //
- case Intrinsic::nvvm_wmma_store_d_f32_col:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col));
- case Intrinsic::nvvm_wmma_store_d_f32_row:
- return getWmmaLdVariant(Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row));
- case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col));
- case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
- return getWmmaLdVariant(Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row));
- case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared));
- case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared));
- case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared));
- case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared));
- case Intrinsic::nvvm_wmma_store_d_f32_col_global:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global));
- case Intrinsic::nvvm_wmma_store_d_f32_row_global:
- return getWmmaLdVariant(
- Variant, /*Stride=*/false,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global));
- case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global));
- case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride:
- return getWmmaLdVariant(
- Variant, /*Stride=*/true,
- WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global));
- }
-}
-#undef WMMA_VARIANTS
-
bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
- if (getWmmaLdStOpcode(IID))
- return tryWMMA_LDST(N);
-
switch (IID) {
default:
return false;
@@ -1026,39 +716,6 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoCh
case Intrinsic::nvvm_texsurf_handle_internal:
SelectTexSurfHandle(N);
return true;
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16:
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite:
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32:
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite:
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16:
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite:
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32:
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite:
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16:
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite:
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32:
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite:
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16:
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite:
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32:
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite:
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16:
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite:
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32:
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite:
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16:
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite:
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32:
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite:
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16:
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite:
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32:
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite:
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16:
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite:
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32:
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite:
- return tryWMMA_MMA(N);
}
}
@@ -3946,6 +3603,12 @@ bool NVPTXDAGToDAGISel::SelectADDRri64(S
return SelectADDRri_imp(OpNode, Addr, Base, Offset, MVT::i64);
}
+// symbol
+bool NVPTXDAGToDAGISel::SelectADDRvar(SDNode *OpNode, SDValue Addr,
+ SDValue &Value) {
+ return SelectDirectAddr(Addr, Value);
+}
+
bool NVPTXDAGToDAGISel::ChkMemSDNodeAddressSpace(SDNode *N,
unsigned int spN) const {
const Value *Src = nullptr;
@@ -4038,172 +3701,3 @@ unsigned NVPTXDAGToDAGISel::GetConvertOp
}
}
}
-
-bool NVPTXDAGToDAGISel::tryWMMA_LDST(SDNode *N) {
- SDValue Chain = N->getOperand(0);
- unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
- SDValue Op1 = N->getOperand(2);
- SDValue Addr, Offset, Base;
- Optional<unsigned> Opcode;
- SDLoc DL(N);
- MemSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
- WmmaVariant Variant;
- SmallVector<SDValue, 12> Ops;
- bool isStore = N->getNumValues() == 1; // Store ops only return a chain.
-
- if (SelectDirectAddr(Op1, Addr)) {
- Variant = WMMA_VARIANT_AVAR;
- Ops.push_back(Addr);
- } else if (SelectADDRsi64(Op1.getNode(), Op1, Base, Offset) ||
- SelectADDRri64(Op1.getNode(), Op1, Base, Offset)) {
- Variant = WMMA_VARIANT_ARI64;
- Ops.push_back(Base);
- Ops.push_back(Offset);
- } else {
- Variant = WMMA_VARIANT_AVAR;
- Ops.push_back(Op1);
- }
- unsigned NumOps = N->getNumOperands();
- // Pass through the rest of the operands to the machine node.
- for (unsigned i = 3; i < NumOps; ++i)
- Ops.push_back(N->getOperand(i));
- Ops.push_back(Chain);
-
- Opcode = getWmmaLdStOpcode(IID, Variant);
- if (!Opcode) {
- llvm::errs() << "tryWMMALD - no Opcode.\n";
- return false;
- }
-
- EVT MemVT = MemSD->getMemoryVT();
- assert(MemVT.isVector() && "Expected vector return type.");
-
- SDNode *MN;
- if (isStore) {
- MN = CurDAG->getMachineNode(Opcode.getValue(), DL, MVT::Other, Ops);
- } else {
- SmallVector<EVT, 9> InstVTs(MemVT.getVectorNumElements(),
- MemSD->getValueType(0));
- InstVTs.push_back(MVT::Other);
- MN = CurDAG->getMachineNode(Opcode.getValue(), DL, InstVTs, Ops);
- }
-
- ReplaceNode(N, MN);
- return true;
-}
-
-bool NVPTXDAGToDAGISel::tryWMMA_MMA(SDNode *N) {
- unsigned IID = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue();
- SDLoc DL(N);
- unsigned Opc;
-
- switch (IID) {
- default:
- return false;
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16:
- Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16_satfinite;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32:
- Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32_satfinite;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16:
- Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16_satfinite;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32:
- Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32_satfinite;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16:
- Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16_satfinite;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32:
- Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32_satfinite;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16:
- Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16_satfinite;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32:
- Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32_satfinite;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16:
- Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16_satfinite;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32:
- Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32_satfinite;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16:
- Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16_satfinite;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32:
- Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32_satfinite;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16:
- Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16_satfinite;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32:
- Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32_satfinite;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16:
- Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16_satfinite;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32:
- Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32;
- break;
- case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite:
- Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32_satfinite;
- break;
- }
-
- SmallVector<SDValue, 24> Ops;
- // Pass through operands and return value types to the machine node.
- for (unsigned i = 1; i < N->getNumOperands(); ++i)
- Ops.push_back(N->getOperand(i));
- SmallVector<EVT, 8> InstVTs(N->getNumValues(), N->getValueType(0));
- SDNode *MN = CurDAG->getMachineNode(Opc, DL, InstVTs, Ops);
- ReplaceNode(N, MN);
- return true;
-}
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.h?rev=327672&r1=327671&r2=327672&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.h Thu Mar 15 14:40:56 2018
@@ -89,13 +89,13 @@ private:
SDValue &Offset);
bool SelectADDRri64(SDNode *OpNode, SDValue Addr, SDValue &Base,
SDValue &Offset);
-
bool SelectADDRsi_imp(SDNode *OpNode, SDValue Addr, SDValue &Base,
SDValue &Offset, MVT mvt);
bool SelectADDRsi(SDNode *OpNode, SDValue Addr, SDValue &Base,
SDValue &Offset);
bool SelectADDRsi64(SDNode *OpNode, SDValue Addr, SDValue &Base,
SDValue &Offset);
+ bool SelectADDRvar(SDNode *OpNode, SDValue Addr, SDValue &Value);
bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const;
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp?rev=327672&r1=327671&r2=327672&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp Thu Mar 15 14:40:56 2018
@@ -3410,7 +3410,7 @@ bool NVPTXTargetLowering::getTgtMemIntri
case Intrinsic::nvvm_wmma_store_d_f16_row_global:
case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: {
- Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v4f16;
Info.ptrVal = I.getArgOperand(0);
Info.offset = 0;
@@ -3431,7 +3431,7 @@ bool NVPTXTargetLowering::getTgtMemIntri
case Intrinsic::nvvm_wmma_store_d_f32_row_global:
case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: {
- Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);
Info.offset = 0;
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td?rev=327672&r1=327671&r2=327672&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td Thu Mar 15 14:40:56 2018
@@ -1527,6 +1527,7 @@ def ADDRri : ComplexPattern<i32, 2, "Sel
[SDNPWantRoot]>;
def ADDRri64 : ComplexPattern<i64, 2, "SelectADDRri64", [frameindex],
[SDNPWantRoot]>;
+def ADDRvar : ComplexPattern<iPTR, 1, "SelectDirectAddr", [], []>;
def MEMri : Operand<i32> {
let PrintMethod = "printMemOperand";
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td?rev=327672&r1=327671&r2=327672&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td Thu Mar 15 14:40:56 2018
@@ -7372,44 +7372,73 @@ def INT_PTX_SREG_WARPSIZE :
//
// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
//
+
+class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>;
+
class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
string Type, NVPTXRegClass regclass,
- Operand SrcOp, int WithOffset, int WithStride>
- : NVPTXInst<!if(!eq(Abc#Type,"cf16"),
- (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3),
- (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
- regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)),
- !if(WithStride,
- !if(WithOffset,
- (ins SrcOp:$src, i32imm:$offset, Int32Regs:$ldm),
- (ins SrcOp:$src, Int32Regs:$ldm)),
- !if(WithOffset,
- (ins SrcOp:$src, i32imm:$offset),
- (ins SrcOp:$src))),
- "wmma.load."#Abc#".sync."#Layout#".m16n16k16"#Space#"." #Type# " \t"
+ DAGOperand SrcOp, bit WithStride>
+ : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
+ // Intrinsic that matches this instruction.
+ Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_load_"
+ # Abc
+ # "_" # Type
+ # "_" # Layout
+ # !subst(".","_",Space)
+ # !if(WithStride,"_stride", ""));
+ dag OutsR03 = (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3);
+ dag OutsR47 = (outs regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7);
+ dag Outs = !if(!eq(Abc#Type,"cf16"), OutsR03, !con(OutsR03, OutsR47));
+
+ dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins));
+ dag Ins = !con((ins SrcOp:$src), StrideArg);
+
+ // Build a dag pattern that matches the intrinsic call.
+ // We want a dag that looks like this:
+ // (set <output args>, (intrinsic <input arguments>)) where input and
+ // output arguments are named patterns that would match corresponding
+ // input/output arguments of the instruction.
+ //
+ // First we construct (set <output arguments>) from instruction's outs dag by
+ // replacing dag operator 'outs' with 'set'.
+ dag PatOuts = !foreach(tmp, Outs, !subst(outs, set, tmp));
+ // Similarly, construct (intrinsic <input arguments>) sub-dag from
+ // instruction's input arguments, only now we also need to replace operands
+ // with patterns that would match them and the operator 'ins' with the
+ // intrinsic.
+ dag PatArgs = !foreach(tmp, Ins,
+ !subst(imem, ADDRvar,
+ !subst(MEMri64, ADDRri64,
+ !subst(MEMri, ADDRri,
+ !subst(ins, Intr, tmp)))));
+ // Finally, consatenate both parts together. !con() requires both dags to have
+ // the same operator, so we wrap PatArgs in a (set ...) dag.
+ let Pattern = [!con(PatOuts, (set PatArgs))];
+ let OutOperandList = Outs;
+ let InOperandList = Ins;
+ let AsmString = "wmma.load."#Abc#".sync."#Layout#".m16n16k16"#Space#"." #Type# " \t"
#!if(!eq(Abc#Type,"cf16"),
"{{$r0, $r1, $r2, $r3}}",
"{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
- #", "
- #!if(WithOffset,"[$src+$offset]", "[$src]")
+ #", [$src]"
#!if(WithStride, ", $ldm", "")
- #";",
- []>,
- Requires<[hasPTX60, hasSM70]>;
+ #";";
+}
multiclass WMMA_LOAD_ALSTO<string Abc, string Layout, string Space,
string Type, NVPTXRegClass regclass,
- Operand SrcOp, int WithOffset = 0> {
- def _stride: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp,
- WithOffset, 1>;
- def NAME: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp,
- WithOffset, 0>;
+ DAGOperand SrcOp> {
+ def _stride: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp, 1>;
+ def NAME: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp, 0>;
}
multiclass WMMA_LOAD_ALST<string Abc, string Layout, string Space,
string Type, NVPTXRegClass regclass> {
- defm _avar: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imemAny, 0>;
- defm _ari64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imemAny, 1>;
+ defm _avar: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imem>;
+ defm _areg: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, Int32Regs>;
+ defm _areg64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, Int64Regs>;
+ defm _ari: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, MEMri>;
+ defm _ari64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, MEMri64>;
}
multiclass WMMA_LOAD_ALT<string Abc, string Layout,
@@ -7434,62 +7463,58 @@ defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"
//
class WMMA_STORE_D_LSTOS<string Layout, string Space,
string Type, NVPTXRegClass regclass,
- Operand DstOp, int WithOffset, int WithStride>
- : NVPTXInst<(outs),
- !if(!eq(Type,"f16"),
- !if(WithStride,
- !if(WithOffset,
- (ins DstOp:$src, i32imm:$offset,
- regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
- Int32Regs:$ldm),
- (ins DstOp:$src,
- regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
- Int32Regs:$ldm)),
- !if(WithOffset,
- (ins DstOp:$src, i32imm:$offset,
- regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3),
- (ins DstOp:$src,
- regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3))),
- !if(WithStride,
- !if(WithOffset,
- (ins DstOp:$src, i32imm:$offset,
- regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
- regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7,
- Int32Regs:$ldm),
- (ins DstOp:$src,
- regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
- regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7,
- Int32Regs:$ldm)),
- !if(WithOffset,
- (ins DstOp:$src, i32imm:$offset,
- regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
- regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7),
- (ins DstOp:$src,
- regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
- regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)))),
- "wmma.store.d.sync."#Layout#".m16n16k16"#Space#"." #Type# " \t"
- #!if(WithOffset,"[$src+$offset], ", "[$src], ")
- #!if(!eq(Type,"f16"),
- "{{$r0, $r1, $r2, $r3}}",
- "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
- #!if(WithStride, ", $ldm", "")
- #";",
- []>,
- Requires<[hasPTX60, hasSM70]>;
+ DAGOperand DstOp, bit WithStride>
+ : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
+ Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_store_d_"
+ # Type
+ # "_" # Layout
+ # !subst(".","_",Space)
+ # !if(WithStride,"_stride", ""));
+
+ dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3);
+ dag InsR47 = (ins regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7);
+ dag InsR = !if(!eq(Type,"f16"), InsR03, !con(InsR03, InsR47));
+ dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins));
+ dag Ins = !con(InsR, StrideArg);
+
+ // Construct the pattern to match corresponding intrinsic call. See the
+ // details in the comments in WMMA_LOAD_ALSTOS.
+ dag PatArgs = !foreach(tmp, Ins,
+ !subst(imem, ADDRvar,
+ !subst(MEMri64, ADDRri64,
+ !subst(MEMri, ADDRri,
+ !subst(ins, Intr, tmp)))));
+ let Pattern = [PatArgs];
+ let OutOperandList = (outs);
+ let InOperandList = Ins;
+ let AsmString = "wmma.store.d.sync."
+ # Layout
+ # ".m16n16k16"
+ # Space
+ # "." # Type
+ # " \t[$src],"
+ # !if(!eq(Type,"f16"),
+ "{{$r0, $r1, $r2, $r3}}",
+ "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
+ # !if(WithStride, ", $ldm", "")
+ # ";";
+
+}
multiclass WMMA_STORE_D_LSTO<string Layout, string Space,
string Type, NVPTXRegClass regclass,
- Operand DstOp, int WithOffset = 0> {
- def _stride: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp,
- WithOffset, 1>;
- def NAME: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp,
- WithOffset, 0>;
+ DAGOperand DstOp> {
+ def _stride: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp, 1>;
+ def NAME: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp, 0>;
}
multiclass WMMA_STORE_D_LST<string Layout, string Space,
string Type, NVPTXRegClass regclass> {
- defm _avar: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imemAny, 0>;
- defm _ari64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imemAny, 1>;
+ defm _avar: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imem>;
+ defm _areg: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, Int32Regs>;
+ defm _areg64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, Int64Regs>;
+ defm _ari: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, MEMri>;
+ defm _ari64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, MEMri64>;
}
multiclass WMMA_STORE_D_LT<string Layout,
@@ -7500,8 +7525,8 @@ multiclass WMMA_STORE_D_LT<string Layout
}
multiclass WMMA_STORE_D_T<string Type, NVPTXRegClass regclass> {
- defm _row: WMMA_STORE_D_LT<"row", Type, regclass>;
- defm _col: WMMA_STORE_D_LT<"col", Type, regclass>;
+ defm _row: WMMA_STORE_D_LT<"row", Type, regclass>;
+ defm _col: WMMA_STORE_D_LT<"col", Type, regclass>;
}
defm INT_WMMA_STORE_D_f16: WMMA_STORE_D_T<"f16", Float16x2Regs>;
@@ -7513,35 +7538,50 @@ class WMMA_MMA_ABDCS<string ALayout, str
string CType, NVPTXRegClass c_reg,
NVPTXRegClass ab_reg,
string Satfinite = "">
- : NVPTXInst<!if(!eq(DType,"f16"),
- (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3),
- (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3,
- d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7)),
- !if(!eq(CType,"f16"),
- (ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
- ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
- ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
- ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
- c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3),
- (ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
- ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
- ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
- ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
- c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3,
- c_reg:$c4, c_reg:$c5, c_reg:$c6, c_reg:$c7)),
- "wmma.mma.sync."#ALayout#"."#BLayout#".m16n16k16."#
- #DType#"."#CType#Satfinite
- #"\n\t\t"
- #!if(!eq(DType,"f16"),
- "{{$d0, $d1, $d2, $d3}}, \n\t\t",
- "{{$d0, $d1, $d2, $d3, $d4, $d5, $d6, $d7}},\n\t\t")
- #"{{$a0, $a1, $a2, $a3, $a4, $a5, $a6, $a7}},\n\t\t"
- #"{{$b0, $b1, $b2, $b3, $b4, $b5, $b6, $b7}},\n\t\t"
- #!if(!eq(CType,"f16"),
- "{{$c0, $c1, $c2, $c3}};",
- "{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};"),
- []>,
- Requires<[hasPTX60, hasSM70]>;
+ : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
+ Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_mma_sync_"
+ # ALayout
+ # "_" # BLayout
+ # "_" # DType
+ # "_" # CType
+ # !subst(".","_",Satfinite));
+ dag Outs = !if(!eq(DType,"f16"),
+ (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3),
+ (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3,
+ d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7));
+ dag InsExtraCArgs = !if(!eq(CType,"f16"),
+ (ins),
+ (ins c_reg:$c4, c_reg:$c5, c_reg:$c6, c_reg:$c7));
+ dag Ins = !con((ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
+ ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
+ ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
+ ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
+ c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3),
+ InsExtraCArgs);
+
+ // Construct the pattern to match corresponding intrinsic call. See the
+ // details in the comments in WMMA_LOAD_ALSTOS.
+ dag PatOuts = !foreach(tmp, Outs, !subst(outs, set, tmp));
+ dag PatArgs = !foreach(tmp, Ins, !subst(ins, Intr, tmp));
+ let Pattern = [!con(PatOuts, (set PatArgs))];
+ let OutOperandList = Outs;
+ let InOperandList = Ins;
+ let AsmString = "wmma.mma.sync."
+ # ALayout
+ # "." # BLayout
+ # ".m16n16k16"
+ # "." # DType
+ # "." # CType
+ # Satfinite # "\n\t\t"
+ # !if(!eq(DType,"f16"),
+ "{{$d0, $d1, $d2, $d3}}, \n\t\t",
+ "{{$d0, $d1, $d2, $d3, $d4, $d5, $d6, $d7}},\n\t\t")
+ # "{{$a0, $a1, $a2, $a3, $a4, $a5, $a6, $a7}},\n\t\t"
+ # "{{$b0, $b1, $b2, $b3, $b4, $b5, $b6, $b7}},\n\t\t"
+ # !if(!eq(CType,"f16"),
+ "{{$c0, $c1, $c2, $c3}};",
+ "{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};");
+}
multiclass WMMA_MMA_ABDC<string ALayout, string BLayout,
string DType, NVPTXRegClass d_reg,
More information about the llvm-commits
mailing list