[llvm] 2a827e4 - [RISCV] Fix crash a vector add has a 4x sext and zext operand.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 31 15:17:46 PDT 2022


Author: Craig Topper
Date: 2022-10-31T15:10:27-07:00
New Revision: 2a827e4a988b614bc6f70abe00308ceeb50dcd0a

URL: https://github.com/llvm/llvm-project/commit/2a827e4a988b614bc6f70abe00308ceeb50dcd0a
DIFF: https://github.com/llvm/llvm-project/commit/2a827e4a988b614bc6f70abe00308ceeb50dcd0a.diff

LOG: [RISCV] Fix crash a vector add has a 4x sext and zext operand.

We can narrow one of the extends and keep the other original by
using a vwaddu.wv or vwadd.wv.

We were previously forgetting to keep the original operand and
instead took the source of its extend. This resulted in a type
mismatch that later failed with an impossible physical register copy.

To fix this I've refactored some code to maintain information about
whether the source needs to be extended at all for longer so we could
use it in materialize.

Differential Revision: https://reviews.llvm.org/D137106

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5ecf46806ceb6..9db4a3fe32fb8 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8450,24 +8450,29 @@ struct NodeExtensionHelper {
     return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
   }
 
-  /// Get or create a value that can feed \p Root with the given \p ExtOpc.
-  /// If \p ExtOpc is None, this returns the source of this operand.
+  /// Get or create a value that can feed \p Root with the given extension \p
+  /// SExt. If \p SExt is None, this returns the source of this operand.
   /// \see ::getSource().
   SDValue getOrCreateExtendedOp(const SDNode *Root, SelectionDAG &DAG,
-                                Optional<unsigned> ExtOpc) const {
+                                Optional<bool> SExt) const {
+    if (!SExt.has_value())
+      return OrigOperand;
+
+    MVT NarrowVT = getNarrowType(Root);
+
     SDValue Source = getSource();
-    if (!ExtOpc)
+    if (Source.getValueType() == NarrowVT)
       return Source;
 
-    MVT NarrowVT = getNarrowType(Root);
+    unsigned ExtOpc = *SExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
+
     // If we need an extension, we should be changing the type.
-    assert(Source.getValueType() != NarrowVT && "Needless extension");
     SDLoc DL(Root);
     auto [Mask, VL] = getMaskAndVL(Root);
     switch (OrigOperand.getOpcode()) {
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
-      return DAG.getNode(*ExtOpc, DL, NarrowVT, Source, Mask, VL);
+      return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
     case RISCVISD::VMV_V_X_VL:
       return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
                          DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
@@ -8712,13 +8717,10 @@ struct NodeExtensionHelper {
 struct CombineResult {
   /// Opcode to be generated when materializing the combine.
   unsigned TargetOpcode;
-  /// Extension opcode to be applied to the source of LHS when materializing
-  /// TargetOpcode.
-  /// \see NodeExtensionHelper::getSource().
-  Optional<unsigned> LHSExtOpc;
-  /// Extension opcode to be applied to the source of RHS when materializing
-  /// TargetOpcode.
-  Optional<unsigned> RHSExtOpc;
+  // No value means no extension is needed. If extension is needed, the value
+  // indicates if it needs to be sign extended.
+  Optional<bool> SExtLHS;
+  Optional<bool> SExtRHS;
   /// Root of the combine.
   SDNode *Root;
   /// LHS of the TargetOpcode.
@@ -8729,13 +8731,8 @@ struct CombineResult {
   CombineResult(unsigned TargetOpcode, SDNode *Root,
                 const NodeExtensionHelper &LHS, Optional<bool> SExtLHS,
                 const NodeExtensionHelper &RHS, Optional<bool> SExtRHS)
-      : TargetOpcode(TargetOpcode), Root(Root), LHS(LHS), RHS(RHS) {
-    MVT NarrowVT = NodeExtensionHelper::getNarrowType(Root);
-    if (SExtLHS && LHS.getSource().getValueType() != NarrowVT)
-      LHSExtOpc = *SExtLHS ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
-    if (SExtRHS && RHS.getSource().getValueType() != NarrowVT)
-      RHSExtOpc = *SExtRHS ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
-  }
+      : TargetOpcode(TargetOpcode), SExtLHS(SExtLHS), SExtRHS(SExtRHS),
+        Root(Root), LHS(LHS), RHS(RHS) {}
 
   /// Return a value that uses TargetOpcode and that can be used to replace
   /// Root.
@@ -8745,8 +8742,8 @@ struct CombineResult {
     std::tie(Mask, VL) = NodeExtensionHelper::getMaskAndVL(Root);
     Merge = Root->getOperand(2);
     return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
-                       LHS.getOrCreateExtendedOp(Root, DAG, LHSExtOpc),
-                       RHS.getOrCreateExtendedOp(Root, DAG, RHSExtOpc), Merge,
+                       LHS.getOrCreateExtendedOp(Root, DAG, SExtLHS),
+                       RHS.getOrCreateExtendedOp(Root, DAG, SExtRHS), Merge,
                        Mask, VL);
   }
 };

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
index 787251565f282..976273863be8d 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
@@ -859,3 +859,19 @@ define <2 x i64> @vwaddu_vx_v2i64_i64(<2 x i32>* %x, i64* %y) nounwind {
   %g = add <2 x i64> %e, %f
   ret <2 x i64> %g
 }
+
+define <4 x i64> @crash(<4 x i16> %x, <4 x i16> %y) {
+; CHECK-LABEL: crash:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e64, m2, ta, ma
+; CHECK-NEXT:    vsext.vf4 v10, v8
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vzext.vf2 v8, v9
+; CHECK-NEXT:    vwaddu.wv v10, v10, v8
+; CHECK-NEXT:    vmv2r.v v8, v10
+; CHECK-NEXT:    ret
+  %a = sext <4 x i16> %x to <4 x i64>
+  %b = zext <4 x i16> %y to <4 x i64>
+  %c = add <4 x i64> %a, %b
+  ret <4 x i64> %c
+}


        


More information about the llvm-commits mailing list