[llvm] [RISCV] Don't look through EXTRACT_ELEMENT in lowerScalarInsert if the element types are different. (PR #78668)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 18 20:42:19 PST 2024


https://github.com/topperc created https://github.com/llvm/llvm-project/pull/78668

If the element type of we're extracting doesn't match the type we're inserting into, we can't directly insert or extract the subvector.

>From fe76c3ed7ef9c1b8bf80a301c246b4fc809265eb Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Thu, 18 Jan 2024 20:38:56 -0800
Subject: [PATCH] [RISCV] Don't look through EXTRACT_ELEMENT in
 lowerScalarInsert if the element types are different.

If the element type of we're extracting doesn't match the type
we're inserting into, we can't directly insert or extract the subvector.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 28 +++++++++++--------
 .../CodeGen/RISCV/rvv/fold-binary-reduce.ll   | 23 +++++++++++++++
 2 files changed, 39 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 019ea93aac9dc66..b41e2f40dc72f01 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -4040,19 +4040,23 @@ static SDValue lowerScalarInsert(SDValue Scalar, SDValue VL, MVT VT,
   if (Scalar.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
       isNullConstant(Scalar.getOperand(1))) {
     SDValue ExtractedVal = Scalar.getOperand(0);
-    MVT ExtractedVT = ExtractedVal.getSimpleValueType();
-    MVT ExtractedContainerVT = ExtractedVT;
-    if (ExtractedContainerVT.isFixedLengthVector()) {
-      ExtractedContainerVT = getContainerForFixedLengthVector(
-          DAG, ExtractedContainerVT, Subtarget);
-      ExtractedVal = convertToScalableVector(ExtractedContainerVT, ExtractedVal,
-                                             DAG, Subtarget);
-    }
-    if (ExtractedContainerVT.bitsLE(VT))
-      return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Passthru, ExtractedVal,
+    // The element types must be the same.
+    if (ExtractedVal.getValueType().getVectorElementType() ==
+        VT.getVectorElementType()) {
+      MVT ExtractedVT = ExtractedVal.getSimpleValueType();
+      MVT ExtractedContainerVT = ExtractedVT;
+      if (ExtractedContainerVT.isFixedLengthVector()) {
+        ExtractedContainerVT = getContainerForFixedLengthVector(
+            DAG, ExtractedContainerVT, Subtarget);
+        ExtractedVal = convertToScalableVector(ExtractedContainerVT,
+                                               ExtractedVal, DAG, Subtarget);
+      }
+      if (ExtractedContainerVT.bitsLE(VT))
+        return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Passthru,
+                           ExtractedVal, DAG.getConstant(0, DL, XLenVT));
+      return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ExtractedVal,
                          DAG.getConstant(0, DL, XLenVT));
-    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ExtractedVal,
-                       DAG.getConstant(0, DL, XLenVT));
+    }
   }
 
 
diff --git a/llvm/test/CodeGen/RISCV/rvv/fold-binary-reduce.ll b/llvm/test/CodeGen/RISCV/rvv/fold-binary-reduce.ll
index 8fd8c2548ff34d4..351c0bab9dca893 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fold-binary-reduce.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fold-binary-reduce.ll
@@ -343,3 +343,26 @@ declare i64 @llvm.smax.i64(i64, i64)
 declare i64 @llvm.smin.i64(i64, i64)
 declare float @llvm.maxnum.f32(float ,float)
 declare float @llvm.minnum.f32(float ,float)
+
+define void @crash(<2 x i32> %0) {
+; CHECK-LABEL: crash:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vmv.v.i v8, 0
+; CHECK-NEXT:    vmv.s.x v9, a0
+; CHECK-NEXT:    vredsum.vs v8, v8, v9
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    sb a0, 0(zero)
+; CHECK-NEXT:    ret
+entry:
+  %1 = extractelement <2 x i32> %0, i64 0
+  %2 = tail call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> zeroinitializer)
+  %3 = zext i16 %2 to i32
+  %op.rdx = add i32 %1, %3
+  %conv18.us = trunc i32 %op.rdx to i8
+  store i8 %conv18.us, ptr null, align 1
+  ret void
+}
+declare i16 @llvm.vector.reduce.add.v4i16(<4 x i16>)



More information about the llvm-commits mailing list