[llvm] [AArch64] SimplifyDemandedBitsForTargetNode - add AArch64ISD::BICi handling (PR #76644)
Sizov Nikita via llvm-commits
llvm-commits at lists.llvm.org
Tue Mar 26 16:30:00 PDT 2024
https://github.com/snikitav updated https://github.com/llvm/llvm-project/pull/76644
>From 1e04454f2f666aeb527d52f9db78f12536a9eeb1 Mon Sep 17 00:00:00 2001
From: Sizov Nikita <s.nikita.v at gmail.com>
Date: Sun, 31 Dec 2023 04:33:21 +0300
Subject: [PATCH] Missing AArch64ISD::BICi handling
---
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 15 ++++--
.../Target/AArch64/AArch64ISelLowering.cpp | 31 ++++++++++++
.../AArch64/aarch64-known-bits-hadd.ll | 4 --
.../CodeGen/AArch64SelectionDAGTest.cpp | 49 +++++++++++++++++++
4 files changed, 90 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index e8d1ac1d3a9167..e3b76b95eb86ad 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3419,13 +3419,18 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
Known = KnownBits::mulhs(Known, Known2);
break;
}
- case ISD::AVGCEILU: {
+ case ISD::AVGFLOORU:
+ case ISD::AVGCEILU:
+ case ISD::AVGFLOORS:
+ case ISD::AVGCEILS: {
+ bool IsCeil = Opcode == ISD::AVGCEILU || Opcode == ISD::AVGCEILS;
+ bool IsSigned = Opcode == ISD::AVGFLOORS || Opcode == ISD::AVGCEILS;
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
- Known = Known.zext(BitWidth + 1);
- Known2 = Known2.zext(BitWidth + 1);
- KnownBits One = KnownBits::makeConstant(APInt(1, 1));
- Known = KnownBits::computeForAddCarry(Known, Known2, One);
+ Known = IsSigned ? Known.sext(BitWidth + 1) : Known.zext(BitWidth + 1);
+ Known2 = IsSigned ? Known2.sext(BitWidth + 1) : Known2.zext(BitWidth + 1);
+ KnownBits Carry = KnownBits::makeConstant(APInt(1, IsCeil ? 1 : 0));
+ Known = KnownBits::computeForAddCarry(Known, Known2, Carry);
Known = Known.extractBits(BitWidth, 1);
break;
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f552f91929201c..6fb3ad7a8d2e02 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24580,6 +24580,19 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
if (auto R = foldOverflowCheck(N, DAG, /* IsAdd */ false))
return R;
return performFlagSettingCombine(N, DCI, AArch64ISD::SBC);
+ case AArch64ISD::BICi: {
+ KnownBits Known;
+ APInt DemandedBits =
+ APInt::getAllOnes(N->getValueType(0).getScalarSizeInBits());
+ APInt DemandedElts =
+ APInt::getAllOnes(N->getValueType(0).getVectorNumElements());
+ TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
+ !DCI.isBeforeLegalizeOps());
+ if (DAG.getTargetLoweringInfo().SimplifyDemandedBits(
+ SDValue(N, 0), DemandedBits, DemandedElts, Known, TLO))
+ return TLO.New;
+ break;
+ }
case ISD::XOR:
return performXorCombine(N, DAG, DCI, Subtarget);
case ISD::MUL:
@@ -27620,6 +27633,24 @@ bool AArch64TargetLowering::SimplifyDemandedBitsForTargetNode(
// used - simplify to just Val.
return TLO.CombineTo(Op, ShiftR->getOperand(0));
}
+ case AArch64ISD::BICi: {
+ // Fold BICi if all destination bits already known to be zeroed
+ SDValue Op0 = Op.getOperand(0);
+ KnownBits KnownOp0 =
+ TLO.DAG.computeKnownBits(Op0, OriginalDemandedElts, Depth + 1);
+ // Op0 &= ~(ConstantOperandVal(1) << ConstantOperandVal(2))
+ uint64_t BitsToClear = Op->getConstantOperandVal(1)
+ << Op->getConstantOperandVal(2);
+ APInt AlreadyZeroedBitsToClear = BitsToClear & KnownOp0.Zero;
+ if (APInt(Known.getBitWidth(), BitsToClear)
+ .isSubsetOf(AlreadyZeroedBitsToClear))
+ return TLO.CombineTo(Op, Op0);
+
+ Known = KnownOp0 &
+ KnownBits::makeConstant(APInt(Known.getBitWidth(), ~BitsToClear));
+
+ return false;
+ }
case ISD::INTRINSIC_WO_CHAIN: {
if (auto ElementSize = IsSVECntIntrinsic(Op)) {
unsigned MaxSVEVectorSizeInBits = Subtarget->getMaxSVEVectorSizeInBits();
diff --git a/llvm/test/CodeGen/AArch64/aarch64-known-bits-hadd.ll b/llvm/test/CodeGen/AArch64/aarch64-known-bits-hadd.ll
index 017f382774892c..f36b8440fe4bfb 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-known-bits-hadd.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-known-bits-hadd.ll
@@ -12,7 +12,6 @@ define <8 x i16> @haddu_zext(<8 x i8> %a0, <8 x i8> %a1) {
; CHECK-NEXT: ushll v0.8h, v0.8b, #0
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NEXT: uhadd v0.8h, v0.8h, v1.8h
-; CHECK-NEXT: bic v0.8h, #254, lsl #8
; CHECK-NEXT: ret
%x0 = zext <8 x i8> %a0 to <8 x i16>
%x1 = zext <8 x i8> %a1 to <8 x i16>
@@ -27,7 +26,6 @@ define <8 x i16> @rhaddu_zext(<8 x i8> %a0, <8 x i8> %a1) {
; CHECK-NEXT: ushll v0.8h, v0.8b, #0
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NEXT: urhadd v0.8h, v0.8h, v1.8h
-; CHECK-NEXT: bic v0.8h, #254, lsl #8
; CHECK-NEXT: ret
%x0 = zext <8 x i8> %a0 to <8 x i16>
%x1 = zext <8 x i8> %a1 to <8 x i16>
@@ -42,7 +40,6 @@ define <8 x i16> @hadds_zext(<8 x i8> %a0, <8 x i8> %a1) {
; CHECK-NEXT: ushll v0.8h, v0.8b, #0
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NEXT: shadd v0.8h, v0.8h, v1.8h
-; CHECK-NEXT: bic v0.8h, #254, lsl #8
; CHECK-NEXT: ret
%x0 = zext <8 x i8> %a0 to <8 x i16>
%x1 = zext <8 x i8> %a1 to <8 x i16>
@@ -57,7 +54,6 @@ define <8 x i16> @shaddu_zext(<8 x i8> %a0, <8 x i8> %a1) {
; CHECK-NEXT: ushll v0.8h, v0.8b, #0
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NEXT: srhadd v0.8h, v0.8h, v1.8h
-; CHECK-NEXT: bic v0.8h, #254, lsl #8
; CHECK-NEXT: ret
%x0 = zext <8 x i8> %a0 to <8 x i16>
%x1 = zext <8 x i8> %a1 to <8 x i16>
diff --git a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
index e0772684e3a954..dd1b5e9c9b5a3a 100644
--- a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
+++ b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
@@ -9,6 +9,7 @@
#include "llvm/Analysis/MemoryLocation.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/AsmParser/Parser.h"
+#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/TargetLowering.h"
@@ -796,4 +797,52 @@ TEST_F(AArch64SelectionDAGTest, computeKnownBits_extload_knownnegative) {
EXPECT_EQ(Known.One, APInt(32, 0xfffffff0));
}
+TEST_F(AArch64SelectionDAGTest,
+ computeKnownBits_AVGFLOORU_AVGFLOORS_AVGCEILU_AVGCEILS) {
+ SDLoc Loc;
+ auto Int8VT = EVT::getIntegerVT(Context, 8);
+ auto Int16VT = EVT::getIntegerVT(Context, 16);
+ auto Int8Vec8VT = EVT::getVectorVT(Context, Int8VT, 8);
+ auto Int16Vec8VT = EVT::getVectorVT(Context, Int16VT, 8);
+
+ SDValue UnknownOp0 = DAG->getRegister(0, Int8Vec8VT);
+ SDValue UnknownOp1 = DAG->getRegister(1, Int8Vec8VT);
+
+ SDValue ZextOp0 =
+ DAG->getNode(ISD::ZERO_EXTEND, Loc, Int16Vec8VT, UnknownOp0);
+ SDValue ZextOp1 =
+ DAG->getNode(ISD::ZERO_EXTEND, Loc, Int16Vec8VT, UnknownOp1);
+ // ZextOp0 = 00000000????????
+ // ZextOp1 = 00000000????????
+ // => (for all AVG* instructions)
+ // Known.Zero = 1111111100000000 (0xFF00)
+ // Known.One = 0000000000000000 (0x0000)
+ auto Zeroes = APInt(16, 0xFF00);
+ auto Ones = APInt(16, 0x0000);
+
+ SDValue AVGFLOORU =
+ DAG->getNode(ISD::AVGFLOORU, Loc, Int16Vec8VT, ZextOp0, ZextOp1);
+ KnownBits KnownAVGFLOORU = DAG->computeKnownBits(AVGFLOORU);
+ EXPECT_EQ(KnownAVGFLOORU.Zero, Zeroes);
+ EXPECT_EQ(KnownAVGFLOORU.One, Ones);
+
+ SDValue AVGFLOORS =
+ DAG->getNode(ISD::AVGFLOORU, Loc, Int16Vec8VT, ZextOp0, ZextOp1);
+ KnownBits KnownAVGFLOORS = DAG->computeKnownBits(AVGFLOORS);
+ EXPECT_EQ(KnownAVGFLOORS.Zero, Zeroes);
+ EXPECT_EQ(KnownAVGFLOORS.One, Ones);
+
+ SDValue AVGCEILU =
+ DAG->getNode(ISD::AVGCEILU, Loc, Int16Vec8VT, ZextOp0, ZextOp1);
+ KnownBits KnownAVGCEILU = DAG->computeKnownBits(AVGCEILU);
+ EXPECT_EQ(KnownAVGCEILU.Zero, Zeroes);
+ EXPECT_EQ(KnownAVGCEILU.One, Ones);
+
+ SDValue AVGCEILS =
+ DAG->getNode(ISD::AVGCEILS, Loc, Int16Vec8VT, ZextOp0, ZextOp1);
+ KnownBits KnownAVGCEILS = DAG->computeKnownBits(AVGCEILS);
+ EXPECT_EQ(KnownAVGCEILS.Zero, Zeroes);
+ EXPECT_EQ(KnownAVGCEILS.One, Ones);
+}
+
} // end namespace llvm
More information about the llvm-commits
mailing list