[llvm] [AMDGPU] MCExpr printing helper with KnownBits support (PR #95951)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 12 12:19:57 PDT 2024


================
@@ -303,3 +305,360 @@ const AMDGPUMCExpr *AMDGPUMCExpr::createOccupancy(unsigned InitOcc,
                  CreateExpr(InitOcc), NumSGPRs, NumVGPRs},
                 Ctx);
 }
+
+static KnownBits fromOptionalToKnownBits(std::optional<bool> CompareResult) {
+  const unsigned BitWidth = 64;
+  const APInt True(BitWidth, 1, /*isSigned=*/false);
+  const APInt False(BitWidth, 0, /*isSigned=*/false);
+  if (CompareResult) {
+    return *CompareResult ? KnownBits::makeConstant(True)
+                          : KnownBits::makeConstant(False);
+  }
+
+  KnownBits UnknownBool(/*BitWidth=*/1);
+  return UnknownBool.zext(BitWidth);
+}
+
+using KnownBitsMap = std::unordered_map<const MCExpr *, KnownBits>;
+
+void KnownBitsMapHelper(const MCExpr *Expr, KnownBitsMap &KBM, unsigned depth) {
+  const unsigned BitWidth = 64;
+  const APInt False(BitWidth, 0, /*isSigned=*/false);
+
+  int64_t Val;
+  if (Expr->evaluateAsAbsolute(Val)) {
+    APInt APValue(BitWidth, Val, /*isSigned=*/true);
+    KBM[Expr] = KnownBits::makeConstant(APValue);
+    return;
+  }
+
+  if (depth == 0) {
+    KBM[Expr] = KnownBits(BitWidth);
+    return;
+  }
+
+  depth--;
+
+  switch (Expr->getKind()) {
+  case MCExpr::ExprKind::Binary: {
+    const MCBinaryExpr *BExpr = cast<MCBinaryExpr>(Expr);
+    const MCExpr *LHS = BExpr->getLHS();
+    const MCExpr *RHS = BExpr->getRHS();
+
+    KnownBitsMapHelper(LHS, KBM, depth);
+    KnownBitsMapHelper(RHS, KBM, depth);
+    KnownBits LHSKnown = KBM[LHS];
+    KnownBits RHSKnown = KBM[RHS];
+    std::optional<bool> CompareRes;
+
+    switch (BExpr->getOpcode()) {
+    default:
+      KBM[Expr] = KnownBits(BitWidth);
+      return;
+    case MCBinaryExpr::Opcode::Add:
+      KBM[Expr] =
+          KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
+                                      /*NUW=*/false, LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::And:
+      KBM[Expr] = LHSKnown & RHSKnown;
+      return;
+    case MCBinaryExpr::Opcode::Div:
+      KBM[Expr] = KnownBits::sdiv(LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::EQ:
+      CompareRes = KnownBits::eq(LHSKnown, RHSKnown);
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    case MCBinaryExpr::Opcode::NE:
+      CompareRes = KnownBits::ne(LHSKnown, RHSKnown);
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    case MCBinaryExpr::Opcode::GT:
+      CompareRes = KnownBits::sgt(LHSKnown, RHSKnown);
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    case MCBinaryExpr::Opcode::GTE:
+      CompareRes = KnownBits::sge(LHSKnown, RHSKnown);
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    case MCBinaryExpr::Opcode::LAnd: {
+      std::optional<bool> LHSBool =
+          KnownBits::ne(LHSKnown, KnownBits::makeConstant(False));
+      std::optional<bool> RHSBool =
+          KnownBits::ne(RHSKnown, KnownBits::makeConstant(False));
+      if (LHSBool && RHSBool) {
+        CompareRes = *LHSBool && *RHSBool;
+      }
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    }
+    case MCBinaryExpr::Opcode::LOr: {
+      KnownBits Bits = LHSKnown | RHSKnown;
+      CompareRes = KnownBits::ne(Bits, KnownBits::makeConstant(False));
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    }
+    case MCBinaryExpr::Opcode::LT:
+      CompareRes = KnownBits::slt(LHSKnown, RHSKnown);
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    case MCBinaryExpr::Opcode::LTE:
+      CompareRes = KnownBits::sle(LHSKnown, RHSKnown);
+      KBM[Expr] = fromOptionalToKnownBits(CompareRes);
+      return;
+    case MCBinaryExpr::Opcode::Mod:
+      KBM[Expr] = KnownBits::srem(LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::Mul:
+      KBM[Expr] = KnownBits::mul(LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::Or:
+      KBM[Expr] = LHSKnown | RHSKnown;
+      return;
+    case MCBinaryExpr::Opcode::Shl:
+      KBM[Expr] = KnownBits::shl(LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::AShr:
+      KBM[Expr] = KnownBits::ashr(LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::LShr:
+      KBM[Expr] = KnownBits::lshr(LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::Sub:
+      KBM[Expr] =
+          KnownBits::computeForAddSub(/*Add=*/false, /*NSW=*/false,
+                                      /*NUW=*/false, LHSKnown, RHSKnown);
+      return;
+    case MCBinaryExpr::Opcode::Xor:
+      KBM[Expr] = LHSKnown ^ RHSKnown;
+      return;
+    }
+  }
+  case MCExpr::ExprKind::Constant: {
+    const MCConstantExpr *CE = cast<MCConstantExpr>(Expr);
+    APInt APValue(BitWidth, CE->getValue(), /*isSigned=*/true);
+    KBM[Expr] = KnownBits::makeConstant(APValue);
+    return;
+  }
+  case MCExpr::ExprKind::SymbolRef: {
+    const MCSymbolRefExpr *RExpr = cast<MCSymbolRefExpr>(Expr);
+    const MCSymbol &Sym = RExpr->getSymbol();
+    if (!Sym.isVariable()) {
+      KBM[Expr] = KnownBits(BitWidth);
+      return;
+    }
+
+    // Variable value retrieval is not for actual use but only for knownbits
+    // analysis.
+    KnownBitsMapHelper(Sym.getVariableValue(/*SetUsed=*/false), KBM, depth);
+    KBM[Expr] = KBM[Sym.getVariableValue(/*SetUsed=*/false)];
+    return;
+  }
+  case MCExpr::ExprKind::Unary: {
+    const MCUnaryExpr *UExpr = cast<MCUnaryExpr>(Expr);
+    KnownBitsMapHelper(UExpr->getSubExpr(), KBM, depth);
+    KnownBits KB = KBM[UExpr->getSubExpr()];
+
+    switch (UExpr->getOpcode()) {
+    default:
+      KBM[Expr] = KnownBits(BitWidth);
+      return;
+    case MCUnaryExpr::Opcode::Minus: {
+      KB.makeNegative();
+      KBM[Expr] = KB;
+      return;
+    }
+    case MCUnaryExpr::Opcode::Not: {
+      KnownBits AllOnes(BitWidth);
+      AllOnes.setAllOnes();
+      KBM[Expr] = KB ^ AllOnes;
+      return;
+    }
+    case MCUnaryExpr::Opcode::Plus: {
+      KB.makeNonNegative();
+      KBM[Expr] = KB;
+      return;
+    }
+    }
+  }
+  case MCExpr::ExprKind::Target: {
+    const AMDGPUMCExpr *AGVK = cast<AMDGPUMCExpr>(Expr);
+
+    switch (AGVK->getKind()) {
+    default:
+      KBM[Expr] = KnownBits(BitWidth);
+      return;
+    case AMDGPUMCExpr::VariantKind::AGVK_Or: {
+      KnownBitsMapHelper(AGVK->getSubExpr(0), KBM, depth);
+      KnownBits KB = KBM[AGVK->getSubExpr(0)];
+      for (const MCExpr *Arg : AGVK->getArgs()) {
+        KnownBitsMapHelper(Arg, KBM, depth);
+        KB |= KBM[Arg];
+      }
+      KBM[Expr] = KB;
+      return;
+    }
+    case AMDGPUMCExpr::VariantKind::AGVK_Max: {
+      KnownBitsMapHelper(AGVK->getSubExpr(0), KBM, depth);
+      KnownBits KB = KBM[AGVK->getSubExpr(0)];
+      for (const MCExpr *Arg : AGVK->getArgs()) {
+        KnownBitsMapHelper(Arg, KBM, depth);
+        KB = KnownBits::umax(KB, KBM[Arg]);
+      }
+      KBM[Expr] = KB;
+      return;
+    }
+    case AMDGPUMCExpr::VariantKind::AGVK_ExtraSGPRs:
+    case AMDGPUMCExpr::VariantKind::AGVK_TotalNumVGPRs:
+    case AMDGPUMCExpr::VariantKind::AGVK_AlignTo:
+    case AMDGPUMCExpr::VariantKind::AGVK_Occupancy: {
+      int64_t Val;
+      if (AGVK->evaluateAsAbsolute(Val)) {
+        APInt APValue(BitWidth, Val, /*isSigned=*/false);
+        KBM[Expr] = KnownBits::makeConstant(APValue);
+        return;
+      }
+      KBM[Expr] = KnownBits(BitWidth);
+      return;
+    }
+    }
+  }
+  }
+}
+
+static const MCExpr *TryFoldHelper(const MCExpr *Expr, KnownBitsMap &KBM,
+                                   MCContext &Ctx) {
+  if (!KBM.count(Expr))
+    return Expr;
+
+  auto valueCheckKnownBits = [](KnownBits &KB, unsigned Value) -> bool {
+    if (!KB.isConstant())
+      return false;
+
+    return Value == KB.getConstant();
+  };
+
+  if (Expr->getKind() == MCExpr::ExprKind::Constant)
+    return Expr;
+
+  // Resolving unary operations to constants may make the value more ambiguous.
+  // For example, `~62` becomes `-63`; however, to me it's more ambiguous if a
+  // bit mask value is represented through a negative number.
+  if (Expr->getKind() != MCExpr::ExprKind::Unary) {
+    if (KBM[Expr].isConstant()) {
+      APInt ConstVal = KBM[Expr].getConstant();
+      return MCConstantExpr::create(ConstVal.getSExtValue(), Ctx);
+    }
+
+    int64_t EvalValue;
+    if (Expr->evaluateAsAbsolute(EvalValue)) {
+      return MCConstantExpr::create(EvalValue, Ctx);
+    }
+  }
+
+  switch (Expr->getKind()) {
+  default:
+    return Expr;
+  case MCExpr::ExprKind::Binary: {
+    const MCBinaryExpr *BExpr = cast<MCBinaryExpr>(Expr);
+    const MCExpr *LHS = BExpr->getLHS();
+    const MCExpr *RHS = BExpr->getRHS();
+
+    switch (BExpr->getOpcode()) {
+    default:
+      return Expr;
+    case MCBinaryExpr::Opcode::Sub: {
+      if (valueCheckKnownBits(KBM[RHS], 0)) {
+        return TryFoldHelper(LHS, KBM, Ctx);
+      }
+      break;
+    }
+    case MCBinaryExpr::Opcode::Add:
+    case MCBinaryExpr::Opcode::Or: {
+      if (valueCheckKnownBits(KBM[LHS], 0)) {
+        return TryFoldHelper(RHS, KBM, Ctx);
+      }
+      if (valueCheckKnownBits(KBM[RHS], 0)) {
+        return TryFoldHelper(LHS, KBM, Ctx);
+      }
+      break;
+    }
+    case MCBinaryExpr::Opcode::Mul: {
+      if (valueCheckKnownBits(KBM[LHS], 1)) {
+        return TryFoldHelper(RHS, KBM, Ctx);
+      }
+      if (valueCheckKnownBits(KBM[RHS], 1)) {
+        return TryFoldHelper(LHS, KBM, Ctx);
+      }
+      break;
+    }
+    case MCBinaryExpr::Opcode::Shl:
+    case MCBinaryExpr::Opcode::AShr:
+    case MCBinaryExpr::Opcode::LShr: {
+      if (valueCheckKnownBits(KBM[RHS], 0)) {
----------------
arsenm wrote:

No braces around single line 

https://github.com/llvm/llvm-project/pull/95951


More information about the llvm-commits mailing list