[llvm] [RISCV][ISel] Combine vector fadd/fsub/fmul with fp extend. (PR #81248)

via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 18 21:41:12 PST 2024


================
@@ -13396,41 +13416,55 @@ struct NodeExtensionHelper {
   /// Helper function to get the narrow type for \p Root.
   /// The narrow type is the type of \p Root where we divided the size of each
   /// element by 2. E.g., if Root's type <2xi16> -> narrow type <2xi8>.
-  /// \pre The size of the type of the elements of Root must be a multiple of 2
-  /// and be greater than 16.
-  static MVT getNarrowType(const SDNode *Root) {
+  /// \pre Both the narrow type and the original type should be legal.
+  static MVT getNarrowType(const SDNode *Root, ExtKind SupportsExt) {
     MVT VT = Root->getSimpleValueType(0);
 
     // Determine the narrow size.
     unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
-    assert(NarrowSize >= 8 && "Trying to extend something we can't represent");
-    MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize),
-                                    VT.getVectorElementCount());
+
+    MVT EltVT = SupportsExt == ExtKind::FPExt
+                    ? MVT::getFloatingPointVT(NarrowSize)
+                    : MVT::getIntegerVT(NarrowSize);
+
+    assert(NarrowSize >= (SupportsExt == ExtKind::FPExt ? 16 : 8) &&
+           "Trying to extend something we can't represent");
+    MVT NarrowVT = MVT::getVectorVT(EltVT, VT.getVectorElementCount());
     return NarrowVT;
   }
 
-  /// Return the opcode required to materialize the folding of the sign
-  /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
+  /// Return the opcode required to materialize the folding for
   /// both operands for \p Opcode.
   /// Put differently, get the opcode to materialize:
-  /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
-  /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
+  /// - ExtKind::SExt: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
+  /// - ExtKind::ZExt: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
   /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
-  static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
+  static unsigned getSameExtensionOpcode(unsigned Opcode, ExtKind SupportsExt) {
     switch (Opcode) {
     case ISD::ADD:
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
-      return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
+      return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_VL
+                                          : RISCVISD::VWADDU_VL;
     case ISD::MUL:
     case RISCVISD::MUL_VL:
-      return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
+      return SupportsExt == ExtKind::SExt ? RISCVISD::VWMUL_VL
+                                          : RISCVISD::VWMULU_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
-      return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL;
+      return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_VL
+                                          : RISCVISD::VWSUBU_VL;
+    case RISCVISD::FADD_VL:
+    case RISCVISD::VFWADD_W_VL:
+      return RISCVISD::VFWADD_VL;
----------------
sun-jacobi wrote:

It might be safer to do so, thanks

https://github.com/llvm/llvm-project/pull/81248


More information about the llvm-commits mailing list