[llvm] [AArch64] Refactor creation of a shuffle mask for TBL (NFC) (PR #92529)
Momchil Velikov via llvm-commits
llvm-commits at lists.llvm.org
Tue May 21 08:06:40 PDT 2024
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/92529
>From 27e9e12db3626f681433a423f434adcee5323f5b Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Tue, 21 May 2024 15:07:17 +0100
Subject: [PATCH 1/2] [AArch64] Add patterns for conversions using fixed-point
scvtf
Change-Id: If19131b160484aba942dbbef042fb67f0b98561d
---
.../Target/AArch64/AArch64ISelLowering.cpp | 2 +-
llvm/lib/Target/AArch64/AArch64InstrInfo.td | 25 +++++
.../AArch64/fixed-point-conv-vec-pat.ll | 103 ++++++++++++++++++
3 files changed, 129 insertions(+), 1 deletion(-)
create mode 100644 llvm/test/CodeGen/AArch64/fixed-point-conv-vec-pat.ll
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e31a27e9428e8..dd422214d45f4 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -14328,7 +14328,7 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
unsigned Opc =
(Op.getOpcode() == ISD::SRA) ? AArch64ISD::VASHR : AArch64ISD::VLSHR;
return DAG.getNode(Opc, DL, VT, Op.getOperand(0),
- DAG.getConstant(Cnt, DL, MVT::i32));
+ DAG.getConstant(Cnt, DL, MVT::i32), Op->getFlags());
}
// Right shift register. Note, there is not a shift right register
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index a39e3b7be76dc..291f553776752 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -735,6 +735,12 @@ def AArch64rev64 : SDNode<"AArch64ISD::REV64", SDT_AArch64UnaryVec>;
def AArch64ext : SDNode<"AArch64ISD::EXT", SDT_AArch64ExtVec>;
def AArch64vashr : SDNode<"AArch64ISD::VASHR", SDT_AArch64vshift>;
+
+def AArch64vashr_exact : PatFrag<(ops node:$lhs, node:$rhs),
+ (AArch64vashr node:$lhs, node:$rhs), [{
+ return N->getFlags().hasExact();
+}]>;
+
def AArch64vlshr : SDNode<"AArch64ISD::VLSHR", SDT_AArch64vshift>;
def AArch64vshl : SDNode<"AArch64ISD::VSHL", SDT_AArch64vshift>;
def AArch64sqshli : SDNode<"AArch64ISD::SQSHL_I", SDT_AArch64vshift>;
@@ -7712,6 +7718,25 @@ defm SCVTF: SIMDVectorRShiftToFP<0, 0b11100, "scvtf",
defm RSHRN : SIMDVectorRShiftNarrowBHS<0, 0b10001, "rshrn", AArch64rshrn>;
defm SHL : SIMDVectorLShiftBHSD<0, 0b01010, "shl", AArch64vshl>;
+let Predicates = [HasNEON] in {
+def : Pat<(v2f32 (sint_to_fp (v2i32 (AArch64vashr_exact v2i32:$Vn, i32:$shift)))),
+ (SCVTFv2i32_shift $Vn, vecshiftR32:$shift)>;
+
+def : Pat<(v4f32 (sint_to_fp (v4i32 (AArch64vashr_exact v4i32:$Vn, i32:$shift)))),
+ (SCVTFv4i32_shift $Vn, vecshiftR32:$shift)>;
+
+def : Pat<(v2f64 (sint_to_fp (v2i64 (AArch64vashr_exact v2i64:$Vn, i32:$shift)))),
+ (SCVTFv2i64_shift $Vn, vecshiftR64:$shift)>;
+}
+
+let Predicates = [HasNEON, HasFullFP16] in {
+def : Pat<(v4f16 (sint_to_fp (v4i16 (AArch64vashr_exact v4i16:$Vn, i32:$shift)))),
+ (SCVTFv4i16_shift $Vn, vecshiftR16:$shift)>;
+
+def : Pat<(v8f16 (sint_to_fp (v8i16 (AArch64vashr_exact v8i16:$Vn, i32:$shift)))),
+ (SCVTFv8i16_shift $Vn, vecshiftR16:$shift)>;
+}
+
// X << 1 ==> X + X
class SHLToADDPat<ValueType ty, RegisterClass regtype>
: Pat<(ty (AArch64vshl (ty regtype:$Rn), (i32 1))),
diff --git a/llvm/test/CodeGen/AArch64/fixed-point-conv-vec-pat.ll b/llvm/test/CodeGen/AArch64/fixed-point-conv-vec-pat.ll
new file mode 100644
index 0000000000000..7141b5b03a1ac
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/fixed-point-conv-vec-pat.ll
@@ -0,0 +1,103 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+
+target triple = "aarch64"
+
+; First some corner cases
+define <4 x float> @f_v4_s0(<4 x i32> %u) {
+; CHECK-LABEL: f_v4_s0:
+; CHECK: // %bb.0:
+; CHECK-NEXT: scvtf v0.4s, v0.4s
+; CHECK-NEXT: ret
+ %s = ashr exact <4 x i32> %u, <i32 0, i32 0, i32 0, i32 0>
+ %v = sitofp <4 x i32> %s to <4 x float>
+ ret <4 x float> %v
+}
+
+define <4 x float> @f_v4_s1(<4 x i32> %u) {
+; CHECK-LABEL: f_v4_s1:
+; CHECK: // %bb.0:
+; CHECK-NEXT: scvtf v0.4s, v0.4s, #1
+; CHECK-NEXT: ret
+ %s = ashr exact <4 x i32> %u, <i32 1, i32 1, i32 1, i32 1>
+ %v = sitofp <4 x i32> %s to <4 x float>
+ ret <4 x float> %v
+}
+
+define <4 x float> @f_v4_s24_inexact(<4 x i32> %u) {
+; CHECK-LABEL: f_v4_s24_inexact:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sshr v0.4s, v0.4s, #24
+; CHECK-NEXT: scvtf v0.4s, v0.4s
+; CHECK-NEXT: ret
+ %s = ashr <4 x i32> %u, <i32 24, i32 24, i32 24, i32 24>
+ %v = sitofp <4 x i32> %s to <4 x float>
+ ret <4 x float> %v
+}
+
+define <4 x float> @f_v4_s32(<4 x i32> %u) {
+; CHECK-LABEL: f_v4_s32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v0.2d, #0000000000000000
+; CHECK-NEXT: ret
+ %s = ashr <4 x i32> %u, <i32 32, i32 32, i32 32, i32 32>
+ %v = sitofp <4 x i32> %s to <4 x float>
+ ret <4 x float> %v
+}
+
+; Common cases for conversion from signed integer to floating point types
+define <2 x float> @f_v2_s24(<2 x i32> %u) {
+; CHECK-LABEL: f_v2_s24:
+; CHECK: // %bb.0:
+; CHECK-NEXT: scvtf v0.2s, v0.2s, #24
+; CHECK-NEXT: ret
+ %s = ashr exact <2 x i32> %u, <i32 24, i32 24>
+ %v = sitofp <2 x i32> %s to <2 x float>
+ ret <2 x float> %v
+}
+
+define <4 x float> @f_v4_s24(<4 x i32> %u) {
+; CHECK-LABEL: f_v4_s24:
+; CHECK: // %bb.0:
+; CHECK-NEXT: scvtf v0.4s, v0.4s, #24
+; CHECK-NEXT: ret
+ %s = ashr exact <4 x i32> %u, <i32 24, i32 24, i32 24, i32 24>
+ %v = sitofp <4 x i32> %s to <4 x float>
+ ret <4 x float> %v
+}
+
+; Check legalisation to <2 x f64> does not get in the way
+define <8 x double> @d_v8_s64(<8 x i64> %u) {
+; CHECK-LABEL: d_v8_s64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: scvtf v0.2d, v0.2d, #56
+; CHECK-NEXT: scvtf v1.2d, v1.2d, #56
+; CHECK-NEXT: scvtf v2.2d, v2.2d, #56
+; CHECK-NEXT: scvtf v3.2d, v3.2d, #56
+; CHECK-NEXT: ret
+ %s = ashr exact <8 x i64> %u, <i64 56, i64 56, i64 56, i64 56, i64 56, i64 56, i64 56, i64 56>
+ %v = sitofp <8 x i64> %s to <8 x double>
+ ret <8 x double> %v
+}
+
+define <4 x half> @h_v4_s8(<4 x i16> %u) #0 {
+; CHECK-LABEL: h_v4_s8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: scvtf v0.4h, v0.4h, #8
+; CHECK-NEXT: ret
+ %s = ashr exact <4 x i16> %u, <i16 8, i16 8, i16 8, i16 8>
+ %v = sitofp <4 x i16> %s to <4 x half>
+ ret <4 x half> %v
+}
+
+define <8 x half> @h_v8_s8(<8 x i16> %u) #0 {
+; CHECK-LABEL: h_v8_s8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: scvtf v0.8h, v0.8h, #8
+; CHECK-NEXT: ret
+ %s = ashr exact <8 x i16> %u, <i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8, i16 8>
+ %v = sitofp <8 x i16> %s to <8 x half>
+ ret <8 x half> %v
+}
+
+attributes #0 = { "target-features"="+fullfp16"}
>From bb2a1e3bcd693f70282dacf1280b54b7f99fef4f Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 17 May 2024 11:50:31 +0100
Subject: [PATCH 2/2] [AArch64] Refactor creation of a shuffle mask for TBL
(NFC)
---
.../Target/AArch64/AArch64ISelLowering.cpp | 88 +++++++++++--------
1 file changed, 50 insertions(+), 38 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index dd422214d45f4..e91493fbf6dbe 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -15710,48 +15710,51 @@ bool AArch64TargetLowering::shouldSinkOperands(
return false;
}
-static bool createTblShuffleForZExt(ZExtInst *ZExt, FixedVectorType *DstTy,
- bool IsLittleEndian) {
- Value *Op = ZExt->getOperand(0);
- auto *SrcTy = cast<FixedVectorType>(Op->getType());
- auto SrcWidth = cast<IntegerType>(SrcTy->getElementType())->getBitWidth();
- auto DstWidth = cast<IntegerType>(DstTy->getElementType())->getBitWidth();
+static bool createTblShuffleMask(unsigned SrcWidth, unsigned DstWidth,
+ unsigned NumElts, bool IsLittleEndian,
+ SmallVectorImpl<int> &Mask) {
if (DstWidth % 8 != 0 || DstWidth <= 16 || DstWidth >= 64)
return false;
- assert(DstWidth % SrcWidth == 0 &&
- "TBL lowering is not supported for a ZExt instruction with this "
- "source & destination element type.");
- unsigned ZExtFactor = DstWidth / SrcWidth;
+ if (DstWidth % SrcWidth != 0)
+ return false;
+
+ unsigned Factor = DstWidth / SrcWidth;
+ unsigned MaskLen = NumElts * Factor;
+
+ Mask.clear();
+ Mask.resize(MaskLen, NumElts);
+
+ unsigned SrcIndex = 0;
+ for (unsigned I = 0; I < MaskLen; I += Factor)
+ Mask[I] = SrcIndex++;
+
+ if (!IsLittleEndian)
+ std::rotate(Mask.rbegin(), Mask.rbegin() + Factor - 1, Mask.rend());
+
+ return true;
+}
+
+static Value *createTblShuffleForZExt(IRBuilderBase &Builder, Value *Op,
+ FixedVectorType *ZExtTy,
+ FixedVectorType *DstTy,
+ bool IsLittleEndian) {
+ auto *SrcTy = cast<FixedVectorType>(Op->getType());
unsigned NumElts = SrcTy->getNumElements();
- IRBuilder<> Builder(ZExt);
+ auto SrcWidth = cast<IntegerType>(SrcTy->getElementType())->getBitWidth();
+ auto DstWidth = cast<IntegerType>(DstTy->getElementType())->getBitWidth();
+
SmallVector<int> Mask;
- // Create a mask that selects <0,...,Op[i]> for each lane of the destination
- // vector to replace the original ZExt. This can later be lowered to a set of
- // tbl instructions.
- for (unsigned i = 0; i < NumElts * ZExtFactor; i++) {
- if (IsLittleEndian) {
- if (i % ZExtFactor == 0)
- Mask.push_back(i / ZExtFactor);
- else
- Mask.push_back(NumElts);
- } else {
- if ((i + 1) % ZExtFactor == 0)
- Mask.push_back((i - ZExtFactor + 1) / ZExtFactor);
- else
- Mask.push_back(NumElts);
- }
- }
+ if (!createTblShuffleMask(SrcWidth, DstWidth, NumElts, IsLittleEndian, Mask))
+ return nullptr;
auto *FirstEltZero = Builder.CreateInsertElement(
PoisonValue::get(SrcTy), Builder.getInt8(0), uint64_t(0));
Value *Result = Builder.CreateShuffleVector(Op, FirstEltZero, Mask);
Result = Builder.CreateBitCast(Result, DstTy);
- if (DstTy != ZExt->getType())
- Result = Builder.CreateZExt(Result, ZExt->getType());
- ZExt->replaceAllUsesWith(Result);
- ZExt->eraseFromParent();
- return true;
+ if (DstTy != ZExtTy)
+ Result = Builder.CreateZExt(Result, ZExtTy);
+ return Result;
}
static void createTblForTrunc(TruncInst *TI, bool IsLittleEndian) {
@@ -15916,21 +15919,30 @@ bool AArch64TargetLowering::optimizeExtendOrTruncateConversion(
DstTy = TruncDstType;
}
-
- return createTblShuffleForZExt(ZExt, DstTy, Subtarget->isLittleEndian());
+ IRBuilder<> Builder(ZExt);
+ Value *Result = createTblShuffleForZExt(
+ Builder, ZExt->getOperand(0), cast<FixedVectorType>(ZExt->getType()),
+ DstTy, Subtarget->isLittleEndian());
+ if (!Result)
+ return false;
+ ZExt->replaceAllUsesWith(Result);
+ ZExt->eraseFromParent();
+ return true;
}
auto *UIToFP = dyn_cast<UIToFPInst>(I);
if (UIToFP && SrcTy->getElementType()->isIntegerTy(8) &&
DstTy->getElementType()->isFloatTy()) {
IRBuilder<> Builder(I);
- auto *ZExt = cast<ZExtInst>(
- Builder.CreateZExt(I->getOperand(0), VectorType::getInteger(DstTy)));
+ Value *ZExt = createTblShuffleForZExt(
+ Builder, I->getOperand(0), FixedVectorType::getInteger(DstTy),
+ FixedVectorType::getInteger(DstTy), Subtarget->isLittleEndian());
+ if (!ZExt)
+ return false;
auto *UI = Builder.CreateUIToFP(ZExt, DstTy);
I->replaceAllUsesWith(UI);
I->eraseFromParent();
- return createTblShuffleForZExt(ZExt, cast<FixedVectorType>(ZExt->getType()),
- Subtarget->isLittleEndian());
+ return true;
}
// Convert 'fptoui <(8|16) x float> to <(8|16) x i8>' to a wide fptoui
More information about the llvm-commits
mailing list