[llvm] [RISCV][GISEL] Legalize, regbank select, and instruction select G_ZEXT, G_SEXT, G_ANYEXT, G_SPLAT_VECTOR, and G_ICMP (PR #85938)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 1 17:33:33 PDT 2024


================
@@ -570,7 +591,146 @@ bool RISCVLegalizerInfo::legalizeVScale(MachineInstr &MI,
     auto VScale = MIB.buildLShr(XLenTy, VLENB, MIB.buildConstant(XLenTy, 3));
     MIB.buildMul(Dst, VScale, MIB.buildConstant(XLenTy, Val));
   }
+  MI.eraseFromParent();
+  return true;
+}
+
+// Custom-lower extensions from mask vectors by using a vselect either with 1
+// for zero/any-extension or -1 for sign-extension:
+//   (vXiN = (s|z)ext vXi1:vmask) -> (vXiN = vselect vmask, (-1 or 1), 0)
+// Note that any-extension is lowered identically to zero-extension.
+bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI,
+                                     MachineIRBuilder &MIB) const {
+
+  unsigned Opc = MI.getOpcode();
+  assert(Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_SEXT ||
+         Opc == TargetOpcode::G_ANYEXT);
+
+  MachineRegisterInfo &MRI = *MIB.getMRI();
+  Register Dst = MI.getOperand(0).getReg();
+  Register Src = MI.getOperand(1).getReg();
+
+  LLT DstTy = MRI.getType(Dst);
+  int64_t ExtTrueVal =
+      Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_ANYEXT ? 1 : -1;
+  LLT DstEltTy = DstTy.getElementType();
+  auto SplatZero = MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, 0));
+  auto SplatTrue =
+      MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, ExtTrueVal));
+  MIB.buildSelect(Dst, Src, SplatTrue, SplatZero);
+
+  MI.eraseFromParent();
+  return true;
+}
+
+/// Return the type of the mask type suitable for masking the provided
+/// vector type.  This is simply an i1 element type vector of the same
+/// (possibly scalable) length.
+static LLT getMaskTypeFor(LLT VecTy) {
+  assert(VecTy.isVector());
+  ElementCount EC = VecTy.getElementCount();
+  return LLT::vector(EC, LLT::scalar(1));
+}
+
+/// Creates an all ones mask suitable for masking a vector of type VecTy with
+/// vector length VL.
+static MachineInstrBuilder buildAllOnesMask(LLT VecTy, const SrcOp &VL,
+                                            MachineIRBuilder &MIB,
+                                            MachineRegisterInfo &MRI) {
+  LLT MaskTy = getMaskTypeFor(VecTy);
+  return MIB.buildInstr(RISCV::G_VMSET_VL, {MaskTy}, {VL});
+}
+
+/// Gets the two common "VL" operands: an all-ones mask and the vector length.
+/// VecTy is a scalable vector type.
+static std::pair<MachineInstrBuilder, Register>
+buildDefaultVLOps(const DstOp &Dst, MachineIRBuilder &MIB,
+                  MachineRegisterInfo &MRI) {
+  LLT VecTy = Dst.getLLTTy(MRI);
+  assert(VecTy.isScalableVector() && "Expecting scalable container type");
+  Register VL(RISCV::X0);
+  MachineInstrBuilder Mask = buildAllOnesMask(VecTy, VL, MIB, MRI);
+  return {Mask, VL};
+}
+
+static MachineInstrBuilder
+buildSplatPartsS64WithVL(const DstOp &Dst, const SrcOp &Passthru, Register Lo,
+                         Register Hi, Register VL, MachineIRBuilder &MIB,
+                         MachineRegisterInfo &MRI) {
+  // TODO: If the Hi bits of the splat are undefined, then it's fine to just
+  // splat Lo even if it might be sign extended. I don't think we have
+  // introduced a case where we're build a s64 where the upper bits are undef
+  // yet.
+
+  // Fall back to a stack store and stride x0 vector load.
+  // TODO: need to lower G_SPLAT_VECTOR_SPLIT_I64. This is done in
+  // preprocessDAG in SDAG.
+  return MIB.buildInstr(RISCV::G_SPLAT_VECTOR_SPLIT_I64_VL, {Dst},
+                        {Passthru, Lo, Hi, VL});
+}
+
+static MachineInstrBuilder
+buildSplatSplitS64WithVL(const DstOp &Dst, const SrcOp &Passthru,
+                         const SrcOp &Scalar, Register VL,
+                         MachineIRBuilder &MIB, MachineRegisterInfo &MRI) {
+  assert(Scalar.getLLTTy(MRI) == LLT::scalar(64) && "Unexpected VecTy!");
+  auto Unmerge = MIB.buildUnmerge(LLT::scalar(32), Scalar);
+  return buildSplatPartsS64WithVL(Dst, Passthru, Unmerge.getReg(0),
+                                  Unmerge.getReg(1), VL, MIB, MRI);
+}
+
+// Lower splats of s1 types to G_ICMP. For each mask vector type, we have a
+// legal equivalently-sized i8 type, so we can use that as a go-between.
+// Splats of s1 types that have constant value can be legalized as VMSET_VL or
+// VMCLR_VL.
+bool RISCVLegalizerInfo::legalizeSplatVector(MachineInstr &MI,
+                                             MachineIRBuilder &MIB) const {
+  assert(MI.getOpcode() == TargetOpcode::G_SPLAT_VECTOR);
+
+  MachineRegisterInfo &MRI = *MIB.getMRI();
+
+  Register Dst = MI.getOperand(0).getReg();
+  Register SplatVal = MI.getOperand(1).getReg();
+
+  LLT VecTy = MRI.getType(Dst);
+  LLT XLenTy(STI.getXLenVT());
+
+  // Handle case of s64 element vectors on rv32
+  if (XLenTy.getSizeInBits() == 32 &&
+      VecTy.getElementType().getSizeInBits() == 64) {
----------------
topperc wrote:

If s64 is legal from D extensions and s64 vectors are legal, then I think we should keep the G_SPLAT_VECTOR as legal. Naively, this would cause us to go from GPR->mem->FPR->VFMV_VF which isn't efficient, but is correct.

I believe this would show up as (G_SPLAT_VECTOR (G_MERGE_VALUES Lo, Hi)), we could potentially add a post-legalize combine to turn it into G_SPLAT_VECTOR_SPLIT_I64_VL or just special case it in RISCVInstructionSelector.

https://github.com/llvm/llvm-project/pull/85938


More information about the llvm-commits mailing list