[llvm] [AMDGPU] MCExpr printing helper with KnownBits support (PR #95951)
Pierre van Houtryve via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 13 04:29:35 PDT 2024
================
@@ -303,3 +305,360 @@ const AMDGPUMCExpr *AMDGPUMCExpr::createOccupancy(unsigned InitOcc,
CreateExpr(InitOcc), NumSGPRs, NumVGPRs},
Ctx);
}
+
+static KnownBits fromOptionalToKnownBits(std::optional<bool> CompareResult) {
+ const unsigned BitWidth = 64;
+ const APInt True(BitWidth, 1, /*isSigned=*/false);
+ const APInt False(BitWidth, 0, /*isSigned=*/false);
+ if (CompareResult) {
+ return *CompareResult ? KnownBits::makeConstant(True)
+ : KnownBits::makeConstant(False);
+ }
+
+ KnownBits UnknownBool(/*BitWidth=*/1);
+ return UnknownBool.zext(BitWidth);
+}
+
+using KnownBitsMap = std::unordered_map<const MCExpr *, KnownBits>;
+
+void KnownBitsMapHelper(const MCExpr *Expr, KnownBitsMap &KBM, unsigned depth) {
+ const unsigned BitWidth = 64;
+ const APInt False(BitWidth, 0, /*isSigned=*/false);
+
+ int64_t Val;
+ if (Expr->evaluateAsAbsolute(Val)) {
+ APInt APValue(BitWidth, Val, /*isSigned=*/true);
+ KBM[Expr] = KnownBits::makeConstant(APValue);
+ return;
+ }
+
+ if (depth == 0) {
+ KBM[Expr] = KnownBits(BitWidth);
+ return;
+ }
+
+ depth--;
+
+ switch (Expr->getKind()) {
+ case MCExpr::ExprKind::Binary: {
+ const MCBinaryExpr *BExpr = cast<MCBinaryExpr>(Expr);
+ const MCExpr *LHS = BExpr->getLHS();
+ const MCExpr *RHS = BExpr->getRHS();
+
+ KnownBitsMapHelper(LHS, KBM, depth);
+ KnownBitsMapHelper(RHS, KBM, depth);
+ KnownBits LHSKnown = KBM[LHS];
+ KnownBits RHSKnown = KBM[RHS];
+ std::optional<bool> CompareRes;
+
+ switch (BExpr->getOpcode()) {
+ default:
+ KBM[Expr] = KnownBits(BitWidth);
+ return;
+ case MCBinaryExpr::Opcode::Add:
+ KBM[Expr] =
+ KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
+ /*NUW=*/false, LHSKnown, RHSKnown);
+ return;
+ case MCBinaryExpr::Opcode::And:
+ KBM[Expr] = LHSKnown & RHSKnown;
+ return;
+ case MCBinaryExpr::Opcode::Div:
+ KBM[Expr] = KnownBits::sdiv(LHSKnown, RHSKnown);
+ return;
+ case MCBinaryExpr::Opcode::EQ:
+ CompareRes = KnownBits::eq(LHSKnown, RHSKnown);
+ KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+ return;
+ case MCBinaryExpr::Opcode::NE:
+ CompareRes = KnownBits::ne(LHSKnown, RHSKnown);
+ KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+ return;
+ case MCBinaryExpr::Opcode::GT:
+ CompareRes = KnownBits::sgt(LHSKnown, RHSKnown);
+ KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+ return;
+ case MCBinaryExpr::Opcode::GTE:
+ CompareRes = KnownBits::sge(LHSKnown, RHSKnown);
+ KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+ return;
+ case MCBinaryExpr::Opcode::LAnd: {
+ std::optional<bool> LHSBool =
+ KnownBits::ne(LHSKnown, KnownBits::makeConstant(False));
+ std::optional<bool> RHSBool =
+ KnownBits::ne(RHSKnown, KnownBits::makeConstant(False));
+ if (LHSBool && RHSBool) {
+ CompareRes = *LHSBool && *RHSBool;
+ }
+ KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+ return;
+ }
+ case MCBinaryExpr::Opcode::LOr: {
+ KnownBits Bits = LHSKnown | RHSKnown;
+ CompareRes = KnownBits::ne(Bits, KnownBits::makeConstant(False));
+ KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+ return;
+ }
+ case MCBinaryExpr::Opcode::LT:
+ CompareRes = KnownBits::slt(LHSKnown, RHSKnown);
+ KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+ return;
+ case MCBinaryExpr::Opcode::LTE:
+ CompareRes = KnownBits::sle(LHSKnown, RHSKnown);
+ KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+ return;
+ case MCBinaryExpr::Opcode::Mod:
+ KBM[Expr] = KnownBits::srem(LHSKnown, RHSKnown);
+ return;
+ case MCBinaryExpr::Opcode::Mul:
+ KBM[Expr] = KnownBits::mul(LHSKnown, RHSKnown);
+ return;
+ case MCBinaryExpr::Opcode::Or:
+ KBM[Expr] = LHSKnown | RHSKnown;
+ return;
+ case MCBinaryExpr::Opcode::Shl:
+ KBM[Expr] = KnownBits::shl(LHSKnown, RHSKnown);
+ return;
+ case MCBinaryExpr::Opcode::AShr:
+ KBM[Expr] = KnownBits::ashr(LHSKnown, RHSKnown);
+ return;
+ case MCBinaryExpr::Opcode::LShr:
+ KBM[Expr] = KnownBits::lshr(LHSKnown, RHSKnown);
+ return;
+ case MCBinaryExpr::Opcode::Sub:
+ KBM[Expr] =
+ KnownBits::computeForAddSub(/*Add=*/false, /*NSW=*/false,
+ /*NUW=*/false, LHSKnown, RHSKnown);
+ return;
+ case MCBinaryExpr::Opcode::Xor:
+ KBM[Expr] = LHSKnown ^ RHSKnown;
+ return;
+ }
+ }
+ case MCExpr::ExprKind::Constant: {
+ const MCConstantExpr *CE = cast<MCConstantExpr>(Expr);
+ APInt APValue(BitWidth, CE->getValue(), /*isSigned=*/true);
+ KBM[Expr] = KnownBits::makeConstant(APValue);
+ return;
+ }
+ case MCExpr::ExprKind::SymbolRef: {
+ const MCSymbolRefExpr *RExpr = cast<MCSymbolRefExpr>(Expr);
+ const MCSymbol &Sym = RExpr->getSymbol();
+ if (!Sym.isVariable()) {
+ KBM[Expr] = KnownBits(BitWidth);
+ return;
+ }
+
+ // Variable value retrieval is not for actual use but only for knownbits
+ // analysis.
+ KnownBitsMapHelper(Sym.getVariableValue(/*SetUsed=*/false), KBM, depth);
+ KBM[Expr] = KBM[Sym.getVariableValue(/*SetUsed=*/false)];
+ return;
+ }
+ case MCExpr::ExprKind::Unary: {
+ const MCUnaryExpr *UExpr = cast<MCUnaryExpr>(Expr);
+ KnownBitsMapHelper(UExpr->getSubExpr(), KBM, depth);
+ KnownBits KB = KBM[UExpr->getSubExpr()];
+
+ switch (UExpr->getOpcode()) {
+ default:
+ KBM[Expr] = KnownBits(BitWidth);
+ return;
+ case MCUnaryExpr::Opcode::Minus: {
+ KB.makeNegative();
+ KBM[Expr] = KB;
+ return;
+ }
+ case MCUnaryExpr::Opcode::Not: {
+ KnownBits AllOnes(BitWidth);
+ AllOnes.setAllOnes();
+ KBM[Expr] = KB ^ AllOnes;
+ return;
+ }
+ case MCUnaryExpr::Opcode::Plus: {
+ KB.makeNonNegative();
+ KBM[Expr] = KB;
+ return;
+ }
+ }
+ }
+ case MCExpr::ExprKind::Target: {
+ const AMDGPUMCExpr *AGVK = cast<AMDGPUMCExpr>(Expr);
+
+ switch (AGVK->getKind()) {
+ default:
+ KBM[Expr] = KnownBits(BitWidth);
+ return;
+ case AMDGPUMCExpr::VariantKind::AGVK_Or: {
+ KnownBitsMapHelper(AGVK->getSubExpr(0), KBM, depth);
+ KnownBits KB = KBM[AGVK->getSubExpr(0)];
+ for (const MCExpr *Arg : AGVK->getArgs()) {
+ KnownBitsMapHelper(Arg, KBM, depth);
+ KB |= KBM[Arg];
+ }
+ KBM[Expr] = KB;
+ return;
+ }
+ case AMDGPUMCExpr::VariantKind::AGVK_Max: {
+ KnownBitsMapHelper(AGVK->getSubExpr(0), KBM, depth);
+ KnownBits KB = KBM[AGVK->getSubExpr(0)];
+ for (const MCExpr *Arg : AGVK->getArgs()) {
+ KnownBitsMapHelper(Arg, KBM, depth);
+ KB = KnownBits::umax(KB, KBM[Arg]);
+ }
+ KBM[Expr] = KB;
+ return;
+ }
+ case AMDGPUMCExpr::VariantKind::AGVK_ExtraSGPRs:
+ case AMDGPUMCExpr::VariantKind::AGVK_TotalNumVGPRs:
+ case AMDGPUMCExpr::VariantKind::AGVK_AlignTo:
+ case AMDGPUMCExpr::VariantKind::AGVK_Occupancy: {
+ int64_t Val;
+ if (AGVK->evaluateAsAbsolute(Val)) {
+ APInt APValue(BitWidth, Val, /*isSigned=*/false);
+ KBM[Expr] = KnownBits::makeConstant(APValue);
+ return;
+ }
+ KBM[Expr] = KnownBits(BitWidth);
+ return;
+ }
+ }
+ }
+ }
+}
+
+static const MCExpr *TryFoldHelper(const MCExpr *Expr, KnownBitsMap &KBM,
+ MCContext &Ctx) {
+ if (!KBM.count(Expr))
+ return Expr;
+
+ auto valueCheckKnownBits = [](KnownBits &KB, unsigned Value) -> bool {
----------------
Pierre-vh wrote:
```suggestion
auto ValueCheckKnownBits = [](KnownBits &KB, unsigned Value) -> bool {
```
https://github.com/llvm/llvm-project/pull/95951
More information about the llvm-commits
mailing list