[llvm] [RISCV][GISEL] Legalize, regbank select, and instruction select G_ZEXT, G_SEXT, G_ANYEXT, G_SPLAT_VECTOR, and G_ICMP (PR #85938)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 1 17:33:33 PDT 2024
- Previous message: [llvm] [RISCV][GISEL] Legalize, regbank select, and instruction select G_ZEXT, G_SEXT, G_ANYEXT, G_SPLAT_VECTOR, and G_ICMP (PR #85938)
- Next message: [llvm] [RISCV][GISEL] Legalize, regbank select, and instruction select G_ZEXT, G_SEXT, G_ANYEXT, G_SPLAT_VECTOR, and G_ICMP (PR #85938)
- Messages sorted by:
[ date ]
[ thread ]
[ subject ]
[ author ]
================
@@ -570,7 +591,146 @@ bool RISCVLegalizerInfo::legalizeVScale(MachineInstr &MI,
auto VScale = MIB.buildLShr(XLenTy, VLENB, MIB.buildConstant(XLenTy, 3));
MIB.buildMul(Dst, VScale, MIB.buildConstant(XLenTy, Val));
}
+ MI.eraseFromParent();
+ return true;
+}
+
+// Custom-lower extensions from mask vectors by using a vselect either with 1
+// for zero/any-extension or -1 for sign-extension:
+// (vXiN = (s|z)ext vXi1:vmask) -> (vXiN = vselect vmask, (-1 or 1), 0)
+// Note that any-extension is lowered identically to zero-extension.
+bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI,
+ MachineIRBuilder &MIB) const {
+
+ unsigned Opc = MI.getOpcode();
+ assert(Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_SEXT ||
+ Opc == TargetOpcode::G_ANYEXT);
+
+ MachineRegisterInfo &MRI = *MIB.getMRI();
+ Register Dst = MI.getOperand(0).getReg();
+ Register Src = MI.getOperand(1).getReg();
+
+ LLT DstTy = MRI.getType(Dst);
+ int64_t ExtTrueVal =
+ Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_ANYEXT ? 1 : -1;
+ LLT DstEltTy = DstTy.getElementType();
+ auto SplatZero = MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, 0));
+ auto SplatTrue =
+ MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, ExtTrueVal));
+ MIB.buildSelect(Dst, Src, SplatTrue, SplatZero);
+
+ MI.eraseFromParent();
+ return true;
+}
+
+/// Return the type of the mask type suitable for masking the provided
+/// vector type. This is simply an i1 element type vector of the same
+/// (possibly scalable) length.
+static LLT getMaskTypeFor(LLT VecTy) {
+ assert(VecTy.isVector());
+ ElementCount EC = VecTy.getElementCount();
+ return LLT::vector(EC, LLT::scalar(1));
+}
+
+/// Creates an all ones mask suitable for masking a vector of type VecTy with
+/// vector length VL.
+static MachineInstrBuilder buildAllOnesMask(LLT VecTy, const SrcOp &VL,
+ MachineIRBuilder &MIB,
+ MachineRegisterInfo &MRI) {
+ LLT MaskTy = getMaskTypeFor(VecTy);
+ return MIB.buildInstr(RISCV::G_VMSET_VL, {MaskTy}, {VL});
+}
+
+/// Gets the two common "VL" operands: an all-ones mask and the vector length.
+/// VecTy is a scalable vector type.
+static std::pair<MachineInstrBuilder, Register>
+buildDefaultVLOps(const DstOp &Dst, MachineIRBuilder &MIB,
+ MachineRegisterInfo &MRI) {
+ LLT VecTy = Dst.getLLTTy(MRI);
+ assert(VecTy.isScalableVector() && "Expecting scalable container type");
+ Register VL(RISCV::X0);
+ MachineInstrBuilder Mask = buildAllOnesMask(VecTy, VL, MIB, MRI);
+ return {Mask, VL};
+}
+
+static MachineInstrBuilder
+buildSplatPartsS64WithVL(const DstOp &Dst, const SrcOp &Passthru, Register Lo,
+ Register Hi, Register VL, MachineIRBuilder &MIB,
+ MachineRegisterInfo &MRI) {
+ // TODO: If the Hi bits of the splat are undefined, then it's fine to just
+ // splat Lo even if it might be sign extended. I don't think we have
+ // introduced a case where we're build a s64 where the upper bits are undef
+ // yet.
+
+ // Fall back to a stack store and stride x0 vector load.
+ // TODO: need to lower G_SPLAT_VECTOR_SPLIT_I64. This is done in
+ // preprocessDAG in SDAG.
+ return MIB.buildInstr(RISCV::G_SPLAT_VECTOR_SPLIT_I64_VL, {Dst},
+ {Passthru, Lo, Hi, VL});
+}
+
+static MachineInstrBuilder
+buildSplatSplitS64WithVL(const DstOp &Dst, const SrcOp &Passthru,
+ const SrcOp &Scalar, Register VL,
+ MachineIRBuilder &MIB, MachineRegisterInfo &MRI) {
+ assert(Scalar.getLLTTy(MRI) == LLT::scalar(64) && "Unexpected VecTy!");
+ auto Unmerge = MIB.buildUnmerge(LLT::scalar(32), Scalar);
+ return buildSplatPartsS64WithVL(Dst, Passthru, Unmerge.getReg(0),
+ Unmerge.getReg(1), VL, MIB, MRI);
+}
+
+// Lower splats of s1 types to G_ICMP. For each mask vector type, we have a
+// legal equivalently-sized i8 type, so we can use that as a go-between.
+// Splats of s1 types that have constant value can be legalized as VMSET_VL or
+// VMCLR_VL.
+bool RISCVLegalizerInfo::legalizeSplatVector(MachineInstr &MI,
+ MachineIRBuilder &MIB) const {
+ assert(MI.getOpcode() == TargetOpcode::G_SPLAT_VECTOR);
+
+ MachineRegisterInfo &MRI = *MIB.getMRI();
+
+ Register Dst = MI.getOperand(0).getReg();
+ Register SplatVal = MI.getOperand(1).getReg();
+
+ LLT VecTy = MRI.getType(Dst);
+ LLT XLenTy(STI.getXLenVT());
+
+ // Handle case of s64 element vectors on rv32
+ if (XLenTy.getSizeInBits() == 32 &&
+ VecTy.getElementType().getSizeInBits() == 64) {
----------------
topperc wrote:
If s64 is legal from D extensions and s64 vectors are legal, then I think we should keep the G_SPLAT_VECTOR as legal. Naively, this would cause us to go from GPR->mem->FPR->VFMV_VF which isn't efficient, but is correct.
I believe this would show up as (G_SPLAT_VECTOR (G_MERGE_VALUES Lo, Hi)), we could potentially add a post-legalize combine to turn it into G_SPLAT_VECTOR_SPLIT_I64_VL or just special case it in RISCVInstructionSelector.
https://github.com/llvm/llvm-project/pull/85938
- Previous message: [llvm] [RISCV][GISEL] Legalize, regbank select, and instruction select G_ZEXT, G_SEXT, G_ANYEXT, G_SPLAT_VECTOR, and G_ICMP (PR #85938)
- Next message: [llvm] [RISCV][GISEL] Legalize, regbank select, and instruction select G_ZEXT, G_SEXT, G_ANYEXT, G_SPLAT_VECTOR, and G_ICMP (PR #85938)
- Messages sorted by:
[ date ]
[ thread ]
[ subject ]
[ author ]
More information about the llvm-commits
mailing list