[llvm] [SandboxIR] Implement ConstantFP (PR #106648)

via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 30 11:06:04 PDT 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/106648

>From be73ad06b12dbfe779267c37c7734cccd8580e41 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Fri, 16 Aug 2024 13:16:23 -0700
Subject: [PATCH] [SandboxIR] Implement ConstantFP

This patch implements sandboxir::ConstantFP mirroring llvm::ConstantFP.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h       |  94 ++++++++++-
 .../llvm/SandboxIR/SandboxIRValues.def        |   1 +
 llvm/include/llvm/SandboxIR/Type.h            |   3 +-
 llvm/lib/SandboxIR/SandboxIR.cpp              |  52 ++++++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    | 155 ++++++++++++++++++
 5 files changed, 303 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 0f7752eda6d66f..2ed7243fa612f4 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -113,6 +113,7 @@ namespace sandboxir {
 
 class BasicBlock;
 class ConstantInt;
+class ConstantFP;
 class Context;
 class Function;
 class Instruction;
@@ -597,6 +598,94 @@ class ConstantInt : public Constant {
 #endif
 };
 
+// TODO: This should inherit from ConstantData.
+class ConstantFP final : public Constant {
+  ConstantFP(llvm::ConstantFP *C, Context &Ctx)
+      : Constant(ClassID::ConstantFP, C, Ctx) {}
+  friend class Context; // For constructor.
+
+public:
+  /// This returns a ConstantFP, or a vector containing a splat of a ConstantFP,
+  /// for the specified value in the specified type. This should only be used
+  /// for simple constant values like 2.0/1.0 etc, that are known-valid both as
+  /// host double and as the target format.
+  static Constant *get(Type *Ty, double V);
+
+  /// If Ty is a vector type, return a Constant with a splat of the given
+  /// value. Otherwise return a ConstantFP for the given value.
+  static Constant *get(Type *Ty, const APFloat &V);
+
+  static Constant *get(Type *Ty, StringRef Str);
+
+  static ConstantFP *get(const APFloat &V, Context &Ctx);
+
+  static Constant *getNaN(Type *Ty, bool Negative = false,
+                          uint64_t Payload = 0);
+  static Constant *getQNaN(Type *Ty, bool Negative = false,
+                           APInt *Payload = nullptr);
+  static Constant *getSNaN(Type *Ty, bool Negative = false,
+                           APInt *Payload = nullptr);
+  static Constant *getZero(Type *Ty, bool Negative = false);
+
+  static Constant *getNegativeZero(Type *Ty);
+  static Constant *getInfinity(Type *Ty, bool Negative = false);
+
+  /// Return true if Ty is big enough to represent V.
+  static bool isValueValidForType(Type *Ty, const APFloat &V);
+
+  inline const APFloat &getValueAPF() const {
+    return cast<llvm::ConstantFP>(Val)->getValueAPF();
+  }
+  inline const APFloat &getValue() const {
+    return cast<llvm::ConstantFP>(Val)->getValue();
+  }
+
+  /// Return true if the value is positive or negative zero.
+  bool isZero() const { return cast<llvm::ConstantFP>(Val)->isZero(); }
+
+  /// Return true if the sign bit is set.
+  bool isNegative() const { return cast<llvm::ConstantFP>(Val)->isNegative(); }
+
+  /// Return true if the value is infinity
+  bool isInfinity() const { return cast<llvm::ConstantFP>(Val)->isInfinity(); }
+
+  /// Return true if the value is a NaN.
+  bool isNaN() const { return cast<llvm::ConstantFP>(Val)->isNaN(); }
+
+  /// We don't rely on operator== working on double values, as it returns true
+  /// for things that are clearly not equal, like -0.0 and 0.0.
+  /// As such, this method can be used to do an exact bit-for-bit comparison of
+  /// two floating point values.  The version with a double operand is retained
+  /// because it's so convenient to write isExactlyValue(2.0), but please use
+  /// it only for simple constants.
+  bool isExactlyValue(const APFloat &V) const {
+    return cast<llvm::ConstantFP>(Val)->isExactlyValue(V);
+  }
+
+  bool isExactlyValue(double V) const {
+    return cast<llvm::ConstantFP>(Val)->isExactlyValue(V);
+  }
+
+  /// For isa/dyn_cast.
+  static bool classof(const sandboxir::Value *From) {
+    return From->getSubclassID() == ClassID::ConstantFP;
+  }
+
+  // TODO: Better name: getOperandNo(const Use&). Should be private.
+  unsigned getUseOperandNo(const Use &Use) const final {
+    llvm_unreachable("ConstantFP has no operands!");
+  }
+#ifndef NDEBUG
+  void verify() const override {
+    assert(isa<llvm::ConstantFP>(Val) && "Expected a ConstantFP!");
+  }
+  void dumpOS(raw_ostream &OS) const override {
+    dumpCommonPrefix(OS);
+    dumpCommonSuffix(OS);
+  }
+#endif
+};
+
 /// Iterator for `Instruction`s in a `BasicBlock.
 /// \Returns an sandboxir::Instruction & when derereferenced.
 class BBIterator {
@@ -3156,7 +3245,10 @@ class Context {
   Constant *getOrCreateConstant(llvm::Constant *LLVMC) {
     return cast<Constant>(getOrCreateValueInternal(LLVMC, 0));
   }
-  friend class ConstantInt; // For getOrCreateConstant().
+  // Friends for getOrCreateConstant().
+#define DEF_CONST(ID, CLASS) friend class CLASS;
+#include "llvm/SandboxIR/SandboxIRValues.def"
+
   /// Create a sandboxir::BasicBlock for an existing LLVM IR \p BB. This will
   /// also create all contents of the block.
   BasicBlock *createBasicBlock(llvm::BasicBlock *BB);
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index d29fc3b5e95871..2fc24ed71c4cf6 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -26,6 +26,7 @@ DEF_USER(User, User)
 DEF_VALUE(Block, BasicBlock)
 DEF_CONST(Constant, Constant)
 DEF_CONST(ConstantInt, ConstantInt)
+DEF_CONST(ConstantFP, ConstantFP)
 
 #ifndef DEF_INSTR
 #define DEF_INSTR(ID, OPCODE, CLASS)
diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 4588cd2f738876..89e787f5f5d4b2 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -27,6 +27,7 @@ class PointerType;
 class VectorType;
 class FunctionType;
 #define DEF_INSTR(ID, OPCODE, CLASS) class CLASS;
+#define DEF_CONST(ID, CLASS) class CLASS;
 #include "llvm/SandboxIR/SandboxIRValues.def"
 
 /// Just like llvm::Type these are immutable, unique, never get freed and can
@@ -42,7 +43,7 @@ class Type {
   friend class ConstantInt;  // For LLVMTy.
   // Friend all instruction classes because `create()` functions use LLVMTy.
 #define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
-  // TODO: Friend DEF_CONST()
+#define DEF_CONST(ID, CLASS) friend class CLASS;
 #include "llvm/SandboxIR/SandboxIRValues.def"
   Context &Ctx;
 
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index bf224b73f3bad2..6bdc580f751d18 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2248,6 +2248,54 @@ ConstantInt *ConstantInt::get(Type *Ty, uint64_t V, bool IsSigned) {
   return cast<ConstantInt>(Ty->getContext().getOrCreateConstant(LLVMC));
 }
 
+Constant *ConstantFP::get(Type *Ty, double V) {
+  auto *LLVMC = llvm::ConstantFP::get(Ty->LLVMTy, V);
+  return Ty->getContext().getOrCreateConstant(LLVMC);
+}
+
+Constant *ConstantFP::get(Type *Ty, const APFloat &V) {
+  auto *LLVMC = llvm::ConstantFP::get(Ty->LLVMTy, V);
+  return Ty->getContext().getOrCreateConstant(LLVMC);
+}
+
+Constant *ConstantFP::get(Type *Ty, StringRef Str) {
+  auto *LLVMC = llvm::ConstantFP::get(Ty->LLVMTy, Str);
+  return Ty->getContext().getOrCreateConstant(LLVMC);
+}
+
+ConstantFP *ConstantFP::get(const APFloat &V, Context &Ctx) {
+  auto *LLVMC = llvm::ConstantFP::get(Ctx.LLVMCtx, V);
+  return cast<ConstantFP>(Ctx.getOrCreateConstant(LLVMC));
+}
+
+Constant *ConstantFP::getNaN(Type *Ty, bool Negative, uint64_t Payload) {
+  auto *LLVMC = llvm::ConstantFP::getNaN(Ty->LLVMTy, Negative, Payload);
+  return cast<Constant>(Ty->getContext().getOrCreateConstant(LLVMC));
+}
+Constant *ConstantFP::getQNaN(Type *Ty, bool Negative, APInt *Payload) {
+  auto *LLVMC = llvm::ConstantFP::getQNaN(Ty->LLVMTy, Negative, Payload);
+  return cast<Constant>(Ty->getContext().getOrCreateConstant(LLVMC));
+}
+Constant *ConstantFP::getSNaN(Type *Ty, bool Negative, APInt *Payload) {
+  auto *LLVMC = llvm::ConstantFP::getSNaN(Ty->LLVMTy, Negative, Payload);
+  return cast<Constant>(Ty->getContext().getOrCreateConstant(LLVMC));
+}
+Constant *ConstantFP::getZero(Type *Ty, bool Negative) {
+  auto *LLVMC = llvm::ConstantFP::getZero(Ty->LLVMTy, Negative);
+  return cast<Constant>(Ty->getContext().getOrCreateConstant(LLVMC));
+}
+Constant *ConstantFP::getNegativeZero(Type *Ty) {
+  auto *LLVMC = llvm::ConstantFP::getNegativeZero(Ty->LLVMTy);
+  return cast<Constant>(Ty->getContext().getOrCreateConstant(LLVMC));
+}
+Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
+  auto *LLVMC = llvm::ConstantFP::getInfinity(Ty->LLVMTy, Negative);
+  return cast<Constant>(Ty->getContext().getOrCreateConstant(LLVMC));
+}
+bool ConstantFP::isValueValidForType(Type *Ty, const APFloat &V) {
+  return llvm::ConstantFP::isValueValidForType(Ty->LLVMTy, V);
+}
+
 FunctionType *Function::getFunctionType() const {
   return cast<FunctionType>(
       Ctx.getType(cast<llvm::Function>(Val)->getFunctionType()));
@@ -2339,6 +2387,10 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
       It->second = std::unique_ptr<ConstantInt>(new ConstantInt(CI, *this));
       return It->second.get();
     }
+    if (auto *CF = dyn_cast<llvm::ConstantFP>(C)) {
+      It->second = std::unique_ptr<ConstantFP>(new ConstantFP(CF, *this));
+      return It->second.get();
+    }
     if (auto *F = dyn_cast<llvm::Function>(LLVMV))
       It->second = std::unique_ptr<Function>(new Function(F, *this));
     else
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index c543846eb2686e..01fe21eb5cfa43 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -130,6 +130,161 @@ define void @foo(i32 %v0) {
   EXPECT_NE(FortyThree, FortyTwo);
 }
 
+TEST_F(SandboxIRTest, ConstantFP) {
+  parseIR(C, R"IR(
+define void @foo(float %v0, double %v1) {
+  %fadd0 = fadd float %v0, 42.0
+  %fadd1 = fadd double %v1, 43.0
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto &BB = *F.begin();
+  auto It = BB.begin();
+  auto *FAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *FAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *FortyTwo = cast<sandboxir::ConstantFP>(FAdd0->getOperand(1));
+  [[maybe_unused]] auto *FortyThree =
+      cast<sandboxir::ConstantFP>(FAdd1->getOperand(1));
+
+  auto *FloatTy = sandboxir::Type::getFloatTy(Ctx);
+  auto *DoubleTy = sandboxir::Type::getDoubleTy(Ctx);
+  auto *LLVMFloatTy = Type::getFloatTy(C);
+  auto *LLVMDoubleTy = Type::getDoubleTy(C);
+  // Check that creating an identical constant gives us the same object.
+  auto *NewFortyTwo = sandboxir::ConstantFP::get(FloatTy, 42.0);
+  EXPECT_EQ(NewFortyTwo, FortyTwo);
+  // Check get(Type, double).
+  auto *FortyFour =
+      cast<sandboxir::ConstantFP>(sandboxir::ConstantFP::get(FloatTy, 44.0));
+  auto *LLVMFortyFour =
+      cast<llvm::ConstantFP>(llvm::ConstantFP::get(LLVMFloatTy, 44.0));
+  EXPECT_NE(FortyFour, FortyTwo);
+  EXPECT_EQ(FortyFour, Ctx.getValue(LLVMFortyFour));
+  // Check get(Type, APFloat).
+  auto *FortyFive = cast<sandboxir::ConstantFP>(
+      sandboxir::ConstantFP::get(DoubleTy, APFloat(45.0)));
+  auto *LLVMFortyFive = cast<llvm::ConstantFP>(
+      llvm::ConstantFP::get(LLVMDoubleTy, APFloat(45.0)));
+  EXPECT_EQ(FortyFive, Ctx.getValue(LLVMFortyFive));
+  // Check get(Type, StringRef).
+  auto *FortySix = sandboxir::ConstantFP::get(FloatTy, "46.0");
+  EXPECT_EQ(FortySix, Ctx.getValue(llvm::ConstantFP::get(LLVMFloatTy, "46.0")));
+  // Check get(APFloat).
+  auto *FortySeven = sandboxir::ConstantFP::get(APFloat(47.0), Ctx);
+  EXPECT_EQ(FortySeven, Ctx.getValue(llvm::ConstantFP::get(C, APFloat(47.0))));
+  // Check getNaN().
+  {
+    auto *NaN = sandboxir::ConstantFP::getNaN(FloatTy);
+    EXPECT_EQ(NaN, Ctx.getValue(llvm::ConstantFP::getNaN(LLVMFloatTy)));
+  }
+  {
+    auto *NaN = sandboxir::ConstantFP::getNaN(FloatTy, /*Negative=*/true);
+    EXPECT_EQ(NaN, Ctx.getValue(llvm::ConstantFP::getNaN(LLVMFloatTy,
+                                                         /*Negative=*/true)));
+  }
+  {
+    auto *NaN = sandboxir::ConstantFP::getNaN(FloatTy, /*Negative=*/true,
+                                              /*Payload=*/1);
+    EXPECT_EQ(NaN, Ctx.getValue(llvm::ConstantFP::getNaN(
+                       LLVMFloatTy, /*Negative=*/true, /*Payload=*/1)));
+  }
+  // Check getQNaN().
+  {
+    auto *QNaN = sandboxir::ConstantFP::getQNaN(FloatTy);
+    EXPECT_EQ(QNaN, Ctx.getValue(llvm::ConstantFP::getQNaN(LLVMFloatTy)));
+  }
+  {
+    auto *QNaN = sandboxir::ConstantFP::getQNaN(FloatTy, /*Negative=*/true);
+    EXPECT_EQ(QNaN, Ctx.getValue(llvm::ConstantFP::getQNaN(LLVMFloatTy,
+                                                           /*Negative=*/true)));
+  }
+  {
+    APInt Payload(1, 1);
+    auto *QNaN =
+        sandboxir::ConstantFP::getQNaN(FloatTy, /*Negative=*/true, &Payload);
+    EXPECT_EQ(QNaN, Ctx.getValue(llvm::ConstantFP::getQNaN(
+                        LLVMFloatTy, /*Negative=*/true, &Payload)));
+  }
+  // Check getSNaN().
+  {
+    auto *SNaN = sandboxir::ConstantFP::getSNaN(FloatTy);
+    EXPECT_EQ(SNaN, Ctx.getValue(llvm::ConstantFP::getSNaN(LLVMFloatTy)));
+  }
+  {
+    auto *SNaN = sandboxir::ConstantFP::getSNaN(FloatTy, /*Negative=*/true);
+    EXPECT_EQ(SNaN, Ctx.getValue(llvm::ConstantFP::getSNaN(LLVMFloatTy,
+                                                           /*Negative=*/true)));
+  }
+  {
+    APInt Payload(1, 1);
+    auto *SNaN =
+        sandboxir::ConstantFP::getSNaN(FloatTy, /*Negative=*/true, &Payload);
+    EXPECT_EQ(SNaN, Ctx.getValue(llvm::ConstantFP::getSNaN(
+                        LLVMFloatTy, /*Negative=*/true, &Payload)));
+  }
+
+  // Check getZero().
+  {
+    auto *Zero = sandboxir::ConstantFP::getZero(FloatTy);
+    EXPECT_EQ(Zero, Ctx.getValue(llvm::ConstantFP::getZero(LLVMFloatTy)));
+  }
+  {
+    auto *Zero = sandboxir::ConstantFP::getZero(FloatTy, /*Negative=*/true);
+    EXPECT_EQ(Zero, Ctx.getValue(llvm::ConstantFP::getZero(LLVMFloatTy,
+                                                           /*Negative=*/true)));
+  }
+
+  // Check getNegativeZero().
+  auto *NegZero = cast<sandboxir::ConstantFP>(
+      sandboxir::ConstantFP::getNegativeZero(FloatTy));
+  EXPECT_EQ(NegZero,
+            Ctx.getValue(llvm::ConstantFP::getNegativeZero(LLVMFloatTy)));
+
+  // Check getInfinity().
+  {
+    auto *Inf = sandboxir::ConstantFP::getInfinity(FloatTy);
+    EXPECT_EQ(Inf, Ctx.getValue(llvm::ConstantFP::getInfinity(LLVMFloatTy)));
+  }
+  {
+    auto *Inf = sandboxir::ConstantFP::getInfinity(FloatTy, /*Negative=*/true);
+    EXPECT_EQ(Inf, Ctx.getValue(llvm::ConstantFP::getInfinity(
+                       LLVMFloatTy, /*Negative=*/true)));
+  }
+
+  // Check isValueValidForType().
+  APFloat V(1.1);
+  EXPECT_EQ(sandboxir::ConstantFP::isValueValidForType(FloatTy, V),
+            llvm::ConstantFP::isValueValidForType(LLVMFloatTy, V));
+  // Check getValueAPF().
+  EXPECT_EQ(FortyFour->getValueAPF(), LLVMFortyFour->getValueAPF());
+  // Check getValue().
+  EXPECT_EQ(FortyFour->getValue(), LLVMFortyFour->getValue());
+  // Check isZero().
+  EXPECT_EQ(FortyFour->isZero(), LLVMFortyFour->isZero());
+  EXPECT_TRUE(sandboxir::ConstantFP::getZero(FloatTy));
+  EXPECT_TRUE(sandboxir::ConstantFP::getZero(FloatTy, /*Negative=*/true));
+  // Check isNegative().
+  EXPECT_TRUE(cast<sandboxir::ConstantFP>(
+                  sandboxir::ConstantFP::getZero(FloatTy, /*Negative=*/true))
+                  ->isNegative());
+  // Check isInfinity().
+  EXPECT_TRUE(
+      cast<sandboxir::ConstantFP>(sandboxir::ConstantFP::getInfinity(FloatTy))
+          ->isInfinity());
+  // Check isNaN().
+  EXPECT_TRUE(
+      cast<sandboxir::ConstantFP>(sandboxir::ConstantFP::getNaN(FloatTy))
+          ->isNaN());
+  // Check isExactlyValue(APFloat).
+  EXPECT_TRUE(NegZero->isExactlyValue(NegZero->getValueAPF()));
+  // Check isExactlyValue(double).
+  EXPECT_TRUE(NegZero->isExactlyValue(-0.0));
+}
+
 TEST_F(SandboxIRTest, Use) {
   parseIR(C, R"IR(
 define i32 @foo(i32 %v0, i32 %v1) {



More information about the llvm-commits mailing list