[llvm] [AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2) (PR #141480)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 17 09:42:59 PDT 2025
davemgreen wrote:
Hi - sorry for the delay. My plan was to wait until after the release and then get this committed and see how it does. I'm still not 100% on whether it might accidentally handle a node incorrectly if it has a constant operand.
I was running some extra tests though and ran into this case, where this particular constant gets lowered to an movi I think, but the operand is an i32 (due to type legalization) and the fp type is a fp16. It just needs to use `FVal = APFloat(APFloat::IEEEhalf(), Imm.trunc(16));` for fp16 values.
```
define <4 x half> @test_v4f16_div_const_0xH1c04(<4 x i16> %in) {
; CHECK-LABEL: test_v4f16_div_const_0xH1c04:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov w8, #7172 // =0x1c04
; CHECK-NEXT: scvtf.4h v0, v0
; CHECK-NEXT: dup.4h v1, w8
; CHECK-NEXT: fmul.4h v0, v0, v1
; CHECK-NEXT: ret
entry:
%vcvt.i = sitofp <4 x i16> %in to <4 x half>
%div.i = fmul <4 x half> %vcvt.i, <half 0xH1c04, half 0xH1c04, half 0xH1c04, half 0xH1c04>
ret <4 x half> %div.i
}
```
I would also recommend changing the start of the function to something like the code below, just to make it super clear how we derive the immediate for each opcode and it is obvious which opcode leads to what.
```
static bool checkCVTFixedPointOperandWithFBitsForVectors(SelectionDAG *CurDAG,
SDValue N,
SDValue &FixedPos,
unsigned FloatWidth,
bool IsReciprocal) {
SDValue ImmediateNode = N;
if (N.getOpcode() == ISD::BITCAST || N.getOpcode() == AArch64ISD::NVCAST) {
ImmediateNode = N.getOperand(0);
// This could have been a bitcast to a scalar
if (!ImmediateNode.getValueType().isVector())
return false;
}
APInt Imm;
if (ImmediateNode.getOpcode() == ISD::BUILD_VECTOR) {
// For BUILD_VECTOR, we must explicitly check if it's a constant splat.
BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(ImmediateNode.getNode());
APInt SplatUndef;
unsigned SplatBitSize;
bool HasAnyUndefs;
if (!BVN->isConstantSplat(Imm, SplatUndef, SplatBitSize, HasAnyUndefs) ||
SplatBitSize != N.getValueType().getScalarSizeInBits())
return false;
} else if (ImmediateNode.getOpcode() == AArch64ISD::MOVIshift) {
EVT NodeVT = N.getValueType();
Imm = APInt(NodeVT.getScalarSizeInBits(),
ImmediateNode.getConstantOperandVal(0)
<< ImmediateNode.getConstantOperandVal(1));
} else if (ImmediateNode.getOpcode() == AArch64ISD::FMOV) {
uint8_t EncodedU8 = ImmediateNode.getConstantOperandVal(0);
uint64_t DecodedBits = AArch64_AM::decodeAdvSIMDModImmType11(EncodedU8);
unsigned BitWidth = N.getValueType().getVectorElementType().getSizeInBits();
uint64_t Mask = (BitWidth == 64) ? ~0ULL : ((1ULL << BitWidth) - 1);
uint64_t MaskedBits = DecodedBits & Mask;
Imm = APInt(BitWidth, MaskedBits);
} else if (ImmediateNode.getOpcode() != AArch64ISD::DUP ||
ImmediateNode.getOpcode() != ISD::SPLAT_VECTOR) {
auto *CI = dyn_cast<ConstantSDNode>(ImmediateNode.getOperand(0));
if (!CI)
return false;
Imm = CI->getAPIntValue();
} else {
return false;
}
```
https://github.com/llvm/llvm-project/pull/141480
More information about the llvm-commits
mailing list