[llvm] [AMDGPU] MCExpr printing helper with KnownBits support (PR #95951)

Pierre van Houtryve via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 13 04:29:35 PDT 2024


================
@@ -303,3 +305,360 @@ const AMDGPUMCExpr *AMDGPUMCExpr::createOccupancy(unsigned InitOcc,
                  CreateExpr(InitOcc), NumSGPRs, NumVGPRs},
                 Ctx);
 }
+
+static KnownBits fromOptionalToKnownBits(std::optional<bool> CompareResult) {
+  const unsigned BitWidth = 64;
+  const APInt True(BitWidth, 1, /*isSigned=*/false);
+  const APInt False(BitWidth, 0, /*isSigned=*/false);
+  if (CompareResult) {
+    return *CompareResult ? KnownBits::makeConstant(True)
+                          : KnownBits::makeConstant(False);
+  }
+
+  KnownBits UnknownBool(/*BitWidth=*/1);
+  return UnknownBool.zext(BitWidth);
+}
+
+using KnownBitsMap = std::unordered_map<const MCExpr *, KnownBits>;
+
+void KnownBitsMapHelper(const MCExpr *Expr, KnownBitsMap &KBM, unsigned depth) {
+  const unsigned BitWidth = 64;
+  const APInt False(BitWidth, 0, /*isSigned=*/false);
+
+  int64_t Val;
+  if (Expr->evaluateAsAbsolute(Val)) {
+    APInt APValue(BitWidth, Val, /*isSigned=*/true);
+    KBM[Expr] = KnownBits::makeConstant(APValue);
+    return;
+  }
+
+  if (depth == 0) {
+    KBM[Expr] = KnownBits(BitWidth);
+    return;
+  }
+
+  depth--;
+
+  switch (Expr->getKind()) {
+  case MCExpr::ExprKind::Binary: {
+    const MCBinaryExpr *BExpr = cast<MCBinaryExpr>(Expr);
+    const MCExpr *LHS = BExpr->getLHS();
+    const MCExpr *RHS = BExpr->getRHS();
+
+    KnownBitsMapHelper(LHS, KBM, depth);
+    KnownBitsMapHelper(RHS, KBM, depth);
+    KnownBits LHSKnown = KBM[LHS];
+    KnownBits RHSKnown = KBM[RHS];
+    std::optional<bool> CompareRes;
+
+    switch (BExpr->getOpcode()) {
+    default:
+      KBM[Expr] = KnownBits(BitWidth);
+      return;
+    case MCBinaryExpr::Opcode::Add:
+      KBM[Expr] =
+          KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
+                                      /*NUW=*/false, LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::And:
+      KBM[Expr] = LHSKnown & RHSKnown;
+      return;
+    case MCBinaryExpr::Opcode::Div:
+      KBM[Expr] = KnownBits::sdiv(LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::EQ:
+      CompareRes = KnownBits::eq(LHSKnown, RHSKnown);
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    case MCBinaryExpr::Opcode::NE:
+      CompareRes = KnownBits::ne(LHSKnown, RHSKnown);
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    case MCBinaryExpr::Opcode::GT:
+      CompareRes = KnownBits::sgt(LHSKnown, RHSKnown);
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    case MCBinaryExpr::Opcode::GTE:
+      CompareRes = KnownBits::sge(LHSKnown, RHSKnown);
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    case MCBinaryExpr::Opcode::LAnd: {
+      std::optional<bool> LHSBool =
+          KnownBits::ne(LHSKnown, KnownBits::makeConstant(False));
+      std::optional<bool> RHSBool =
+          KnownBits::ne(RHSKnown, KnownBits::makeConstant(False));
+      if (LHSBool && RHSBool) {
+        CompareRes = *LHSBool && *RHSBool;
+      }
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    }
+    case MCBinaryExpr::Opcode::LOr: {
+      KnownBits Bits = LHSKnown | RHSKnown;
+      CompareRes = KnownBits::ne(Bits, KnownBits::makeConstant(False));
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    }
+    case MCBinaryExpr::Opcode::LT:
+      CompareRes = KnownBits::slt(LHSKnown, RHSKnown);
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    case MCBinaryExpr::Opcode::LTE:
+      CompareRes = KnownBits::sle(LHSKnown, RHSKnown);
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    case MCBinaryExpr::Opcode::Mod:
+      KBM[Expr] = KnownBits::srem(LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::Mul:
+      KBM[Expr] = KnownBits::mul(LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::Or:
+      KBM[Expr] = LHSKnown | RHSKnown;
+      return;
+    case MCBinaryExpr::Opcode::Shl:
+      KBM[Expr] = KnownBits::shl(LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::AShr:
+      KBM[Expr] = KnownBits::ashr(LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::LShr:
+      KBM[Expr] = KnownBits::lshr(LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::Sub:
+      KBM[Expr] =
+          KnownBits::computeForAddSub(/*Add=*/false, /*NSW=*/false,
+                                      /*NUW=*/false, LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::Xor:
+      KBM[Expr] = LHSKnown ^ RHSKnown;
+      return;
+    }
+  }
+  case MCExpr::ExprKind::Constant: {
+    const MCConstantExpr *CE = cast<MCConstantExpr>(Expr);
+    APInt APValue(BitWidth, CE->getValue(), /*isSigned=*/true);
+    KBM[Expr] = KnownBits::makeConstant(APValue);
+    return;
+  }
+  case MCExpr::ExprKind::SymbolRef: {
+    const MCSymbolRefExpr *RExpr = cast<MCSymbolRefExpr>(Expr);
+    const MCSymbol &Sym = RExpr->getSymbol();
+    if (!Sym.isVariable()) {
+      KBM[Expr] = KnownBits(BitWidth);
+      return;
+    }
+
+    // Variable value retrieval is not for actual use but only for knownbits
+    // analysis.
+    KnownBitsMapHelper(Sym.getVariableValue(/*SetUsed=*/false), KBM, depth);
+    KBM[Expr] = KBM[Sym.getVariableValue(/*SetUsed=*/false)];
+    return;
+  }
+  case MCExpr::ExprKind::Unary: {
+    const MCUnaryExpr *UExpr = cast<MCUnaryExpr>(Expr);
+    KnownBitsMapHelper(UExpr->getSubExpr(), KBM, depth);
+    KnownBits KB = KBM[UExpr->getSubExpr()];
+
+    switch (UExpr->getOpcode()) {
+    default:
+      KBM[Expr] = KnownBits(BitWidth);
+      return;
+    case MCUnaryExpr::Opcode::Minus: {
+      KB.makeNegative();
+      KBM[Expr] = KB;
+      return;
+    }
+    case MCUnaryExpr::Opcode::Not: {
+      KnownBits AllOnes(BitWidth);
+      AllOnes.setAllOnes();
+      KBM[Expr] = KB ^ AllOnes;
+      return;
+    }
+    case MCUnaryExpr::Opcode::Plus: {
+      KB.makeNonNegative();
+      KBM[Expr] = KB;
+      return;
+    }
+    }
+  }
+  case MCExpr::ExprKind::Target: {
+    const AMDGPUMCExpr *AGVK = cast<AMDGPUMCExpr>(Expr);
+
+    switch (AGVK->getKind()) {
+    default:
+      KBM[Expr] = KnownBits(BitWidth);
+      return;
+    case AMDGPUMCExpr::VariantKind::AGVK_Or: {
+      KnownBitsMapHelper(AGVK->getSubExpr(0), KBM, depth);
+      KnownBits KB = KBM[AGVK->getSubExpr(0)];
+      for (const MCExpr *Arg : AGVK->getArgs()) {
+        KnownBitsMapHelper(Arg, KBM, depth);
+        KB |= KBM[Arg];
+      }
+      KBM[Expr] = KB;
+      return;
+    }
+    case AMDGPUMCExpr::VariantKind::AGVK_Max: {
+      KnownBitsMapHelper(AGVK->getSubExpr(0), KBM, depth);
+      KnownBits KB = KBM[AGVK->getSubExpr(0)];
+      for (const MCExpr *Arg : AGVK->getArgs()) {
+        KnownBitsMapHelper(Arg, KBM, depth);
+        KB = KnownBits::umax(KB, KBM[Arg]);
+      }
+      KBM[Expr] = KB;
+      return;
+    }
+    case AMDGPUMCExpr::VariantKind::AGVK_ExtraSGPRs:
+    case AMDGPUMCExpr::VariantKind::AGVK_TotalNumVGPRs:
+    case AMDGPUMCExpr::VariantKind::AGVK_AlignTo:
+    case AMDGPUMCExpr::VariantKind::AGVK_Occupancy: {
+      int64_t Val;
+      if (AGVK->evaluateAsAbsolute(Val)) {
+        APInt APValue(BitWidth, Val, /*isSigned=*/false);
+        KBM[Expr] = KnownBits::makeConstant(APValue);
+        return;
+      }
+      KBM[Expr] = KnownBits(BitWidth);
+      return;
+    }
+    }
+  }
+  }
+}
+
+static const MCExpr *TryFoldHelper(const MCExpr *Expr, KnownBitsMap &KBM,
+                                   MCContext &Ctx) {
+  if (!KBM.count(Expr))
+    return Expr;
+
+  auto valueCheckKnownBits = [](KnownBits &KB, unsigned Value) -> bool {
----------------
Pierre-vh wrote:

```suggestion
  auto ValueCheckKnownBits = [](KnownBits &KB, unsigned Value) -> bool {
```

https://github.com/llvm/llvm-project/pull/95951


More information about the llvm-commits mailing list