[llvm] Missing AArch64ISD::BIC & AArch64ISD::BICi handling (PR #76644)

Sizov Nikita via llvm-commits llvm-commits at lists.llvm.org
Sat Dec 30 18:12:02 PST 2023


https://github.com/snikitav created https://github.com/llvm/llvm-project/pull/76644

Fold BICi if all destination bits are already known to be zeroes

```llvm
define <8 x i16> @haddu_known(<8 x i8> %a0, <8 x i8> %a1) {
  %x0 = zext <8 x i8> %a0 to <8 x i16>
  %x1 = zext <8 x i8> %a1 to <8 x i16>
  %hadd = call <8 x i16> @llvm.aarch64.neon.uhadd.v8i16(<8 x i16> %x0, <8 x i16> %x1)
  %res = and <8 x i16> %hadd, <i16 511, i16 511, i16 511, i16 511,i16 511, i16 511, i16 511, i16 511>
  ret <8 x i16> %res
}
declare <8 x i16> @llvm.aarch64.neon.uhadd.v8i16(<8 x i16>, <8 x i16>)
```

```
haddu_known:                            // @haddu_known
        ushll   v0.8h, v0.8b, #0
        ushll   v1.8h, v1.8b, #0
        uhadd   v0.8h, v0.8h, v1.8h
        bic     v0.8h, #254, lsl #8 <-- this one will be removed as we know high bits are zero extended
        ret
```

>From 08f70b59fa396d6f0fac3ebe9c9bf71b11723873 Mon Sep 17 00:00:00 2001
From: Sizov Nikita <s.nikita.v at gmail.com>
Date: Sun, 31 Dec 2023 04:33:21 +0300
Subject: [PATCH] SimplifyDemandedBitsForTargetNode - Missing AArch64ISD::BIC &
 AArch64ISD::BICi handling

---
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |   9 +-
 .../Target/AArch64/AArch64ISelLowering.cpp    |  24 ++++
 .../test/CodeGen/AArch64/aarch64-neon-bici.ll | 125 ++++++++++++++++++
 3 files changed, 156 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/aarch64-neon-bici.ll

diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 81facf92e55ae9..c79d3d8246daaf 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3390,10 +3390,15 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
       Known = KnownBits::mulhs(Known, Known2);
     break;
   }
-  case ISD::AVGCEILU: {
+  case ISD::AVGFLOORU:
+  case ISD::AVGCEILU:
+  case ISD::AVGFLOORS:
+  case ISD::AVGCEILS: {
     Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
     Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
-    Known = Known.zext(BitWidth + 1);
+    Known = (Opcode == ISD::AVGFLOORU || Opcode == ISD::AVGCEILU)
+                ? Known.zext(BitWidth + 1)
+                : Known.sext(BitWidth + 1);
     Known2 = Known2.zext(BitWidth + 1);
     KnownBits One = KnownBits::makeConstant(APInt(1, 1));
     Known = KnownBits::computeForAddCarry(Known, Known2, One);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index dffe69bdb900db..c5f906aa7473f7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -23672,6 +23672,18 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     if (auto R = foldOverflowCheck(N, DAG, /* IsAdd */ false))
       return R;
     return performFlagSettingCombine(N, DCI, AArch64ISD::SBC);
+  case AArch64ISD::BICi: {
+    KnownBits Known;
+    APInt DemandedElts(32, N->getValueType(0).getVectorNumElements());
+    APInt EltSize(32, N->getValueType(0).getScalarSizeInBits());
+    TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
+                                          !DCI.isBeforeLegalizeOps());
+    if (SimplifyDemandedBitsForTargetNode(SDValue(N, 0), EltSize, DemandedElts,
+                                          Known, TLO, 0)) {
+      return TLO.New;
+    }
+    break;
+  }
   case ISD::XOR:
     return performXorCombine(N, DAG, DCI, Subtarget);
   case ISD::MUL:
@@ -26658,6 +26670,18 @@ bool AArch64TargetLowering::SimplifyDemandedBitsForTargetNode(
     // used - simplify to just Val.
     return TLO.CombineTo(Op, ShiftR->getOperand(0));
   }
+  case AArch64ISD::BICi: {
+    // Fold BICi if all destination bits already known to be zeroed
+    SDValue Op0 = Op.getOperand(0);
+    KnownBits KnownOp0 = TLO.DAG.computeKnownBits(Op0, 0);
+    APInt Shift = Op.getConstantOperandAPInt(2);
+    APInt Op1Val = Op.getConstantOperandAPInt(1);
+    APInt BitsToClear = Op1Val.shl(Shift).zextOrTrunc(KnownOp0.getBitWidth());
+    APInt AlreadyZeroedBitsToClear = BitsToClear & KnownOp0.Zero;
+    if (AlreadyZeroedBitsToClear == BitsToClear)
+      return TLO.CombineTo(Op, Op0);
+    return false;
+  }
   case ISD::INTRINSIC_WO_CHAIN: {
     if (auto ElementSize = IsSVECntIntrinsic(Op)) {
       unsigned MaxSVEVectorSizeInBits = Subtarget->getMaxSVEVectorSizeInBits();
diff --git a/llvm/test/CodeGen/AArch64/aarch64-neon-bici.ll b/llvm/test/CodeGen/AArch64/aarch64-neon-bici.ll
new file mode 100644
index 00000000000000..55ac9c2f1b4075
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/aarch64-neon-bici.ll
@@ -0,0 +1,125 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64-neon < %s | FileCheck %s
+
+declare <8 x i16> @llvm.aarch64.neon.uhadd.v8i16(<8 x i16>, <8 x i16>)
+declare <8 x i16> @llvm.aarch64.neon.urhadd.v8i16(<8 x i16>, <8 x i16>)
+declare <8 x i16> @llvm.aarch64.neon.shadd.v8i16(<8 x i16>, <8 x i16>)
+declare <8 x i16> @llvm.aarch64.neon.srhadd.v8i16(<8 x i16>, <8 x i16>)
+
+define <8 x i16> @haddu_zext(<8 x i8> %a0, <8 x i8> %a1) {
+; CHECK-LABEL: haddu_zext:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    uhadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    ret
+  %x0 = zext <8 x i8> %a0 to <8 x i16>
+  %x1 = zext <8 x i8> %a1 to <8 x i16>
+  %hadd = call <8 x i16> @llvm.aarch64.neon.uhadd.v8i16(<8 x i16> %x0, <8 x i16> %x1)
+  %res = and <8 x i16> %hadd, <i16 511, i16 511, i16 511, i16 511,i16 511, i16 511, i16 511, i16 511>
+  ret <8 x i16> %res
+}
+
+define <8 x i16> @rhaddu_zext(<8 x i8> %a0, <8 x i8> %a1) {
+; CHECK-LABEL: rhaddu_zext:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    urhadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    ret
+  %x0 = zext <8 x i8> %a0 to <8 x i16>
+  %x1 = zext <8 x i8> %a1 to <8 x i16>
+  %hadd = call <8 x i16> @llvm.aarch64.neon.urhadd.v8i16(<8 x i16> %x0, <8 x i16> %x1)
+  %res = and <8 x i16> %hadd, <i16 511, i16 511, i16 511, i16 511, i16 511, i16 511, i16 511, i16 511>
+  ret <8 x i16> %res
+}
+
+define <8 x i16> @hadds_zext(<8 x i8> %a0, <8 x i8> %a1) {
+; CHECK-LABEL: hadds_zext:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    shadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    ret
+  %x0 = zext <8 x i8> %a0 to <8 x i16>
+  %x1 = zext <8 x i8> %a1 to <8 x i16>
+  %hadd = call <8 x i16> @llvm.aarch64.neon.shadd.v8i16(<8 x i16> %x0, <8 x i16> %x1)
+  %res = and <8 x i16> %hadd, <i16 511, i16 511, i16 511, i16 511, i16 511, i16 511, i16 511, i16 511>
+  ret <8 x i16> %res
+}
+
+define <8 x i16> @shaddu_zext(<8 x i8> %a0, <8 x i8> %a1) {
+; CHECK-LABEL: shaddu_zext:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    srhadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    ret
+  %x0 = zext <8 x i8> %a0 to <8 x i16>
+  %x1 = zext <8 x i8> %a1 to <8 x i16>
+  %hadd = call <8 x i16> @llvm.aarch64.neon.srhadd.v8i16(<8 x i16> %x0, <8 x i16> %x1)
+  %res = and <8 x i16> %hadd, <i16 511, i16 511, i16 511, i16 511, i16 511, i16 511, i16 511, i16 511>
+  ret <8 x i16> %res
+}
+
+; ; negative tests
+
+define <8 x i16> @haddu_sext(<8 x i8> %a0, <8 x i8> %a1) {
+; CHECK-LABEL: haddu_sext:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
+; CHECK-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-NEXT:    uhadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    bic v0.8h, #254, lsl #8
+; CHECK-NEXT:    ret
+  %x0 = sext <8 x i8> %a0 to <8 x i16>
+  %x1 = sext <8 x i8> %a1 to <8 x i16>
+  %hadd = call <8 x i16> @llvm.aarch64.neon.uhadd.v8i16(<8 x i16> %x0, <8 x i16> %x1)
+  %res = and <8 x i16> %hadd, <i16 511, i16 511, i16 511, i16 511,i16 511, i16 511, i16 511, i16 511>
+  ret <8 x i16> %res
+}
+
+define <8 x i16> @urhadd_sext(<8 x i8> %a0, <8 x i8> %a1) {
+; CHECK-LABEL: urhadd_sext:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
+; CHECK-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-NEXT:    urhadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    bic v0.8h, #254, lsl #8
+; CHECK-NEXT:    ret
+  %x0 = sext <8 x i8> %a0 to <8 x i16>
+  %x1 = sext <8 x i8> %a1 to <8 x i16>
+  %hadd = call <8 x i16> @llvm.aarch64.neon.urhadd.v8i16(<8 x i16> %x0, <8 x i16> %x1)
+  %res = and <8 x i16> %hadd, <i16 511, i16 511, i16 511, i16 511,i16 511, i16 511, i16 511, i16 511>
+  ret <8 x i16> %res
+}
+
+define <8 x i16> @hadds_sext(<8 x i8> %a0, <8 x i8> %a1) {
+; CHECK-LABEL: hadds_sext:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
+; CHECK-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-NEXT:    shadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    bic v0.8h, #254, lsl #8
+; CHECK-NEXT:    ret
+  %x0 = sext <8 x i8> %a0 to <8 x i16>
+  %x1 = sext <8 x i8> %a1 to <8 x i16>
+  %hadd = call <8 x i16> @llvm.aarch64.neon.shadd.v8i16(<8 x i16> %x0, <8 x i16> %x1)
+  %res = and <8 x i16> %hadd, <i16 511, i16 511, i16 511, i16 511, i16 511, i16 511, i16 511, i16 511>
+  ret <8 x i16> %res
+}
+
+define <8 x i16> @shaddu_sext(<8 x i8> %a0, <8 x i8> %a1) {
+; CHECK-LABEL: shaddu_sext:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
+; CHECK-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-NEXT:    srhadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    bic v0.8h, #254, lsl #8
+; CHECK-NEXT:    ret
+  %x0 = sext <8 x i8> %a0 to <8 x i16>
+  %x1 = sext <8 x i8> %a1 to <8 x i16>
+  %hadd = call <8 x i16> @llvm.aarch64.neon.srhadd.v8i16(<8 x i16> %x0, <8 x i16> %x1)
+  %res = and <8 x i16> %hadd, <i16 511, i16 511, i16 511, i16 511, i16 511, i16 511, i16 511, i16 511>
+  ret <8 x i16> %res
+}



More information about the llvm-commits mailing list