[llvm] [RISCV] Handle FP riscv_masked_strided_load with 0 stride. (PR #84576)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 8 14:20:00 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

<details>
<summary>Changes</summary>

Previously, we tried to create an integer extending load. We need to a non-extending FP load instead.

Fixes #<!-- -->84541.

---
Full diff: https://github.com/llvm/llvm-project/pull/84576.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+8-3) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-store-asm.ll (+39) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 9b748cdcf74511..fa37306a49990f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -9080,15 +9080,20 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
     SDValue Result, Chain;
 
     // TODO: We restrict this to unmasked loads currently in consideration of
-    // the complexity of hanlding all falses masks.
-    if (IsUnmasked && isNullConstant(Stride)) {
-      MVT ScalarVT = ContainerVT.getVectorElementType();
+    // the complexity of handling all falses masks.
+    MVT ScalarVT = ContainerVT.getVectorElementType();
+    if (IsUnmasked && isNullConstant(Stride) && ContainerVT.isInteger()) {
       SDValue ScalarLoad =
           DAG.getExtLoad(ISD::ZEXTLOAD, DL, XLenVT, Load->getChain(), Ptr,
                          ScalarVT, Load->getMemOperand());
       Chain = ScalarLoad.getValue(1);
       Result = lowerScalarSplat(SDValue(), ScalarLoad, VL, ContainerVT, DL, DAG,
                                 Subtarget);
+    } else if (IsUnmasked && isNullConstant(Stride) && isTypeLegal(ScalarVT)) {
+      SDValue ScalarLoad = DAG.getLoad(ScalarVT, DL, Load->getChain(), Ptr,
+                                       Load->getMemOperand());
+      Chain = ScalarLoad.getValue(1);
+      Result = DAG.getSplat(ContainerVT, DL, ScalarLoad);
     } else {
       SDValue IntID = DAG.getTargetConstant(
           IsUnmasked ? Intrinsic::riscv_vlse : Intrinsic::riscv_vlse_mask, DL,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-store-asm.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-store-asm.ll
index 17d64c86dd53a4..c38406bafa8a97 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-store-asm.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-store-asm.ll
@@ -915,3 +915,42 @@ bb4:                                              ; preds = %bb4, %bb2
 bb16:                                             ; preds = %bb4, %bb
   ret void
 }
+
+define void @gather_zero_stride_fp(ptr noalias nocapture %A, ptr noalias nocapture readonly %B) {
+; CHECK-LABEL: gather_zero_stride_fp:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    lui a2, 1
+; CHECK-NEXT:    add a2, a0, a2
+; CHECK-NEXT:    vsetivli zero, 8, e32, m1, ta, ma
+; CHECK-NEXT:  .LBB15_1: # %vector.body
+; CHECK-NEXT:    # =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    flw fa5, 0(a1)
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vfadd.vf v8, v8, fa5
+; CHECK-NEXT:    vse32.v v8, (a0)
+; CHECK-NEXT:    addi a0, a0, 128
+; CHECK-NEXT:    addi a1, a1, 640
+; CHECK-NEXT:    bne a0, a2, .LBB15_1
+; CHECK-NEXT:  # %bb.2: # %for.cond.cleanup
+; CHECK-NEXT:    ret
+entry:
+  br label %vector.body
+
+vector.body:                                      ; preds = %vector.body, %entry
+  %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
+  %vec.ind = phi <8 x i64> [ zeroinitializer, %entry ], [ %vec.ind.next, %vector.body ]
+  %i = mul nuw nsw <8 x i64> %vec.ind, <i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5>
+  %i1 = getelementptr inbounds float, ptr %B, <8 x i64> %i
+  %wide.masked.gather = call <8 x float> @llvm.masked.gather.v8f32.v32p0(<8 x ptr> %i1, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x float> undef)
+  %i2 = getelementptr inbounds float, ptr %A, i64 %index
+  %wide.load = load <8 x float>, ptr %i2, align 4
+  %i4 = fadd <8 x float> %wide.load, %wide.masked.gather
+  store <8 x float> %i4, ptr %i2, align 4
+  %index.next = add nuw i64 %index, 32
+  %vec.ind.next = add <8 x i64> %vec.ind, <i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32>
+  %i6 = icmp eq i64 %index.next, 1024
+  br i1 %i6, label %for.cond.cleanup, label %vector.body
+
+for.cond.cleanup:                                 ; preds = %vector.body
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/84576


More information about the llvm-commits mailing list