[llvm] [RISCV][WIP] Emit code for strcmp for unaligned strings when one stri… (PR #86645)

Mikhail Gudim via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 26 02:20:05 PDT 2024


https://github.com/mgudim created https://github.com/llvm/llvm-project/pull/86645

None

>From 2f8f5c14e4e09b983f35edd0978de2b2db6b8191 Mon Sep 17 00:00:00 2001
From: Mikhail Gudim <mgudim at gmail.com>
Date: Fri, 22 Mar 2024 04:07:03 -0400
Subject: [PATCH] [RISCV][WIP] Emit code for strcmp for unaligned strings when
 one string is constant.

---
 llvm/lib/Target/RISCV/CMakeLists.txt          |   1 +
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 156 ++++++++++++++++++
 llvm/lib/Target/RISCV/RISCVISelLowering.h     |   1 +
 llvm/lib/Target/RISCV/RISCVInstrInfo.td       |  24 +++
 .../RISCV/RISCVSelectionDAGTargetInfo.cpp     | 127 ++++++++++++++
 .../RISCV/RISCVSelectionDAGTargetInfo.h       |  33 ++++
 llvm/lib/Target/RISCV/RISCVSubtarget.cpp      |   1 +
 llvm/lib/Target/RISCV/RISCVSubtarget.h        |   3 +-
 8 files changed, 345 insertions(+), 1 deletion(-)
 create mode 100644 llvm/lib/Target/RISCV/RISCVSelectionDAGTargetInfo.cpp
 create mode 100644 llvm/lib/Target/RISCV/RISCVSelectionDAGTargetInfo.h

diff --git a/llvm/lib/Target/RISCV/CMakeLists.txt b/llvm/lib/Target/RISCV/CMakeLists.txt
index 8715403f3839a6..6229029106ae2a 100644
--- a/llvm/lib/Target/RISCV/CMakeLists.txt
+++ b/llvm/lib/Target/RISCV/CMakeLists.txt
@@ -52,6 +52,7 @@ add_llvm_target(RISCVCodeGen
   RISCVPushPopOptimizer.cpp
   RISCVRegisterInfo.cpp
   RISCVSubtarget.cpp
+  RISCVSelectionDAGTargetInfo.cpp
   RISCVTargetMachine.cpp
   RISCVTargetObjectFile.cpp
   RISCVTargetTransformInfo.cpp
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e647f56416bfa6..b467d6d72e883e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17655,6 +17655,159 @@ static MachineBasicBlock *emitFROUND(MachineInstr &MI, MachineBasicBlock *MBB,
   return DoneMBB;
 }
 
+static MachineBasicBlock *emitSTRCMPI(MachineInstr &MI, MachineBasicBlock *MBB,
+                                      const RISCVSubtarget &Subtarget) {
+
+  const RISCVInstrInfo &TII = *Subtarget.getInstrInfo();
+  MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
+  MachineFunction &MF = *MI.getParent()->getParent();
+  DebugLoc DL = MI.getDebugLoc();
+
+  const BasicBlock *LLVM_BB = MBB->getBasicBlock();
+  MachineFunction::iterator MBBI = ++MBB->getIterator();
+
+  const GlobalVariable *GV = cast<GlobalVariable>(MI.getOperand(2).getGlobal());
+  StringRef Str = cast<ConstantDataArray>(GV->getInitializer())->getAsCString();
+  int NumOfBytes = Str.size();
+
+  MachineBasicBlock *ExitMBB = MF.CreateMachineBasicBlock(LLVM_BB);
+  MF.insert(MBBI, ExitMBB);
+  ExitMBB->splice(ExitMBB->end(), MBB, std::next(MI.getIterator()), MBB->end());
+  ExitMBB->transferSuccessorsAndUpdatePHIs(MBB);
+  MBBI = ExitMBB->getIterator();
+
+  bool NeedToNegateResult = MI.getOperand(3).getImm() == 1;
+  Register PHIReg = NeedToNegateResult ?
+    MRI.createVirtualRegister(&RISCV::GPRRegClass) :
+    MI.getOperand(0).getReg();
+  MachineInstrBuilder PHI_MIB =
+      BuildMI(*ExitMBB, ExitMBB->begin(), DL, TII.get(RISCV::PHI),
+              PHIReg);
+  if (NeedToNegateResult) {
+    BuildMI(
+      *ExitMBB,
+      ++ExitMBB->begin(),
+      DL,
+      TII.get(RISCV::SUB),
+      MI.getOperand(0).getReg()
+    )
+    .addReg(RISCV::X0)
+    .addReg(PHIReg);
+  }
+
+  MachineBasicBlock *ReturnNegMBB = MF.CreateMachineBasicBlock(LLVM_BB);
+  MF.insert(MBBI, ReturnNegMBB);
+  Register NegReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+  BuildMI(*ReturnNegMBB, ReturnNegMBB->end(), DL, TII.get(RISCV::ADDI), NegReg)
+      .addReg(RISCV::X0)
+      .addImm(-1);
+  ReturnNegMBB->addSuccessor(ExitMBB);
+  PHI_MIB.addReg(NegReg).addMBB(ReturnNegMBB);
+  MBBI = ReturnNegMBB->getIterator();
+
+  Register BaseReg = MI.getOperand(1).getReg();
+  MachineMemOperand &MMO = *MI.memoperands()[0];
+
+  MachineBasicBlock *ReturnLastLoadedByteMBB = MF.CreateMachineBasicBlock(LLVM_BB);
+  MF.insert(MBBI, ReturnLastLoadedByteMBB);
+  Register LoadedLastByteReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+  MachineInstr  &LoadLastByteMI = *BuildMI(
+    *ReturnLastLoadedByteMBB,
+    ReturnLastLoadedByteMBB->end(),
+    DL,
+    TII.get(RISCV::LBU),
+    LoadedLastByteReg
+  )
+  .addReg(BaseReg)
+  .addImm(NumOfBytes)
+  .cloneMemRefs(MI)
+  .getInstr();
+  MachineMemOperand *NewMMO = MF.getMachineMemOperand(
+      MMO.getPointerInfo(),
+      MachineMemOperand::MOLoad,
+      LLT(MVT::i8),
+      Align(1)
+  );
+  LoadLastByteMI.setMemRefs(MF, {NewMMO});
+  LoadLastByteMI.memoperands()[0]->setOffset(NumOfBytes);
+  BuildMI(*ReturnLastLoadedByteMBB, ReturnLastLoadedByteMBB->end(), DL, TII.get(RISCV::PseudoJump))
+      .addMBB(ExitMBB);
+  ReturnLastLoadedByteMBB->addSuccessor(ExitMBB);
+  PHI_MIB.addReg(LoadedLastByteReg).addMBB(ReturnLastLoadedByteMBB);
+  MBBI = ReturnLastLoadedByteMBB->getIterator();
+
+  SmallVector<MachineBasicBlock *, 4> NewMBBs;
+  for (int i = 0; i < NumOfBytes - 1; ++i) {
+    MachineBasicBlock *NewMBB = MF.CreateMachineBasicBlock(LLVM_BB);
+    NewMBBs.push_back(NewMBB);
+    MF.insert(MBBI, NewMBB);
+  }
+  NewMBBs.push_back(ReturnLastLoadedByteMBB);
+
+  // TODO: Always assume that constant string is second and negate result if needed.
+  int64_t Offset = 0;
+  char Byte = Str[0];
+  MachineBasicBlock::iterator MII = std::next(MI.getIterator());
+  MachineBasicBlock *CurrMBB = MBB;
+  MachineBasicBlock *NextMBB = NewMBBs[0];
+
+  auto emitCodeToCheckOneByteEquality = [&] {
+    Register LoadedByteReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+    MachineInstr  &LoadByteMI = *BuildMI(*CurrMBB, MII, DL, TII.get(RISCV::LBU), LoadedByteReg)
+        .addReg(BaseReg)
+        .addImm(Offset)
+        .cloneMemRefs(MI)
+        .getInstr();
+    MachineMemOperand *NewMMO = MF.getMachineMemOperand(
+        MMO.getPointerInfo(),
+        MachineMemOperand::MOLoad,
+        LLT(MVT::i8),
+        Align(1)
+    );
+    LoadByteMI.setMemRefs(MF, {NewMMO});
+    LoadByteMI.memoperands()[0]->setOffset(Offset);
+
+    BuildMI(*CurrMBB, MII, DL, TII.get(RISCV::BEQ))
+        .addReg(LoadedByteReg)
+        .addReg(RISCV::X0)
+        .addMBB(ReturnNegMBB);
+
+    MachineBasicBlock *CheckBytesEqualMBB = MF.CreateMachineBasicBlock(LLVM_BB);
+    MF.insert(MBBI, CheckBytesEqualMBB);
+    CurrMBB->addSuccessor(ReturnNegMBB);
+    CurrMBB->addSuccessor(CheckBytesEqualMBB);
+
+    MachineBasicBlock::iterator CheckBytesEqualMMBI = CheckBytesEqualMBB->begin();
+    Register DiffReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+    BuildMI(*CheckBytesEqualMBB, CheckBytesEqualMMBI, DL, TII.get(RISCV::ADDI), DiffReg)
+        .addReg(LoadedByteReg)
+        .addImm(-Byte);
+
+    BuildMI(*CheckBytesEqualMBB, CheckBytesEqualMMBI, DL, TII.get(RISCV::BNE))
+        .addReg(LoadedByteReg)
+        .addReg(DiffReg)
+        .addMBB(ExitMBB);
+
+    CheckBytesEqualMBB->addSuccessor(ExitMBB);
+    PHI_MIB.addReg(DiffReg).addMBB(CheckBytesEqualMBB);
+    CheckBytesEqualMBB->addSuccessor(NextMBB);
+  };
+
+  emitCodeToCheckOneByteEquality();
+
+  for (int i = 0; i < NumOfBytes - 1; ++i) {
+    ++Offset;
+    Byte = Str[i];
+    CurrMBB = NewMBBs[i];
+    MII = CurrMBB->begin();
+    NextMBB = NewMBBs[i + 1];
+    emitCodeToCheckOneByteEquality();
+  }
+
+  MI.eraseFromParent();
+  return ExitMBB;
+}
+
 MachineBasicBlock *
 RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
                                                  MachineBasicBlock *BB) const {
@@ -17737,6 +17890,8 @@ RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
   case RISCV::PseudoFROUND_D_INX:
   case RISCV::PseudoFROUND_D_IN32X:
     return emitFROUND(MI, BB, Subtarget);
+  case RISCV::PseudoSTRCMPI:
+    return emitSTRCMPI(MI, BB, Subtarget);
   case TargetOpcode::STATEPOINT:
   case TargetOpcode::STACKMAP:
   case TargetOpcode::PATCHPOINT:
@@ -19512,6 +19667,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(SWAP_CSR)
   NODE_NAME_CASE(CZERO_EQZ)
   NODE_NAME_CASE(CZERO_NEZ)
+  NODE_NAME_CASE(STRCMP)
   NODE_NAME_CASE(SF_VC_XV_SE)
   NODE_NAME_CASE(SF_VC_IV_SE)
   NODE_NAME_CASE(SF_VC_VV_SE)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index a38463f810270a..52dda10a56a666 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -456,6 +456,7 @@ enum NodeType : unsigned {
   TH_LDD,
   TH_SWD,
   TH_SDD,
+  STRCMP
 };
 // clang-format on
 } // namespace RISCVISD
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
index e753c1f1add0c6..1a48bac0966648 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
@@ -137,6 +137,7 @@ def SPMem : MemOperand<SP>;
 
 def GPRCMem : MemOperand<GPRC>;
 
+
 class SImmAsmOperand<int width, string suffix = "">
     : ImmAsmOperand<"S", width, suffix> {
 }
@@ -1952,6 +1953,29 @@ def : Pat<(shl (zext GPR:$rs), uimm5:$shamt),
           (SRLI (i64 (SLLI GPR:$rs, 32)), (ImmSubFrom32 uimm5:$shamt))>;
 }
 
+def riscv_strcmp : SDNode<
+  "RISCVISD::STRCMP",
+  SDTypeProfile<1, 2, [SDTCisPtrTy<0>, SDTCisPtrTy<1>]>,
+  [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]
+>;
+
+let usesCustomInserter = 1, mayLoad = 1, mayStore = 0, hasSideEffects = 0 in
+def PseudoSTRCMPI : Pseudo<
+  (outs GPR:$rd),
+  (ins GPR:$str1, i64imm:$str2, i64imm:$constant_str_idx),
+  []
+>;
+
+def : Pat<
+  (XLenVT (riscv_strcmp tglobaladdr:$str1, iPTR:$str2)),
+  (PseudoSTRCMPI GPR:$str2, tglobaladdr:$str1, 0)
+>;
+
+def : Pat<
+  (XLenVT (riscv_strcmp iPTR:$str1, tglobaladdr:$str2)),
+  (PseudoSTRCMPI GPR:$str1, tglobaladdr:$str2, 1)
+>;
+
 //===----------------------------------------------------------------------===//
 // Standard extensions
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/RISCV/RISCVSelectionDAGTargetInfo.cpp b/llvm/lib/Target/RISCV/RISCVSelectionDAGTargetInfo.cpp
new file mode 100644
index 00000000000000..cf217092a5ef08
--- /dev/null
+++ b/llvm/lib/Target/RISCV/RISCVSelectionDAGTargetInfo.cpp
@@ -0,0 +1,127 @@
+//===-- RISCVSelectionDAGTargetInfo.cpp - RISCV SelectionDAG Info
+//-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the RISCVSelectionDAGTargetInfo class.
+//
+//===----------------------------------------------------------------------===//
+
+#include "RISCVSelectionDAGTargetInfo.h"
+#include "RISCVSubtarget.h"
+#include "llvm/CodeGen/SelectionDAG.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/Type.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "riscv-selectiondag-target-info"
+
+static cl::opt<unsigned> MaxStrcmpSpecializeLength(
+  "riscv-max-strcmp-specialize-length",
+  cl::Hidden,
+  cl::desc("Do not specialize strcmp if the length of constant string is greater or equal to this parameter"),
+  cl::init(0)
+);
+
+static bool canSpecializeStrcmp(const GlobalAddressSDNode *GA) {
+  const GlobalVariable *GV = dyn_cast<GlobalVariable>(GA->getGlobal());
+  if (!GV || !GV->isConstant() || !GV->hasInitializer())
+    return false;
+  // NOTE: this doesn't work for empty strings
+  const ConstantDataArray *CDA =
+      dyn_cast<ConstantDataArray>(GV->getInitializer());
+  if (!CDA || !CDA->isCString())
+    return false;
+
+  StringRef CString = CDA->getAsCString();
+  if (CString.str().length() >= MaxStrcmpSpecializeLength)
+    return false;
+
+  return true;
+}
+
+std::pair<SDValue, SDValue>
+RISCVSelectionDAGTargetInfo::EmitTargetCodeForStrcmp(
+    SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Src1,
+    SDValue Src2, MachinePointerInfo Op1PtrInfo,
+    MachinePointerInfo Op2PtrInfo) const {
+  // This is the default setting, so exit early if the optimization is turned off.
+  if (MaxStrcmpSpecializeLength == 0)
+    return std::make_pair(SDValue(), Chain);
+
+  const RISCVSubtarget &Subtarget =
+      DAG.getMachineFunction().getSubtarget<RISCVSubtarget>();
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  MVT XLenVT = Subtarget.getXLenVT();
+  const DataLayout &DLayout = DAG.getDataLayout();
+
+  bool Src1IsUnaligned = false;
+  if (const Value *Src1V = dyn_cast_if_present<const Value *>(Op1PtrInfo.V)) {
+    Src1IsUnaligned = Src1V->getPointerAlignment(DLayout) < XLenVT.getSizeInBits() / 8;
+  }
+  bool Src2IsUnaligned = false;
+  if (const Value *Src2V = dyn_cast_if_present<const Value *>(Op2PtrInfo.V)) {
+    Src2IsUnaligned = Src2V->getPointerAlignment(DLayout) < XLenVT.getSizeInBits() / 8;
+  }
+  if (!(Src1IsUnaligned || Src2IsUnaligned))
+    return std::make_pair(SDValue(), Chain);
+
+  const GlobalAddressSDNode *CStringGA = nullptr;
+  SDValue Other;
+  MachinePointerInfo MPI;
+  bool ConstantStringIsSecond = false;
+
+  const GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Src1);
+  if (GA && canSpecializeStrcmp(GA)) {
+    CStringGA = GA;
+    Other = Src2;
+    MPI = Op2PtrInfo;
+  }
+  if (!CStringGA) {
+    GA = dyn_cast<GlobalAddressSDNode>(Src2);
+    if (GA && canSpecializeStrcmp(GA)) {
+      ConstantStringIsSecond = true;
+      CStringGA = GA;
+      Other = Src1;
+      MPI = Op1PtrInfo;
+    }
+  }
+
+  if (!CStringGA)
+    return std::make_pair(SDValue(), Chain);
+
+  SDValue TGA = DAG.getTargetGlobalAddress(
+      CStringGA->getGlobal(), DL, TLI.getPointerTy(DLayout), 0,
+      CStringGA->getTargetFlags());
+
+  SDValue Str1 = TGA;
+  SDValue Str2 = Other;
+  if (ConstantStringIsSecond)
+    std::swap(Str1, Str2);
+
+  MachineFunction &MF = DAG.getMachineFunction();
+  MachineMemOperand *MMO = MF.getMachineMemOperand(
+      MPI,
+      MachineMemOperand::MOLoad,
+      LLT(MVT::i8),
+      Align(1)
+  );
+  // TODO: what should be the MemVT?
+  SDValue STRCMPNode = DAG.getMemIntrinsicNode(
+    RISCVISD::STRCMP,
+    DL,
+    DAG.getVTList(XLenVT, MVT::Other),
+    {Chain, Str1, Str2},
+    MVT::i8,
+    MMO
+  );
+
+  SDValue ChainOut = STRCMPNode.getValue(1);
+  return std::make_pair(STRCMPNode, ChainOut);
+}
diff --git a/llvm/lib/Target/RISCV/RISCVSelectionDAGTargetInfo.h b/llvm/lib/Target/RISCV/RISCVSelectionDAGTargetInfo.h
new file mode 100644
index 00000000000000..1b95ff0e81a5a1
--- /dev/null
+++ b/llvm/lib/Target/RISCV/RISCVSelectionDAGTargetInfo.h
@@ -0,0 +1,33 @@
+//===-- RISCVSelectionDAGTargetInfo.h - RISCV SelectionDAG Info ---*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the RISCV subclass for SelectionDAGTargetInfo.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_RISCV_RISCVSELECTIONDAGINFO_H
+#define LLVM_LIB_TARGET_RISCV_RISCVSELECTIONDAGINFO_H
+
+#include "llvm/CodeGen/SelectionDAGTargetInfo.h"
+
+namespace llvm {
+
+class RISCVSelectionDAGTargetInfo : public SelectionDAGTargetInfo {
+public:
+  explicit RISCVSelectionDAGTargetInfo() = default;
+  std::pair<SDValue, SDValue>
+  EmitTargetCodeForStrcmp(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
+                          SDValue Src1, SDValue Src2,
+                          MachinePointerInfo Op1PtrInfo,
+                          MachinePointerInfo Op2PtrInfo) const override;
+};
+
+} // end namespace llvm
+
+#endif
diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
index d3236bb07d56d5..00ec619b760fa6 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
@@ -16,6 +16,7 @@
 #include "GISel/RISCVRegisterBankInfo.h"
 #include "RISCV.h"
 #include "RISCVFrameLowering.h"
+#include "RISCVSelectionDAGTargetInfo.h"
 #include "RISCVTargetMachine.h"
 #include "llvm/CodeGen/MacroFusion.h"
 #include "llvm/CodeGen/ScheduleDAGMutation.h"
diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h
index ba108912d93400..e4ad26d70c933f 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.h
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h
@@ -17,6 +17,7 @@
 #include "RISCVFrameLowering.h"
 #include "RISCVISelLowering.h"
 #include "RISCVInstrInfo.h"
+#include "RISCVSelectionDAGTargetInfo.h"
 #include "llvm/CodeGen/GlobalISel/CallLowering.h"
 #include "llvm/CodeGen/GlobalISel/InstructionSelector.h"
 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
@@ -86,7 +87,7 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo {
   RISCVInstrInfo InstrInfo;
   RISCVRegisterInfo RegInfo;
   RISCVTargetLowering TLInfo;
-  SelectionDAGTargetInfo TSInfo;
+  RISCVSelectionDAGTargetInfo TSInfo;
 
   /// Initializes using the passed in CPU and feature strings so that we can
   /// use initializer lists for subtarget initialization.



More information about the llvm-commits mailing list