[llvm] [RISCV][WIP] Optimize sum of absolute differences pattern. (PR #82722)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 22 18:02:51 PST 2024


https://github.com/topperc created https://github.com/llvm/llvm-project/pull/82722

This writes (abs (sub (zext X), (zext Y))) to
(zext (sub (zext (max X, Y), (min X, Y)))).

This was taken from my downstream and has some overfitting to a particular benchmark.
It only works on i32 vectors and checks that the user can also become a widening instruction.

Posting for discussion.

>From 9c48f6ccc249d7186b9c789973ae7e8d5de1250b Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Thu, 22 Feb 2024 17:59:16 -0800
Subject: [PATCH] [RISCV][WIP] Optimize sum of absolute differences pattern.

This writes (abs (sub (zext X), (zext Y))) to
(zext (sub (zext (max X, Y), (min X, Y)))).

This was taken from my downstream and has some overfitting to a
particular benchmark.

Posting for discussion.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp |  58 ++++++++++
 llvm/test/CodeGen/RISCV/rvv/sad.ll          | 120 ++++++++++++++++++++
 2 files changed, 178 insertions(+)
 create mode 100644 llvm/test/CodeGen/RISCV/rvv/sad.ll

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5c67aaf6785669..6bd62b79e5a74f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13176,6 +13176,61 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG,
   return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);
 }
 
+// Look for (abs (sub (zext X), (zext Y))).
+// Rewrite as (zext (sub (zext (max X, Y), (min X, Y)))) if the user is an add
+// or reduction add. The min/max can be done in parallel and with a lower LMUL
+// than the original code. The two zexts can be folded into widening sub and
+// widening add or widening redsum.
+static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG) {
+  EVT VT = N->getValueType(0);
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+
+  if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i32 ||
+      !TLI.isTypeLegal(VT))
+    return SDValue();
+
+  SDValue Src = N->getOperand(0);
+  if (Src.getOpcode() != ISD::SUB || !Src.hasOneUse())
+    return SDValue();
+
+  // Make sure the use is an add or reduce add so the zext we create at the end
+  // will be folded.
+  if (!N->hasOneUse() || (N->use_begin()->getOpcode() != ISD::ADD &&
+                          N->use_begin()->getOpcode() != ISD::VECREDUCE_ADD))
+    return SDValue();
+
+  // Inputs to the subtract should be zext.
+  SDValue Op0 = Src.getOperand(0);
+  SDValue Op1 = Src.getOperand(1);
+  if (Op0.getOpcode() != ISD::ZERO_EXTEND || !Op0.hasOneUse() ||
+      Op1.getOpcode() != ISD::ZERO_EXTEND || !Op1.hasOneUse())
+    return SDValue();
+
+  Op0 = Op0.getOperand(0);
+  Op1 = Op1.getOperand(0);
+
+  // Inputs should be i8 vectors.
+  if (Op0.getValueType().getVectorElementType() != MVT::i8 ||
+      Op1.getValueType().getVectorElementType() != MVT::i8)
+    return SDValue();
+
+  SDLoc DL(N);
+
+  SDValue Max = DAG.getNode(ISD::UMAX, DL, Op0.getValueType(), Op0, Op1);
+  SDValue Min = DAG.getNode(ISD::UMIN, DL, Op0.getValueType(), Op0, Op1);
+
+  // The intermediate VT should be i16.
+  EVT IntermediateVT =
+      EVT::getVectorVT(*DAG.getContext(), MVT::i16, VT.getVectorElementCount());
+
+  Max = DAG.getNode(ISD::ZERO_EXTEND, DL, IntermediateVT, Max);
+  Min = DAG.getNode(ISD::ZERO_EXTEND, DL, IntermediateVT, Min);
+
+  SDValue Sub = DAG.getNode(ISD::SUB, DL, IntermediateVT, Max, Min);
+
+  return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Sub);
+}
+
 static SDValue performMULCombine(SDNode *N, SelectionDAG &DAG) {
   EVT VT = N->getValueType(0);
   if (!VT.isVector())
@@ -15698,6 +15753,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
                        DAG.getConstant(~SignBit, DL, VT));
   }
   case ISD::ABS: {
+    if (SDValue V = performABSCombine(N, DAG))
+      return V;
+
     EVT VT = N->getValueType(0);
     SDValue N0 = N->getOperand(0);
     // abs (sext) -> zext (abs)
diff --git a/llvm/test/CodeGen/RISCV/rvv/sad.ll b/llvm/test/CodeGen/RISCV/rvv/sad.ll
new file mode 100644
index 00000000000000..ed25431c6f45cc
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/sad.ll
@@ -0,0 +1,120 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v | FileCheck %s
+
+define signext i32 @sad(ptr %a, ptr %b) {
+; CHECK-LABEL: sad:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf4, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a1)
+; CHECK-NEXT:    vminu.vv v10, v8, v9
+; CHECK-NEXT:    vmaxu.vv v8, v8, v9
+; CHECK-NEXT:    vwsubu.vv v9, v8, v10
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vmv.s.x v8, zero
+; CHECK-NEXT:    vsetvli zero, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vwredsumu.vs v8, v9, v8
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+entry:
+  %0 = load <4 x i8>, ptr %a, align 1
+  %1 = zext <4 x i8> %0 to <4 x i32>
+  %2 = load <4 x i8>, ptr %b, align 1
+  %3 = zext <4 x i8> %2 to <4 x i32>
+  %4 = sub nsw <4 x i32> %1, %3
+  %5 = tail call <4 x i32> @llvm.abs.v4i32(<4 x i32> %4, i1 true)
+  %6 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %5)
+  ret i32 %6
+}
+
+define signext i32 @sad2(ptr %a, ptr %b, i32 signext %stridea, i32 signext %strideb) {
+; CHECK-LABEL: sad2:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a1)
+; CHECK-NEXT:    add a0, a0, a2
+; CHECK-NEXT:    add a1, a1, a3
+; CHECK-NEXT:    vle8.v v10, (a0)
+; CHECK-NEXT:    vle8.v v11, (a1)
+; CHECK-NEXT:    vminu.vv v12, v8, v9
+; CHECK-NEXT:    vmaxu.vv v8, v8, v9
+; CHECK-NEXT:    vwsubu.vv v14, v8, v12
+; CHECK-NEXT:    vminu.vv v8, v10, v11
+; CHECK-NEXT:    vmaxu.vv v9, v10, v11
+; CHECK-NEXT:    vwsubu.vv v12, v9, v8
+; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT:    add a0, a0, a2
+; CHECK-NEXT:    add a1, a1, a3
+; CHECK-NEXT:    vle8.v v16, (a0)
+; CHECK-NEXT:    vle8.v v17, (a1)
+; CHECK-NEXT:    vwaddu.vv v8, v12, v14
+; CHECK-NEXT:    vsetvli zero, zero, e8, m1, ta, ma
+; CHECK-NEXT:    vminu.vv v12, v16, v17
+; CHECK-NEXT:    vmaxu.vv v13, v16, v17
+; CHECK-NEXT:    vwsubu.vv v14, v13, v12
+; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT:    add a0, a0, a2
+; CHECK-NEXT:    add a1, a1, a3
+; CHECK-NEXT:    vle8.v v12, (a0)
+; CHECK-NEXT:    vle8.v v13, (a1)
+; CHECK-NEXT:    vwaddu.wv v8, v8, v14
+; CHECK-NEXT:    vsetvli zero, zero, e8, m1, ta, ma
+; CHECK-NEXT:    vminu.vv v14, v12, v13
+; CHECK-NEXT:    vmaxu.vv v12, v12, v13
+; CHECK-NEXT:    vwsubu.vv v16, v12, v14
+; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT:    vwaddu.wv v8, v8, v16
+; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-NEXT:    vmv.s.x v12, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v12
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+entry:
+  %idx.ext8 = sext i32 %strideb to i64
+  %idx.ext = sext i32 %stridea to i64
+  %0 = load <16 x i8>, ptr %a, align 1
+  %1 = zext <16 x i8> %0 to <16 x i32>
+  %2 = load <16 x i8>, ptr %b, align 1
+  %3 = zext <16 x i8> %2 to <16 x i32>
+  %4 = sub nsw <16 x i32> %1, %3
+  %5 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %4, i1 true)
+  %6 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %5)
+  %add.ptr = getelementptr inbounds i8, ptr %a, i64 %idx.ext
+  %add.ptr9 = getelementptr inbounds i8, ptr %b, i64 %idx.ext8
+  %7 = load <16 x i8>, ptr %add.ptr, align 1
+  %8 = zext <16 x i8> %7 to <16 x i32>
+  %9 = load <16 x i8>, ptr %add.ptr9, align 1
+  %10 = zext <16 x i8> %9 to <16 x i32>
+  %11 = sub nsw <16 x i32> %8, %10
+  %12 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %11, i1 true)
+  %13 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %12)
+  %op.rdx.1 = add i32 %13, %6
+  %add.ptr.1 = getelementptr inbounds i8, ptr %add.ptr, i64 %idx.ext
+  %add.ptr9.1 = getelementptr inbounds i8, ptr %add.ptr9, i64 %idx.ext8
+  %14 = load <16 x i8>, ptr %add.ptr.1, align 1
+  %15 = zext <16 x i8> %14 to <16 x i32>
+  %16 = load <16 x i8>, ptr %add.ptr9.1, align 1
+  %17 = zext <16 x i8> %16 to <16 x i32>
+  %18 = sub nsw <16 x i32> %15, %17
+  %19 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %18, i1 true)
+  %20 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %19)
+  %op.rdx.2 = add i32 %20, %op.rdx.1
+  %add.ptr.2 = getelementptr inbounds i8, ptr %add.ptr.1, i64 %idx.ext
+  %add.ptr9.2 = getelementptr inbounds i8, ptr %add.ptr9.1, i64 %idx.ext8
+  %21 = load <16 x i8>, ptr %add.ptr.2, align 1
+  %22 = zext <16 x i8> %21 to <16 x i32>
+  %23 = load <16 x i8>, ptr %add.ptr9.2, align 1
+  %24 = zext <16 x i8> %23 to <16 x i32>
+  %25 = sub nsw <16 x i32> %22, %24
+  %26 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %25, i1 true)
+  %27 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %26)
+  %op.rdx.3 = add i32 %27, %op.rdx.2
+  ret i32 %op.rdx.3
+}
+
+declare <4 x i32> @llvm.abs.v4i32(<4 x i32>, i1)
+declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
+declare <16 x i32> @llvm.abs.v16i32(<16 x i32>, i1)
+declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)



More information about the llvm-commits mailing list