[llvm] AMDGPU/GlobalISelDivergenceLowering: select divergent i1 phis (PR #78482)

Pierre van Houtryve via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 22 05:35:14 PST 2024


================
@@ -42,14 +45,152 @@ class AMDGPUGlobalISelDivergenceLowering : public MachineFunctionPass {
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesCFG();
+    AU.addRequired<MachineDominatorTree>();
+    AU.addRequired<MachinePostDominatorTree>();
+    AU.addRequired<MachineUniformityAnalysisPass>();
     MachineFunctionPass::getAnalysisUsage(AU);
   }
 };
 
+class DivergenceLoweringHelper : public PhiLoweringHelper {
+public:
+  DivergenceLoweringHelper(MachineFunction *MF, MachineDominatorTree *DT,
+                           MachinePostDominatorTree *PDT,
+                           MachineUniformityInfo *MUI);
+
+private:
+  MachineUniformityInfo *MUI = nullptr;
+
+public:
+  void markAsLaneMask(Register DstReg) const override;
+  void getCandidatesForLowering(
+      SmallVectorImpl<MachineInstr *> &Vreg1Phis) const override;
+  void collectIncomingValuesFromPhi(
+      const MachineInstr *MI,
+      SmallVectorImpl<Incoming> &Incomings) const override;
+  void replaceDstReg(Register NewReg, Register OldReg,
+                     MachineBasicBlock *MBB) override;
+  void buildMergeLaneMasks(MachineBasicBlock &MBB,
+                           MachineBasicBlock::iterator I, const DebugLoc &DL,
+                           Register DstReg, Register PrevReg,
+                           Register CurReg) override;
+  void constrainAsLaneMask(Incoming &In) override;
+};
+
+DivergenceLoweringHelper::DivergenceLoweringHelper(
+    MachineFunction *MF, MachineDominatorTree *DT,
+    MachinePostDominatorTree *PDT, MachineUniformityInfo *MUI)
+    : PhiLoweringHelper(MF, DT, PDT), MUI(MUI) {}
+
+// _(s1) -> SReg_32/64(s1)
+void DivergenceLoweringHelper::markAsLaneMask(Register DstReg) const {
+  assert(MRI->getType(DstReg) == LLT::scalar(1));
+
+  if (MRI->getRegClassOrNull(DstReg)) {
+    MRI->constrainRegClass(DstReg, ST->getBoolRC());
+    return;
+  }
+
+  MRI->setRegClass(DstReg, ST->getBoolRC());
----------------
Pierre-vh wrote:

I think you need to check the return result of `constrainRegClass` as it can return `nullptr` on failure

https://github.com/llvm/llvm-project/pull/78482


More information about the llvm-commits mailing list