[llvm] [NVPTX] Improve lowering of v4i8 (PR #67866)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 14 11:13:30 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: Artem Belevich (Artem-B)
<details>
<summary>Changes</summary>
Make it a legal type and plumb through lowering of relevant instructions.
---
Patch is 122.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67866.diff
15 Files Affected:
- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+31)
- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h (+2)
- (modified) llvm/lib/Target/NVPTX/NVPTX.h (+12)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+14-7)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+228-47)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+5)
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+165-35)
- (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td (+1-1)
- (modified) llvm/test/CodeGen/NVPTX/extractelement.ll (+54-1)
- (modified) llvm/test/CodeGen/NVPTX/i16x2-instructions.ll (+1-1)
- (added) llvm/test/CodeGen/NVPTX/i8x4-instructions.ll (+1272)
- (modified) llvm/test/CodeGen/NVPTX/load-with-non-coherent-cache.ll (+2-2)
- (modified) llvm/test/CodeGen/NVPTX/param-load-store.ll (+12-14)
- (modified) llvm/test/CodeGen/NVPTX/unfold-masked-merge-vector-variablemask.ll (+96-429)
- (modified) llvm/test/CodeGen/NVPTX/vec8.ll (+2-3)
``````````diff
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 5d27accdc198c..b7a20c351f5ff 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -309,3 +309,34 @@ void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
const MCSymbol &Sym = cast<MCSymbolRefExpr>(Expr)->getSymbol();
O << Sym.getName();
}
+
+void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
+ raw_ostream &O, const char *Modifier) {
+ const MCOperand &MO = MI->getOperand(OpNum);
+ int64_t Imm = MO.getImm();
+
+ switch (Imm) {
+ default:
+ return;
+ case NVPTX::PTXPrmtMode::NONE:
+ break;
+ case NVPTX::PTXPrmtMode::F4E:
+ O << ".f4e";
+ break;
+ case NVPTX::PTXPrmtMode::B4E:
+ O << ".b4e";
+ break;
+ case NVPTX::PTXPrmtMode::RC8:
+ O << ".rc8";
+ break;
+ case NVPTX::PTXPrmtMode::ECL:
+ O << ".ecl";
+ break;
+ case NVPTX::PTXPrmtMode::ECR:
+ O << ".ecr";
+ break;
+ case NVPTX::PTXPrmtMode::RC16:
+ O << ".rc16";
+ break;
+ }
+}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 49ad3f269229d..e6954f861cd10 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -47,6 +47,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
raw_ostream &O, const char *Modifier = nullptr);
void printProtoIdent(const MCInst *MI, int OpNum,
raw_ostream &O, const char *Modifier = nullptr);
+ void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O,
+ const char *Modifier = nullptr);
};
}
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 8dc68911fff0c..07ee34968b023 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -181,6 +181,18 @@ enum CmpMode {
FTZ_FLAG = 0x100
};
}
+
+namespace PTXPrmtMode {
+enum PrmtMode {
+ NONE,
+ F4E,
+ B4E,
+ RC8,
+ ECL,
+ ECR,
+ RC16,
+};
+}
}
void initializeNVPTXDAGToDAGISelPass(PassRegistry &);
} // namespace llvm
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 0aef2591c6e23..68391cdb6ff17 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -14,6 +14,7 @@
#include "MCTargetDesc/NVPTXBaseInfo.h"
#include "NVPTXUtilities.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
@@ -829,6 +830,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
case MVT::v2f16:
case MVT::v2bf16:
case MVT::v2i16:
+ case MVT::v4i8:
return Opcode_i32;
case MVT::f32:
return Opcode_f32;
@@ -910,7 +912,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
// Vector Setting
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
if (SimpleVT.isVector()) {
- assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
+ assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
+ "Unexpected vector type");
// v2f16/v2bf16/v2i16 is loaded using ld.b32
fromTypeWidth = 32;
}
@@ -1254,6 +1257,7 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
SDLoc DL(N);
SDNode *LD;
SDValue Base, Offset, Addr;
+ EVT OrigType = N->getValueType(0);
EVT EltVT = Mem->getMemoryVT();
unsigned NumElts = 1;
@@ -1261,12 +1265,15 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
NumElts = EltVT.getVectorNumElements();
EltVT = EltVT.getVectorElementType();
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
- if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
- (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
- (EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
+ if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
+ (EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
+ (EltVT == MVT::i16 && OrigType == MVT::v2i16)) {
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
- EltVT = N->getValueType(0);
+ EltVT = OrigType;
NumElts /= 2;
+ } else if (OrigType == MVT::v4i8) {
+ EltVT = OrigType;
+ NumElts = 1;
}
}
@@ -1601,7 +1608,6 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
// concept of sign-/zero-extension, so emulate it here by adding an explicit
// CVT instruction. Ptxas should clean up any redundancies here.
- EVT OrigType = N->getValueType(0);
LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);
if (OrigType != EltVT &&
@@ -1679,7 +1685,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
MVT ScalarVT = SimpleVT.getScalarType();
unsigned toTypeWidth = ScalarVT.getSizeInBits();
if (SimpleVT.isVector()) {
- assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
+ assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
+ "Unexpected vector type");
// v2x16 is stored using st.b32
toTypeWidth = 32;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b24aae4792ce6..36da2e7b40efa 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -221,6 +221,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
llvm_unreachable("Unexpected type");
}
NumElts /= 2;
+ } else if (EltVT.getSimpleVT() == MVT::i8 &&
+ (NumElts % 4 == 0 || NumElts == 3)) {
+ // v*i8 are formally lowered as v4i8
+ EltVT = MVT::v4i8;
+ NumElts = (NumElts + 3) / 4;
}
for (unsigned j = 0; j != NumElts; ++j) {
ValueVTs.push_back(EltVT);
@@ -458,6 +463,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
+ addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
@@ -491,10 +497,26 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2i16, Expand);
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2i16, Expand);
+ setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8, Custom);
+ setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
+ setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
+ setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
+ // Only logical ops can be done on v4i8 directly, others must be done
+ // elementwise.
+ setOperationAction(
+ {ISD::ADD, ISD::MUL, ISD::ABS, ISD::SMIN,
+ ISD::SMAX, ISD::UMIN, ISD::UMAX, ISD::CTPOP,
+ ISD::CTLZ, ISD::ADD, ISD::SUB, ISD::MUL,
+ ISD::SHL, ISD::SREM, ISD::UREM, ISD::SDIV,
+ ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
+ ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::SINT_TO_FP,
+ ISD::UINT_TO_FP},
+ MVT::v4i8, Expand);
+
// Operations not directly supported by NVPTX.
- for (MVT VT :
- {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32, MVT::f64,
- MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}) {
+ for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
+ MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::v4i8,
+ MVT::i32, MVT::i64}) {
setOperationAction(ISD::SELECT_CC, VT, Expand);
setOperationAction(ISD::BR_CC, VT, Expand);
}
@@ -672,7 +694,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// We have some custom DAG combine patterns for these nodes
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
- ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT});
+ ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT,
+ ISD::VSELECT});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -881,6 +904,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
return "NVPTXISD::FUN_SHFR_CLAMP";
case NVPTXISD::IMAD:
return "NVPTXISD::IMAD";
+ case NVPTXISD::BFE:
+ return "NVPTXISD::BFE";
+ case NVPTXISD::BFI:
+ return "NVPTXISD::BFI";
+ case NVPTXISD::PRMT:
+ return "NVPTXISD::PRMT";
case NVPTXISD::SETP_F16X2:
return "NVPTXISD::SETP_F16X2";
case NVPTXISD::Dummy:
@@ -2150,58 +2179,98 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
}
-// We can init constant f16x2 with a single .b32 move. Normally it
+// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
// would get lowered as two constant loads and vector-packing move.
-// mov.b16 %h1, 0x4000;
-// mov.b16 %h2, 0x3C00;
-// mov.b32 %hh2, {%h2, %h1};
// Instead we want just a constant move:
-// mov.b32 %hh2, 0x40003C00
-//
-// This results in better SASS code with CUDA 7.x. Ptxas in CUDA 8.0
-// generates good SASS in both cases.
+// mov.b32 %r2, 0x40003C00
SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
SelectionDAG &DAG) const {
EVT VT = Op->getValueType(0);
- if (!(Isv2x16VT(VT)))
+ if (!(Isv2x16VT(VT) || VT == MVT::v4i8))
return Op;
- APInt E0;
- APInt E1;
- if (VT == MVT::v2f16 || VT == MVT::v2bf16) {
- if (!(isa<ConstantFPSDNode>(Op->getOperand(0)) &&
- isa<ConstantFPSDNode>(Op->getOperand(1))))
- return Op;
-
- E0 = cast<ConstantFPSDNode>(Op->getOperand(0))
- ->getValueAPF()
- .bitcastToAPInt();
- E1 = cast<ConstantFPSDNode>(Op->getOperand(1))
- ->getValueAPF()
- .bitcastToAPInt();
- } else {
- assert(VT == MVT::v2i16);
- if (!(isa<ConstantSDNode>(Op->getOperand(0)) &&
- isa<ConstantSDNode>(Op->getOperand(1))))
- return Op;
- E0 = cast<ConstantSDNode>(Op->getOperand(0))->getAPIntValue();
- E1 = cast<ConstantSDNode>(Op->getOperand(1))->getAPIntValue();
+ SDLoc DL(Op);
+
+ if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
+ return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
+ isa<ConstantFPSDNode>(Operand);
+ })) {
+ // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
+ // to optimize calculation of constant parts.
+ if (VT == MVT::v4i8) {
+ SDValue C8 = DAG.getConstant(8, DL, MVT::i32);
+ SDValue E01 = DAG.getNode(
+ NVPTXISD::BFI, DL, MVT::i32,
+ DAG.getAnyExtOrTrunc(Op->getOperand(1), DL, MVT::i32),
+ DAG.getAnyExtOrTrunc(Op->getOperand(0), DL, MVT::i32), C8, C8);
+ SDValue E012 =
+ DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
+ DAG.getAnyExtOrTrunc(Op->getOperand(2), DL, MVT::i32),
+ E01, DAG.getConstant(16, DL, MVT::i32), C8);
+ SDValue E0123 =
+ DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
+ DAG.getAnyExtOrTrunc(Op->getOperand(3), DL, MVT::i32),
+ E012, DAG.getConstant(24, DL, MVT::i32), C8);
+ return DAG.getNode(ISD::BITCAST, DL, VT, E0123);
+ }
+ return Op;
}
- SDValue Const =
- DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32);
+
+ // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
+ auto GetOperand = [](SDValue Op, int N) -> APInt {
+ const SDValue &Operand = Op->getOperand(N);
+ EVT VT = Op->getValueType(0);
+ if (Operand->isUndef())
+ return APInt(32, 0);
+ APInt Value;
+ if (VT == MVT::v2f16 || VT == MVT::v2bf16)
+ Value = cast<ConstantFPSDNode>(Operand)->getValueAPF().bitcastToAPInt();
+ else if (VT == MVT::v2i16 || VT == MVT::v4i8)
+ Value = cast<ConstantSDNode>(Operand)->getAPIntValue();
+ else
+ llvm_unreachable("Unsupported type");
+ // i8 values are carried around as i16, so we need to zero out upper bits,
+ // so they do not get in the way of combining individual byte values
+ if (VT == MVT::v4i8)
+ Value = Value.trunc(8);
+ return Value.zext(32);
+ };
+ APInt Value;
+ if (Isv2x16VT(VT)) {
+ Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(16);
+ } else if (VT == MVT::v4i8) {
+ Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(8) |
+ GetOperand(Op, 2).shl(16) | GetOperand(Op, 3).shl(24);
+ } else {
+ llvm_unreachable("Unsupported type");
+ }
+ SDValue Const = DAG.getConstant(Value, SDLoc(Op), MVT::i32);
return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
}
SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
SelectionDAG &DAG) const {
SDValue Index = Op->getOperand(1);
+ SDValue Vector = Op->getOperand(0);
+ SDLoc DL(Op);
+ EVT VectorVT = Vector.getValueType();
+
+ if (VectorVT == MVT::v4i8) {
+ SDValue BFE =
+ DAG.getNode(NVPTXISD::BFE, DL, MVT::i32,
+ {Vector,
+ DAG.getNode(ISD::MUL, DL, MVT::i32,
+ DAG.getZExtOrTrunc(Index, DL, MVT::i32),
+ DAG.getConstant(8, DL, MVT::i32)),
+ DAG.getConstant(8, DL, MVT::i32)});
+ return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
+ }
+
// Constant index will be matched by tablegen.
if (isa<ConstantSDNode>(Index.getNode()))
return Op;
// Extract individual elements and select one of them.
- SDValue Vector = Op->getOperand(0);
- EVT VectorVT = Vector.getValueType();
assert(Isv2x16VT(VectorVT) && "Unexpected vector type.");
EVT EltVT = VectorVT.getVectorElementType();
@@ -2214,6 +2283,49 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
ISD::CondCode::SETEQ);
}
+SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDValue Vector = Op->getOperand(0);
+ EVT VectorVT = Vector.getValueType();
+
+ if (VectorVT != MVT::v4i8)
+ return Op;
+ SDLoc DL(Op);
+ SDValue Value = Op->getOperand(1);
+ if (Value->isUndef())
+ return Vector;
+
+ SDValue Index = Op->getOperand(2);
+
+ SDValue BFI =
+ DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
+ {DAG.getZExtOrTrunc(Value, DL, MVT::i32), Vector,
+ DAG.getNode(ISD::MUL, DL, MVT::i32,
+ DAG.getZExtOrTrunc(Index, DL, MVT::i32),
+ DAG.getConstant(8, DL, MVT::i32)),
+ DAG.getConstant(8, DL, MVT::i32)});
+ return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), BFI);
+}
+
+SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDValue V1 = Op.getOperand(0);
+ EVT VectorVT = V1.getValueType();
+ if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
+ return Op;
+
+ // Lower shuffle to PRMT instruction.
+ const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
+ SDValue V2 = Op.getOperand(1);
+ uint32_t Selector = 0;
+ for (auto I : llvm::enumerate(SVN->getMask()))
+ Selector |= (I.value() << (I.index() * 4));
+
+ SDLoc DL(Op);
+ return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
+ DAG.getConstant(Selector, DL, MVT::i32),
+ DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32));
+}
/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
/// amount, or
@@ -2464,6 +2576,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return Op;
case ISD::EXTRACT_VECTOR_ELT:
return LowerEXTRACT_VECTOR_ELT(Op, DAG);
+ case ISD::INSERT_VECTOR_ELT:
+ return LowerINSERT_VECTOR_ELT(Op, DAG);
+ case ISD::VECTOR_SHUFFLE:
+ return LowerVECTOR_SHUFFLE(Op, DAG);
case ISD::CONCAT_VECTORS:
return LowerCONCAT_VECTORS(Op, DAG);
case ISD::STORE:
@@ -2578,9 +2694,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
if (Op.getValueType() == MVT::i1)
return LowerLOADi1(Op, DAG);
- // v2f16/v2bf16/v2i16 are legal, so we can't rely on legalizer to handle
+ // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
// unaligned loads and have to handle it here.
- if (Isv2x16VT(Op.getValueType())) {
+ EVT VT = Op.getValueType();
+ if (Isv2x16VT(VT) || VT == MVT::v4i8) {
LoadSDNode *Load = cast<LoadSDNode>(Op);
EVT MemVT = Load->getMemoryVT();
if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
@@ -2625,13 +2742,13 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
// stores and have to handle it here.
- if (Isv2x16VT(VT) &&
+ if ((Isv2x16VT(VT) || VT == MVT::v4i8) &&
!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
VT, *Store->getMemOperand()))
return expandUnalignedStore(Store, DAG);
// v2f16, v2bf16 and v2i16 don't need special handling.
- if (Isv2x16VT(VT))
+ if (Isv2x16VT(VT) || VT == MVT::v4i8)
return SDValue();
if (VT.isVector())
@@ -2903,7 +3020,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
EVT LoadVT = EltVT;
if (EltVT == MVT::i1)
LoadVT = MVT::i8;
- else if (Isv2x16VT(EltVT))
+ else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
// getLoad needs a vector type, but it can't handle
// vectors which contain v2f16 or v2bf16 elements. So we must load
// using i32 here and then bitcast back.
@@ -2929,7 +3046,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (EltVT == MVT::i1)
Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
// v2f16 was loaded as an i32. Now we must bitcast it back.
- else if (Isv2x16VT(EltVT))
+ else if (EltVT != LoadVT)
Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
// If a promoted integer type is used, truncate down to the original
@@ -4975,6 +5092,32 @@ static SDValue PerformANDCombine(SDNode *N,
}
SDValue AExt;
+
+ // Convert BFE-> truncate i16 -> and 255
+ // To just BFE-> truncate i16, as the value already has all the bits in the
+ // right places.
+ if (Val.getOpcode() == ISD::TRUNCATE) {
+ SDValue BFE = Val.getOperand(0);
+ if (BFE.getOpcode() != NVPTXISD::BFE)
+ return SDValue();
+
+ ConstantSDNode *BFEBits = dyn_cast<ConstantSDNode>(BFE.getOperand(0));
+ if (!BFEBits)
+ return SDValue();
+ uint64_t BFEBitsVal = BFEBits->getZExtValue();
+
+ ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
+ if (!MaskCnst) {
+ // Not an AND with a constant
+ return SDValue();
+ }
+ uint64_t MaskVal = MaskCnst->getZExtValue();
+
+ if (MaskVal != (uint64_t(1) << BFEBitsVal) - 1)
+ return SDValue();
+ // If we get here, the AND is unnecessary. Just replace it with the trunc
+ DCI.CombineTo(N, Val, false);
+ }
// Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
if (Val.getOpcode() == ISD::ANY_EXTEND) {
AExt = Val;
@@ -5254,13 +5397,15 @@ static SDValue PerformSETCCCombine(SDNode *N,
static SDValue PerformEXTRACTCombine(SDNode *N,
Tar...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/67866
More information about the llvm-commits
mailing list