[llvm-branch-commits] [llvm] release/18.x: [RISCV] Fix mgather -> riscv.masked.strided.load combine not extending indices (#82506) (PR #82572)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Feb 21 19:58:05 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: None (llvmbot)

<details>
<summary>Changes</summary>

Backport 2cd59bdc891ab59a1abfe5205feb45791a530a47 11d115d0569b212dfeb7fe6485be48070e068e19 815644b4dd882ade2e5649d4f97c3dd6f7aea200

Requested by: @<!-- -->patrick-rivos

---
Full diff: https://github.com/llvm/llvm-project/pull/82572.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+12-8) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll (+43) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 80447d03c000b8..a0cec426002b6f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -3192,7 +3192,8 @@ static std::optional<uint64_t> getExactInteger(const APFloat &APF,
 // Note that this method will also match potentially unappealing index
 // sequences, like <i32 0, i32 50939494>, however it is left to the caller to
 // determine whether this is worth generating code for.
-static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
+static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op,
+                                                      unsigned EltSizeInBits) {
   unsigned NumElts = Op.getNumOperands();
   assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unexpected BUILD_VECTOR");
   bool IsInteger = Op.getValueType().isInteger();
@@ -3200,7 +3201,7 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
   std::optional<unsigned> SeqStepDenom;
   std::optional<int64_t> SeqStepNum, SeqAddend;
   std::optional<std::pair<uint64_t, unsigned>> PrevElt;
-  unsigned EltSizeInBits = Op.getValueType().getScalarSizeInBits();
+  assert(EltSizeInBits >= Op.getValueType().getScalarSizeInBits());
   for (unsigned Idx = 0; Idx < NumElts; Idx++) {
     // Assume undef elements match the sequence; we just have to be careful
     // when interpolating across them.
@@ -3213,14 +3214,14 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
       if (!isa<ConstantSDNode>(Op.getOperand(Idx)))
         return std::nullopt;
       Val = Op.getConstantOperandVal(Idx) &
-            maskTrailingOnes<uint64_t>(EltSizeInBits);
+            maskTrailingOnes<uint64_t>(Op.getScalarValueSizeInBits());
     } else {
       // The BUILD_VECTOR must be all constants.
       if (!isa<ConstantFPSDNode>(Op.getOperand(Idx)))
         return std::nullopt;
       if (auto ExactInteger = getExactInteger(
               cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
-              EltSizeInBits))
+              Op.getScalarValueSizeInBits()))
         Val = *ExactInteger;
       else
         return std::nullopt;
@@ -3276,11 +3277,11 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
     uint64_t Val;
     if (IsInteger) {
       Val = Op.getConstantOperandVal(Idx) &
-            maskTrailingOnes<uint64_t>(EltSizeInBits);
+            maskTrailingOnes<uint64_t>(Op.getScalarValueSizeInBits());
     } else {
       Val = *getExactInteger(
           cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
-          EltSizeInBits);
+          Op.getScalarValueSizeInBits());
     }
     uint64_t ExpectedVal =
         (int64_t)(Idx * (uint64_t)*SeqStepNum) / *SeqStepDenom;
@@ -3550,7 +3551,7 @@ static SDValue lowerBuildVectorOfConstants(SDValue Op, SelectionDAG &DAG,
   // Try and match index sequences, which we can lower to the vid instruction
   // with optional modifications. An all-undef vector is matched by
   // getSplatValue, above.
-  if (auto SimpleVID = isSimpleVIDSequence(Op)) {
+  if (auto SimpleVID = isSimpleVIDSequence(Op, Op.getScalarValueSizeInBits())) {
     int64_t StepNumerator = SimpleVID->StepNumerator;
     unsigned StepDenominator = SimpleVID->StepDenominator;
     int64_t Addend = SimpleVID->Addend;
@@ -15562,7 +15563,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
 
     if (Index.getOpcode() == ISD::BUILD_VECTOR &&
         MGN->getExtensionType() == ISD::NON_EXTLOAD && isTypeLegal(VT)) {
-      if (std::optional<VIDSequence> SimpleVID = isSimpleVIDSequence(Index);
+      // The sequence will be XLenVT, not the type of Index. Tell
+      // isSimpleVIDSequence this so we avoid overflow.
+      if (std::optional<VIDSequence> SimpleVID =
+              isSimpleVIDSequence(Index, Subtarget.getXLen());
           SimpleVID && SimpleVID->StepDenominator == 1) {
         const int64_t StepNumerator = SimpleVID->StepNumerator;
         const int64_t Addend = SimpleVID->Addend;
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
index 890707c6337fad..88c299a19fb4e8 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
@@ -15086,5 +15086,48 @@ define <32 x i64> @mgather_strided_split(ptr %base) {
   ret <32 x i64> %x
 }
 
+define <4 x i32> @masked_gather_widen_sew_negative_stride(ptr %base) {
+; RV32V-LABEL: masked_gather_widen_sew_negative_stride:
+; RV32V:       # %bb.0:
+; RV32V-NEXT:    addi a0, a0, 136
+; RV32V-NEXT:    li a1, -136
+; RV32V-NEXT:    vsetivli zero, 2, e64, m1, ta, ma
+; RV32V-NEXT:    vlse64.v v8, (a0), a1
+; RV32V-NEXT:    ret
+;
+; RV64V-LABEL: masked_gather_widen_sew_negative_stride:
+; RV64V:       # %bb.0:
+; RV64V-NEXT:    addi a0, a0, 136
+; RV64V-NEXT:    li a1, -136
+; RV64V-NEXT:    vsetivli zero, 2, e64, m1, ta, ma
+; RV64V-NEXT:    vlse64.v v8, (a0), a1
+; RV64V-NEXT:    ret
+;
+; RV32ZVE32F-LABEL: masked_gather_widen_sew_negative_stride:
+; RV32ZVE32F:       # %bb.0:
+; RV32ZVE32F-NEXT:    lui a1, 16393
+; RV32ZVE32F-NEXT:    addi a1, a1, -888
+; RV32ZVE32F-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; RV32ZVE32F-NEXT:    vmv.s.x v9, a1
+; RV32ZVE32F-NEXT:    vluxei8.v v8, (a0), v9
+; RV32ZVE32F-NEXT:    ret
+;
+; RV64ZVE32F-LABEL: masked_gather_widen_sew_negative_stride:
+; RV64ZVE32F:       # %bb.0:
+; RV64ZVE32F-NEXT:    addi a1, a0, 136
+; RV64ZVE32F-NEXT:    lw a2, 140(a0)
+; RV64ZVE32F-NEXT:    lw a3, 0(a0)
+; RV64ZVE32F-NEXT:    lw a0, 4(a0)
+; RV64ZVE32F-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; RV64ZVE32F-NEXT:    vlse32.v v8, (a1), zero
+; RV64ZVE32F-NEXT:    vslide1down.vx v8, v8, a2
+; RV64ZVE32F-NEXT:    vslide1down.vx v8, v8, a3
+; RV64ZVE32F-NEXT:    vslide1down.vx v8, v8, a0
+; RV64ZVE32F-NEXT:    ret
+  %ptrs = getelementptr i32, ptr %base, <4 x i64> <i64 34, i64 35, i64 0, i64 1>
+  %x = call <4 x i32> @llvm.masked.gather.v4i32.v32p0(<4 x ptr> %ptrs, i32 8, <4 x i1> shufflevector(<4 x i1> insertelement(<4 x i1> poison, i1 true, i32 0), <4 x i1> poison, <4 x i32> zeroinitializer), <4 x i32> poison)
+  ret <4 x i32> %x
+}
+
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; RV64: {{.*}}

``````````

</details>


https://github.com/llvm/llvm-project/pull/82572


More information about the llvm-branch-commits mailing list