[llvm] [NVPTX] Improve lowering of v4i8 (PR #67866)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 14 11:13:30 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Artem Belevich (Artem-B)

<details>
<summary>Changes</summary>

Make it a legal type and plumb through lowering of relevant instructions.

---

Patch is 122.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67866.diff


15 Files Affected:

- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+31) 
- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h (+2) 
- (modified) llvm/lib/Target/NVPTX/NVPTX.h (+12) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+14-7) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+228-47) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+5) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+165-35) 
- (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td (+1-1) 
- (modified) llvm/test/CodeGen/NVPTX/extractelement.ll (+54-1) 
- (modified) llvm/test/CodeGen/NVPTX/i16x2-instructions.ll (+1-1) 
- (added) llvm/test/CodeGen/NVPTX/i8x4-instructions.ll (+1272) 
- (modified) llvm/test/CodeGen/NVPTX/load-with-non-coherent-cache.ll (+2-2) 
- (modified) llvm/test/CodeGen/NVPTX/param-load-store.ll (+12-14) 
- (modified) llvm/test/CodeGen/NVPTX/unfold-masked-merge-vector-variablemask.ll (+96-429) 
- (modified) llvm/test/CodeGen/NVPTX/vec8.ll (+2-3) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 5d27accdc198c..b7a20c351f5ff 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -309,3 +309,34 @@ void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
   const MCSymbol &Sym = cast<MCSymbolRefExpr>(Expr)->getSymbol();
   O << Sym.getName();
 }
+
+void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
+                                     raw_ostream &O, const char *Modifier) {
+  const MCOperand &MO = MI->getOperand(OpNum);
+  int64_t Imm = MO.getImm();
+
+  switch (Imm) {
+  default:
+    return;
+  case NVPTX::PTXPrmtMode::NONE:
+    break;
+  case NVPTX::PTXPrmtMode::F4E:
+    O << ".f4e";
+    break;
+  case NVPTX::PTXPrmtMode::B4E:
+    O << ".b4e";
+    break;
+  case NVPTX::PTXPrmtMode::RC8:
+    O << ".rc8";
+    break;
+  case NVPTX::PTXPrmtMode::ECL:
+    O << ".ecl";
+    break;
+  case NVPTX::PTXPrmtMode::ECR:
+    O << ".ecr";
+    break;
+  case NVPTX::PTXPrmtMode::RC16:
+    O << ".rc16";
+    break;
+  }
+}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 49ad3f269229d..e6954f861cd10 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -47,6 +47,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
                        raw_ostream &O, const char *Modifier = nullptr);
   void printProtoIdent(const MCInst *MI, int OpNum,
                        raw_ostream &O, const char *Modifier = nullptr);
+  void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O,
+                     const char *Modifier = nullptr);
 };
 
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 8dc68911fff0c..07ee34968b023 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -181,6 +181,18 @@ enum CmpMode {
   FTZ_FLAG = 0x100
 };
 }
+
+namespace PTXPrmtMode {
+enum PrmtMode {
+  NONE,
+  F4E,
+  B4E,
+  RC8,
+  ECL,
+  ECR,
+  RC16,
+};
+}
 }
 void initializeNVPTXDAGToDAGISelPass(PassRegistry &);
 } // namespace llvm
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 0aef2591c6e23..68391cdb6ff17 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -14,6 +14,7 @@
 #include "MCTargetDesc/NVPTXBaseInfo.h"
 #include "NVPTXUtilities.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/IR/GlobalValue.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
@@ -829,6 +830,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
   case MVT::v2f16:
   case MVT::v2bf16:
   case MVT::v2i16:
+  case MVT::v4i8:
     return Opcode_i32;
   case MVT::f32:
     return Opcode_f32;
@@ -910,7 +912,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
   // Vector Setting
   unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
   if (SimpleVT.isVector()) {
-    assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
+    assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
+           "Unexpected vector type");
     // v2f16/v2bf16/v2i16 is loaded using ld.b32
     fromTypeWidth = 32;
   }
@@ -1254,6 +1257,7 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
   SDLoc DL(N);
   SDNode *LD;
   SDValue Base, Offset, Addr;
+  EVT OrigType = N->getValueType(0);
 
   EVT EltVT = Mem->getMemoryVT();
   unsigned NumElts = 1;
@@ -1261,12 +1265,15 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
     NumElts = EltVT.getVectorNumElements();
     EltVT = EltVT.getVectorElementType();
     // vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
-    if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
-        (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
-        (EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
+    if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
+        (EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
+        (EltVT == MVT::i16 && OrigType == MVT::v2i16)) {
       assert(NumElts % 2 == 0 && "Vector must have even number of elements");
-      EltVT = N->getValueType(0);
+      EltVT = OrigType;
       NumElts /= 2;
+    } else if (OrigType == MVT::v4i8) {
+      EltVT = OrigType;
+      NumElts = 1;
     }
   }
 
@@ -1601,7 +1608,6 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
   // concept of sign-/zero-extension, so emulate it here by adding an explicit
   // CVT instruction. Ptxas should clean up any redundancies here.
 
-  EVT OrigType = N->getValueType(0);
   LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);
 
   if (OrigType != EltVT &&
@@ -1679,7 +1685,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
   MVT ScalarVT = SimpleVT.getScalarType();
   unsigned toTypeWidth = ScalarVT.getSizeInBits();
   if (SimpleVT.isVector()) {
-    assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
+    assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
+           "Unexpected vector type");
     // v2x16 is stored using st.b32
     toTypeWidth = 32;
   }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b24aae4792ce6..36da2e7b40efa 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -221,6 +221,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
           llvm_unreachable("Unexpected type");
         }
         NumElts /= 2;
+      } else if (EltVT.getSimpleVT() == MVT::i8 &&
+                 (NumElts % 4 == 0 || NumElts == 3)) {
+        // v*i8 are formally lowered as v4i8
+        EltVT = MVT::v4i8;
+        NumElts = (NumElts + 3) / 4;
       }
       for (unsigned j = 0; j != NumElts; ++j) {
         ValueVTs.push_back(EltVT);
@@ -458,6 +463,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
   addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
+  addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
   addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
@@ -491,10 +497,26 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2i16, Expand);
   setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2i16, Expand);
 
+  setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8, Custom);
+  setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
+  setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
+  setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
+  // Only logical ops can be done on v4i8 directly, others must be done
+  // elementwise.
+  setOperationAction(
+      {ISD::ADD,       ISD::MUL,        ISD::ABS,        ISD::SMIN,
+       ISD::SMAX,      ISD::UMIN,       ISD::UMAX,       ISD::CTPOP,
+       ISD::CTLZ,      ISD::ADD,        ISD::SUB,        ISD::MUL,
+       ISD::SHL,       ISD::SREM,       ISD::UREM,       ISD::SDIV,
+       ISD::UDIV,      ISD::SRA,        ISD::SRL,        ISD::MULHS,
+       ISD::MULHU,     ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::SINT_TO_FP,
+       ISD::UINT_TO_FP},
+      MVT::v4i8, Expand);
+
   // Operations not directly supported by NVPTX.
-  for (MVT VT :
-       {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32, MVT::f64,
-        MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}) {
+  for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
+                 MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::v4i8,
+                 MVT::i32, MVT::i64}) {
     setOperationAction(ISD::SELECT_CC, VT, Expand);
     setOperationAction(ISD::BR_CC, VT, Expand);
   }
@@ -672,7 +694,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
 
   // We have some custom DAG combine patterns for these nodes
   setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
-                       ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT});
+                       ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT,
+                       ISD::VSELECT});
 
   // setcc for f16x2 and bf16x2 needs special handling to prevent
   // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -881,6 +904,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     return "NVPTXISD::FUN_SHFR_CLAMP";
   case NVPTXISD::IMAD:
     return "NVPTXISD::IMAD";
+  case NVPTXISD::BFE:
+    return "NVPTXISD::BFE";
+  case NVPTXISD::BFI:
+    return "NVPTXISD::BFI";
+  case NVPTXISD::PRMT:
+    return "NVPTXISD::PRMT";
   case NVPTXISD::SETP_F16X2:
     return "NVPTXISD::SETP_F16X2";
   case NVPTXISD::Dummy:
@@ -2150,58 +2179,98 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
 }
 
-// We can init constant f16x2 with a single .b32 move.  Normally it
+// We can init constant f16x2/v2i16/v4i8 with a single .b32 move.  Normally it
 // would get lowered as two constant loads and vector-packing move.
-//        mov.b16         %h1, 0x4000;
-//        mov.b16         %h2, 0x3C00;
-//        mov.b32         %hh2, {%h2, %h1};
 // Instead we want just a constant move:
-//        mov.b32         %hh2, 0x40003C00
-//
-// This results in better SASS code with CUDA 7.x. Ptxas in CUDA 8.0
-// generates good SASS in both cases.
+//        mov.b32         %r2, 0x40003C00
 SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
                                                SelectionDAG &DAG) const {
   EVT VT = Op->getValueType(0);
-  if (!(Isv2x16VT(VT)))
+  if (!(Isv2x16VT(VT) || VT == MVT::v4i8))
     return Op;
-  APInt E0;
-  APInt E1;
-  if (VT == MVT::v2f16 || VT == MVT::v2bf16) {
-    if (!(isa<ConstantFPSDNode>(Op->getOperand(0)) &&
-          isa<ConstantFPSDNode>(Op->getOperand(1))))
-      return Op;
-
-    E0 = cast<ConstantFPSDNode>(Op->getOperand(0))
-             ->getValueAPF()
-             .bitcastToAPInt();
-    E1 = cast<ConstantFPSDNode>(Op->getOperand(1))
-             ->getValueAPF()
-             .bitcastToAPInt();
-  } else {
-    assert(VT == MVT::v2i16);
-    if (!(isa<ConstantSDNode>(Op->getOperand(0)) &&
-          isa<ConstantSDNode>(Op->getOperand(1))))
-      return Op;
 
-    E0 = cast<ConstantSDNode>(Op->getOperand(0))->getAPIntValue();
-    E1 = cast<ConstantSDNode>(Op->getOperand(1))->getAPIntValue();
+  SDLoc DL(Op);
+
+  if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
+        return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
+               isa<ConstantFPSDNode>(Operand);
+      })) {
+    // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
+    // to optimize calculation of constant parts.
+    if (VT == MVT::v4i8) {
+      SDValue C8 = DAG.getConstant(8, DL, MVT::i32);
+      SDValue E01 = DAG.getNode(
+          NVPTXISD::BFI, DL, MVT::i32,
+          DAG.getAnyExtOrTrunc(Op->getOperand(1), DL, MVT::i32),
+          DAG.getAnyExtOrTrunc(Op->getOperand(0), DL, MVT::i32), C8, C8);
+      SDValue E012 =
+          DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
+                      DAG.getAnyExtOrTrunc(Op->getOperand(2), DL, MVT::i32),
+                      E01, DAG.getConstant(16, DL, MVT::i32), C8);
+      SDValue E0123 =
+          DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
+                      DAG.getAnyExtOrTrunc(Op->getOperand(3), DL, MVT::i32),
+                      E012, DAG.getConstant(24, DL, MVT::i32), C8);
+      return DAG.getNode(ISD::BITCAST, DL, VT, E0123);
+    }
+    return Op;
   }
-  SDValue Const =
-      DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32);
+
+  // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
+  auto GetOperand = [](SDValue Op, int N) -> APInt {
+    const SDValue &Operand = Op->getOperand(N);
+    EVT VT = Op->getValueType(0);
+    if (Operand->isUndef())
+      return APInt(32, 0);
+    APInt Value;
+    if (VT == MVT::v2f16 || VT == MVT::v2bf16)
+      Value = cast<ConstantFPSDNode>(Operand)->getValueAPF().bitcastToAPInt();
+    else if (VT == MVT::v2i16 || VT == MVT::v4i8)
+      Value = cast<ConstantSDNode>(Operand)->getAPIntValue();
+    else
+      llvm_unreachable("Unsupported type");
+    // i8 values are carried around as i16, so we need to zero out upper bits,
+    // so they do not get in the way of combining individual byte values
+    if (VT == MVT::v4i8)
+      Value = Value.trunc(8);
+    return Value.zext(32);
+  };
+  APInt Value;
+  if (Isv2x16VT(VT)) {
+    Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(16);
+  } else if (VT == MVT::v4i8) {
+    Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(8) |
+            GetOperand(Op, 2).shl(16) | GetOperand(Op, 3).shl(24);
+  } else {
+    llvm_unreachable("Unsupported type");
+  }
+  SDValue Const = DAG.getConstant(Value, SDLoc(Op), MVT::i32);
   return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
 }
 
 SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
                                                      SelectionDAG &DAG) const {
   SDValue Index = Op->getOperand(1);
+  SDValue Vector = Op->getOperand(0);
+  SDLoc DL(Op);
+  EVT VectorVT = Vector.getValueType();
+
+  if (VectorVT == MVT::v4i8) {
+    SDValue BFE =
+        DAG.getNode(NVPTXISD::BFE, DL, MVT::i32,
+                    {Vector,
+                     DAG.getNode(ISD::MUL, DL, MVT::i32,
+                                 DAG.getZExtOrTrunc(Index, DL, MVT::i32),
+                                 DAG.getConstant(8, DL, MVT::i32)),
+                     DAG.getConstant(8, DL, MVT::i32)});
+    return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
+  }
+
   // Constant index will be matched by tablegen.
   if (isa<ConstantSDNode>(Index.getNode()))
     return Op;
 
   // Extract individual elements and select one of them.
-  SDValue Vector = Op->getOperand(0);
-  EVT VectorVT = Vector.getValueType();
   assert(Isv2x16VT(VectorVT) && "Unexpected vector type.");
   EVT EltVT = VectorVT.getVectorElementType();
 
@@ -2214,6 +2283,49 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
                          ISD::CondCode::SETEQ);
 }
 
+SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
+                                                    SelectionDAG &DAG) const {
+  SDValue Vector = Op->getOperand(0);
+  EVT VectorVT = Vector.getValueType();
+
+  if (VectorVT != MVT::v4i8)
+    return Op;
+  SDLoc DL(Op);
+  SDValue Value = Op->getOperand(1);
+  if (Value->isUndef())
+    return Vector;
+
+  SDValue Index = Op->getOperand(2);
+
+  SDValue BFI =
+      DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
+                  {DAG.getZExtOrTrunc(Value, DL, MVT::i32), Vector,
+                   DAG.getNode(ISD::MUL, DL, MVT::i32,
+                               DAG.getZExtOrTrunc(Index, DL, MVT::i32),
+                               DAG.getConstant(8, DL, MVT::i32)),
+                   DAG.getConstant(8, DL, MVT::i32)});
+  return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), BFI);
+}
+
+SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
+                                                 SelectionDAG &DAG) const {
+  SDValue V1 = Op.getOperand(0);
+  EVT VectorVT = V1.getValueType();
+  if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
+    return Op;
+
+  // Lower shuffle to PRMT instruction.
+  const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
+  SDValue V2 = Op.getOperand(1);
+  uint32_t Selector = 0;
+  for (auto I : llvm::enumerate(SVN->getMask()))
+    Selector |= (I.value() << (I.index() * 4));
+
+  SDLoc DL(Op);
+  return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
+                     DAG.getConstant(Selector, DL, MVT::i32),
+                     DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32));
+}
 /// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
 /// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
 ///    amount, or
@@ -2464,6 +2576,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return Op;
   case ISD::EXTRACT_VECTOR_ELT:
     return LowerEXTRACT_VECTOR_ELT(Op, DAG);
+  case ISD::INSERT_VECTOR_ELT:
+    return LowerINSERT_VECTOR_ELT(Op, DAG);
+  case ISD::VECTOR_SHUFFLE:
+    return LowerVECTOR_SHUFFLE(Op, DAG);
   case ISD::CONCAT_VECTORS:
     return LowerCONCAT_VECTORS(Op, DAG);
   case ISD::STORE:
@@ -2578,9 +2694,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
   if (Op.getValueType() == MVT::i1)
     return LowerLOADi1(Op, DAG);
 
-  // v2f16/v2bf16/v2i16 are legal, so we can't rely on legalizer to handle
+  // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
   // unaligned loads and have to handle it here.
-  if (Isv2x16VT(Op.getValueType())) {
+  EVT VT = Op.getValueType();
+  if (Isv2x16VT(VT) || VT == MVT::v4i8) {
     LoadSDNode *Load = cast<LoadSDNode>(Op);
     EVT MemVT = Load->getMemoryVT();
     if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
@@ -2625,13 +2742,13 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
 
   // v2f16 is legal, so we can't rely on legalizer to handle unaligned
   // stores and have to handle it here.
-  if (Isv2x16VT(VT) &&
+  if ((Isv2x16VT(VT) || VT == MVT::v4i8) &&
       !allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
                                       VT, *Store->getMemOperand()))
     return expandUnalignedStore(Store, DAG);
 
   // v2f16, v2bf16 and v2i16 don't need special handling.
-  if (Isv2x16VT(VT))
+  if (Isv2x16VT(VT) || VT == MVT::v4i8)
     return SDValue();
 
   if (VT.isVector())
@@ -2903,7 +3020,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
           EVT LoadVT = EltVT;
           if (EltVT == MVT::i1)
             LoadVT = MVT::i8;
-          else if (Isv2x16VT(EltVT))
+          else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
             // getLoad needs a vector type, but it can't handle
             // vectors which contain v2f16 or v2bf16 elements. So we must load
             // using i32 here and then bitcast back.
@@ -2929,7 +3046,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
             if (EltVT == MVT::i1)
               Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
             // v2f16 was loaded as an i32. Now we must bitcast it back.
-            else if (Isv2x16VT(EltVT))
+            else if (EltVT != LoadVT)
               Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
 
             // If a promoted integer type is used, truncate down to the original
@@ -4975,6 +5092,32 @@ static SDValue PerformANDCombine(SDNode *N,
   }
 
   SDValue AExt;
+
+  // Convert BFE-> truncate i16 -> and 255
+  // To just BFE-> truncate i16, as the value already has all the bits in the
+  // right places.
+  if (Val.getOpcode() == ISD::TRUNCATE) {
+    SDValue BFE = Val.getOperand(0);
+    if (BFE.getOpcode() != NVPTXISD::BFE)
+      return SDValue();
+
+    ConstantSDNode *BFEBits = dyn_cast<ConstantSDNode>(BFE.getOperand(0));
+    if (!BFEBits)
+      return SDValue();
+    uint64_t BFEBitsVal = BFEBits->getZExtValue();
+
+    ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
+    if (!MaskCnst) {
+      // Not an AND with a constant
+      return SDValue();
+    }
+    uint64_t MaskVal = MaskCnst->getZExtValue();
+
+    if (MaskVal != (uint64_t(1) << BFEBitsVal) - 1)
+      return SDValue();
+    // If we get here, the AND is unnecessary.  Just replace it with the trunc
+    DCI.CombineTo(N, Val, false);
+  }
   // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
   if (Val.getOpcode() == ISD::ANY_EXTEND) {
     AExt = Val;
@@ -5254,13 +5397,15 @@ static SDValue PerformSETCCCombine(SDNode *N,
 static SDValue PerformEXTRACTCombine(SDNode *N,
                                      Tar...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/67866


More information about the llvm-commits mailing list