[llvm] a1f5fe8 - [NVPTX] Optimize v2x16 BUILD_VECTORs to PRMT (#116675)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 17 02:22:23 PST 2024


Author: Fraser Cormack
Date: 2024-12-17T10:22:19Z
New Revision: a1f5fe8c851ba6a0070e4cab9e7436e962677ac6

URL: https://github.com/llvm/llvm-project/commit/a1f5fe8c851ba6a0070e4cab9e7436e962677ac6
DIFF: https://github.com/llvm/llvm-project/commit/a1f5fe8c851ba6a0070e4cab9e7436e962677ac6.diff

LOG: [NVPTX] Optimize v2x16 BUILD_VECTORs to PRMT (#116675)

When two 16-bit values are combined into a v2x16 vector, and those
values are truncated come from 32-bit values, a PRMT instruction can
save registers by selecting bytes directly from the original 32-bit
values. We do this during a post-legalize DAG combine, as these
opportunities are typically only exposed after the BUILD_VECTOR's
operands have been legalized.

Additionally, if the 32-bit values are right-shifted, we can fold in the
shift by selecting higher bytes with PRMT. Only logical right-shifts by
16 are supported (for now) since those are the only situations seen in
practice. Right shifts by 16 often come up during the legalization of
EXTRACT_VECTOR_ELT.

This idea was brought up in a PR comment by @Artem-B.

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/test/CodeGen/NVPTX/bf16-instructions.ll
    llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
    llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
    llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
    llvm/test/CodeGen/NVPTX/i16x2-instructions.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index a02607efb7fc28..7aa8b6ff55e8a7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -767,7 +767,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   // We have some custom DAG combine patterns for these nodes
   setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
                        ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM,
-                       ISD::VSELECT});
+                       ISD::VSELECT, ISD::BUILD_VECTOR});
 
   // setcc for f16x2 and bf16x2 needs special handling to prevent
   // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5120,6 +5120,66 @@ static SDValue PerformLOADCombine(SDNode *N,
       DL);
 }
 
+static SDValue
+PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+  auto VT = N->getValueType(0);
+  if (!DCI.isAfterLegalizeDAG() || !Isv2x16VT(VT))
+    return SDValue();
+
+  auto Op0 = N->getOperand(0);
+  auto Op1 = N->getOperand(1);
+
+  // Start out by assuming we want to take the lower 2 bytes of each i32
+  // operand.
+  uint64_t Op0Bytes = 0x10;
+  uint64_t Op1Bytes = 0x54;
+
+  std::pair<SDValue *, uint64_t *> OpData[2] = {{&Op0, &Op0Bytes},
+                                                {&Op1, &Op1Bytes}};
+
+  // Check that each operand is an i16, truncated from an i32 operand. We'll
+  // select individual bytes from those original operands. Optionally, fold in a
+  // shift right of that original operand.
+  for (auto &[Op, OpBytes] : OpData) {
+    // Eat up any bitcast
+    if (Op->getOpcode() == ISD::BITCAST)
+      *Op = Op->getOperand(0);
+
+    if (!(Op->getValueType() == MVT::i16 && Op->getOpcode() == ISD::TRUNCATE &&
+          Op->getOperand(0).getValueType() == MVT::i32))
+      return SDValue();
+
+    // If the truncate has multiple uses, this optimization can increase
+    // register pressure
+    if (!Op->hasOneUse())
+      return SDValue();
+
+    *Op = Op->getOperand(0);
+
+    // Optionally, fold in a shift-right of the original operand and let permute
+    // pick the two higher bytes of the original value directly.
+    if (Op->getOpcode() == ISD::SRL && isa<ConstantSDNode>(Op->getOperand(1))) {
+      if (cast<ConstantSDNode>(Op->getOperand(1))->getZExtValue() == 16) {
+        // Shift the PRMT byte selector to pick upper bytes from each respective
+        // value, instead of the lower ones: 0x10 -> 0x32, 0x54 -> 0x76
+        assert((*OpBytes == 0x10 || *OpBytes == 0x54) &&
+               "PRMT selector values out of range");
+        *OpBytes += 0x22;
+        *Op = Op->getOperand(0);
+      }
+    }
+  }
+
+  SDLoc DL(N);
+  auto &DAG = DCI.DAG;
+
+  auto PRMT = DAG.getNode(
+      NVPTXISD::PRMT, DL, MVT::v4i8,
+      {Op0, Op1, DAG.getConstant((Op1Bytes << 8) | Op0Bytes, DL, MVT::i32),
+       DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
+  return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
+}
+
 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5154,6 +5214,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
       return PerformEXTRACTCombine(N, DCI);
     case ISD::VSELECT:
       return PerformVSELECTCombine(N, DCI);
+    case ISD::BUILD_VECTOR:
+      return PerformBUILD_VECTORCombine(N, DCI);
   }
   return SDValue();
 }

diff  --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 6bf834ca832c27..08ed317ef93007 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -159,8 +159,8 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM70-LABEL: test_faddx2(
 ; SM70:       {
 ; SM70-NEXT:    .reg .pred %p<3>;
-; SM70-NEXT:    .reg .b16 %rs<13>;
-; SM70-NEXT:    .reg .b32 %r<24>;
+; SM70-NEXT:    .reg .b16 %rs<9>;
+; SM70-NEXT:    .reg .b32 %r<25>;
 ; SM70-NEXT:    .reg .f32 %f<7>;
 ; SM70-EMPTY:
 ; SM70-NEXT:  // %bb.0:
@@ -182,7 +182,6 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM70-NEXT:    setp.nan.f32 %p1, %f3, %f3;
 ; SM70-NEXT:    or.b32 %r11, %r7, 4194304;
 ; SM70-NEXT:    selp.b32 %r12, %r11, %r10, %p1;
-; SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs7}, %r12; }
 ; SM70-NEXT:    cvt.u32.u16 %r13, %rs1;
 ; SM70-NEXT:    shl.b32 %r14, %r13, 16;
 ; SM70-NEXT:    mov.b32 %f4, %r14;
@@ -197,8 +196,7 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM70-NEXT:    setp.nan.f32 %p2, %f6, %f6;
 ; SM70-NEXT:    or.b32 %r21, %r17, 4194304;
 ; SM70-NEXT:    selp.b32 %r22, %r21, %r20, %p2;
-; SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs11}, %r22; }
-; SM70-NEXT:    mov.b32 %r23, {%rs11, %rs7};
+; SM70-NEXT:    prmt.b32 %r23, %r22, %r12, 0x7632U;
 ; SM70-NEXT:    st.param.b32 [func_retval0], %r23;
 ; SM70-NEXT:    ret;
 ;
@@ -262,8 +260,8 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM70-LABEL: test_fsubx2(
 ; SM70:       {
 ; SM70-NEXT:    .reg .pred %p<3>;
-; SM70-NEXT:    .reg .b16 %rs<13>;
-; SM70-NEXT:    .reg .b32 %r<24>;
+; SM70-NEXT:    .reg .b16 %rs<9>;
+; SM70-NEXT:    .reg .b32 %r<25>;
 ; SM70-NEXT:    .reg .f32 %f<7>;
 ; SM70-EMPTY:
 ; SM70-NEXT:  // %bb.0:
@@ -285,7 +283,6 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM70-NEXT:    setp.nan.f32 %p1, %f3, %f3;
 ; SM70-NEXT:    or.b32 %r11, %r7, 4194304;
 ; SM70-NEXT:    selp.b32 %r12, %r11, %r10, %p1;
-; SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs7}, %r12; }
 ; SM70-NEXT:    cvt.u32.u16 %r13, %rs1;
 ; SM70-NEXT:    shl.b32 %r14, %r13, 16;
 ; SM70-NEXT:    mov.b32 %f4, %r14;
@@ -300,8 +297,7 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM70-NEXT:    setp.nan.f32 %p2, %f6, %f6;
 ; SM70-NEXT:    or.b32 %r21, %r17, 4194304;
 ; SM70-NEXT:    selp.b32 %r22, %r21, %r20, %p2;
-; SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs11}, %r22; }
-; SM70-NEXT:    mov.b32 %r23, {%rs11, %rs7};
+; SM70-NEXT:    prmt.b32 %r23, %r22, %r12, 0x7632U;
 ; SM70-NEXT:    st.param.b32 [func_retval0], %r23;
 ; SM70-NEXT:    ret;
 ;
@@ -365,8 +361,8 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM70-LABEL: test_fmulx2(
 ; SM70:       {
 ; SM70-NEXT:    .reg .pred %p<3>;
-; SM70-NEXT:    .reg .b16 %rs<13>;
-; SM70-NEXT:    .reg .b32 %r<24>;
+; SM70-NEXT:    .reg .b16 %rs<9>;
+; SM70-NEXT:    .reg .b32 %r<25>;
 ; SM70-NEXT:    .reg .f32 %f<7>;
 ; SM70-EMPTY:
 ; SM70-NEXT:  // %bb.0:
@@ -388,7 +384,6 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM70-NEXT:    setp.nan.f32 %p1, %f3, %f3;
 ; SM70-NEXT:    or.b32 %r11, %r7, 4194304;
 ; SM70-NEXT:    selp.b32 %r12, %r11, %r10, %p1;
-; SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs7}, %r12; }
 ; SM70-NEXT:    cvt.u32.u16 %r13, %rs1;
 ; SM70-NEXT:    shl.b32 %r14, %r13, 16;
 ; SM70-NEXT:    mov.b32 %f4, %r14;
@@ -403,8 +398,7 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM70-NEXT:    setp.nan.f32 %p2, %f6, %f6;
 ; SM70-NEXT:    or.b32 %r21, %r17, 4194304;
 ; SM70-NEXT:    selp.b32 %r22, %r21, %r20, %p2;
-; SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs11}, %r22; }
-; SM70-NEXT:    mov.b32 %r23, {%rs11, %rs7};
+; SM70-NEXT:    prmt.b32 %r23, %r22, %r12, 0x7632U;
 ; SM70-NEXT:    st.param.b32 [func_retval0], %r23;
 ; SM70-NEXT:    ret;
 ;
@@ -468,8 +462,8 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM70-LABEL: test_fdiv(
 ; SM70:       {
 ; SM70-NEXT:    .reg .pred %p<3>;
-; SM70-NEXT:    .reg .b16 %rs<13>;
-; SM70-NEXT:    .reg .b32 %r<24>;
+; SM70-NEXT:    .reg .b16 %rs<9>;
+; SM70-NEXT:    .reg .b32 %r<25>;
 ; SM70-NEXT:    .reg .f32 %f<7>;
 ; SM70-EMPTY:
 ; SM70-NEXT:  // %bb.0:
@@ -491,7 +485,6 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM70-NEXT:    setp.nan.f32 %p1, %f3, %f3;
 ; SM70-NEXT:    or.b32 %r11, %r7, 4194304;
 ; SM70-NEXT:    selp.b32 %r12, %r11, %r10, %p1;
-; SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs7}, %r12; }
 ; SM70-NEXT:    cvt.u32.u16 %r13, %rs1;
 ; SM70-NEXT:    shl.b32 %r14, %r13, 16;
 ; SM70-NEXT:    mov.b32 %f4, %r14;
@@ -506,8 +499,7 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM70-NEXT:    setp.nan.f32 %p2, %f6, %f6;
 ; SM70-NEXT:    or.b32 %r21, %r17, 4194304;
 ; SM70-NEXT:    selp.b32 %r22, %r21, %r20, %p2;
-; SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs11}, %r22; }
-; SM70-NEXT:    mov.b32 %r23, {%rs11, %rs7};
+; SM70-NEXT:    prmt.b32 %r23, %r22, %r12, 0x7632U;
 ; SM70-NEXT:    st.param.b32 [func_retval0], %r23;
 ; SM70-NEXT:    ret;
 ;
@@ -1706,8 +1698,8 @@ define <2 x bfloat> @test_maxnum_v2(<2 x bfloat> %a, <2 x bfloat> %b) {
 ; SM70-LABEL: test_maxnum_v2(
 ; SM70:       {
 ; SM70-NEXT:    .reg .pred %p<3>;
-; SM70-NEXT:    .reg .b16 %rs<13>;
-; SM70-NEXT:    .reg .b32 %r<24>;
+; SM70-NEXT:    .reg .b16 %rs<9>;
+; SM70-NEXT:    .reg .b32 %r<25>;
 ; SM70-NEXT:    .reg .f32 %f<7>;
 ; SM70-EMPTY:
 ; SM70-NEXT:  // %bb.0:
@@ -1729,7 +1721,6 @@ define <2 x bfloat> @test_maxnum_v2(<2 x bfloat> %a, <2 x bfloat> %b) {
 ; SM70-NEXT:    setp.nan.f32 %p1, %f3, %f3;
 ; SM70-NEXT:    or.b32 %r11, %r7, 4194304;
 ; SM70-NEXT:    selp.b32 %r12, %r11, %r10, %p1;
-; SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs7}, %r12; }
 ; SM70-NEXT:    cvt.u32.u16 %r13, %rs1;
 ; SM70-NEXT:    shl.b32 %r14, %r13, 16;
 ; SM70-NEXT:    mov.b32 %f4, %r14;
@@ -1744,8 +1735,7 @@ define <2 x bfloat> @test_maxnum_v2(<2 x bfloat> %a, <2 x bfloat> %b) {
 ; SM70-NEXT:    setp.nan.f32 %p2, %f6, %f6;
 ; SM70-NEXT:    or.b32 %r21, %r17, 4194304;
 ; SM70-NEXT:    selp.b32 %r22, %r21, %r20, %p2;
-; SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs11}, %r22; }
-; SM70-NEXT:    mov.b32 %r23, {%rs11, %rs7};
+; SM70-NEXT:    prmt.b32 %r23, %r22, %r12, 0x7632U;
 ; SM70-NEXT:    st.param.b32 [func_retval0], %r23;
 ; SM70-NEXT:    ret;
 ;

diff  --git a/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll b/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
index 498821163686b9..fde2255a783434 100644
--- a/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-relu-contract.ll
@@ -1046,8 +1046,8 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
 ; CHECK-SM70-LABEL: fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(
 ; CHECK-SM70:       {
 ; CHECK-SM70-NEXT:    .reg .pred %p<9>;
-; CHECK-SM70-NEXT:    .reg .b16 %rs<25>;
-; CHECK-SM70-NEXT:    .reg .b32 %r<61>;
+; CHECK-SM70-NEXT:    .reg .b16 %rs<21>;
+; CHECK-SM70-NEXT:    .reg .b32 %r<62>;
 ; CHECK-SM70-NEXT:    .reg .f32 %f<19>;
 ; CHECK-SM70-EMPTY:
 ; CHECK-SM70-NEXT:  // %bb.0:
@@ -1130,7 +1130,6 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p7, %f15, %f15;
 ; CHECK-SM70-NEXT:    or.b32 %r49, %r45, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r50, %r49, %r48, %p7;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs20}, %r50; }
 ; CHECK-SM70-NEXT:    cvt.u32.u16 %r51, %rs17;
 ; CHECK-SM70-NEXT:    shl.b32 %r52, %r51, 16;
 ; CHECK-SM70-NEXT:    mov.b32 %f16, %r52;
@@ -1144,8 +1143,7 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p8, %f18, %f18;
 ; CHECK-SM70-NEXT:    or.b32 %r58, %r54, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r59, %r58, %r57, %p8;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs23}, %r59; }
-; CHECK-SM70-NEXT:    mov.b32 %r60, {%rs23, %rs20};
+; CHECK-SM70-NEXT:    prmt.b32 %r60, %r59, %r50, 0x7632U;
 ; CHECK-SM70-NEXT:    st.param.b32 [func_retval0], %r60;
 ; CHECK-SM70-NEXT:    ret;
   %1 = fmul <2 x bfloat> %a, %b
@@ -1185,8 +1183,8 @@ define <2 x bfloat> @fma_bf16x2_expanded_maxnum_no_nans(<2 x bfloat> %a, <2 x bf
 ; CHECK-SM70-LABEL: fma_bf16x2_expanded_maxnum_no_nans(
 ; CHECK-SM70:       {
 ; CHECK-SM70-NEXT:    .reg .pred %p<5>;
-; CHECK-SM70-NEXT:    .reg .b16 %rs<17>;
-; CHECK-SM70-NEXT:    .reg .b32 %r<43>;
+; CHECK-SM70-NEXT:    .reg .b16 %rs<13>;
+; CHECK-SM70-NEXT:    .reg .b32 %r<44>;
 ; CHECK-SM70-NEXT:    .reg .f32 %f<13>;
 ; CHECK-SM70-EMPTY:
 ; CHECK-SM70-NEXT:  // %bb.0:
@@ -1240,7 +1238,6 @@ define <2 x bfloat> @fma_bf16x2_expanded_maxnum_no_nans(<2 x bfloat> %a, <2 x bf
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p3, %f10, %f10;
 ; CHECK-SM70-NEXT:    or.b32 %r33, %r29, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r34, %r33, %r32, %p3;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs13}, %r34; }
 ; CHECK-SM70-NEXT:    and.b32 %r35, %r15, -65536;
 ; CHECK-SM70-NEXT:    mov.b32 %f11, %r35;
 ; CHECK-SM70-NEXT:    max.f32 %f12, %f11, 0f00000000;
@@ -1251,8 +1248,7 @@ define <2 x bfloat> @fma_bf16x2_expanded_maxnum_no_nans(<2 x bfloat> %a, <2 x bf
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p4, %f12, %f12;
 ; CHECK-SM70-NEXT:    or.b32 %r40, %r36, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r41, %r40, %r39, %p4;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs15}, %r41; }
-; CHECK-SM70-NEXT:    mov.b32 %r42, {%rs15, %rs13};
+; CHECK-SM70-NEXT:    prmt.b32 %r42, %r41, %r34, 0x7632U;
 ; CHECK-SM70-NEXT:    st.param.b32 [func_retval0], %r42;
 ; CHECK-SM70-NEXT:    ret;
   %1 = fmul <2 x bfloat> %a, %b

diff  --git a/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll b/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
index 64e037c1527fc7..723c72165f7858 100644
--- a/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-relu-fma-intrinsic.ll
@@ -711,8 +711,8 @@ define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2
 ; CHECK-SM70-LABEL: fma_bf16x2_no_nans_multiple_uses_of_fma(
 ; CHECK-SM70:       {
 ; CHECK-SM70-NEXT:    .reg .pred %p<7>;
-; CHECK-SM70-NEXT:    .reg .b16 %rs<17>;
-; CHECK-SM70-NEXT:    .reg .b32 %r<57>;
+; CHECK-SM70-NEXT:    .reg .b16 %rs<13>;
+; CHECK-SM70-NEXT:    .reg .b32 %r<58>;
 ; CHECK-SM70-NEXT:    .reg .f32 %f<17>;
 ; CHECK-SM70-EMPTY:
 ; CHECK-SM70-NEXT:  // %bb.0:
@@ -786,7 +786,6 @@ define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p5, %f14, %f14;
 ; CHECK-SM70-NEXT:    or.b32 %r47, %r43, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r48, %r47, %r46, %p5;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs13}, %r48; }
 ; CHECK-SM70-NEXT:    and.b32 %r49, %r34, -65536;
 ; CHECK-SM70-NEXT:    mov.b32 %f15, %r49;
 ; CHECK-SM70-NEXT:    add.f32 %f16, %f15, %f9;
@@ -797,8 +796,7 @@ define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p6, %f16, %f16;
 ; CHECK-SM70-NEXT:    or.b32 %r54, %r50, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r55, %r54, %r53, %p6;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs15}, %r55; }
-; CHECK-SM70-NEXT:    mov.b32 %r56, {%rs15, %rs13};
+; CHECK-SM70-NEXT:    prmt.b32 %r56, %r55, %r48, 0x7632U;
 ; CHECK-SM70-NEXT:    st.param.b32 [func_retval0], %r56;
 ; CHECK-SM70-NEXT:    ret;
   %1 = call <2 x bfloat> @llvm.fma.bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
@@ -837,8 +835,8 @@ define <2 x bfloat> @fma_bf16x2_maxnum_no_nans(<2 x bfloat> %a, <2 x bfloat> %b,
 ; CHECK-SM70-LABEL: fma_bf16x2_maxnum_no_nans(
 ; CHECK-SM70:       {
 ; CHECK-SM70-NEXT:    .reg .pred %p<5>;
-; CHECK-SM70-NEXT:    .reg .b16 %rs<17>;
-; CHECK-SM70-NEXT:    .reg .b32 %r<43>;
+; CHECK-SM70-NEXT:    .reg .b16 %rs<13>;
+; CHECK-SM70-NEXT:    .reg .b32 %r<44>;
 ; CHECK-SM70-NEXT:    .reg .f32 %f<13>;
 ; CHECK-SM70-EMPTY:
 ; CHECK-SM70-NEXT:  // %bb.0:
@@ -892,7 +890,6 @@ define <2 x bfloat> @fma_bf16x2_maxnum_no_nans(<2 x bfloat> %a, <2 x bfloat> %b,
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p3, %f10, %f10;
 ; CHECK-SM70-NEXT:    or.b32 %r33, %r29, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r34, %r33, %r32, %p3;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs13}, %r34; }
 ; CHECK-SM70-NEXT:    and.b32 %r35, %r15, -65536;
 ; CHECK-SM70-NEXT:    mov.b32 %f11, %r35;
 ; CHECK-SM70-NEXT:    max.f32 %f12, %f11, 0f00000000;
@@ -903,8 +900,7 @@ define <2 x bfloat> @fma_bf16x2_maxnum_no_nans(<2 x bfloat> %a, <2 x bfloat> %b,
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p4, %f12, %f12;
 ; CHECK-SM70-NEXT:    or.b32 %r40, %r36, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r41, %r40, %r39, %p4;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs15}, %r41; }
-; CHECK-SM70-NEXT:    mov.b32 %r42, {%rs15, %rs13};
+; CHECK-SM70-NEXT:    prmt.b32 %r42, %r41, %r34, 0x7632U;
 ; CHECK-SM70-NEXT:    st.param.b32 [func_retval0], %r42;
 ; CHECK-SM70-NEXT:    ret;
   %1 = call <2 x bfloat> @llvm.fma.bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)

diff  --git a/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll b/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
index eb009247a5e092..ce35d9066a4759 100644
--- a/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
+++ b/llvm/test/CodeGen/NVPTX/fma-relu-instruction-flag.ll
@@ -781,8 +781,8 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
 ; CHECK-SM70-LABEL: fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(
 ; CHECK-SM70:       {
 ; CHECK-SM70-NEXT:    .reg .pred %p<9>;
-; CHECK-SM70-NEXT:    .reg .b16 %rs<25>;
-; CHECK-SM70-NEXT:    .reg .b32 %r<61>;
+; CHECK-SM70-NEXT:    .reg .b16 %rs<21>;
+; CHECK-SM70-NEXT:    .reg .b32 %r<62>;
 ; CHECK-SM70-NEXT:    .reg .f32 %f<19>;
 ; CHECK-SM70-EMPTY:
 ; CHECK-SM70-NEXT:  // %bb.0:
@@ -865,7 +865,6 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p7, %f15, %f15;
 ; CHECK-SM70-NEXT:    or.b32 %r49, %r45, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r50, %r49, %r48, %p7;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs20}, %r50; }
 ; CHECK-SM70-NEXT:    cvt.u32.u16 %r51, %rs17;
 ; CHECK-SM70-NEXT:    shl.b32 %r52, %r51, 16;
 ; CHECK-SM70-NEXT:    mov.b32 %f16, %r52;
@@ -879,8 +878,7 @@ define <2 x bfloat> @fma_bf16x2_expanded_no_nans_multiple_uses_of_fma(<2 x bfloa
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p8, %f18, %f18;
 ; CHECK-SM70-NEXT:    or.b32 %r58, %r54, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r59, %r58, %r57, %p8;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs23}, %r59; }
-; CHECK-SM70-NEXT:    mov.b32 %r60, {%rs23, %rs20};
+; CHECK-SM70-NEXT:    prmt.b32 %r60, %r59, %r50, 0x7632U;
 ; CHECK-SM70-NEXT:    st.param.b32 [func_retval0], %r60;
 ; CHECK-SM70-NEXT:    ret;
   %1 = fmul fast <2 x bfloat> %a, %b
@@ -920,8 +918,8 @@ define <2 x bfloat> @fma_bf16x2_expanded_maxnum_no_nans(<2 x bfloat> %a, <2 x bf
 ; CHECK-SM70-LABEL: fma_bf16x2_expanded_maxnum_no_nans(
 ; CHECK-SM70:       {
 ; CHECK-SM70-NEXT:    .reg .pred %p<5>;
-; CHECK-SM70-NEXT:    .reg .b16 %rs<17>;
-; CHECK-SM70-NEXT:    .reg .b32 %r<43>;
+; CHECK-SM70-NEXT:    .reg .b16 %rs<13>;
+; CHECK-SM70-NEXT:    .reg .b32 %r<44>;
 ; CHECK-SM70-NEXT:    .reg .f32 %f<13>;
 ; CHECK-SM70-EMPTY:
 ; CHECK-SM70-NEXT:  // %bb.0:
@@ -975,7 +973,6 @@ define <2 x bfloat> @fma_bf16x2_expanded_maxnum_no_nans(<2 x bfloat> %a, <2 x bf
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p3, %f10, %f10;
 ; CHECK-SM70-NEXT:    or.b32 %r33, %r29, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r34, %r33, %r32, %p3;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs13}, %r34; }
 ; CHECK-SM70-NEXT:    and.b32 %r35, %r15, -65536;
 ; CHECK-SM70-NEXT:    mov.b32 %f11, %r35;
 ; CHECK-SM70-NEXT:    max.f32 %f12, %f11, 0f00000000;
@@ -986,8 +983,7 @@ define <2 x bfloat> @fma_bf16x2_expanded_maxnum_no_nans(<2 x bfloat> %a, <2 x bf
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p4, %f12, %f12;
 ; CHECK-SM70-NEXT:    or.b32 %r40, %r36, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r41, %r40, %r39, %p4;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs15}, %r41; }
-; CHECK-SM70-NEXT:    mov.b32 %r42, {%rs15, %rs13};
+; CHECK-SM70-NEXT:    prmt.b32 %r42, %r41, %r34, 0x7632U;
 ; CHECK-SM70-NEXT:    st.param.b32 [func_retval0], %r42;
 ; CHECK-SM70-NEXT:    ret;
   %1 = fmul fast <2 x bfloat> %a, %b
@@ -1702,8 +1698,8 @@ define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2
 ; CHECK-SM70-LABEL: fma_bf16x2_no_nans_multiple_uses_of_fma(
 ; CHECK-SM70:       {
 ; CHECK-SM70-NEXT:    .reg .pred %p<7>;
-; CHECK-SM70-NEXT:    .reg .b16 %rs<17>;
-; CHECK-SM70-NEXT:    .reg .b32 %r<57>;
+; CHECK-SM70-NEXT:    .reg .b16 %rs<13>;
+; CHECK-SM70-NEXT:    .reg .b32 %r<58>;
 ; CHECK-SM70-NEXT:    .reg .f32 %f<17>;
 ; CHECK-SM70-EMPTY:
 ; CHECK-SM70-NEXT:  // %bb.0:
@@ -1777,7 +1773,6 @@ define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p5, %f14, %f14;
 ; CHECK-SM70-NEXT:    or.b32 %r47, %r43, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r48, %r47, %r46, %p5;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs13}, %r48; }
 ; CHECK-SM70-NEXT:    and.b32 %r49, %r34, -65536;
 ; CHECK-SM70-NEXT:    mov.b32 %f15, %r49;
 ; CHECK-SM70-NEXT:    add.rn.f32 %f16, %f15, %f9;
@@ -1788,8 +1783,7 @@ define <2 x bfloat> @fma_bf16x2_no_nans_multiple_uses_of_fma(<2 x bfloat> %a, <2
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p6, %f16, %f16;
 ; CHECK-SM70-NEXT:    or.b32 %r54, %r50, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r55, %r54, %r53, %p6;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs15}, %r55; }
-; CHECK-SM70-NEXT:    mov.b32 %r56, {%rs15, %rs13};
+; CHECK-SM70-NEXT:    prmt.b32 %r56, %r55, %r48, 0x7632U;
 ; CHECK-SM70-NEXT:    st.param.b32 [func_retval0], %r56;
 ; CHECK-SM70-NEXT:    ret;
   %1 = call nnan <2 x bfloat> @llvm.fma.bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
@@ -1828,8 +1822,8 @@ define <2 x bfloat> @fma_bf16x2_maxnum_no_nans(<2 x bfloat> %a, <2 x bfloat> %b,
 ; CHECK-SM70-LABEL: fma_bf16x2_maxnum_no_nans(
 ; CHECK-SM70:       {
 ; CHECK-SM70-NEXT:    .reg .pred %p<5>;
-; CHECK-SM70-NEXT:    .reg .b16 %rs<17>;
-; CHECK-SM70-NEXT:    .reg .b32 %r<43>;
+; CHECK-SM70-NEXT:    .reg .b16 %rs<13>;
+; CHECK-SM70-NEXT:    .reg .b32 %r<44>;
 ; CHECK-SM70-NEXT:    .reg .f32 %f<13>;
 ; CHECK-SM70-EMPTY:
 ; CHECK-SM70-NEXT:  // %bb.0:
@@ -1883,7 +1877,6 @@ define <2 x bfloat> @fma_bf16x2_maxnum_no_nans(<2 x bfloat> %a, <2 x bfloat> %b,
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p3, %f10, %f10;
 ; CHECK-SM70-NEXT:    or.b32 %r33, %r29, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r34, %r33, %r32, %p3;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs13}, %r34; }
 ; CHECK-SM70-NEXT:    and.b32 %r35, %r15, -65536;
 ; CHECK-SM70-NEXT:    mov.b32 %f11, %r35;
 ; CHECK-SM70-NEXT:    max.f32 %f12, %f11, 0f00000000;
@@ -1894,8 +1887,7 @@ define <2 x bfloat> @fma_bf16x2_maxnum_no_nans(<2 x bfloat> %a, <2 x bfloat> %b,
 ; CHECK-SM70-NEXT:    setp.nan.f32 %p4, %f12, %f12;
 ; CHECK-SM70-NEXT:    or.b32 %r40, %r36, 4194304;
 ; CHECK-SM70-NEXT:    selp.b32 %r41, %r40, %r39, %p4;
-; CHECK-SM70-NEXT:    { .reg .b16 tmp; mov.b32 {tmp, %rs15}, %r41; }
-; CHECK-SM70-NEXT:    mov.b32 %r42, {%rs15, %rs13};
+; CHECK-SM70-NEXT:    prmt.b32 %r42, %r41, %r34, 0x7632U;
 ; CHECK-SM70-NEXT:    st.param.b32 [func_retval0], %r42;
 ; CHECK-SM70-NEXT:    ret;
   %1 = call nnan <2 x bfloat> @llvm.fma.bf16x2(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)

diff  --git a/llvm/test/CodeGen/NVPTX/i16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i16x2-instructions.ll
index 388bd314801fc7..5d849517096dca 100644
--- a/llvm/test/CodeGen/NVPTX/i16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i16x2-instructions.ll
@@ -807,17 +807,79 @@ define <2 x i16> @test_select_cc_i16_i32(<2 x i16> %a, <2 x i16> %b,
 define <2 x i16> @test_trunc_2xi32(<2 x i32> %a) #0 {
 ; COMMON-LABEL: test_trunc_2xi32(
 ; COMMON:       {
-; COMMON-NEXT:    .reg .b16 %rs<3>;
-; COMMON-NEXT:    .reg .b32 %r<4>;
+; COMMON-NEXT:    .reg .b32 %r<5>;
 ; COMMON-EMPTY:
 ; COMMON-NEXT:  // %bb.0:
 ; COMMON-NEXT:    ld.param.v2.u32 {%r1, %r2}, [test_trunc_2xi32_param_0];
-; COMMON-NEXT:    cvt.u16.u32 %rs1, %r2;
-; COMMON-NEXT:    cvt.u16.u32 %rs2, %r1;
-; COMMON-NEXT:    mov.b32 %r3, {%rs2, %rs1};
+; COMMON-NEXT:    prmt.b32 %r3, %r1, %r2, 0x5410U;
+; COMMON-NEXT:    st.param.b32 [func_retval0], %r3;
+; COMMON-NEXT:    ret;
+  %r = trunc <2 x i32> %a to <2 x i16>
+  ret <2 x i16> %r
+}
+
+define <2 x i16> @test_trunc_2xi32_muliple_use0(<2 x i32> %a, ptr %p) #0 {
+; I16x2-LABEL: test_trunc_2xi32_muliple_use0(
+; I16x2:       {
+; I16x2-NEXT:    .reg .b32 %r<7>;
+; I16x2-NEXT:    .reg .b64 %rd<2>;
+; I16x2-EMPTY:
+; I16x2-NEXT:  // %bb.0:
+; I16x2-NEXT:    ld.param.v2.u32 {%r1, %r2}, [test_trunc_2xi32_muliple_use0_param_0];
+; I16x2-NEXT:    ld.param.u64 %rd1, [test_trunc_2xi32_muliple_use0_param_1];
+; I16x2-NEXT:    prmt.b32 %r3, %r1, %r2, 0x5410U;
+; I16x2-NEXT:    mov.b32 %r5, 65537;
+; I16x2-NEXT:    add.s16x2 %r6, %r3, %r5;
+; I16x2-NEXT:    st.u32 [%rd1], %r6;
+; I16x2-NEXT:    st.param.b32 [func_retval0], %r3;
+; I16x2-NEXT:    ret;
+;
+; NO-I16x2-LABEL: test_trunc_2xi32_muliple_use0(
+; NO-I16x2:       {
+; NO-I16x2-NEXT:    .reg .b16 %rs<5>;
+; NO-I16x2-NEXT:    .reg .b32 %r<5>;
+; NO-I16x2-NEXT:    .reg .b64 %rd<2>;
+; NO-I16x2-EMPTY:
+; NO-I16x2-NEXT:  // %bb.0:
+; NO-I16x2-NEXT:    ld.param.v2.u32 {%r1, %r2}, [test_trunc_2xi32_muliple_use0_param_0];
+; NO-I16x2-NEXT:    ld.param.u64 %rd1, [test_trunc_2xi32_muliple_use0_param_1];
+; NO-I16x2-NEXT:    cvt.u16.u32 %rs1, %r2;
+; NO-I16x2-NEXT:    cvt.u16.u32 %rs2, %r1;
+; NO-I16x2-NEXT:    mov.b32 %r3, {%rs2, %rs1};
+; NO-I16x2-NEXT:    add.s16 %rs3, %rs1, 1;
+; NO-I16x2-NEXT:    add.s16 %rs4, %rs2, 1;
+; NO-I16x2-NEXT:    mov.b32 %r4, {%rs4, %rs3};
+; NO-I16x2-NEXT:    st.u32 [%rd1], %r4;
+; NO-I16x2-NEXT:    st.param.b32 [func_retval0], %r3;
+; NO-I16x2-NEXT:    ret;
+  %r = trunc <2 x i32> %a to <2 x i16>
+  ; Reuse the truncate - optimizing to PRMT when we don't have i16x2 vectors
+  ; would increase register pressure
+  %s = add <2 x i16> %r, splat (i16 1)
+  store <2 x i16> %s, ptr %p
+  ret <2 x i16> %r
+}
+
+define <2 x i16> @test_trunc_2xi32_muliple_use1(<2 x i32> %a, ptr %p) #0 {
+; COMMON-LABEL: test_trunc_2xi32_muliple_use1(
+; COMMON:       {
+; COMMON-NEXT:    .reg .b32 %r<7>;
+; COMMON-NEXT:    .reg .b64 %rd<2>;
+; COMMON-EMPTY:
+; COMMON-NEXT:  // %bb.0:
+; COMMON-NEXT:    ld.param.v2.u32 {%r1, %r2}, [test_trunc_2xi32_muliple_use1_param_0];
+; COMMON-NEXT:    ld.param.u64 %rd1, [test_trunc_2xi32_muliple_use1_param_1];
+; COMMON-NEXT:    prmt.b32 %r3, %r1, %r2, 0x5410U;
+; COMMON-NEXT:    add.s32 %r5, %r2, 1;
+; COMMON-NEXT:    add.s32 %r6, %r1, 1;
+; COMMON-NEXT:    st.v2.u32 [%rd1], {%r6, %r5};
 ; COMMON-NEXT:    st.param.b32 [func_retval0], %r3;
 ; COMMON-NEXT:    ret;
   %r = trunc <2 x i32> %a to <2 x i16>
+  ; Reuse the original value - optimizing to PRMT does not increase register
+  ; pressure
+  %s = add <2 x i32> %a, splat (i32 1)
+  store <2 x i32> %s, ptr %p
   ret <2 x i16> %r
 }
 


        


More information about the llvm-commits mailing list