[llvm] a1e39f3 - [RISCV] Merge getLoadFP*Imm into a single function.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 14 13:11:20 PDT 2023


Author: Craig Topper
Date: 2023-03-14T13:11:11-07:00
New Revision: a1e39f35c50a5292f464cf59aa63b8ed0107dfce

URL: https://github.com/llvm/llvm-project/commit/a1e39f35c50a5292f464cf59aa63b8ed0107dfce
DIFF: https://github.com/llvm/llvm-project/commit/a1e39f35c50a5292f464cf59aa63b8ed0107dfce.diff

LOG: [RISCV] Merge getLoadFP*Imm into a single function.

We currently have 3 functions and 3 lookup tables. This was the
most expediant and obvious way to fix several bugs.

This patch uses a single function and single lookup
table. It uses APFloat::convert to convert from the half or double
to single precision. If the conversion doesn't have any errors or
lose any information we use the f32 table to finish the lookup.

Reviewed By: asb

Differential Revision: https://reviews.llvm.org/D145897

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
    llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp
    llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfoZfa.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
index 783c5eb7630b7..e8edc96951229 100644
--- a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
+++ b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp
@@ -496,7 +496,7 @@ struct RISCVOperand final : public MCParsedAsmOperand {
       return isUImm5();
     if (Kind != KindTy::FPImmediate)
       return false;
-    int Idx = RISCVLoadFPImm::getLoadFP64Imm(
+    int Idx = RISCVLoadFPImm::getLoadFPImm(
         APFloat(APFloat::IEEEdouble(), APInt(64, getFPConst())));
     // Don't allow decimal version of the minimum value. It is a 
diff erent value
     // for each supported data type.
@@ -985,7 +985,7 @@ struct RISCVOperand final : public MCParsedAsmOperand {
       return;
     }
 
-    int Imm = RISCVLoadFPImm::getLoadFP64Imm(
+    int Imm = RISCVLoadFPImm::getLoadFPImm(
         APFloat(APFloat::IEEEdouble(), APInt(64, getFPConst())));
     Inst.addOperand(MCOperand::createImm(Imm));
   }

diff  --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp
index 3f9003c05e0f5..98c8e883e5960 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp
@@ -214,87 +214,37 @@ bool RISCVRVC::uncompress(MCInst &OutInst, const MCInst &MI,
   return uncompressInst(OutInst, MI, STI);
 }
 
-// Lookup table for fli.h for entries 1-31. Entry 0(-1.0) is handled separately.
-// NOTE: The exponent for entry 1 is larger than entry 2 and 3 because they
-// are denormals.
-static constexpr std::pair<uint8_t, uint8_t> LoadFP16ImmArr[] = {
-    {0b00001, 0b00}, {0b00000, 0b01}, {0b00000, 0b10}, {0b00111, 0b00},
-    {0b01000, 0b00}, {0b01011, 0b00}, {0b01100, 0b00}, {0b01101, 0b00},
-    {0b01101, 0b01}, {0b01101, 0b10}, {0b01101, 0b11}, {0b01110, 0b00},
-    {0b01110, 0b01}, {0b01110, 0b10}, {0b01110, 0b11}, {0b01111, 0b00},
-    {0b01111, 0b01}, {0b01111, 0b10}, {0b01111, 0b11}, {0b10000, 0b00},
-    {0b10000, 0b01}, {0b10000, 0b10}, {0b10001, 0b00}, {0b10010, 0b00},
-    {0b10011, 0b00}, {0b10110, 0b00}, {0b10111, 0b00}, {0b11110, 0b00},
-    {0b11111, 0b00}, {0b11111, 0b00}, {0b11111, 0b10},
-};
-
-// Lookup table for fli.s for entries 1-31.
+// Lookup table for fli.s for entries 2-31.
 static constexpr std::pair<uint8_t, uint8_t> LoadFP32ImmArr[] = {
-    {0b00000001, 0b00}, {0b01101111, 0b00}, {0b01110000, 0b00},
-    {0b01110111, 0b00}, {0b01111000, 0b00}, {0b01111011, 0b00},
-    {0b01111100, 0b00}, {0b01111101, 0b00}, {0b01111101, 0b01},
-    {0b01111101, 0b10}, {0b01111101, 0b11}, {0b01111110, 0b00},
-    {0b01111110, 0b01}, {0b01111110, 0b10}, {0b01111110, 0b11},
-    {0b01111111, 0b00}, {0b01111111, 0b01}, {0b01111111, 0b10},
-    {0b01111111, 0b11}, {0b10000000, 0b00}, {0b10000000, 0b01},
-    {0b10000000, 0b10}, {0b10000001, 0b00}, {0b10000010, 0b00},
-    {0b10000011, 0b00}, {0b10000110, 0b00}, {0b10000111, 0b00},
-    {0b10001110, 0b00}, {0b10001111, 0b00}, {0b11111111, 0b00},
-    {0b11111111, 0b10},
-};
-
-// Lookup table for fli.d for entries 1-31.
-static constexpr std::pair<uint16_t, uint8_t> LoadFP64ImmArr[] = {
-    {0b00000000001, 0b00}, {0b01111101111, 0b00}, {0b01111110000, 0b00},
-    {0b01111110111, 0b00}, {0b01111111000, 0b00}, {0b01111111011, 0b00},
-    {0b01111111100, 0b00}, {0b01111111101, 0b00}, {0b01111111101, 0b01},
-    {0b01111111101, 0b10}, {0b01111111101, 0b11}, {0b01111111110, 0b00},
-    {0b01111111110, 0b01}, {0b01111111110, 0b10}, {0b01111111110, 0b11},
-    {0b01111111111, 0b00}, {0b01111111111, 0b01}, {0b01111111111, 0b10},
-    {0b01111111111, 0b11}, {0b10000000000, 0b00}, {0b10000000000, 0b01},
-    {0b10000000000, 0b10}, {0b10000000001, 0b00}, {0b10000000010, 0b00},
-    {0b10000000011, 0b00}, {0b10000000110, 0b00}, {0b10000000111, 0b00},
-    {0b10000001110, 0b00}, {0b10000001111, 0b00}, {0b11111111111, 0b00},
-    {0b11111111111, 0b10},
+    {0b01101111, 0b00}, {0b01110000, 0b00}, {0b01110111, 0b00},
+    {0b01111000, 0b00}, {0b01111011, 0b00}, {0b01111100, 0b00},
+    {0b01111101, 0b00}, {0b01111101, 0b01}, {0b01111101, 0b10},
+    {0b01111101, 0b11}, {0b01111110, 0b00}, {0b01111110, 0b01},
+    {0b01111110, 0b10}, {0b01111110, 0b11}, {0b01111111, 0b00},
+    {0b01111111, 0b01}, {0b01111111, 0b10}, {0b01111111, 0b11},
+    {0b10000000, 0b00}, {0b10000000, 0b01}, {0b10000000, 0b10},
+    {0b10000001, 0b00}, {0b10000010, 0b00}, {0b10000011, 0b00},
+    {0b10000110, 0b00}, {0b10000111, 0b00}, {0b10001110, 0b00},
+    {0b10001111, 0b00}, {0b11111111, 0b00}, {0b11111111, 0b10},
 };
 
-int RISCVLoadFPImm::getLoadFP16Imm(const APFloat &FPImm) {
-  assert(&FPImm.getSemantics() == &APFloat::IEEEhalf());
-
-  APInt Imm = FPImm.bitcastToAPInt();
-
-  if (Imm.extractBitsAsZExtValue(8, 0) != 0)
+int RISCVLoadFPImm::getLoadFPImm(APFloat FPImm) {
+  assert((&FPImm.getSemantics() == &APFloat::IEEEsingle() ||
+          &FPImm.getSemantics() == &APFloat::IEEEdouble() ||
+          &FPImm.getSemantics() == &APFloat::IEEEhalf()) &&
+         "Unexpected semantics");
+
+  // Handle the minimum normalized value which is 
diff erent for each type.
+  if (FPImm.isSmallestNormalized())
+    return 1;
+
+  // Convert to single precision to use its lookup table.
+  bool LosesInfo;
+  APFloat::opStatus Status = FPImm.convert(
+      APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &LosesInfo);
+  if (Status != APFloat::opOK || LosesInfo)
     return -1;
 
-  bool Sign = Imm.extractBitsAsZExtValue(1, 15);
-  uint8_t Mantissa = Imm.extractBitsAsZExtValue(2, 8);
-  uint8_t Exp = Imm.extractBitsAsZExtValue(5, 10);
-
-  // The array isn't sorted so we must use std::find unlike fp32 and fp64.
-  auto EMI = llvm::find(LoadFP16ImmArr, std::make_pair(Exp, Mantissa));
-  if (EMI == std::end(LoadFP16ImmArr))
-    return -1;
-
-  // Table doesn't have entry 0.
-  int Entry = std::distance(std::begin(LoadFP16ImmArr), EMI) + 1;
-
-  // The only legal negative value is -1.0(entry 0). 1.0 is entry 16.
-  if (Sign) {
-    if (Entry == 16)
-      return 0;
-    return false;
-  }
-
-  // Entry 29 and 30 are both infinity, but 30 is the real infinity.
-  if (Entry == 29)
-    ++Entry;
-
-  return Entry;
-}
-
-int RISCVLoadFPImm::getLoadFP32Imm(const APFloat &FPImm) {
-  assert(&FPImm.getSemantics() == &APFloat::IEEEsingle());
-
   APInt Imm = FPImm.bitcastToAPInt();
 
   if (Imm.extractBitsAsZExtValue(21, 0) != 0)
@@ -309,38 +259,8 @@ int RISCVLoadFPImm::getLoadFP32Imm(const APFloat &FPImm) {
       EMI->second != Mantissa)
     return -1;
 
-  // Table doesn't have entry 0.
-  int Entry = std::distance(std::begin(LoadFP32ImmArr), EMI) + 1;
-
-  // The only legal negative value is -1.0(entry 0). 1.0 is entry 16.
-  if (Sign) {
-    if (Entry == 16)
-      return 0;
-    return false;
-  }
-
-  return Entry;
-}
-
-int RISCVLoadFPImm::getLoadFP64Imm(const APFloat &FPImm) {
-  assert(&FPImm.getSemantics() == &APFloat::IEEEdouble());
-
-  APInt Imm = FPImm.bitcastToAPInt();
-
-  if (Imm.extractBitsAsZExtValue(50, 0) != 0)
-    return -1;
-
-  bool Sign = Imm.extractBitsAsZExtValue(1, 63);
-  uint8_t Mantissa = Imm.extractBitsAsZExtValue(2, 50);
-  uint16_t Exp = Imm.extractBitsAsZExtValue(11, 52);
-
-  auto EMI = llvm::lower_bound(LoadFP64ImmArr, std::make_pair(Exp, Mantissa));
-  if (EMI == std::end(LoadFP64ImmArr) || EMI->first != Exp ||
-      EMI->second != Mantissa)
-    return -1;
-
-  // Table doesn't have entry 0.
-  int Entry = std::distance(std::begin(LoadFP64ImmArr), EMI) + 1;
+  // Table doesn't have entry 0 or 1.
+  int Entry = std::distance(std::begin(LoadFP32ImmArr), EMI) + 2;
 
   // The only legal negative value is -1.0(entry 0). 1.0 is entry 16.
   if (Sign) {
@@ -362,8 +282,8 @@ float RISCVLoadFPImm::getFPImm(unsigned Imm) {
     Imm = 16;
   }
 
-  uint32_t Exp = LoadFP32ImmArr[Imm - 1].first;
-  uint32_t Mantissa = LoadFP32ImmArr[Imm - 1].second;
+  uint32_t Exp = LoadFP32ImmArr[Imm - 2].first;
+  uint32_t Mantissa = LoadFP32ImmArr[Imm - 2].second;
 
   uint32_t I = Sign << 31 | Exp << 23 | Mantissa << 21;
   return bit_cast<float>(I);

diff  --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
index cdb972b463f30..70fdc0e4ff120 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
@@ -349,20 +349,10 @@ inline static bool isValidRoundingMode(unsigned Mode) {
 namespace RISCVLoadFPImm {
 float getFPImm(unsigned Imm);
 
-/// getLoadFP32Imm - Return a 5-bit binary encoding of the 32-bit
-/// floating-point immediate value. If the value cannot be represented as a
-/// 5-bit binary encoding, then return -1.
-int getLoadFP32Imm(const APFloat &FPImm);
-
-/// getLoadFP64Imm - Return a 5-bit binary encoding of the 64-bit
-/// floating-point immediate value. If the value cannot be represented as a
-/// 5-bit binary encoding, then return -1.
-int getLoadFP64Imm(const APFloat &FPImm);
-
-/// getLoadFP16Imm - Return a 5-bit binary encoding of the 16-bit
-/// floating-point immediate value. If the value cannot be represented as a
-/// 5-bit binary encoding, then return -1.
-int getLoadFP16Imm(const APFloat &FPImm);
+/// getLoadFPImm - Return a 5-bit binary encoding of the floating-point
+/// immediate value. If the value cannot be represented as a 5-bit binary
+/// encoding, then return -1.
+int getLoadFPImm(APFloat FPImm);
 } // namespace RISCVLoadFPImm
 
 namespace RISCVSysReg {

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 8f68dab55284c..c3c63ff45e9b9 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1540,21 +1540,20 @@ bool RISCVTargetLowering::isOffsetFoldingLegal(
 }
 
 bool RISCVTargetLowering::isLegalZfaFPImm(const APFloat &Imm, EVT VT) const {
-  if (!Subtarget.hasStdExtZfa() || !VT.isSimple())
+  if (!Subtarget.hasStdExtZfa())
     return false;
 
-  switch (VT.getSimpleVT().SimpleTy) {
-  default:
-    return false;
-  case MVT::f16:
-    return (Subtarget.hasStdExtZfh() || Subtarget.hasStdExtZvfh()) &&
-           RISCVLoadFPImm::getLoadFP16Imm(Imm) != -1;
-  case MVT::f32:
-    return RISCVLoadFPImm::getLoadFP32Imm(Imm) != -1;
-  case MVT::f64:
+  bool IsSupportedVT = false;
+  if (VT == MVT::f16) {
+    IsSupportedVT = Subtarget.hasStdExtZfh() || Subtarget.hasStdExtZvfh();
+  } else if (VT == MVT::f32) {
+    IsSupportedVT = true;
+  } else if (VT == MVT::f64) {
     assert(Subtarget.hasStdExtD() && "Expect D extension");
-    return RISCVLoadFPImm::getLoadFP64Imm(Imm) != -1;
+    IsSupportedVT = true;
   }
+
+  return IsSupportedVT && RISCVLoadFPImm::getLoadFPImm(Imm) != -1;
 }
 
 bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZfa.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZfa.td
index 96982a435af90..bac642218fe2b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZfa.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZfa.td
@@ -179,20 +179,13 @@ def : InstAlias<"fgeq.h $rd, $rs, $rt",
 // Codegen patterns
 //===----------------------------------------------------------------------===//
 
-def fp32imm_to_loadfpimm : SDNodeXForm<fpimm, [{
-  return CurDAG->getTargetConstant(RISCVLoadFPImm::getLoadFP32Imm(N->getValueAPF()),
+def fpimm_to_loadfpimm : SDNodeXForm<fpimm, [{
+  return CurDAG->getTargetConstant(RISCVLoadFPImm::getLoadFPImm(N->getValueAPF()),
                                    SDLoc(N), Subtarget->getXLenVT());}]>;
 
-def fp64imm_to_loadfpimm : SDNodeXForm<fpimm, [{
-  return CurDAG->getTargetConstant(RISCVLoadFPImm::getLoadFP64Imm(N->getValueAPF()),
-                                   SDLoc(N), Subtarget->getXLenVT());}]>;
-
-def fp16imm_to_loadfpimm : SDNodeXForm<fpimm, [{
-  return CurDAG->getTargetConstant(RISCVLoadFPImm::getLoadFP16Imm(N->getValueAPF()),
-                                   SDLoc(N), Subtarget->getXLenVT());}]>;
 
 let Predicates = [HasStdExtZfa] in {
-def : Pat<(f32 fpimm:$imm), (FLI_S (fp32imm_to_loadfpimm fpimm:$imm))>;
+def : Pat<(f32 fpimm:$imm), (FLI_S (fpimm_to_loadfpimm fpimm:$imm))>;
 
 def: PatFprFpr<fminimum, FMINM_S, FPR32>;
 def: PatFprFpr<fmaximum, FMAXM_S, FPR32>;
@@ -216,7 +209,7 @@ def: PatSetCC<FPR32, strict_fsetcc, SETOLE, FLEQ_S>;
 } // Predicates = [HasStdExtZfa]
 
 let Predicates = [HasStdExtZfa, HasStdExtD] in {
-def : Pat<(f64 fpimm:$imm), (FLI_D (fp64imm_to_loadfpimm fpimm:$imm))>;
+def : Pat<(f64 fpimm:$imm), (FLI_D (fpimm_to_loadfpimm fpimm:$imm))>;
 
 def: PatFprFpr<fminimum, FMINM_D, FPR64>;
 def: PatFprFpr<fmaximum, FMAXM_D, FPR64>;
@@ -246,7 +239,7 @@ def : Pat<(RISCVBuildPairF64 GPR:$rs1, GPR:$rs2),
 }
 
 let Predicates = [HasStdExtZfa, HasStdExtZfh] in {
-def : Pat<(f16 fpimm:$imm), (FLI_H (fp16imm_to_loadfpimm fpimm:$imm))>;
+def : Pat<(f16 fpimm:$imm), (FLI_H (fpimm_to_loadfpimm fpimm:$imm))>;
 
 def: PatFprFpr<fminimum, FMINM_H, FPR16>;
 def: PatFprFpr<fmaximum, FMAXM_H, FPR16>;


        


More information about the llvm-commits mailing list