[llvm] [AMDGPU] MCExpr printing helper with KnownBits support (PR #95951)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 13 13:06:53 PDT 2024


================
@@ -303,3 +305,369 @@ const AMDGPUMCExpr *AMDGPUMCExpr::createOccupancy(unsigned InitOcc,
                  CreateExpr(InitOcc), NumSGPRs, NumVGPRs},
                 Ctx);
 }
+
+static KnownBits fromOptionalToKnownBits(std::optional<bool> CompareResult) {
+  static constexpr unsigned BitWidth = 64;
+  const APInt True(BitWidth, 1, /*isSigned=*/false);
+  const APInt False(BitWidth, 0, /*isSigned=*/false);
+  if (CompareResult) {
+    return *CompareResult ? KnownBits::makeConstant(True)
+                          : KnownBits::makeConstant(False);
+  }
+
+  KnownBits UnknownBool(/*BitWidth=*/1);
+  return UnknownBool.zext(BitWidth);
+}
+
+using KnownBitsMap = DenseMap<const MCExpr *, KnownBits>;
+void knownBitsMapHelper(const MCExpr *Expr, KnownBitsMap &KBM,
+                        unsigned Depth = 0);
+
+void binaryOpKnownBitsMapHelper(const MCExpr *Expr, KnownBitsMap &KBM,
+                                unsigned Depth) {
+  static constexpr unsigned BitWidth = 64;
+  const MCBinaryExpr *BExpr = cast<MCBinaryExpr>(Expr);
+  const MCExpr *LHS = BExpr->getLHS();
+  const MCExpr *RHS = BExpr->getRHS();
+
+  knownBitsMapHelper(LHS, KBM, Depth + 1);
+  knownBitsMapHelper(RHS, KBM, Depth + 1);
+  KnownBits LHSKnown = KBM[LHS];
+  KnownBits RHSKnown = KBM[RHS];
+  std::optional<bool> CompareRes;
+
+  switch (BExpr->getOpcode()) {
+  default:
+    KBM[Expr] = KnownBits(BitWidth);
+    return;
+  case MCBinaryExpr::Opcode::Add: {
+    KBM[Expr] = KnownBits::add(LHSKnown, RHSKnown);
+    return;
+  }
+  case MCBinaryExpr::Opcode::And:
+    KBM[Expr] = LHSKnown & RHSKnown;
+    return;
+  case MCBinaryExpr::Opcode::Div:
+    KBM[Expr] = KnownBits::sdiv(LHSKnown, RHSKnown);
+    return;
+  case MCBinaryExpr::Opcode::EQ:
+    CompareRes = KnownBits::eq(LHSKnown, RHSKnown);
+    KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+    return;
+  case MCBinaryExpr::Opcode::NE:
+    CompareRes = KnownBits::ne(LHSKnown, RHSKnown);
+    KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+    return;
+  case MCBinaryExpr::Opcode::GT:
+    CompareRes = KnownBits::sgt(LHSKnown, RHSKnown);
+    KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+    return;
+  case MCBinaryExpr::Opcode::GTE:
+    CompareRes = KnownBits::sge(LHSKnown, RHSKnown);
+    KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+    return;
+  case MCBinaryExpr::Opcode::LAnd: {
+    const APInt False(BitWidth, 0, /*isSigned=*/false);
+    std::optional<bool> LHSBool =
+        KnownBits::ne(LHSKnown, KnownBits::makeConstant(False));
+    std::optional<bool> RHSBool =
+        KnownBits::ne(RHSKnown, KnownBits::makeConstant(False));
+    if (LHSBool && RHSBool)
+      CompareRes = *LHSBool && *RHSBool;
+    KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+    return;
+  }
+  case MCBinaryExpr::Opcode::LOr: {
+    const APInt False(BitWidth, 0, /*isSigned=*/false);
+    KnownBits Bits = LHSKnown | RHSKnown;
+    CompareRes = KnownBits::ne(Bits, KnownBits::makeConstant(False));
+    KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+    return;
+  }
+  case MCBinaryExpr::Opcode::LT:
+    CompareRes = KnownBits::slt(LHSKnown, RHSKnown);
+    KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+    return;
+  case MCBinaryExpr::Opcode::LTE:
+    CompareRes = KnownBits::sle(LHSKnown, RHSKnown);
+    KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+    return;
+  case MCBinaryExpr::Opcode::Mod:
+    KBM[Expr] = KnownBits::srem(LHSKnown, RHSKnown);
+    return;
+  case MCBinaryExpr::Opcode::Mul:
+    KBM[Expr] = KnownBits::mul(LHSKnown, RHSKnown);
+    return;
+  case MCBinaryExpr::Opcode::Or:
+    KBM[Expr] = LHSKnown | RHSKnown;
+    return;
+  case MCBinaryExpr::Opcode::Shl:
+    KBM[Expr] = KnownBits::shl(LHSKnown, RHSKnown);
+    return;
+  case MCBinaryExpr::Opcode::AShr:
+    KBM[Expr] = KnownBits::ashr(LHSKnown, RHSKnown);
+    return;
+  case MCBinaryExpr::Opcode::LShr:
+    KBM[Expr] = KnownBits::lshr(LHSKnown, RHSKnown);
+    return;
+  case MCBinaryExpr::Opcode::Sub: {
+    KBM[Expr] = KnownBits::sub(LHSKnown, RHSKnown);
+    return;
+  }
+  case MCBinaryExpr::Opcode::Xor:
+    KBM[Expr] = LHSKnown ^ RHSKnown;
+    return;
+  }
+}
+
+void unaryOpKnownBitsMapHelper(const MCExpr *Expr, KnownBitsMap &KBM,
+                               unsigned Depth) {
+  static constexpr unsigned BitWidth = 64;
+  const MCUnaryExpr *UExpr = cast<MCUnaryExpr>(Expr);
+  knownBitsMapHelper(UExpr->getSubExpr(), KBM, Depth + 1);
+  KnownBits KB = KBM[UExpr->getSubExpr()];
+
+  switch (UExpr->getOpcode()) {
+  default:
+    KBM[Expr] = KnownBits(BitWidth);
+    return;
+  case MCUnaryExpr::Opcode::Minus: {
+    KB.makeNegative();
+    KBM[Expr] = KB;
+    return;
+  }
+  case MCUnaryExpr::Opcode::Not: {
+    KnownBits AllOnes(BitWidth);
+    AllOnes.setAllOnes();
+    KBM[Expr] = KB ^ AllOnes;
+    return;
+  }
+  case MCUnaryExpr::Opcode::Plus: {
+    KB.makeNonNegative();
+    KBM[Expr] = KB;
+    return;
+  }
+  }
+}
+
+void targetOpKnownBitsMapHelper(const MCExpr *Expr, KnownBitsMap &KBM,
----------------
arsenm wrote:

Can you just move this up to avoid the forward declare. Also static 

https://github.com/llvm/llvm-project/pull/95951


More information about the llvm-commits mailing list