[llvm] [AMDGPU] Unify handling of BITOP3 operation (PR #132019)

Jakub Chlanda via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 19 06:10:37 PDT 2025


https://github.com/jchlanda created https://github.com/llvm/llvm-project/pull/132019

Abstract away the helper for `BITOP3` operation, supporting global and standard instruction selection.

>From 046e68fa38d5b135569e2666ecd7e89add3004e2 Mon Sep 17 00:00:00 2001
From: Jakub Chlanda <jakub at codeplay.com>
Date: Wed, 19 Mar 2025 13:00:41 +0000
Subject: [PATCH] [AMDGPU] Unify handling of BITOP3 operation

---
 .../lib/Target/AMDGPU/AMDGPUGlobalISelUtils.h | 212 +++++++++++++++++-
 llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp | 124 +---------
 .../AMDGPU/AMDGPUInstructionSelector.cpp      | 120 +---------
 3 files changed, 216 insertions(+), 240 deletions(-)

diff --git a/llvm/lib/Target/AMDGPU/AMDGPUGlobalISelUtils.h b/llvm/lib/Target/AMDGPU/AMDGPUGlobalISelUtils.h
index 70cfdacec700c..c200dcaba7c59 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUGlobalISelUtils.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUGlobalISelUtils.h
@@ -10,7 +10,10 @@
 #define LLVM_LIB_TARGET_AMDGPU_AMDGPUGLOBALISELUTILS_H
 
 #include "llvm/ADT/DenseSet.h"
+#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
+#include "llvm/CodeGen/GlobalISel/Utils.h"
 #include "llvm/CodeGen/Register.h"
+#include "llvm/CodeGen/SelectionDAGNodes.h"
 #include <utility>
 
 namespace llvm {
@@ -51,7 +54,214 @@ class IntrinsicLaneMaskAnalyzer {
 
 void buildReadAnyLane(MachineIRBuilder &B, Register SgprDst, Register VgprSrc,
                       const RegisterBankInfo &RBI);
+
+template <typename T> struct BitOp3Helper {
+  BitOp3Helper() = delete;
+  BitOp3Helper(const BitOp3Helper &) = delete;
+  BitOp3Helper &operator=(const BitOp3Helper &) = delete;
+};
+
+template <> struct BitOp3Helper<Register> {
+  BitOp3Helper(MachineRegisterInfo *MRI) : MRI(MRI) {}
+  bool isAllOnes(Register R) const {
+    return mi_match(R, *MRI, MIPatternMatch::m_AllOnesInt());
+  }
+  bool isZero(Register R) const {
+    return mi_match(R, *MRI, MIPatternMatch::m_ZeroInt());
+  }
+  bool isNot(Register R, Register &LHS) const {
+    if (mi_match(R, *MRI, m_Not(MIPatternMatch::m_Reg(LHS)))) {
+      LHS = getSrcRegIgnoringCopies(LHS, *MRI);
+      return true;
+    }
+    return false;
+  }
+  std::pair<Register, Register> getLHSRHS(Register R) {
+    MachineInstr *MI = MRI->getVRegDef(R);
+    auto LHS = getSrcRegIgnoringCopies(MI->getOperand(1).getReg(), *MRI);
+    auto RHS = getSrcRegIgnoringCopies(MI->getOperand(2).getReg(), *MRI);
+    return std::make_pair(LHS, RHS);
+  }
+  unsigned getOpcode(Register R) {
+    MachineInstr *MI = MRI->getVRegDef(R);
+    switch (MI->getOpcode()) {
+    case TargetOpcode::G_AND:
+      return ISD::AND;
+    case TargetOpcode::G_OR:
+      return ISD::OR;
+    case TargetOpcode::G_XOR:
+      return ISD::XOR;
+    default:
+      // Use DELETED_NODE as a notion of an unsupported value.
+      return ISD::DELETED_NODE;
+    }
+  }
+
+  MachineRegisterInfo *MRI;
+};
+
+template <> struct BitOp3Helper<SDValue> {
+  BitOp3Helper(const MachineRegisterInfo *MRI = nullptr) : MRI(MRI) {}
+  bool isAllOnes(SDValue Op) const {
+    if (auto *C = dyn_cast<ConstantSDNode>(Op))
+      if (C->isAllOnes())
+        return true;
+    return false;
+  }
+  bool isZero(SDValue Op) const {
+    if (auto *C = dyn_cast<ConstantSDNode>(Op))
+      if (C->isZero())
+        return true;
+    return false;
+  }
+  bool isNot(SDValue Op, SDValue &LHS) const {
+    if (Op.getOpcode() == ISD::XOR)
+      if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
+        if (C->isAllOnes()) {
+          LHS = Op.getOperand(0);
+          return true;
+        }
+    return false;
+  }
+  std::pair<SDValue, SDValue> getLHSRHS(SDValue In) {
+    auto LHS = In.getOperand(0);
+    auto RHS = In.getOperand(1);
+    return std::make_pair(LHS, RHS);
+  }
+  unsigned getOpcode(SDValue Op) {
+    switch (Op.getOpcode()) {
+    case ISD::AND:
+    case ISD::OR:
+    case ISD::XOR:
+      return Op.getOpcode();
+    default:
+      // Use DELETED_NODE as a notion of an unsupported value.
+      return ISD::DELETED_NODE;
+    }
+  }
+
+  [[maybe_unused]] const MachineRegisterInfo *MRI;
+};
+
+// Match BITOP3 operation and return a number of matched instructions plus
+// truth table.
+template <typename T>
+static std::pair<unsigned, uint8_t>
+BitOp3_Op(llvm::AMDGPU::BitOp3Helper<T> &Helper, T In, SmallVectorImpl<T> &Src) {
+  unsigned NumOpcodes = 0;
+  uint8_t LHSBits, RHSBits;
+
+  auto getOperandBits = [&Helper, &In, &Src](T Op, uint8_t &Bits) -> bool {
+    // Define truth table given Src0, Src1, Src2 bits permutations:
+    //                          0     0     0
+    //                          0     0     1
+    //                          0     1     0
+    //                          0     1     1
+    //                          1     0     0
+    //                          1     0     1
+    //                          1     1     0
+    //                          1     1     1
+    const uint8_t SrcBits[3] = {0xf0, 0xcc, 0xaa};
+
+    if (Helper.isAllOnes(Op)) {
+      Bits = 0xff;
+      return true;
+    }
+    if (Helper.isZero(Op)) {
+      Bits = 0;
+      return true;
+    }
+
+    for (unsigned I = 0; I < Src.size(); ++I) {
+      // Try to find existing reused operand
+      if (Src[I] == Op) {
+        Bits = SrcBits[I];
+        return true;
+      }
+      // Try to replace parent operator
+      if (Src[I] == In) {
+        Bits = SrcBits[I];
+        Src[I] = Op;
+        return true;
+      }
+    }
+
+    if (Src.size() == 3) {
+      // No room left for operands. Try one last time, there can be a 'not' of
+      // one of our source operands. In this case we can compute the bits
+      // without growing Src vector.
+      T LHS;
+      if (Helper.isNot(Op, LHS)) {
+        for (unsigned I = 0; I < Src.size(); ++I) {
+          if (Src[I] == LHS) {
+            Bits = ~SrcBits[I];
+            return true;
+          }
+        }
+      }
+
+      return false;
+    }
+
+    Bits = SrcBits[Src.size()];
+    Src.push_back(Op);
+    return true;
+  };
+
+  switch (Helper.getOpcode(In)) {
+  case ISD::AND:
+  case ISD::OR:
+  case ISD::XOR: {
+    auto LHSRHS = Helper.getLHSRHS(In);
+    T LHS = std::get<0>(LHSRHS);
+    T RHS = std::get<1>(LHSRHS);
+
+    SmallVector<T, 3> Backup(Src.begin(), Src.end());
+    if (!getOperandBits(LHS, LHSBits) || !getOperandBits(RHS, RHSBits)) {
+      Src = Backup;
+      return std::make_pair(0, 0);
+    }
+
+    // Recursion is naturally limited by the size of the operand vector.
+    auto LHSHelper = BitOp3Helper<decltype(LHS)>(Helper.MRI);
+    auto Op = BitOp3_Op(LHSHelper, LHS, Src);
+    if (Op.first) {
+      NumOpcodes += Op.first;
+      LHSBits = Op.second;
+    }
+
+    auto RHSHelper = BitOp3Helper<decltype(RHS)>(Helper.MRI);
+    Op = BitOp3_Op(RHSHelper, RHS, Src);
+    if (Op.first) {
+      NumOpcodes += Op.first;
+      RHSBits = Op.second;
+    }
+    break;
+  }
+  default:
+    return std::make_pair(0, 0);
+  }
+
+  uint8_t TTbl;
+  switch (Helper.getOpcode(In)) {
+  case ISD::AND:
+    TTbl = LHSBits & RHSBits;
+    break;
+  case ISD::OR:
+    TTbl = LHSBits | RHSBits;
+    break;
+  case ISD::XOR:
+    TTbl = LHSBits ^ RHSBits;
+    break;
+  default:
+    llvm_unreachable("Unhandled opcode");
+    break;
+  }
+
+  return std::make_pair(NumOpcodes + 1, TTbl);
 }
-}
+
+} // namespace AMDGPU
+} // namespace llvm
 
 #endif
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index 536bf0c208752..32f4cf07cf2a0 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -13,6 +13,7 @@
 
 #include "AMDGPUISelDAGToDAG.h"
 #include "AMDGPU.h"
+#include "AMDGPUGlobalISelUtils.h"
 #include "AMDGPUInstrInfo.h"
 #include "AMDGPUSubtarget.h"
 #include "AMDGPUTargetMachine.h"
@@ -3688,133 +3689,14 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixMods(SDValue In, SDValue &Src,
   return true;
 }
 
-// Match BITOP3 operation and return a number of matched instructions plus
-// truth table.
-static std::pair<unsigned, uint8_t> BitOp3_Op(SDValue In,
-                                              SmallVectorImpl<SDValue> &Src) {
-  unsigned NumOpcodes = 0;
-  uint8_t LHSBits, RHSBits;
-
-  auto getOperandBits = [&Src, In](SDValue Op, uint8_t &Bits) -> bool {
-    // Define truth table given Src0, Src1, Src2 bits permutations:
-    //                          0     0     0
-    //                          0     0     1
-    //                          0     1     0
-    //                          0     1     1
-    //                          1     0     0
-    //                          1     0     1
-    //                          1     1     0
-    //                          1     1     1
-    const uint8_t SrcBits[3] = { 0xf0, 0xcc, 0xaa };
-
-    if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
-      if (C->isAllOnes()) {
-        Bits = 0xff;
-        return true;
-      }
-      if (C->isZero()) {
-        Bits = 0;
-        return true;
-      }
-    }
-
-    for (unsigned I = 0; I < Src.size(); ++I) {
-      // Try to find existing reused operand
-      if (Src[I] == Op) {
-        Bits = SrcBits[I];
-        return true;
-      }
-      // Try to replace parent operator
-      if (Src[I] == In) {
-        Bits = SrcBits[I];
-        Src[I] = Op;
-        return true;
-      }
-    }
-
-    if (Src.size() == 3) {
-      // No room left for operands. Try one last time, there can be a 'not' of
-      // one of our source operands. In this case we can compute the bits
-      // without growing Src vector.
-      if (Op.getOpcode() == ISD::XOR) {
-        if (auto *C = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
-          if (C->isAllOnes()) {
-            SDValue LHS = Op.getOperand(0);
-            for (unsigned I = 0; I < Src.size(); ++I) {
-              if (Src[I] == LHS) {
-                Bits = ~SrcBits[I];
-                return true;
-              }
-            }
-          }
-        }
-      }
-
-      return false;
-    }
-
-    Bits = SrcBits[Src.size()];
-    Src.push_back(Op);
-    return true;
-  };
-
-  switch (In.getOpcode()) {
-  case ISD::AND:
-  case ISD::OR:
-  case ISD::XOR: {
-    SDValue LHS = In.getOperand(0);
-    SDValue RHS = In.getOperand(1);
-
-    SmallVector<SDValue, 3> Backup(Src.begin(), Src.end());
-    if (!getOperandBits(LHS, LHSBits) ||
-        !getOperandBits(RHS, RHSBits)) {
-      Src = Backup;
-      return std::make_pair(0, 0);
-    }
-
-    // Recursion is naturally limited by the size of the operand vector.
-    auto Op = BitOp3_Op(LHS, Src);
-    if (Op.first) {
-      NumOpcodes += Op.first;
-      LHSBits = Op.second;
-    }
-
-    Op = BitOp3_Op(RHS, Src);
-    if (Op.first) {
-      NumOpcodes += Op.first;
-      RHSBits = Op.second;
-    }
-    break;
-  }
-  default:
-    return std::make_pair(0, 0);
-  }
-
-  uint8_t TTbl;
-  switch (In.getOpcode()) {
-  case ISD::AND:
-    TTbl = LHSBits & RHSBits;
-    break;
-  case ISD::OR:
-    TTbl = LHSBits | RHSBits;
-    break;
-  case ISD::XOR:
-    TTbl = LHSBits ^ RHSBits;
-    break;
-  default:
-    break;
-  }
-
-  return std::make_pair(NumOpcodes + 1, TTbl);
-}
-
 bool AMDGPUDAGToDAGISel::SelectBITOP3(SDValue In, SDValue &Src0, SDValue &Src1,
                                       SDValue &Src2, SDValue &Tbl) const {
   SmallVector<SDValue, 3> Src;
   uint8_t TTbl;
   unsigned NumOpcodes;
 
-  std::tie(NumOpcodes, TTbl) = BitOp3_Op(In, Src);
+  auto Helper = AMDGPU::BitOp3Helper<SDValue>();
+  std::tie(NumOpcodes, TTbl) = AMDGPU::BitOp3_Op(Helper, In, Src);
 
   // Src.empty() case can happen if all operands are all zero or all ones.
   // Normally it shall be optimized out before reaching this.
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
index 745621fc1e089..da2a6b96779a6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
@@ -3759,123 +3759,6 @@ bool AMDGPUInstructionSelector::selectWaveAddress(MachineInstr &MI) const {
   return true;
 }
 
-// Match BITOP3 operation and return a number of matched instructions plus
-// truth table.
-static std::pair<unsigned, uint8_t> BitOp3_Op(Register R,
-                                              SmallVectorImpl<Register> &Src,
-                                              const MachineRegisterInfo &MRI) {
-  unsigned NumOpcodes = 0;
-  uint8_t LHSBits, RHSBits;
-
-  auto getOperandBits = [&Src, R, &MRI](Register Op, uint8_t &Bits) -> bool {
-    // Define truth table given Src0, Src1, Src2 bits permutations:
-    //                          0     0     0
-    //                          0     0     1
-    //                          0     1     0
-    //                          0     1     1
-    //                          1     0     0
-    //                          1     0     1
-    //                          1     1     0
-    //                          1     1     1
-    const uint8_t SrcBits[3] = { 0xf0, 0xcc, 0xaa };
-
-    if (mi_match(Op, MRI, m_AllOnesInt())) {
-      Bits = 0xff;
-      return true;
-    }
-    if (mi_match(Op, MRI, m_ZeroInt())) {
-      Bits = 0;
-      return true;
-    }
-
-    for (unsigned I = 0; I < Src.size(); ++I) {
-      // Try to find existing reused operand
-      if (Src[I] == Op) {
-        Bits = SrcBits[I];
-        return true;
-      }
-      // Try to replace parent operator
-      if (Src[I] == R) {
-        Bits = SrcBits[I];
-        Src[I] = Op;
-        return true;
-      }
-    }
-
-    if (Src.size() == 3) {
-      // No room left for operands. Try one last time, there can be a 'not' of
-      // one of our source operands. In this case we can compute the bits
-      // without growing Src vector.
-      Register LHS;
-      if (mi_match(Op, MRI, m_Not(m_Reg(LHS)))) {
-        LHS = getSrcRegIgnoringCopies(LHS, MRI);
-        for (unsigned I = 0; I < Src.size(); ++I) {
-          if (Src[I] == LHS) {
-            Bits = ~SrcBits[I];
-            return true;
-          }
-        }
-      }
-
-      return false;
-    }
-
-    Bits = SrcBits[Src.size()];
-    Src.push_back(Op);
-    return true;
-  };
-
-  MachineInstr *MI = MRI.getVRegDef(R);
-  switch (MI->getOpcode()) {
-  case TargetOpcode::G_AND:
-  case TargetOpcode::G_OR:
-  case TargetOpcode::G_XOR: {
-    Register LHS = getSrcRegIgnoringCopies(MI->getOperand(1).getReg(), MRI);
-    Register RHS = getSrcRegIgnoringCopies(MI->getOperand(2).getReg(), MRI);
-
-    SmallVector<Register, 3> Backup(Src.begin(), Src.end());
-    if (!getOperandBits(LHS, LHSBits) ||
-        !getOperandBits(RHS, RHSBits)) {
-      Src = Backup;
-      return std::make_pair(0, 0);
-    }
-
-    // Recursion is naturally limited by the size of the operand vector.
-    auto Op = BitOp3_Op(LHS, Src, MRI);
-    if (Op.first) {
-      NumOpcodes += Op.first;
-      LHSBits = Op.second;
-    }
-
-    Op = BitOp3_Op(RHS, Src, MRI);
-    if (Op.first) {
-      NumOpcodes += Op.first;
-      RHSBits = Op.second;
-    }
-    break;
-  }
-  default:
-    return std::make_pair(0, 0);
-  }
-
-  uint8_t TTbl;
-  switch (MI->getOpcode()) {
-  case TargetOpcode::G_AND:
-    TTbl = LHSBits & RHSBits;
-    break;
-  case TargetOpcode::G_OR:
-    TTbl = LHSBits | RHSBits;
-    break;
-  case TargetOpcode::G_XOR:
-    TTbl = LHSBits ^ RHSBits;
-    break;
-  default:
-    break;
-  }
-
-  return std::make_pair(NumOpcodes + 1, TTbl);
-}
-
 bool AMDGPUInstructionSelector::selectBITOP3(MachineInstr &MI) const {
   if (!Subtarget->hasBitOp3Insts())
     return false;
@@ -3890,7 +3773,8 @@ bool AMDGPUInstructionSelector::selectBITOP3(MachineInstr &MI) const {
   uint8_t TTbl;
   unsigned NumOpcodes;
 
-  std::tie(NumOpcodes, TTbl) = BitOp3_Op(DstReg, Src, *MRI);
+  auto Helper = AMDGPU::BitOp3Helper<Register>(MRI);
+  std::tie(NumOpcodes, TTbl) = AMDGPU::BitOp3_Op(Helper, DstReg, Src);
 
   // Src.empty() case can happen if all operands are all zero or all ones.
   // Normally it shall be optimized out before reaching this.



More information about the llvm-commits mailing list