[llvm] [RISCV] Update implies for subtarget feature. (PR #75824)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 18 08:31:44 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Yeting Kuo (yetingk)

<details>
<summary>Changes</summary>

PR #<!-- -->75576 and #<!-- -->75735 update some implies in llvm/lib/Support/RISCVISAInfo.cpp, but both of them miss the subtarget feature part.
This patch still preserve predicate HasStdExtZfhOrZfhmin and HasStdExtZhinxOrZhinxmin, since they could make error message more readable. ( Users might not know that zfh implies zfhmin.)

---
Full diff: https://github.com/llvm/llvm-project/pull/75824.diff


5 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVFeatures.td (+11-11) 
- (modified) llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp (+1-2) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+16-16) 
- (modified) llvm/lib/Target/RISCV/RISCVSubtarget.h (+4-10) 
- (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h (+1-1) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td
index 294927aecb94b8..60bb3ad9531114 100644
--- a/llvm/lib/Target/RISCV/RISCVFeatures.td
+++ b/llvm/lib/Target/RISCV/RISCVFeatures.td
@@ -107,15 +107,15 @@ def HasStdExtZfhmin : Predicate<"Subtarget->hasStdExtZfhmin()">,
 def FeatureStdExtZfh
     : SubtargetFeature<"zfh", "HasStdExtZfh", "true",
                        "'Zfh' (Half-Precision Floating-Point)",
-                       [FeatureStdExtF]>;
+                       [FeatureStdExtZfhmin]>;
 def HasStdExtZfh : Predicate<"Subtarget->hasStdExtZfh()">,
                              AssemblerPredicate<(all_of FeatureStdExtZfh),
                              "'Zfh' (Half-Precision Floating-Point)">;
 def NoStdExtZfh : Predicate<"!Subtarget->hasStdExtZfh()">;
 
 def HasStdExtZfhOrZfhmin
-    : Predicate<"Subtarget->hasStdExtZfhOrZfhmin()">,
-                AssemblerPredicate<(any_of FeatureStdExtZfh, FeatureStdExtZfhmin),
+    : Predicate<"Subtarget->hasStdExtZfhmin()">,
+                AssemblerPredicate<(all_of FeatureStdExtZfhmin),
                                    "'Zfh' (Half-Precision Floating-Point) or "
                                    "'Zfhmin' (Half-Precision Floating-Point Minimal)">;
 
@@ -146,15 +146,15 @@ def HasStdExtZhinxmin : Predicate<"Subtarget->hasStdExtZhinxmin()">,
 def FeatureStdExtZhinx
     : SubtargetFeature<"zhinx", "HasStdExtZhinx", "true",
                        "'Zhinx' (Half Float in Integer)",
-                       [FeatureStdExtZfinx]>;
+                       [FeatureStdExtZhinxmin]>;
 def HasStdExtZhinx : Predicate<"Subtarget->hasStdExtZhinx()">,
                                AssemblerPredicate<(all_of FeatureStdExtZhinx),
                                "'Zhinx' (Half Float in Integer)">;
 def NoStdExtZhinx : Predicate<"!Subtarget->hasStdExtZhinx()">;
 
 def HasStdExtZhinxOrZhinxmin
-    : Predicate<"Subtarget->hasStdExtZhinx() || Subtarget->hasStdExtZhinxmin()">,
-                AssemblerPredicate<(any_of FeatureStdExtZhinx, FeatureStdExtZhinxmin),
+    : Predicate<"Subtarget->hasStdExtZhinxmin()">,
+                AssemblerPredicate<(all_of FeatureStdExtZhinxmin),
                                    "'Zhinx' (Half Float in Integer) or "
                                    "'Zhinxmin' (Half Float in Integer Minimal)">;
 
@@ -487,16 +487,16 @@ def HasStdExtZvfbfwma : Predicate<"Subtarget->hasStdExtZvfbfwma()">,
 
 def HasVInstructionsBF16 : Predicate<"Subtarget->hasVInstructionsBF16()">;
 
-def FeatureStdExtZvfh
-    : SubtargetFeature<"zvfh", "HasStdExtZvfh", "true",
-                       "'Zvfh' (Vector Half-Precision Floating-Point)",
-                       [FeatureStdExtZve32f, FeatureStdExtZfhmin]>;
-
 def FeatureStdExtZvfhmin
     : SubtargetFeature<"zvfhmin", "HasStdExtZvfhmin", "true",
                        "'Zvfhmin' (Vector Half-Precision Floating-Point Minimal)",
                        [FeatureStdExtZve32f]>;
 
+def FeatureStdExtZvfh
+    : SubtargetFeature<"zvfh", "HasStdExtZvfh", "true",
+                       "'Zvfh' (Vector Half-Precision Floating-Point)",
+                       [FeatureStdExtZvfhmin, FeatureStdExtZfhmin]>;
+
 def HasVInstructionsF16 : Predicate<"Subtarget->hasVInstructionsF16()">;
 
 def HasVInstructionsF16Minimal : Predicate<"Subtarget->hasVInstructionsF16Minimal()">,
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 09b3ab96974c4e..098a320c91533b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -915,8 +915,7 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
       Opc = RISCV::FMV_H_X;
       break;
     case MVT::f16:
-      Opc =
-          Subtarget->hasStdExtZhinxOrZhinxmin() ? RISCV::COPY : RISCV::FMV_H_X;
+      Opc = Subtarget->hasStdExtZhinxmin() ? RISCV::COPY : RISCV::FMV_H_X;
       break;
     case MVT::f32:
       Opc = Subtarget->hasStdExtZfinx() ? RISCV::COPY : RISCV::FMV_W_X;
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 03e994586d0c44..22c61eb20885b8 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -122,7 +122,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   if (Subtarget.is64Bit() && RV64LegalI32)
     addRegisterClass(MVT::i32, &RISCV::GPRRegClass);
 
-  if (Subtarget.hasStdExtZfhOrZfhmin())
+  if (Subtarget.hasStdExtZfhmin())
     addRegisterClass(MVT::f16, &RISCV::FPR16RegClass);
   if (Subtarget.hasStdExtZfbfmin())
     addRegisterClass(MVT::bf16, &RISCV::FPR16RegClass);
@@ -130,7 +130,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     addRegisterClass(MVT::f32, &RISCV::FPR32RegClass);
   if (Subtarget.hasStdExtD())
     addRegisterClass(MVT::f64, &RISCV::FPR64RegClass);
-  if (Subtarget.hasStdExtZhinxOrZhinxmin())
+  if (Subtarget.hasStdExtZhinxmin())
     addRegisterClass(MVT::f16, &RISCV::GPRF16RegClass);
   if (Subtarget.hasStdExtZfinx())
     addRegisterClass(MVT::f32, &RISCV::GPRF32RegClass);
@@ -439,7 +439,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
       ISD::FCEIL, ISD::FFLOOR, ISD::FTRUNC, ISD::FRINT, ISD::FROUND,
       ISD::FROUNDEVEN};
 
-  if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin())
+  if (Subtarget.hasStdExtZfhminOrZhinxmin())
     setOperationAction(ISD::BITCAST, MVT::i16, Custom);
 
   static const unsigned ZfhminZfbfminPromoteOps[] = {
@@ -469,7 +469,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
   }
 
-  if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) {
+  if (Subtarget.hasStdExtZfhminOrZhinxmin()) {
     if (Subtarget.hasStdExtZfhOrZhinx()) {
       setOperationAction(FPLegalNodeTypes, MVT::f16, Legal);
       setOperationAction(FPRndMode, MVT::f16,
@@ -1322,7 +1322,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
       // Custom-legalize bitcasts from fixed-length vectors to scalar types.
       setOperationAction(ISD::BITCAST, {MVT::i8, MVT::i16, MVT::i32, MVT::i64},
                          Custom);
-      if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin())
+      if (Subtarget.hasStdExtZfhminOrZhinxmin())
         setOperationAction(ISD::BITCAST, MVT::f16, Custom);
       if (Subtarget.hasStdExtFOrZfinx())
         setOperationAction(ISD::BITCAST, MVT::f32, Custom);
@@ -1388,7 +1388,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
   if (Subtarget.hasStdExtZbkb())
     setTargetDAGCombine(ISD::BITREVERSE);
-  if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin())
+  if (Subtarget.hasStdExtZfhminOrZhinxmin())
     setTargetDAGCombine(ISD::SIGN_EXTEND_INREG);
   if (Subtarget.hasStdExtFOrZfinx())
     setTargetDAGCombine({ISD::ZERO_EXTEND, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
@@ -2099,7 +2099,7 @@ bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
                                        bool ForCodeSize) const {
   bool IsLegalVT = false;
   if (VT == MVT::f16)
-    IsLegalVT = Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin();
+    IsLegalVT = Subtarget.hasStdExtZfhminOrZhinxmin();
   else if (VT == MVT::f32)
     IsLegalVT = Subtarget.hasStdExtFOrZfinx();
   else if (VT == MVT::f64)
@@ -2171,7 +2171,7 @@ MVT RISCVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
   // Use f32 to pass f16 if it is legal and Zfh/Zfhmin is not enabled.
   // We might still end up using a GPR but that will be decided based on ABI.
   if (VT == MVT::f16 && Subtarget.hasStdExtFOrZfinx() &&
-      !Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin())
+      !Subtarget.hasStdExtZfhminOrZhinxmin())
     return MVT::f32;
 
   MVT PartVT = TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
@@ -2188,7 +2188,7 @@ unsigned RISCVTargetLowering::getNumRegistersForCallingConv(LLVMContext &Context
   // Use f32 to pass f16 if it is legal and Zfh/Zfhmin is not enabled.
   // We might still end up using a GPR but that will be decided based on ABI.
   if (VT == MVT::f16 && Subtarget.hasStdExtFOrZfinx() &&
-      !Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin())
+      !Subtarget.hasStdExtZfhminOrZhinxmin())
     return 1;
 
   return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
@@ -5761,7 +5761,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     EVT Op0VT = Op0.getValueType();
     MVT XLenVT = Subtarget.getXLenVT();
     if (VT == MVT::f16 && Op0VT == MVT::i16 &&
-        Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) {
+        Subtarget.hasStdExtZfhminOrZhinxmin()) {
       SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Op0);
       SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, NewOp0);
       return FPConv;
@@ -11527,11 +11527,11 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
     EVT Op0VT = Op0.getValueType();
     MVT XLenVT = Subtarget.getXLenVT();
     if (VT == MVT::i16 && Op0VT == MVT::f16 &&
-        Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) {
+        Subtarget.hasStdExtZfhminOrZhinxmin()) {
       SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0);
       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv));
     } else if (VT == MVT::i16 && Op0VT == MVT::bf16 &&
-        Subtarget.hasStdExtZfbfmin()) {
+               Subtarget.hasStdExtZfbfmin()) {
       SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0);
       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv));
     } else if (VT == MVT::i32 && Op0VT == MVT::f32 && Subtarget.is64Bit() &&
@@ -18632,7 +18632,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
       // TODO: Support fixed vectors up to XLen for P extension?
       if (VT.isVector())
         break;
-      if (VT == MVT::f16 && Subtarget.hasStdExtZhinxOrZhinxmin())
+      if (VT == MVT::f16 && Subtarget.hasStdExtZhinxmin())
         return std::make_pair(0U, &RISCV::GPRF16RegClass);
       if (VT == MVT::f32 && Subtarget.hasStdExtZfinx())
         return std::make_pair(0U, &RISCV::GPRF32RegClass);
@@ -18640,7 +18640,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
         return std::make_pair(0U, &RISCV::GPRPF64RegClass);
       return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
     case 'f':
-      if (Subtarget.hasStdExtZfhOrZfhmin() && VT == MVT::f16)
+      if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16)
         return std::make_pair(0U, &RISCV::FPR16RegClass);
       if (Subtarget.hasStdExtF() && VT == MVT::f32)
         return std::make_pair(0U, &RISCV::FPR32RegClass);
@@ -18753,7 +18753,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
       }
       if (VT == MVT::f32 || VT == MVT::Other)
         return std::make_pair(FReg, &RISCV::FPR32RegClass);
-      if (Subtarget.hasStdExtZfhOrZfhmin() && VT == MVT::f16) {
+      if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16) {
         unsigned RegNo = FReg - RISCV::F0_F;
         unsigned HReg = RISCV::F0_H + RegNo;
         return std::make_pair(HReg, &RISCV::FPR16RegClass);
@@ -19100,7 +19100,7 @@ bool RISCVTargetLowering::shouldConvertFpToSat(unsigned Op, EVT FPVT,
 
   switch (FPVT.getSimpleVT().SimpleTy) {
   case MVT::f16:
-    return Subtarget.hasStdExtZfhOrZfhmin();
+    return Subtarget.hasStdExtZfhmin();
   case MVT::f32:
     return Subtarget.hasStdExtF();
   case MVT::f64:
diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h
index 23d56cfa6e4e52..7540218633bfcb 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.h
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h
@@ -143,16 +143,12 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo {
   bool hasStdExtZvl() const { return ZvlLen != 0; }
   bool hasStdExtFOrZfinx() const { return HasStdExtF || HasStdExtZfinx; }
   bool hasStdExtDOrZdinx() const { return HasStdExtD || HasStdExtZdinx; }
-  bool hasStdExtZfhOrZfhmin() const { return HasStdExtZfh || HasStdExtZfhmin; }
   bool hasStdExtZfhOrZhinx() const { return HasStdExtZfh || HasStdExtZhinx; }
-  bool hasStdExtZhinxOrZhinxmin() const {
-    return HasStdExtZhinx || HasStdExtZhinxmin;
-  }
-  bool hasStdExtZfhOrZfhminOrZhinxOrZhinxmin() const {
-    return hasStdExtZfhOrZfhmin() || hasStdExtZhinxOrZhinxmin();
+  bool hasStdExtZfhminOrZhinxmin() const {
+    return HasStdExtZfhmin || HasStdExtZhinxmin;
   }
   bool hasHalfFPLoadStoreMove() const {
-    return hasStdExtZfhOrZfhmin() || HasStdExtZfbfmin;
+    return HasStdExtZfhmin || HasStdExtZfbfmin;
   }
   bool is64Bit() const { return IsRV64; }
   MVT getXLenVT() const {
@@ -201,9 +197,7 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo {
   // Vector codegen related methods.
   bool hasVInstructions() const { return HasStdExtZve32x; }
   bool hasVInstructionsI64() const { return HasStdExtZve64x; }
-  bool hasVInstructionsF16Minimal() const {
-    return HasStdExtZvfhmin || HasStdExtZvfh;
-  }
+  bool hasVInstructionsF16Minimal() const { return HasStdExtZvfhmin; }
   bool hasVInstructionsF16() const { return HasStdExtZvfh; }
   bool hasVInstructionsBF16() const { return HasStdExtZvfbfmin; }
   bool hasVInstructionsF32() const { return HasStdExtZve32f; }
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index efc8350064a6e3..96ecc771863e56 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -334,7 +334,7 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
       return RISCVRegisterClass::GPRRC;
 
     Type *ScalarTy = Ty->getScalarType();
-    if ((ScalarTy->isHalfTy() && ST->hasStdExtZfhOrZfhmin()) ||
+    if ((ScalarTy->isHalfTy() && ST->hasStdExtZfhmin()) ||
         (ScalarTy->isFloatTy() && ST->hasStdExtF()) ||
         (ScalarTy->isDoubleTy() && ST->hasStdExtD())) {
       return RISCVRegisterClass::FPRRC;

``````````

</details>


https://github.com/llvm/llvm-project/pull/75824


More information about the llvm-commits mailing list