[llvm] [RISCV] Combine vXi32 (mul (and (lshr X, 15), 0x10001), 0xffff) -> (bitcast (sra (v2Xi16 (bitcast X)), 15)) (PR #93565)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue May 28 08:43:32 PDT 2024


https://github.com/topperc created https://github.com/llvm/llvm-project/pull/93565

Similar for i16 and i64 elements for both fixed and scalable vectors.

This reduces the number of vector instructions, but increases vl/vtype toggles.

This reduces some code in 525.x264_r from SPEC2017. In that usage, the vectors are fixed with a small number of elements so vsetivli can be used.

>From 51183e936480b5a86002e49b91fd415eda35fa5c Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Fri, 24 May 2024 17:10:28 -0700
Subject: [PATCH] [RISCV] Combine vXi32 (mul (and (lshr X, 15), 0x10001),
 0xffff) -> (bitcast (sra (v2Xi16 (bitcast X)), 15))

Similar for i16 and i64 elements for both fixed and scalable vectors.

This reduces the number of vector instructions, but increases
vl/vtype toggles.

This reduces some code in 525.x264_r from SPEC2017. In that usage,
the vectors are fixed with a small number of elements so vsetivli
can be used.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp |  41 +++++++
 llvm/test/CodeGen/RISCV/rvv/mul-combine.ll  | 117 ++++++++++++++++++++
 2 files changed, 158 insertions(+)
 create mode 100644 llvm/test/CodeGen/RISCV/rvv/mul-combine.ll

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f0e5a7d393b6c..0889e94585155 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13704,6 +13704,44 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+// Combine vXi32 (mul (and (lshr X, 15), 0x10001), 0xffff) ->
+// (bitcast (sra (v2Xi16 (bitcast X)), 15))
+// Same for other equivalent types with other equivalen6t constants.
+static SDValue combineVectorMulToSraBitcast(SDNode *N, SelectionDAG &DAG) {
+  EVT VT = N->getValueType(0);
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+
+  // Do this for legal vectors unless they are i1 or i8 vectors.
+  if (!VT.isVector() || !TLI.isTypeLegal(VT) || VT.getScalarSizeInBits() < 16)
+    return SDValue();
+
+  if (N->getOperand(0).getOpcode() != ISD::AND ||
+      N->getOperand(0).getOperand(0).getOpcode() != ISD::SRL)
+    return SDValue();
+
+  SDValue And = N->getOperand(0);
+  SDValue Srl = And.getOperand(0);
+
+  APInt V1, V2, V3;
+  if (!ISD::isConstantSplatVector(N->getOperand(1).getNode(), V1) ||
+      !ISD::isConstantSplatVector(And.getOperand(1).getNode(), V2) ||
+      !ISD::isConstantSplatVector(Srl.getOperand(1).getNode(), V3))
+    return SDValue();
+
+  unsigned HalfSize = VT.getScalarSizeInBits() / 2;
+  if (!V1.isMask(HalfSize) || V2 != (1ULL | 1ULL << HalfSize) ||
+      V3 != (HalfSize - 1))
+    return SDValue();
+
+  EVT HalfVT = EVT::getVectorVT(*DAG.getContext(),
+                                EVT::getIntegerVT(*DAG.getContext(), HalfSize),
+                                VT.getVectorElementCount() * 2);
+  SDLoc DL(N);
+  SDValue Cast = DAG.getNode(ISD::BITCAST, DL, HalfVT, Srl.getOperand(0));
+  SDValue Sra = DAG.getNode(ISD::SRA, DL, HalfVT, Cast,
+                            DAG.getConstant(HalfSize - 1, DL, HalfVT));
+  return DAG.getNode(ISD::BITCAST, DL, VT, Sra);
+}
 
 static SDValue performMULCombine(SDNode *N, SelectionDAG &DAG,
                                  TargetLowering::DAGCombinerInfo &DCI,
@@ -13748,6 +13786,9 @@ static SDValue performMULCombine(SDNode *N, SelectionDAG &DAG,
   if (SDValue V = combineBinOpOfZExt(N, DAG))
     return V;
 
+  if (SDValue V = combineVectorMulToSraBitcast(N, DAG))
+    return V;
+
   return SDValue();
 }
 
diff --git a/llvm/test/CodeGen/RISCV/rvv/mul-combine.ll b/llvm/test/CodeGen/RISCV/rvv/mul-combine.ll
new file mode 100644
index 0000000000000..6a7da925b4d43
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/mul-combine.ll
@@ -0,0 +1,117 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK-RV32
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK-RV64
+
+define <2 x i16> @test_v2i16(<2 x i16> %x) {
+; CHECK-RV32-LABEL: test_v2i16:
+; CHECK-RV32:       # %bb.0:
+; CHECK-RV32-NEXT:    vsetivli zero, 4, e8, mf4, ta, ma
+; CHECK-RV32-NEXT:    vsra.vi v8, v8, 7
+; CHECK-RV32-NEXT:    ret
+;
+; CHECK-RV64-LABEL: test_v2i16:
+; CHECK-RV64:       # %bb.0:
+; CHECK-RV64-NEXT:    vsetivli zero, 4, e8, mf4, ta, ma
+; CHECK-RV64-NEXT:    vsra.vi v8, v8, 7
+; CHECK-RV64-NEXT:    ret
+  %1 = lshr <2 x i16> %x, <i16 7, i16 7>
+  %2 = and <2 x i16> %1, <i16 257, i16 257>
+  %3 = mul <2 x i16> %2, <i16 255, i16 255>
+  ret <2 x i16> %3
+}
+
+define <vscale x 2 x i16> @test_nxv2i16(<vscale x 2 x i16> %x) {
+; CHECK-RV32-LABEL: test_nxv2i16:
+; CHECK-RV32:       # %bb.0:
+; CHECK-RV32-NEXT:    vsetvli a0, zero, e16, mf2, ta, ma
+; CHECK-RV32-NEXT:    vsrl.vi v8, v8, 7
+; CHECK-RV32-NEXT:    li a0, 257
+; CHECK-RV32-NEXT:    vand.vx v8, v8, a0
+; CHECK-RV32-NEXT:    vsll.vi v8, v8, 8
+; CHECK-RV32-NEXT:    ret
+;
+; CHECK-RV64-LABEL: test_nxv2i16:
+; CHECK-RV64:       # %bb.0:
+; CHECK-RV64-NEXT:    vsetvli a0, zero, e16, mf2, ta, ma
+; CHECK-RV64-NEXT:    vsrl.vi v8, v8, 7
+; CHECK-RV64-NEXT:    li a0, 257
+; CHECK-RV64-NEXT:    vand.vx v8, v8, a0
+; CHECK-RV64-NEXT:    vsll.vi v8, v8, 8
+; CHECK-RV64-NEXT:    ret
+  %1 = lshr <vscale x 2 x i16> %x, splat (i16 7)
+  %2 = and <vscale x 2 x i16> %1, splat (i16 257)
+  %3 = mul <vscale x 2 x i16> %2, splat (i16 256)
+  ret <vscale x 2 x i16> %3
+}
+
+define <2 x i32> @test_v2i32(<2 x i32> %x) {
+; CHECK-RV32-LABEL: test_v2i32:
+; CHECK-RV32:       # %bb.0:
+; CHECK-RV32-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-RV32-NEXT:    vsra.vi v8, v8, 15
+; CHECK-RV32-NEXT:    ret
+;
+; CHECK-RV64-LABEL: test_v2i32:
+; CHECK-RV64:       # %bb.0:
+; CHECK-RV64-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-RV64-NEXT:    vsra.vi v8, v8, 15
+; CHECK-RV64-NEXT:    ret
+  %1 = lshr <2 x i32> %x, <i32 15, i32 15>
+  %2 = and <2 x i32> %1, <i32 65537, i32 65537>
+  %3 = mul <2 x i32> %2, <i32 65535, i32 65535>
+  ret <2 x i32> %3
+}
+
+define <vscale x 2 x i32> @test_nxv2i32(<vscale x 2 x i32> %x) {
+; CHECK-RV32-LABEL: test_nxv2i32:
+; CHECK-RV32:       # %bb.0:
+; CHECK-RV32-NEXT:    vsetvli a0, zero, e16, m1, ta, ma
+; CHECK-RV32-NEXT:    vsra.vi v8, v8, 15
+; CHECK-RV32-NEXT:    ret
+;
+; CHECK-RV64-LABEL: test_nxv2i32:
+; CHECK-RV64:       # %bb.0:
+; CHECK-RV64-NEXT:    vsetvli a0, zero, e16, m1, ta, ma
+; CHECK-RV64-NEXT:    vsra.vi v8, v8, 15
+; CHECK-RV64-NEXT:    ret
+  %1 = lshr <vscale x 2 x i32> %x, splat (i32 15)
+  %2 = and <vscale x 2 x i32> %1, splat (i32 65537)
+  %3 = mul <vscale x 2 x i32> %2, splat (i32 65535)
+  ret <vscale x 2 x i32> %3
+}
+
+define <2 x i64> @test_v2i64(<2 x i64> %x) {
+; CHECK-RV32-LABEL: test_v2i64:
+; CHECK-RV32:       # %bb.0:
+; CHECK-RV32-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-RV32-NEXT:    vsra.vi v8, v8, 31
+; CHECK-RV32-NEXT:    ret
+;
+; CHECK-RV64-LABEL: test_v2i64:
+; CHECK-RV64:       # %bb.0:
+; CHECK-RV64-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-RV64-NEXT:    vsra.vi v8, v8, 31
+; CHECK-RV64-NEXT:    ret
+  %1 = lshr <2 x i64> %x, <i64 31, i64 31>
+  %2 = and <2 x i64> %1, <i64 4294967297, i64 4294967297>
+  %3 = mul <2 x i64> %2, <i64 4294967295, i64 4294967295>
+  ret <2 x i64> %3
+}
+
+define <vscale x 2 x i64> @test_nxv2i64(<vscale x 2 x i64> %x) {
+; CHECK-RV32-LABEL: test_nxv2i64:
+; CHECK-RV32:       # %bb.0:
+; CHECK-RV32-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; CHECK-RV32-NEXT:    vsra.vi v8, v8, 31
+; CHECK-RV32-NEXT:    ret
+;
+; CHECK-RV64-LABEL: test_nxv2i64:
+; CHECK-RV64:       # %bb.0:
+; CHECK-RV64-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; CHECK-RV64-NEXT:    vsra.vi v8, v8, 31
+; CHECK-RV64-NEXT:    ret
+  %1 = lshr <vscale x 2 x i64> %x, splat (i64 31)
+  %2 = and <vscale x 2 x i64> %1, splat (i64 4294967297)
+  %3 = mul <vscale x 2 x i64> %2, splat (i64 4294967295)
+  ret <vscale x 2 x i64> %3
+}



More information about the llvm-commits mailing list