[llvm] [RISCV] Select zero splats of EEW=64 on RV32 (PR #107205)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 4 02:33:14 PDT 2024


https://github.com/lukel97 created https://github.com/llvm/llvm-project/pull/107205

A EEW=64 splat of zero on RV32 will have a bitcast in the middle, so it isn't picked up by findVSplat. But we can add a special case for it which allows some .vx and .vi patterns to be picked up.


>From daa227bac14b80f47e82bd4c2c53691a917c4005 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 4 Sep 2024 17:28:50 +0800
Subject: [PATCH] [RISCV] Select zero splats of EEW=64 on RV32

A EEW=64 splat of zero on RV32 will have a bitcast in the middle, so it isn't picked up by findVSplat. But we can add a special case for it which allows some .vx and .vi patterns to be picked up.
---
 llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp   |  27 ++++
 .../rvv/fixed-vectors-masked-load-int.ll      |  56 +++-----
 .../rvv/fixed-vectors-masked-store-int.ll     | 122 ++++++------------
 3 files changed, 87 insertions(+), 118 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 78006885421603..125edd24be5cab 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -3452,10 +3452,37 @@ bool RISCVDAGToDAGISel::selectVSplat(SDValue N, SDValue &SplatVal) {
   return true;
 }
 
+// Look for splats of zero. On RV32 a EEW=64 there may be a bitcast in between.
+//
+//   t72: nxv16i32 = RISCVISD::VMV_V_X_VL ...
+//   t73: v32i32 = extract_subvector t72, Constant:i32<0>
+//   t21: v16i64 = bitcast t73
+//   t42: nxv8i64 = insert_subvector undef:nxv8i64, t21, Constant:i32<0>
+static bool isZeroSplat(SDValue N) {
+  if (N.getOpcode() == ISD::INSERT_SUBVECTOR && N.getOperand(0).isUndef())
+    N = N.getOperand(1);
+  if (N.getOpcode() == ISD::BITCAST)
+    N = N.getOperand(0);
+  if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR)
+    N = N.getOperand(0);
+
+  return (N.getOpcode() == RISCVISD::VMV_V_X_VL ||
+          N.getOpcode() == RISCVISD::VMV_S_X_VL) &&
+         isa<ConstantSDNode>(N.getOperand(1)) &&
+         N.getConstantOperandVal(1) == 0;
+}
+
 static bool selectVSplatImmHelper(SDValue N, SDValue &SplatVal,
                                   SelectionDAG &DAG,
                                   const RISCVSubtarget &Subtarget,
                                   std::function<bool(int64_t)> ValidateImm) {
+  // Explicitly look for zero splats to handle the EEW=64 on RV32 case.
+  if (isZeroSplat(N) && ValidateImm(0)) {
+    SplatVal =
+        DAG.getConstant(0, SDLoc(N), Subtarget.getXLenVT(), /*isTarget=*/true);
+    return true;
+  }
+
   SDValue Splat = findVSplat(N);
   if (!Splat || !isa<ConstantSDNode>(Splat.getOperand(1)))
     return false;
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-load-int.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-load-int.ll
index ad075e4b4e198c..2f20caa6eb1894 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-load-int.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-load-int.ll
@@ -397,43 +397,22 @@ define void @masked_load_v32i32(ptr %a, ptr %m_ptr, ptr %res_ptr) nounwind {
 declare <32 x i32> @llvm.masked.load.v32i32(ptr, i32, <32 x i1>, <32 x i32>)
 
 define void @masked_load_v32i64(ptr %a, ptr %m_ptr, ptr %res_ptr) nounwind {
-; RV32-LABEL: masked_load_v32i64:
-; RV32:       # %bb.0:
-; RV32-NEXT:    addi a3, a1, 128
-; RV32-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
-; RV32-NEXT:    vle64.v v0, (a1)
-; RV32-NEXT:    vle64.v v24, (a3)
-; RV32-NEXT:    li a1, 32
-; RV32-NEXT:    vsetvli zero, a1, e32, m8, ta, ma
-; RV32-NEXT:    vmv.v.i v16, 0
-; RV32-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
-; RV32-NEXT:    vmseq.vv v8, v0, v16
-; RV32-NEXT:    vmseq.vv v0, v24, v16
-; RV32-NEXT:    addi a1, a0, 128
-; RV32-NEXT:    vle64.v v16, (a1), v0.t
-; RV32-NEXT:    vmv1r.v v0, v8
-; RV32-NEXT:    vle64.v v8, (a0), v0.t
-; RV32-NEXT:    vse64.v v8, (a2)
-; RV32-NEXT:    addi a0, a2, 128
-; RV32-NEXT:    vse64.v v16, (a0)
-; RV32-NEXT:    ret
-;
-; RV64-LABEL: masked_load_v32i64:
-; RV64:       # %bb.0:
-; RV64-NEXT:    addi a3, a1, 128
-; RV64-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
-; RV64-NEXT:    vle64.v v16, (a1)
-; RV64-NEXT:    vle64.v v24, (a3)
-; RV64-NEXT:    vmseq.vi v8, v16, 0
-; RV64-NEXT:    vmseq.vi v0, v24, 0
-; RV64-NEXT:    addi a1, a0, 128
-; RV64-NEXT:    vle64.v v16, (a1), v0.t
-; RV64-NEXT:    vmv1r.v v0, v8
-; RV64-NEXT:    vle64.v v8, (a0), v0.t
-; RV64-NEXT:    vse64.v v8, (a2)
-; RV64-NEXT:    addi a0, a2, 128
-; RV64-NEXT:    vse64.v v16, (a0)
-; RV64-NEXT:    ret
+; CHECK-LABEL: masked_load_v32i64:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a3, a1, 128
+; CHECK-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
+; CHECK-NEXT:    vle64.v v16, (a1)
+; CHECK-NEXT:    vle64.v v24, (a3)
+; CHECK-NEXT:    vmseq.vi v8, v16, 0
+; CHECK-NEXT:    vmseq.vi v0, v24, 0
+; CHECK-NEXT:    addi a1, a0, 128
+; CHECK-NEXT:    vle64.v v16, (a1), v0.t
+; CHECK-NEXT:    vmv1r.v v0, v8
+; CHECK-NEXT:    vle64.v v8, (a0), v0.t
+; CHECK-NEXT:    vse64.v v8, (a2)
+; CHECK-NEXT:    addi a0, a2, 128
+; CHECK-NEXT:    vse64.v v16, (a0)
+; CHECK-NEXT:    ret
   %m = load <32 x i64>, ptr %m_ptr
   %mask = icmp eq <32 x i64> %m, zeroinitializer
   %load = call <32 x i64> @llvm.masked.load.v32i64(ptr %a, i32 8, <32 x i1> %mask, <32 x i64> undef)
@@ -547,3 +526,6 @@ define void @masked_load_v256i8(ptr %a, ptr %m_ptr, ptr %res_ptr) nounwind {
   ret void
 }
 declare <256 x i8> @llvm.masked.load.v256i8(ptr, i32, <256 x i1>, <256 x i8>)
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; RV32: {{.*}}
+; RV64: {{.*}}
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-int.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-int.ll
index 86c28247e97ef1..90690bbc8e2085 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-int.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-store-int.ll
@@ -397,87 +397,44 @@ define void @masked_store_v32i32(ptr %val_ptr, ptr %a, ptr %m_ptr) nounwind {
 declare void @llvm.masked.store.v32i32.p0(<32 x i32>, ptr, i32, <32 x i1>)
 
 define void @masked_store_v32i64(ptr %val_ptr, ptr %a, ptr %m_ptr) nounwind {
-; RV32-LABEL: masked_store_v32i64:
-; RV32:       # %bb.0:
-; RV32-NEXT:    addi sp, sp, -16
-; RV32-NEXT:    csrr a3, vlenb
-; RV32-NEXT:    slli a3, a3, 4
-; RV32-NEXT:    sub sp, sp, a3
-; RV32-NEXT:    addi a3, a2, 128
-; RV32-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
-; RV32-NEXT:    vle64.v v24, (a2)
-; RV32-NEXT:    vle64.v v8, (a3)
-; RV32-NEXT:    csrr a2, vlenb
-; RV32-NEXT:    slli a2, a2, 3
-; RV32-NEXT:    add a2, sp, a2
-; RV32-NEXT:    addi a2, a2, 16
-; RV32-NEXT:    vs8r.v v8, (a2) # Unknown-size Folded Spill
-; RV32-NEXT:    li a2, 32
-; RV32-NEXT:    vsetvli zero, a2, e32, m8, ta, ma
-; RV32-NEXT:    vmv.v.i v8, 0
-; RV32-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
-; RV32-NEXT:    vmseq.vv v7, v24, v8
-; RV32-NEXT:    addi a2, a0, 128
-; RV32-NEXT:    vle64.v v24, (a2)
-; RV32-NEXT:    vle64.v v16, (a0)
-; RV32-NEXT:    addi a0, sp, 16
-; RV32-NEXT:    vs8r.v v16, (a0) # Unknown-size Folded Spill
-; RV32-NEXT:    csrr a0, vlenb
-; RV32-NEXT:    slli a0, a0, 3
-; RV32-NEXT:    add a0, sp, a0
-; RV32-NEXT:    addi a0, a0, 16
-; RV32-NEXT:    vl8r.v v16, (a0) # Unknown-size Folded Reload
-; RV32-NEXT:    vmseq.vv v0, v16, v8
-; RV32-NEXT:    addi a0, a1, 128
-; RV32-NEXT:    vse64.v v24, (a0), v0.t
-; RV32-NEXT:    vmv1r.v v0, v7
-; RV32-NEXT:    addi a0, sp, 16
-; RV32-NEXT:    vl8r.v v8, (a0) # Unknown-size Folded Reload
-; RV32-NEXT:    vse64.v v8, (a1), v0.t
-; RV32-NEXT:    csrr a0, vlenb
-; RV32-NEXT:    slli a0, a0, 4
-; RV32-NEXT:    add sp, sp, a0
-; RV32-NEXT:    addi sp, sp, 16
-; RV32-NEXT:    ret
-;
-; RV64-LABEL: masked_store_v32i64:
-; RV64:       # %bb.0:
-; RV64-NEXT:    addi sp, sp, -16
-; RV64-NEXT:    csrr a3, vlenb
-; RV64-NEXT:    slli a3, a3, 4
-; RV64-NEXT:    sub sp, sp, a3
-; RV64-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
-; RV64-NEXT:    vle64.v v8, (a2)
-; RV64-NEXT:    addi a2, a2, 128
-; RV64-NEXT:    vle64.v v16, (a2)
-; RV64-NEXT:    csrr a2, vlenb
-; RV64-NEXT:    slli a2, a2, 3
-; RV64-NEXT:    add a2, sp, a2
-; RV64-NEXT:    addi a2, a2, 16
-; RV64-NEXT:    vs8r.v v16, (a2) # Unknown-size Folded Spill
-; RV64-NEXT:    vmseq.vi v0, v8, 0
-; RV64-NEXT:    vle64.v v24, (a0)
-; RV64-NEXT:    addi a0, a0, 128
-; RV64-NEXT:    vle64.v v8, (a0)
-; RV64-NEXT:    addi a0, sp, 16
-; RV64-NEXT:    vs8r.v v8, (a0) # Unknown-size Folded Spill
-; RV64-NEXT:    csrr a0, vlenb
-; RV64-NEXT:    slli a0, a0, 3
-; RV64-NEXT:    add a0, sp, a0
-; RV64-NEXT:    addi a0, a0, 16
-; RV64-NEXT:    vl8r.v v16, (a0) # Unknown-size Folded Reload
-; RV64-NEXT:    vmseq.vi v8, v16, 0
-; RV64-NEXT:    vse64.v v24, (a1), v0.t
-; RV64-NEXT:    addi a0, a1, 128
-; RV64-NEXT:    vmv1r.v v0, v8
-; RV64-NEXT:    addi a1, sp, 16
-; RV64-NEXT:    vl8r.v v8, (a1) # Unknown-size Folded Reload
-; RV64-NEXT:    vse64.v v8, (a0), v0.t
-; RV64-NEXT:    csrr a0, vlenb
-; RV64-NEXT:    slli a0, a0, 4
-; RV64-NEXT:    add sp, sp, a0
-; RV64-NEXT:    addi sp, sp, 16
-; RV64-NEXT:    ret
+; CHECK-LABEL: masked_store_v32i64:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi sp, sp, -16
+; CHECK-NEXT:    csrr a3, vlenb
+; CHECK-NEXT:    slli a3, a3, 4
+; CHECK-NEXT:    sub sp, sp, a3
+; CHECK-NEXT:    vsetivli zero, 16, e64, m8, ta, ma
+; CHECK-NEXT:    vle64.v v8, (a2)
+; CHECK-NEXT:    addi a2, a2, 128
+; CHECK-NEXT:    vle64.v v16, (a2)
+; CHECK-NEXT:    csrr a2, vlenb
+; CHECK-NEXT:    slli a2, a2, 3
+; CHECK-NEXT:    add a2, sp, a2
+; CHECK-NEXT:    addi a2, a2, 16
+; CHECK-NEXT:    vs8r.v v16, (a2) # Unknown-size Folded Spill
+; CHECK-NEXT:    vmseq.vi v0, v8, 0
+; CHECK-NEXT:    vle64.v v24, (a0)
+; CHECK-NEXT:    addi a0, a0, 128
+; CHECK-NEXT:    vle64.v v8, (a0)
+; CHECK-NEXT:    addi a0, sp, 16
+; CHECK-NEXT:    vs8r.v v8, (a0) # Unknown-size Folded Spill
+; CHECK-NEXT:    csrr a0, vlenb
+; CHECK-NEXT:    slli a0, a0, 3
+; CHECK-NEXT:    add a0, sp, a0
+; CHECK-NEXT:    addi a0, a0, 16
+; CHECK-NEXT:    vl8r.v v16, (a0) # Unknown-size Folded Reload
+; CHECK-NEXT:    vmseq.vi v8, v16, 0
+; CHECK-NEXT:    vse64.v v24, (a1), v0.t
+; CHECK-NEXT:    addi a0, a1, 128
+; CHECK-NEXT:    vmv1r.v v0, v8
+; CHECK-NEXT:    addi a1, sp, 16
+; CHECK-NEXT:    vl8r.v v8, (a1) # Unknown-size Folded Reload
+; CHECK-NEXT:    vse64.v v8, (a0), v0.t
+; CHECK-NEXT:    csrr a0, vlenb
+; CHECK-NEXT:    slli a0, a0, 4
+; CHECK-NEXT:    add sp, sp, a0
+; CHECK-NEXT:    addi sp, sp, 16
+; CHECK-NEXT:    ret
   %m = load <32 x i64>, ptr %m_ptr
   %mask = icmp eq <32 x i64> %m, zeroinitializer
   %val = load <32 x i64>, ptr %val_ptr
@@ -683,3 +640,6 @@ define void @masked_store_v256i8(ptr %val_ptr, ptr %a, ptr %m_ptr) nounwind {
   ret void
 }
 declare void @llvm.masked.store.v256i8.p0(<256 x i8>, ptr, i32, <256 x i1>)
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; RV32: {{.*}}
+; RV64: {{.*}}



More information about the llvm-commits mailing list