[llvm] [RISCV] Fix mgather -> riscv.masked.strided.load combine not extending indices (PR #82506)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 21 08:32:31 PST 2024


https://github.com/lukel97 created https://github.com/llvm/llvm-project/pull/82506

This fixes the miscompile reported in #82430 by telling isSimpleVIDSequence to
sign extending to XLen instead of the type of the indices, since the "sequence"
of indices generated by a strided load will be at XLen.

This was the simplest way I could think of of getting isSimpleVIDSequence to
treat the indexes as if they were zero extended to XLenVT.

Another way we could do this is by refactoring out the "get constant integers"
part from isSimpleVIDSequence and handle them as APInts so we can separately
zero extend it.


>From 0330d1825001799846ec09e69d76a800fb3b9a53 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 22 Feb 2024 00:23:03 +0800
Subject: [PATCH] [RISCV] Fix mgather -> riscv.masked.strided.load combine not
 extending indices

This fixes the miscompile reported in #82430 by telling isSimpleVIDSequence to
sign extending to XLen instead of the type of the indices, since the "sequence"
of indices generated by a strided load will be at XLen.

This was the simplest way I could think of of getting isSimpleVIDSequence to
treat the indexes as if they were zero extended to XLenVT.

Another way we could do this is by refactoring out the "get constant integers"
part from isSimpleVIDSequence and handle them as APInts so we can separately
zero extend it.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 20 +++++++++++--------
 .../RISCV/rvv/fixed-vectors-masked-gather.ll  |  8 ++------
 2 files changed, 14 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f7275eb7c77bb3..75be97ff32bbe5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -3240,7 +3240,8 @@ static std::optional<uint64_t> getExactInteger(const APFloat &APF,
 // Note that this method will also match potentially unappealing index
 // sequences, like <i32 0, i32 50939494>, however it is left to the caller to
 // determine whether this is worth generating code for.
-static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
+static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op,
+                                                      unsigned EltSizeInBits) {
   unsigned NumElts = Op.getNumOperands();
   assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unexpected BUILD_VECTOR");
   bool IsInteger = Op.getValueType().isInteger();
@@ -3248,7 +3249,7 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
   std::optional<unsigned> SeqStepDenom;
   std::optional<int64_t> SeqStepNum, SeqAddend;
   std::optional<std::pair<uint64_t, unsigned>> PrevElt;
-  unsigned EltSizeInBits = Op.getValueType().getScalarSizeInBits();
+  assert(EltSizeInBits >= Op.getValueType().getScalarSizeInBits());
   for (unsigned Idx = 0; Idx < NumElts; Idx++) {
     // Assume undef elements match the sequence; we just have to be careful
     // when interpolating across them.
@@ -3261,14 +3262,14 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
       if (!isa<ConstantSDNode>(Op.getOperand(Idx)))
         return std::nullopt;
       Val = Op.getConstantOperandVal(Idx) &
-            maskTrailingOnes<uint64_t>(EltSizeInBits);
+            maskTrailingOnes<uint64_t>(Op.getScalarValueSizeInBits());
     } else {
       // The BUILD_VECTOR must be all constants.
       if (!isa<ConstantFPSDNode>(Op.getOperand(Idx)))
         return std::nullopt;
       if (auto ExactInteger = getExactInteger(
               cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
-              EltSizeInBits))
+              Op.getScalarValueSizeInBits()))
         Val = *ExactInteger;
       else
         return std::nullopt;
@@ -3324,11 +3325,11 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
     uint64_t Val;
     if (IsInteger) {
       Val = Op.getConstantOperandVal(Idx) &
-            maskTrailingOnes<uint64_t>(EltSizeInBits);
+            maskTrailingOnes<uint64_t>(Op.getScalarValueSizeInBits());
     } else {
       Val = *getExactInteger(
           cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
-          EltSizeInBits);
+          Op.getScalarValueSizeInBits());
     }
     uint64_t ExpectedVal =
         (int64_t)(Idx * (uint64_t)*SeqStepNum) / *SeqStepDenom;
@@ -3598,7 +3599,7 @@ static SDValue lowerBuildVectorOfConstants(SDValue Op, SelectionDAG &DAG,
   // Try and match index sequences, which we can lower to the vid instruction
   // with optional modifications. An all-undef vector is matched by
   // getSplatValue, above.
-  if (auto SimpleVID = isSimpleVIDSequence(Op)) {
+  if (auto SimpleVID = isSimpleVIDSequence(Op, Op.getScalarValueSizeInBits())) {
     int64_t StepNumerator = SimpleVID->StepNumerator;
     unsigned StepDenominator = SimpleVID->StepDenominator;
     int64_t Addend = SimpleVID->Addend;
@@ -15978,7 +15979,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
 
     if (Index.getOpcode() == ISD::BUILD_VECTOR &&
         MGN->getExtensionType() == ISD::NON_EXTLOAD && isTypeLegal(VT)) {
-      if (std::optional<VIDSequence> SimpleVID = isSimpleVIDSequence(Index);
+      // The sequence will be XLenVT, not the type of Index. Tell
+      // isSimpleVIDSequence this so we avoid overflow.
+      if (std::optional<VIDSequence> SimpleVID =
+              isSimpleVIDSequence(Index, Subtarget.getXLen());
           SimpleVID && SimpleVID->StepDenominator == 1) {
         const int64_t StepNumerator = SimpleVID->StepNumerator;
         const int64_t Addend = SimpleVID->Addend;
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
index 1724b48dd6be9e..2628672ee6b722 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
@@ -15086,14 +15086,10 @@ define <32 x i64> @mgather_strided_split(ptr %base) {
   ret <32 x i64> %x
 }
 
-; FIXME: This is a miscompile triggered by the mgather ->
-; riscv.masked.strided.load combine. In order for it to trigger we need either a
-; strided gather that RISCVGatherScatterLowering doesn't pick up, or a new
-; strided gather generated by the widening sew combine.
 define <4 x i32> @masked_gather_widen_sew_negative_stride(ptr %base) {
 ; RV32V-LABEL: masked_gather_widen_sew_negative_stride:
 ; RV32V:       # %bb.0:
-; RV32V-NEXT:    addi a0, a0, -128
+; RV32V-NEXT:    addi a0, a0, 128
 ; RV32V-NEXT:    li a1, -128
 ; RV32V-NEXT:    vsetivli zero, 2, e64, m1, ta, ma
 ; RV32V-NEXT:    vlse64.v v8, (a0), a1
@@ -15101,7 +15097,7 @@ define <4 x i32> @masked_gather_widen_sew_negative_stride(ptr %base) {
 ;
 ; RV64V-LABEL: masked_gather_widen_sew_negative_stride:
 ; RV64V:       # %bb.0:
-; RV64V-NEXT:    addi a0, a0, -128
+; RV64V-NEXT:    addi a0, a0, 128
 ; RV64V-NEXT:    li a1, -128
 ; RV64V-NEXT:    vsetivli zero, 2, e64, m1, ta, ma
 ; RV64V-NEXT:    vlse64.v v8, (a0), a1



More information about the llvm-commits mailing list