[llvm] ddae50d - [RISCV] Combine trunc (sra sext (x), zext (y)) to sra (x, smin (y, scalarsizeinbits(y) - 1)) (#65728)

via llvm-commits llvm-commits at lists.llvm.org
Sun Sep 17 02:11:32 PDT 2023


Author: Vettel
Date: 2023-09-17T17:11:28+08:00
New Revision: ddae50d1e6fd71b98a4abefeeebb7ad6ac3329c5

URL: https://github.com/llvm/llvm-project/commit/ddae50d1e6fd71b98a4abefeeebb7ad6ac3329c5
DIFF: https://github.com/llvm/llvm-project/commit/ddae50d1e6fd71b98a4abefeeebb7ad6ac3329c5.diff

LOG: [RISCV] Combine trunc (sra sext (x), zext (y)) to sra (x, smin (y, scalarsizeinbits(y) - 1)) (#65728)

For RVV, If we want to perform an i8 or i16 element-wise vector
arithmetic right shift in the upper C/C++ program, the value to be
shifted would be first sign extended to i32, and the shift amount would
also be zero_extended to i32 to perform the vsra.vv instruction, and
followed by a truncate to get the final calculation result, such pattern
will later expanded to a series of "vsetvli" and "vnsrl" instructions
later, this is because the RVV spec only support 2 * SEW -> SEW
truncate. But for vector, the shift amount can also be determined by
smin (Y, ScalarSizeInBits(Y) - 1)). Also, for the vsra instruction, we
only care about the low lg2(SEW) bits as the shift amount.

- Alive2: https://alive2.llvm.org/ce/z/u3-Zdr
- C++ Test cases : https://gcc.godbolt.org/z/q1qE7fbha

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5c3ea3660e6e672..6d7e6a98f64eaab 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13749,6 +13749,56 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
       }
     }
     return SDValue();
+  case RISCVISD::TRUNCATE_VECTOR_VL: {
+    // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
+    // This would be benefit for the cases where X and Y are both the same value
+    // type of low precision vectors. Since the truncate would be lowered into
+    // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
+    // restriction, such pattern would be expanded into a series of "vsetvli"
+    // and "vnsrl" instructions later to reach this point.
+    auto IsTruncNode = [](SDValue V) {
+      if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
+        return false;
+      SDValue VL = V.getOperand(2);
+      auto *C = dyn_cast<ConstantSDNode>(VL);
+      // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
+      bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
+                             (isa<RegisterSDNode>(VL) &&
+                              cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
+      return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
+             IsVLMAXForVMSET;
+    };
+
+    SDValue Op = N->getOperand(0);
+
+    // We need to first find the inner level of TRUNCATE_VECTOR_VL node
+    // to distinguish such pattern.
+    while (IsTruncNode(Op)) {
+      if (!Op.hasOneUse())
+        return SDValue();
+      Op = Op.getOperand(0);
+    }
+
+    if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) {
+      SDValue N0 = Op.getOperand(0);
+      SDValue N1 = Op.getOperand(1);
+      if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
+          N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) {
+        SDValue N00 = N0.getOperand(0);
+        SDValue N10 = N1.getOperand(0);
+        if (N00.getValueType().isVector() &&
+            N00.getValueType() == N10.getValueType() &&
+            N->getValueType(0) == N10.getValueType()) {
+          unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
+          SDValue SMin = DAG.getNode(
+              ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
+              DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
+          return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
+        }
+      }
+    }
+    break;
+  }
   case ISD::TRUNCATE:
     return performTRUNCATECombine(N, DAG, Subtarget);
   case ISD::SELECT:

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll
index 743031616967754..738e9cf805b46f4 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll
@@ -12,6 +12,21 @@ define <vscale x 1 x i8> @vsra_vv_nxv1i8(<vscale x 1 x i8> %va, <vscale x 1 x i8
   ret <vscale x 1 x i8> %vc
 }
 
+define <vscale x 1 x i8> @vsra_vv_nxv1i8_sext_zext(<vscale x 1 x i8> %va, <vscale x 1 x i8> %vb) {
+; CHECK-LABEL: vsra_vv_nxv1i8_sext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 7
+; CHECK-NEXT:    vsetvli a1, zero, e8, mf8, ta, ma
+; CHECK-NEXT:    vmin.vx v9, v8, a0
+; CHECK-NEXT:    vsra.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %sexted_va = sext <vscale x 1 x i8> %va to <vscale x 1 x i32>
+  %zexted_vb = zext <vscale x 1 x i8> %va to <vscale x 1 x i32>
+  %expand = ashr <vscale x 1 x i32> %sexted_va, %zexted_vb
+  %vc = trunc <vscale x 1 x i32> %expand to <vscale x 1 x i8>
+  ret <vscale x 1 x i8> %vc
+}
+
 define <vscale x 1 x i8> @vsra_vx_nxv1i8(<vscale x 1 x i8> %va, i8 signext %b) {
 ; CHECK-LABEL: vsra_vx_nxv1i8:
 ; CHECK:       # %bb.0:
@@ -46,6 +61,21 @@ define <vscale x 2 x i8> @vsra_vv_nxv2i8(<vscale x 2 x i8> %va, <vscale x 2 x i8
   ret <vscale x 2 x i8> %vc
 }
 
+define <vscale x 2 x i8> @vsra_vv_nxv2i8_sext_zext(<vscale x 2 x i8> %va, <vscale x 2 x i8> %vb) {
+; CHECK-LABEL: vsra_vv_nxv2i8_sext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 7
+; CHECK-NEXT:    vsetvli a1, zero, e8, mf4, ta, ma
+; CHECK-NEXT:    vmin.vx v9, v8, a0
+; CHECK-NEXT:    vsra.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %sexted_va = sext <vscale x 2 x i8> %va to <vscale x 2 x i32>
+  %zexted_vb = zext <vscale x 2 x i8> %va to <vscale x 2 x i32>
+  %expand = ashr <vscale x 2 x i32> %sexted_va, %zexted_vb
+  %vc = trunc <vscale x 2 x i32> %expand to <vscale x 2 x i8>
+  ret <vscale x 2 x i8> %vc
+}
+
 define <vscale x 2 x i8> @vsra_vx_nxv2i8(<vscale x 2 x i8> %va, i8 signext %b) {
 ; CHECK-LABEL: vsra_vx_nxv2i8:
 ; CHECK:       # %bb.0:
@@ -80,6 +110,21 @@ define <vscale x 4 x i8> @vsra_vv_nxv4i8(<vscale x 4 x i8> %va, <vscale x 4 x i8
   ret <vscale x 4 x i8> %vc
 }
 
+define <vscale x 4 x i8> @vsra_vv_nxv4i8_sext_zext(<vscale x 4 x i8> %va, <vscale x 4 x i8> %vb) {
+; CHECK-LABEL: vsra_vv_nxv4i8_sext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 7
+; CHECK-NEXT:    vsetvli a1, zero, e8, mf2, ta, ma
+; CHECK-NEXT:    vmin.vx v9, v8, a0
+; CHECK-NEXT:    vsra.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %sexted_va = sext <vscale x 4 x i8> %va to <vscale x 4 x i32>
+  %zexted_vb = zext <vscale x 4 x i8> %va to <vscale x 4 x i32>
+  %expand = ashr <vscale x 4 x i32> %sexted_va, %zexted_vb
+  %vc = trunc <vscale x 4 x i32> %expand to <vscale x 4 x i8>
+  ret <vscale x 4 x i8> %vc
+}
+
 define <vscale x 4 x i8> @vsra_vx_nxv4i8(<vscale x 4 x i8> %va, i8 signext %b) {
 ; CHECK-LABEL: vsra_vx_nxv4i8:
 ; CHECK:       # %bb.0:
@@ -114,6 +159,21 @@ define <vscale x 8 x i8> @vsra_vv_nxv8i8(<vscale x 8 x i8> %va, <vscale x 8 x i8
   ret <vscale x 8 x i8> %vc
 }
 
+define <vscale x 8 x i8> @vsra_vv_nxv8i8_sext_zext(<vscale x 8 x i8> %va, <vscale x 8 x i8> %vb) {
+; CHECK-LABEL: vsra_vv_nxv8i8_sext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 7
+; CHECK-NEXT:    vsetvli a1, zero, e8, m1, ta, ma
+; CHECK-NEXT:    vmin.vx v9, v8, a0
+; CHECK-NEXT:    vsra.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %sexted_va = sext <vscale x 8 x i8> %va to <vscale x 8 x i32>
+  %zexted_vb = zext <vscale x 8 x i8> %va to <vscale x 8 x i32>
+  %expand = ashr <vscale x 8 x i32> %sexted_va, %zexted_vb
+  %vc = trunc <vscale x 8 x i32> %expand to <vscale x 8 x i8>
+  ret <vscale x 8 x i8> %vc
+}
+
 define <vscale x 8 x i8> @vsra_vx_nxv8i8(<vscale x 8 x i8> %va, i8 signext %b) {
 ; CHECK-LABEL: vsra_vx_nxv8i8:
 ; CHECK:       # %bb.0:
@@ -148,6 +208,21 @@ define <vscale x 16 x i8> @vsra_vv_nxv16i8(<vscale x 16 x i8> %va, <vscale x 16
   ret <vscale x 16 x i8> %vc
 }
 
+define <vscale x 16 x i8> @vsra_vv_nxv16i8_sext_zext(<vscale x 16 x i8> %va, <vscale x 16 x i8> %vb) {
+; CHECK-LABEL: vsra_vv_nxv16i8_sext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 7
+; CHECK-NEXT:    vsetvli a1, zero, e8, m2, ta, ma
+; CHECK-NEXT:    vmin.vx v10, v8, a0
+; CHECK-NEXT:    vsra.vv v8, v8, v10
+; CHECK-NEXT:    ret
+  %sexted_va = sext <vscale x 16 x i8> %va to <vscale x 16 x i32>
+  %zexted_vb = zext <vscale x 16 x i8> %va to <vscale x 16 x i32>
+  %expand = ashr <vscale x 16 x i32> %sexted_va, %zexted_vb
+  %vc = trunc <vscale x 16 x i32> %expand to <vscale x 16 x i8>
+  ret <vscale x 16 x i8> %vc
+}
+
 define <vscale x 16 x i8> @vsra_vx_nxv16i8(<vscale x 16 x i8> %va, i8 signext %b) {
 ; CHECK-LABEL: vsra_vx_nxv16i8:
 ; CHECK:       # %bb.0:
@@ -250,6 +325,21 @@ define <vscale x 1 x i16> @vsra_vv_nxv1i16(<vscale x 1 x i16> %va, <vscale x 1 x
   ret <vscale x 1 x i16> %vc
 }
 
+define <vscale x 1 x i16> @vsra_vv_nxv1i16_sext_zext(<vscale x 1 x i16> %va, <vscale x 1 x i16> %vb) {
+; CHECK-LABEL: vsra_vv_nxv1i16_sext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 15
+; CHECK-NEXT:    vsetvli a1, zero, e16, mf4, ta, ma
+; CHECK-NEXT:    vmin.vx v9, v8, a0
+; CHECK-NEXT:    vsra.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %sexted_va = sext <vscale x 1 x i16> %va to <vscale x 1 x i32>
+  %zexted_vb = zext <vscale x 1 x i16> %va to <vscale x 1 x i32>
+  %expand = ashr <vscale x 1 x i32> %sexted_va, %zexted_vb
+  %vc = trunc <vscale x 1 x i32> %expand to <vscale x 1 x i16>
+  ret <vscale x 1 x i16> %vc
+}
+
 define <vscale x 1 x i16> @vsra_vx_nxv1i16(<vscale x 1 x i16> %va, i16 signext %b) {
 ; CHECK-LABEL: vsra_vx_nxv1i16:
 ; CHECK:       # %bb.0:
@@ -284,6 +374,21 @@ define <vscale x 2 x i16> @vsra_vv_nxv2i16(<vscale x 2 x i16> %va, <vscale x 2 x
   ret <vscale x 2 x i16> %vc
 }
 
+define <vscale x 2 x i16> @vsra_vv_nxv2i16_sext_zext(<vscale x 2 x i16> %va, <vscale x 2 x i16> %vb) {
+; CHECK-LABEL: vsra_vv_nxv2i16_sext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 15
+; CHECK-NEXT:    vsetvli a1, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vmin.vx v9, v8, a0
+; CHECK-NEXT:    vsra.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %sexted_va = sext <vscale x 2 x i16> %va to <vscale x 2 x i32>
+  %zexted_vb = zext <vscale x 2 x i16> %va to <vscale x 2 x i32>
+  %expand = ashr <vscale x 2 x i32> %sexted_va, %zexted_vb
+  %vc = trunc <vscale x 2 x i32> %expand to <vscale x 2 x i16>
+  ret <vscale x 2 x i16> %vc
+}
+
 define <vscale x 2 x i16> @vsra_vx_nxv2i16(<vscale x 2 x i16> %va, i16 signext %b) {
 ; CHECK-LABEL: vsra_vx_nxv2i16:
 ; CHECK:       # %bb.0:
@@ -318,6 +423,21 @@ define <vscale x 4 x i16> @vsra_vv_nxv4i16(<vscale x 4 x i16> %va, <vscale x 4 x
   ret <vscale x 4 x i16> %vc
 }
 
+define <vscale x 4 x i16> @vsra_vv_nxv4i16_sext_zext(<vscale x 4 x i16> %va, <vscale x 4 x i16> %vb) {
+; CHECK-LABEL: vsra_vv_nxv4i16_sext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 15
+; CHECK-NEXT:    vsetvli a1, zero, e16, m1, ta, ma
+; CHECK-NEXT:    vmin.vx v9, v8, a0
+; CHECK-NEXT:    vsra.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %sexted_va = sext <vscale x 4 x i16> %va to <vscale x 4 x i32>
+  %zexted_vb = zext <vscale x 4 x i16> %va to <vscale x 4 x i32>
+  %expand = ashr <vscale x 4 x i32> %sexted_va, %zexted_vb
+  %vc = trunc <vscale x 4 x i32> %expand to <vscale x 4 x i16>
+  ret <vscale x 4 x i16> %vc
+}
+
 define <vscale x 4 x i16> @vsra_vx_nxv4i16(<vscale x 4 x i16> %va, i16 signext %b) {
 ; CHECK-LABEL: vsra_vx_nxv4i16:
 ; CHECK:       # %bb.0:
@@ -352,6 +472,21 @@ define <vscale x 8 x i16> @vsra_vv_nxv8i16(<vscale x 8 x i16> %va, <vscale x 8 x
   ret <vscale x 8 x i16> %vc
 }
 
+define <vscale x 8 x i16> @vsra_vv_nxv8i16_sext_zext(<vscale x 8 x i16> %va, <vscale x 8 x i16> %vb) {
+; CHECK-LABEL: vsra_vv_nxv8i16_sext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 15
+; CHECK-NEXT:    vsetvli a1, zero, e16, m2, ta, ma
+; CHECK-NEXT:    vmin.vx v10, v8, a0
+; CHECK-NEXT:    vsra.vv v8, v8, v10
+; CHECK-NEXT:    ret
+  %sexted_va = sext <vscale x 8 x i16> %va to <vscale x 8 x i32>
+  %zexted_vb = zext <vscale x 8 x i16> %va to <vscale x 8 x i32>
+  %expand = ashr <vscale x 8 x i32> %sexted_va, %zexted_vb
+  %vc = trunc <vscale x 8 x i32> %expand to <vscale x 8 x i16>
+  ret <vscale x 8 x i16> %vc
+}
+
 define <vscale x 8 x i16> @vsra_vx_nxv8i16(<vscale x 8 x i16> %va, i16 signext %b) {
 ; CHECK-LABEL: vsra_vx_nxv8i16:
 ; CHECK:       # %bb.0:
@@ -386,6 +521,21 @@ define <vscale x 16 x i16> @vsra_vv_nxv16i16(<vscale x 16 x i16> %va, <vscale x
   ret <vscale x 16 x i16> %vc
 }
 
+define <vscale x 16 x i16> @vsra_vv_nxv16i16_sext_zext(<vscale x 16 x i16> %va, <vscale x 16 x i16> %vb) {
+; CHECK-LABEL: vsra_vv_nxv16i16_sext_zext:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 15
+; CHECK-NEXT:    vsetvli a1, zero, e16, m4, ta, ma
+; CHECK-NEXT:    vmin.vx v12, v8, a0
+; CHECK-NEXT:    vsra.vv v8, v8, v12
+; CHECK-NEXT:    ret
+  %sexted_va = sext <vscale x 16 x i16> %va to <vscale x 16 x i32>
+  %zexted_vb = zext <vscale x 16 x i16> %va to <vscale x 16 x i32>
+  %expand = ashr <vscale x 16 x i32> %sexted_va, %zexted_vb
+  %vc = trunc <vscale x 16 x i32> %expand to <vscale x 16 x i16>
+  ret <vscale x 16 x i16> %vc
+}
+
 define <vscale x 16 x i16> @vsra_vx_nxv16i16(<vscale x 16 x i16> %va, i16 signext %b) {
 ; CHECK-LABEL: vsra_vx_nxv16i16:
 ; CHECK:       # %bb.0:


        


More information about the llvm-commits mailing list