[llvm] [NVPTX] improve lowering for common byte-extraction operations. (PR #66945)

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 20 12:14:20 PDT 2023


https://github.com/Artem-B created https://github.com/llvm/llvm-project/pull/66945

Some critical code paths we have depend on efficient byte extraction from data loaded as integers.
By default LLVM tries to extract by storing/loading from stack, which is very inefficient on GPU.


>From 21fe3fcb459d5d8ecd68fc50c7ad67eb22ff26c6 Mon Sep 17 00:00:00 2001
From: Artem Belevich <tra at google.com>
Date: Mon, 18 Sep 2023 17:22:15 -0700
Subject: [PATCH] [NVPTX] improve lowering for common byte-extraction
 operations.

Some critical code paths we have depend on efficient byte extraction from data
loaded as 32-bit integers.
---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 46 +++++++++-
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td     | 42 +++++++---
 llvm/test/CodeGen/NVPTX/extractelement.ll   | 93 +++++++++++++++++++++
 3 files changed, 170 insertions(+), 11 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/extractelement.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index e4d5e5c71b7e188..e9401d4b93c371e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -23,6 +23,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/CodeGen/Analysis.h"
+#include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineMemOperand.h"
 #include "llvm/CodeGen/MachineValueType.h"
@@ -672,7 +673,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
 
   // We have some custom DAG combine patterns for these nodes
   setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
-                       ISD::SREM, ISD::UREM});
+                       ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT});
 
   // setcc for f16x2 and bf16x2 needs special handling to prevent
   // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5252,6 +5253,47 @@ static SDValue PerformSETCCCombine(SDNode *N,
                          CCNode.getValue(1));
 }
 
+static SDValue PerformEXTRACTCombine(SDNode *N,
+                                     TargetLowering::DAGCombinerInfo &DCI) {
+  SDValue Vector = N->getOperand(0);
+  EVT VectorVT = Vector.getValueType();
+  if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
+      IsPTXVectorType(VectorVT.getSimpleVT()))
+    return SDValue(); // Native vector loads already combine nicely w/
+                      // extract_vector_elt.
+  // Don't mess with singletons or v2*16 types, we already handle them OK.
+  if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT))
+    return SDValue();
+
+  uint64_t VectorBits = VectorVT.getSizeInBits();
+  // We only handle the types we can extract in-register.
+  if (!(VectorBits == 16 || VectorBits == 32 || VectorBits == 64))
+    return SDValue();
+
+  ConstantSDNode *Index = dyn_cast<ConstantSDNode>(N->getOperand(1));
+  // Index == 0 is handled by generic DAG combiner.
+  if (!Index || Index->getZExtValue() == 0)
+    return SDValue();
+
+  SDLoc DL(N);
+
+  MVT IVT = MVT::getIntegerVT(VectorBits);
+  EVT EltVT = VectorVT.getVectorElementType();
+  EVT EltIVT = EltVT.changeTypeToInteger();
+  uint64_t EltBits = EltVT.getScalarSizeInBits();
+
+  SDValue Result = DCI.DAG.getNode(
+      ISD::TRUNCATE, DL, EltIVT,
+      DCI.DAG.getNode(
+          ISD::SRA, DL, IVT, DCI.DAG.getNode(ISD::BITCAST, DL, IVT, Vector),
+          DCI.DAG.getConstant(Index->getZExtValue() * EltBits, DL, IVT)));
+
+  // If element has non-integer type, bitcast it back to the expected type.
+  if (EltVT != EltIVT)
+    Result = DCI.DAG.getNode(ISD::BITCAST, DL, EltVT, Result);
+  return Result;
+}
+
 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5275,6 +5317,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
     case NVPTXISD::StoreRetvalV2:
     case NVPTXISD::StoreRetvalV4:
       return PerformStoreRetvalCombine(N);
+    case ISD::EXTRACT_VECTOR_ELT:
+      return PerformEXTRACTCombine(N, DCI);
   }
   return SDValue();
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 3e48c0f9d2c6ab0..ad10d7938ef12e4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1713,34 +1713,56 @@ def FUNSHFRCLAMP :
 // BFE - bit-field extract
 //
 
-// Template for BFE instructions.  Takes four args,
-//   [dest (reg), src (reg), start (reg or imm), end (reg or imm)].
+// Template for BFE/BFI instructions.
+// Args: [dest (reg), src (reg), start (reg or imm), end (reg or imm)].
 // Start may be an imm only if end is also an imm.  FIXME: Is this a
 // restriction in PTX?
 //
 // dest and src may be int32 or int64, but start and end are always int32.
-multiclass BFE<string TyStr, RegisterClass RC> {
+multiclass BFX<string Instr, RegisterClass RC> {
   def rrr
     : NVPTXInst<(outs RC:$d),
                 (ins RC:$a, Int32Regs:$b, Int32Regs:$c),
-                !strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>;
+                !strconcat(Instr, " \t$d, $a, $b, $c;"), []>;
   def rri
     : NVPTXInst<(outs RC:$d),
                 (ins RC:$a, Int32Regs:$b, i32imm:$c),
-                !strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>;
+                !strconcat(Instr, " \t$d, $a, $b, $c;"), []>;
   def rii
     : NVPTXInst<(outs RC:$d),
                 (ins RC:$a, i32imm:$b, i32imm:$c),
-                !strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>;
+                !strconcat(Instr, " \t$d, $a, $b, $c;"), []>;
 }
 
 let hasSideEffects = false in {
-  defm BFE_S32 : BFE<"s32", Int32Regs>;
-  defm BFE_U32 : BFE<"u32", Int32Regs>;
-  defm BFE_S64 : BFE<"s64", Int64Regs>;
-  defm BFE_U64 : BFE<"u64", Int64Regs>;
+  defm BFE_S32 : BFX<"bfe.s32", Int32Regs>;
+  defm BFE_U32 : BFX<"bfe.u32", Int32Regs>;
+  defm BFE_S64 : BFX<"bfe.s64", Int64Regs>;
+  defm BFE_U64 : BFX<"bfe.u64", Int64Regs>;
+
+  defm BFI_S32 : BFX<"bfi.s32", Int32Regs>;
+  defm BFI_U32 : BFX<"bfi.u32", Int32Regs>;
+  defm BFI_S64 : BFX<"bfi.s64", Int64Regs>;
+  defm BFI_U64 : BFX<"bfi.u64", Int64Regs>;
 }
 
+// Common byte extraction patterns
+def : Pat<(i16 (sext_inreg (trunc Int32Regs:$s), i8)),
+          (CVT_s8_s32 Int32Regs:$s, CvtNONE)>;
+def : Pat<(i16 (sext_inreg (trunc (srl (i32 Int32Regs:$s),  (i32 imm:$o))), i8)),
+          (CVT_s8_s32 (BFE_S32rii Int32Regs:$s, imm:$o, 8), CvtNONE)>;
+def : Pat<(sext_inreg (srl (i32 Int32Regs:$s),  (i32 imm:$o)), i8),
+          (BFE_S32rii Int32Regs:$s, imm:$o, 8)>;
+def : Pat<(i16 (sra (i16 (trunc Int32Regs:$s)), (i32 8))),
+          (CVT_s8_s32 (BFE_S32rii Int32Regs:$s, 8, 8), CvtNONE)>;
+
+def : Pat<(sext_inreg (srl (i64 Int64Regs:$s),  (i32 imm:$o)), i8),
+          (BFE_S64rii Int64Regs:$s, imm:$o, 8)>;
+def : Pat<(i16 (sext_inreg (trunc Int64Regs:$s), i8)),
+          (CVT_s8_s64 Int64Regs:$s, CvtNONE)>;
+def : Pat<(i16 (sext_inreg (trunc (srl (i64 Int64Regs:$s),  (i32 imm:$o))), i8)),
+          (CVT_s8_s64 (BFE_S64rii Int64Regs:$s, imm:$o, 8), CvtNONE)>;
+
 //-----------------------------------
 // Comparison instructions (setp, set)
 //-----------------------------------
diff --git a/llvm/test/CodeGen/NVPTX/extractelement.ll b/llvm/test/CodeGen/NVPTX/extractelement.ll
new file mode 100644
index 000000000000000..bd1310410c7f5df
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/extractelement.ll
@@ -0,0 +1,93 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_35 -verify-machineinstrs | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_35 | %ptxas-verify %}
+
+
+; CHECK-LABEL: test_v2i8
+; CHECK-DAG:        ld.param.u16    [[A:%rs[0-9+]]], [test_v2i8_param_0];
+; CHECK-DAG:        cvt.s16.s8      [[E0:%rs[0-9+]]], [[A]];
+; CHECK-DAG:        shr.s16         [[E1:%rs[0-9+]]], [[A]], 8;
+define i16  @test_v2i8(i16 %a) #0 {
+  %v = bitcast i16 %a to <2 x i8>
+  %r0 = extractelement <2 x i8> %v, i64 0
+  %r1 = extractelement <2 x i8> %v, i64 1
+  %r0i = sext i8 %r0 to i16
+  %r1i = sext i8 %r1 to i16
+  %r01 = add i16 %r0i, %r1i
+  ret i16 %r01
+}
+
+; CHECK-LABEL: test_v4i8
+; CHECK:            ld.param.u32    [[R:%r[0-9+]]], [test_v4i8_param_0];
+; CHECK-DAG:        cvt.s8.s32      [[E0:%rs[0-9+]]], [[R]];
+; CHECK-DAG:        bfe.s32         [[R1:%r[0-9+]]], [[R]], 8, 8;
+; CHECK-DAG:        cvt.s8.s32      [[E1:%rs[0-9+]]], [[R1]];
+; CHECK-DAG:        bfe.s32         [[R2:%r[0-9+]]], [[R]], 16, 8;
+; CHECK-DAG:        cvt.s8.s32      [[E2:%rs[0-9+]]], [[R2]];
+; CHECK-DAG:        bfe.s32         [[R3:%r[0-9+]]], [[R]], 24, 8;
+; CHECK-DAG:        cvt.s8.s32      [[E3:%rs[0-9+]]], [[R3]];
+define i16  @test_v4i8(i32 %a) #0 {
+  %v = bitcast i32 %a to <4 x i8>
+  %r0 = extractelement <4 x i8> %v, i64 0
+  %r1 = extractelement <4 x i8> %v, i64 1
+  %r2 = extractelement <4 x i8> %v, i64 2
+  %r3 = extractelement <4 x i8> %v, i64 3
+  %r0i = sext i8 %r0 to i16
+  %r1i = sext i8 %r1 to i16
+  %r2i = sext i8 %r2 to i16
+  %r3i = sext i8 %r3 to i16
+  %r01 = add i16 %r0i, %r1i
+  %r23 = add i16 %r2i, %r3i
+  %r = add i16 %r01, %r23
+  ret i16 %r
+}
+
+; CHECK-LABEL: test_v8i8
+; CHECK:       ld.param.u64    [[R:%rd[0-9+]]], [test_v8i8_param_0];
+; CHECK-DAG:        cvt.s8.s64      [[E0:%rs[0-9+]]], [[R]];
+; Element 1 is still extracted by trunc, shr 8, not sure why.
+; CHECK-DAG:        cvt.u16.u64     [[R01:%rs[0-9+]]], [[R]];
+; CHECK-DAG:        shr.s16         [[E1:%rs[0-9+]]], [[R01]], 8;
+; CHECK-DAG:        bfe.s64         [[RD2:%rd[0-9+]]], [[R]], 16, 8;
+; CHECK-DAG:        cvt.s8.s64      [[E2:%rs[0-9+]]], [[RD2]];
+; CHECK-DAG:        bfe.s64         [[RD3:%rd[0-9+]]], [[R]], 24, 8;
+; CHECK-DAG:        cvt.s8.s64      [[E3:%rs[0-9+]]], [[RD3]];
+; CHECK-DAG:        bfe.s64         [[RD4:%rd[0-9+]]], [[R]], 32, 8;
+; CHECK-DAG:        cvt.s8.s64      [[E4:%rs[0-9+]]], [[RD4]];
+; CHECK-DAG:        bfe.s64         [[RD5:%rd[0-9+]]], [[R]], 40, 8;
+; CHECK-DAG:        cvt.s8.s64      [[E5:%rs[0-9+]]], [[RD5]];
+; CHECK-DAG:        bfe.s64         [[RD6:%rd[0-9+]]], [[R]], 48, 8;
+; CHECK-DAG:        cvt.s8.s64      [[E6:%rs[0-9+]]], [[RD6]];
+; CHECK-DAG:        bfe.s64         [[RD7:%rd[0-9+]]], [[R]], 56, 8;
+; CHECK-DAG:        cvt.s8.s64      [[E7:%rs[0-9+]]], [[RD7]];
+
+define i16  @test_v8i8(i64 %a) #0 {
+  %v = bitcast i64 %a to <8 x i8>
+  %r0 = extractelement <8 x i8> %v, i64 0
+  %r1 = extractelement <8 x i8> %v, i64 1
+  %r2 = extractelement <8 x i8> %v, i64 2
+  %r3 = extractelement <8 x i8> %v, i64 3
+  %r4 = extractelement <8 x i8> %v, i64 4
+  %r5 = extractelement <8 x i8> %v, i64 5
+  %r6 = extractelement <8 x i8> %v, i64 6
+  %r7 = extractelement <8 x i8> %v, i64 7
+  %r0i = sext i8 %r0 to i16
+  %r1i = sext i8 %r1 to i16
+  %r2i = sext i8 %r2 to i16
+  %r3i = sext i8 %r3 to i16
+  %r4i = sext i8 %r4 to i16
+  %r5i = sext i8 %r5 to i16
+  %r6i = sext i8 %r6 to i16
+  %r7i = sext i8 %r7 to i16
+  %r01 = add i16 %r0i, %r1i
+  %r23 = add i16 %r2i, %r3i
+  %r45 = add i16 %r4i, %r5i
+  %r67 = add i16 %r6i, %r7i
+  %r0123 = add i16 %r01, %r23
+  %r4567 = add i16 %r45, %r67
+  %r = add i16 %r0123, %r4567
+  ret i16 %r
+}
+
+
+
+!0 = !{}



More information about the llvm-commits mailing list