[llvm] 534b26a - [Hexagon] Improve inserting/extracting to/from scalar predicates
Krzysztof Parzyszek via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 17 13:08:20 PST 2022
Author: Krzysztof Parzyszek
Date: 2022-11-17T13:03:45-08:00
New Revision: 534b26aa07675d7c4b579b1179e07ddf5e880d17
URL: https://github.com/llvm/llvm-project/commit/534b26aa07675d7c4b579b1179e07ddf5e880d17
DIFF: https://github.com/llvm/llvm-project/commit/534b26aa07675d7c4b579b1179e07ddf5e880d17.diff
LOG: [Hexagon] Improve inserting/extracting to/from scalar predicates
Fixes https://github.com/llvm/llvm-project/issues/59042.
Added:
llvm/test/CodeGen/Hexagon/isel-extract-pred.ll
llvm/test/CodeGen/Hexagon/isel-insert-pred.ll
Modified:
llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
llvm/lib/Target/Hexagon/HexagonISelLowering.h
Removed:
################################################################################
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
index 11a57b5d2faf..5e6e0238438d 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
@@ -2641,60 +2641,13 @@ HexagonTargetLowering::extractVector(SDValue VecV, SDValue IdxV,
MVT VecTy = ty(VecV);
assert(!ValTy.isVector() ||
VecTy.getVectorElementType() == ValTy.getVectorElementType());
+ if (VecTy.getVectorElementType() == MVT::i1)
+ return extractVectorPred(VecV, IdxV, dl, ValTy, ResTy, DAG);
+
unsigned VecWidth = VecTy.getSizeInBits();
unsigned ValWidth = ValTy.getSizeInBits();
unsigned ElemWidth = VecTy.getVectorElementType().getSizeInBits();
assert((VecWidth % ElemWidth) == 0);
- auto *IdxN = dyn_cast<ConstantSDNode>(IdxV);
-
- // Special case for v{8,4,2}i1 (the only boolean vectors legal in Hexagon
- // without any coprocessors).
- if (ElemWidth == 1) {
- assert(VecWidth == VecTy.getVectorNumElements() &&
- "Vector elements should equal vector width size");
- assert(VecWidth == 8 || VecWidth == 4 || VecWidth == 2);
- // Check if this is an extract of the lowest bit.
- if (IdxN) {
- // Extracting the lowest bit is a no-op, but it changes the type,
- // so it must be kept as an operation to avoid errors related to
- // type mismatches.
- if (IdxN->isZero() && ValTy.getSizeInBits() == 1)
- return DAG.getNode(HexagonISD::TYPECAST, dl, MVT::i1, VecV);
- }
-
- // If the value extracted is a single bit, use tstbit.
- if (ValWidth == 1) {
- SDValue A0 = getInstr(Hexagon::C2_tfrpr, dl, MVT::i32, {VecV}, DAG);
- SDValue M0 = DAG.getConstant(8 / VecWidth, dl, MVT::i32);
- SDValue I0 = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV, M0);
- return DAG.getNode(HexagonISD::TSTBIT, dl, MVT::i1, A0, I0);
- }
-
- // Each bool vector (v2i1, v4i1, v8i1) always occupies 8 bits in
- // a predicate register. The elements of the vector are repeated
- // in the register (if necessary) so that the total number is 8.
- // The extracted subvector will need to be expanded in such a way.
- unsigned Scale = VecWidth / ValWidth;
-
- // Generate (p2d VecV) >> 8*Idx to move the interesting bytes to
- // position 0.
- assert(ty(IdxV) == MVT::i32);
- unsigned VecRep = 8 / VecWidth;
- SDValue S0 = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV,
- DAG.getConstant(8*VecRep, dl, MVT::i32));
- SDValue T0 = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, VecV);
- SDValue T1 = DAG.getNode(ISD::SRL, dl, MVT::i64, T0, S0);
- while (Scale > 1) {
- // The longest possible subvector is at most 32 bits, so it is always
- // contained in the low subregister.
- T1 = DAG.getTargetExtractSubreg(Hexagon::isub_lo, dl, MVT::i32, T1);
- T1 = expandPredicate(T1, dl, DAG);
- Scale /= 2;
- }
-
- return DAG.getNode(HexagonISD::D2P, dl, ResTy, T1);
- }
-
assert(VecWidth == 32 || VecWidth == 64);
// Cast everything to scalar integer types.
@@ -2704,7 +2657,7 @@ HexagonTargetLowering::extractVector(SDValue VecV, SDValue IdxV,
SDValue WidthV = DAG.getConstant(ValWidth, dl, MVT::i32);
SDValue ExtV;
- if (IdxN) {
+ if (auto *IdxN = dyn_cast<ConstantSDNode>(IdxV)) {
unsigned Off = IdxN->getZExtValue() * ElemWidth;
if (VecWidth == 64 && ValWidth == 32) {
assert(Off == 0 || Off == 32);
@@ -2735,36 +2688,68 @@ HexagonTargetLowering::extractVector(SDValue VecV, SDValue IdxV,
}
SDValue
-HexagonTargetLowering::insertVector(SDValue VecV, SDValue ValV, SDValue IdxV,
- const SDLoc &dl, MVT ValTy,
- SelectionDAG &DAG) const {
+HexagonTargetLowering::extractVectorPred(SDValue VecV, SDValue IdxV,
+ const SDLoc &dl, MVT ValTy, MVT ResTy,
+ SelectionDAG &DAG) const {
+ // Special case for v{8,4,2}i1 (the only boolean vectors legal in Hexagon
+ // without any coprocessors).
MVT VecTy = ty(VecV);
- if (VecTy.getVectorElementType() == MVT::i1) {
- MVT ValTy = ty(ValV);
- assert(ValTy.getVectorElementType() == MVT::i1);
- SDValue ValR = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, ValV);
- unsigned VecLen = VecTy.getVectorNumElements();
- unsigned Scale = VecLen / ValTy.getVectorNumElements();
- assert(Scale > 1);
-
- for (unsigned R = Scale; R > 1; R /= 2) {
- ValR = contractPredicate(ValR, dl, DAG);
- ValR = getCombine(DAG.getUNDEF(MVT::i32), ValR, dl, MVT::i64, DAG);
- }
+ unsigned VecWidth = VecTy.getSizeInBits();
+ unsigned ValWidth = ValTy.getSizeInBits();
+ assert(VecWidth == VecTy.getVectorNumElements() &&
+ "Vector elements should equal vector width size");
+ assert(VecWidth == 8 || VecWidth == 4 || VecWidth == 2);
+
+ // Check if this is an extract of the lowest bit.
+ if (auto *IdxN = dyn_cast<ConstantSDNode>(IdxV)) {
+ // Extracting the lowest bit is a no-op, but it changes the type,
+ // so it must be kept as an operation to avoid errors related to
+ // type mismatches.
+ if (IdxN->isZero() && ValTy.getSizeInBits() == 1)
+ return DAG.getNode(HexagonISD::TYPECAST, dl, MVT::i1, VecV);
+ }
+
+ // If the value extracted is a single bit, use tstbit.
+ if (ValWidth == 1) {
+ SDValue A0 = getInstr(Hexagon::C2_tfrpr, dl, MVT::i32, {VecV}, DAG);
+ SDValue M0 = DAG.getConstant(8 / VecWidth, dl, MVT::i32);
+ SDValue I0 = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV, M0);
+ return DAG.getNode(HexagonISD::TSTBIT, dl, MVT::i1, A0, I0);
+ }
+
+ // Each bool vector (v2i1, v4i1, v8i1) always occupies 8 bits in
+ // a predicate register. The elements of the vector are repeated
+ // in the register (if necessary) so that the total number is 8.
+ // The extracted subvector will need to be expanded in such a way.
+ unsigned Scale = VecWidth / ValWidth;
+
+ // Generate (p2d VecV) >> 8*Idx to move the interesting bytes to
+ // position 0.
+ assert(ty(IdxV) == MVT::i32);
+ unsigned VecRep = 8 / VecWidth;
+ SDValue S0 = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV,
+ DAG.getConstant(8*VecRep, dl, MVT::i32));
+ SDValue T0 = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, VecV);
+ SDValue T1 = DAG.getNode(ISD::SRL, dl, MVT::i64, T0, S0);
+ while (Scale > 1) {
// The longest possible subvector is at most 32 bits, so it is always
// contained in the low subregister.
- ValR = DAG.getTargetExtractSubreg(Hexagon::isub_lo, dl, MVT::i32, ValR);
-
- unsigned ValBytes = 64 / Scale;
- SDValue Width = DAG.getConstant(ValBytes*8, dl, MVT::i32);
- SDValue Idx = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV,
- DAG.getConstant(8, dl, MVT::i32));
- SDValue VecR = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, VecV);
- SDValue Ins = DAG.getNode(HexagonISD::INSERT, dl, MVT::i32,
- {VecR, ValR, Width, Idx});
- return DAG.getNode(HexagonISD::D2P, dl, VecTy, Ins);
+ T1 = DAG.getTargetExtractSubreg(Hexagon::isub_lo, dl, MVT::i32, T1);
+ T1 = expandPredicate(T1, dl, DAG);
+ Scale /= 2;
}
+ return DAG.getNode(HexagonISD::D2P, dl, ResTy, T1);
+}
+
+SDValue
+HexagonTargetLowering::insertVector(SDValue VecV, SDValue ValV, SDValue IdxV,
+ const SDLoc &dl, MVT ValTy,
+ SelectionDAG &DAG) const {
+ MVT VecTy = ty(VecV);
+ if (VecTy.getVectorElementType() == MVT::i1)
+ return insertVectorPred(VecV, ValV, IdxV, dl, ValTy, DAG);
+
unsigned VecWidth = VecTy.getSizeInBits();
unsigned ValWidth = ValTy.getSizeInBits();
assert(VecWidth == 32 || VecWidth == 64);
@@ -2799,13 +2784,53 @@ HexagonTargetLowering::insertVector(SDValue VecV, SDValue ValV, SDValue IdxV,
return DAG.getNode(ISD::BITCAST, dl, VecTy, InsV);
}
+SDValue
+HexagonTargetLowering::insertVectorPred(SDValue VecV, SDValue ValV,
+ SDValue IdxV, const SDLoc &dl,
+ MVT ValTy, SelectionDAG &DAG) const {
+ MVT VecTy = ty(VecV);
+ unsigned VecLen = VecTy.getVectorNumElements();
+
+ if (ValTy == MVT::i1) {
+ SDValue ToReg = getInstr(Hexagon::C2_tfrpr, dl, MVT::i32, {VecV}, DAG);
+ SDValue Ext = DAG.getSExtOrTrunc(ValV, dl, MVT::i32);
+ SDValue Width = DAG.getConstant(8 / VecLen, dl, MVT::i32);
+ SDValue Idx = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV, Width);
+ SDValue Ins =
+ DAG.getNode(HexagonISD::INSERT, dl, MVT::i32, {ToReg, Ext, Width, Idx});
+ return getInstr(Hexagon::C2_tfrrp, dl, VecTy, {Ins}, DAG);
+ }
+
+ assert(ValTy.getVectorElementType() == MVT::i1);
+ SDValue ValR = ValTy.isVector()
+ ? DAG.getNode(HexagonISD::P2D, dl, MVT::i64, ValV)
+ : DAG.getSExtOrTrunc(ValV, dl, MVT::i64);
+
+ unsigned Scale = VecLen / ValTy.getVectorNumElements();
+ assert(Scale > 1);
+
+ for (unsigned R = Scale; R > 1; R /= 2) {
+ ValR = contractPredicate(ValR, dl, DAG);
+ ValR = getCombine(DAG.getUNDEF(MVT::i32), ValR, dl, MVT::i64, DAG);
+ }
+
+ SDValue Width = DAG.getConstant(64 / Scale, dl, MVT::i32);
+ SDValue Idx = DAG.getNode(ISD::MUL, dl, MVT::i32, IdxV, Width);
+ SDValue VecR = DAG.getNode(HexagonISD::P2D, dl, MVT::i64, VecV);
+ SDValue Ins =
+ DAG.getNode(HexagonISD::INSERT, dl, MVT::i64, {VecR, ValR, Width, Idx});
+ return DAG.getNode(HexagonISD::D2P, dl, VecTy, Ins);
+}
+
SDValue
HexagonTargetLowering::expandPredicate(SDValue Vec32, const SDLoc &dl,
SelectionDAG &DAG) const {
assert(ty(Vec32).getSizeInBits() == 32);
if (isUndef(Vec32))
return DAG.getUNDEF(MVT::i64);
- return getInstr(Hexagon::S2_vsxtbh, dl, MVT::i64, {Vec32}, DAG);
+ SDValue P = DAG.getBitcast(MVT::v4i8, Vec32);
+ SDValue X = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i16, P);
+ return DAG.getBitcast(MVT::i64, X);
}
SDValue
@@ -2814,7 +2839,12 @@ HexagonTargetLowering::contractPredicate(SDValue Vec64, const SDLoc &dl,
assert(ty(Vec64).getSizeInBits() == 64);
if (isUndef(Vec64))
return DAG.getUNDEF(MVT::i32);
- return getInstr(Hexagon::S2_vtrunehb, dl, MVT::i32, {Vec64}, DAG);
+ // Collect even bytes:
+ SDValue A = DAG.getBitcast(MVT::v8i8, Vec64);
+ SDValue S = DAG.getVectorShuffle(MVT::v8i8, dl, A, DAG.getUNDEF(MVT::v8i8),
+ {0, 2, 4, 6, 1, 3, 5, 7});
+ return extractVector(S, DAG.getConstant(0, dl, MVT::i32), dl, MVT::v4i8,
+ MVT::i32, DAG);
}
SDValue
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.h b/llvm/lib/Target/Hexagon/HexagonISelLowering.h
index 5db849d94c65..1387f0c1b355 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLowering.h
+++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.h
@@ -378,8 +378,12 @@ class HexagonTargetLowering : public TargetLowering {
SelectionDAG &DAG) const;
SDValue extractVector(SDValue VecV, SDValue IdxV, const SDLoc &dl,
MVT ValTy, MVT ResTy, SelectionDAG &DAG) const;
+ SDValue extractVectorPred(SDValue VecV, SDValue IdxV, const SDLoc &dl,
+ MVT ValTy, MVT ResTy, SelectionDAG &DAG) const;
SDValue insertVector(SDValue VecV, SDValue ValV, SDValue IdxV,
const SDLoc &dl, MVT ValTy, SelectionDAG &DAG) const;
+ SDValue insertVectorPred(SDValue VecV, SDValue ValV, SDValue IdxV,
+ const SDLoc &dl, MVT ValTy, SelectionDAG &DAG) const;
SDValue expandPredicate(SDValue Vec32, const SDLoc &dl,
SelectionDAG &DAG) const;
SDValue contractPredicate(SDValue Vec64, const SDLoc &dl,
diff --git a/llvm/test/CodeGen/Hexagon/isel-extract-pred.ll b/llvm/test/CodeGen/Hexagon/isel-extract-pred.ll
new file mode 100644
index 000000000000..d05686524023
--- /dev/null
+++ b/llvm/test/CodeGen/Hexagon/isel-extract-pred.ll
@@ -0,0 +1,73 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -march=hexagon < %s | FileCheck %s
+
+define i32 @f0(ptr %a0, i32 %a1) #0 {
+; CHECK-LABEL: f0:
+; CHECK: // %bb.0:
+; CHECK-NEXT: {
+; CHECK-NEXT: r0 = memub(r0+#0)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r1 = asl(r1,#2)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: p0 = tstbit(r0,r1)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r0 = mux(p0,#-1,#0)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: jumpr r31
+; CHECK-NEXT: }
+ %v0 = load <2 x i1>, ptr %a0
+ %v1 = extractelement <2 x i1> %v0, i32 %a1
+ %v2 = sext i1 %v1 to i32
+ ret i32 %v2
+}
+
+define i32 @f1(ptr %a0, i32 %a1) #0 {
+; CHECK-LABEL: f1:
+; CHECK: // %bb.0:
+; CHECK-NEXT: {
+; CHECK-NEXT: r0 = memub(r0+#0)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r1 = asl(r1,#1)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: p0 = tstbit(r0,r1)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r0 = mux(p0,#-1,#0)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: jumpr r31
+; CHECK-NEXT: }
+ %v0 = load <4 x i1>, ptr %a0
+ %v1 = extractelement <4 x i1> %v0, i32 %a1
+ %v2 = sext i1 %v1 to i32
+ ret i32 %v2
+}
+
+define i32 @f2(ptr %a0, i32 %a1) #0 {
+; CHECK-LABEL: f2:
+; CHECK: // %bb.0:
+; CHECK-NEXT: {
+; CHECK-NEXT: r0 = memub(r0+#0)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: p0 = tstbit(r0,r1)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r0 = mux(p0,#-1,#0)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: jumpr r31
+; CHECK-NEXT: }
+ %v0 = load <8 x i1>, ptr %a0
+ %v1 = extractelement <8 x i1> %v0, i32 %a1
+ %v2 = sext i1 %v1 to i32
+ ret i32 %v2
+}
+
+attributes #0 = { nounwind "target-features"="-packets" }
diff --git a/llvm/test/CodeGen/Hexagon/isel-insert-pred.ll b/llvm/test/CodeGen/Hexagon/isel-insert-pred.ll
new file mode 100644
index 000000000000..2fa4a8b7bf9f
--- /dev/null
+++ b/llvm/test/CodeGen/Hexagon/isel-insert-pred.ll
@@ -0,0 +1,112 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -march=hexagon < %s | FileCheck %s
+
+define void @f0(ptr %a0, i32 %a1, i32 %a2) #0 {
+; CHECK-LABEL: f0:
+; CHECK: // %bb.0:
+; CHECK-NEXT: {
+; CHECK-NEXT: p0 = tstbit(r1,#0)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r2 = asl(r2,#2)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r4 = memub(r0+#0)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r3 = #4
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r1 = mux(p0,#-1,#0)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r4 = insert(r1,r3:2)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r1 = and(r4,#255)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: memb(r0+#0) = r1
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: jumpr r31
+; CHECK-NEXT: }
+ %v0 = load <2 x i1>, ptr %a0
+ %v1 = trunc i32 %a1 to i1
+ %v2 = insertelement <2 x i1> %v0, i1 %v1, i32 %a2
+ store <2 x i1> %v2, ptr %a0
+ ret void
+}
+
+define void @f1(ptr %a0, i32 %a1, i32 %a2) #0 {
+; CHECK-LABEL: f1:
+; CHECK: // %bb.0:
+; CHECK-NEXT: {
+; CHECK-NEXT: p0 = tstbit(r1,#0)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r2 = asl(r2,#1)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r4 = memub(r0+#0)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r3 = #2
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r1 = mux(p0,#-1,#0)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r4 = insert(r1,r3:2)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r1 = and(r4,#255)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: memb(r0+#0) = r1
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: jumpr r31
+; CHECK-NEXT: }
+ %v0 = load <4 x i1>, ptr %a0
+ %v1 = trunc i32 %a1 to i1
+ %v2 = insertelement <4 x i1> %v0, i1 %v1, i32 %a2
+ store <4 x i1> %v2, ptr %a0
+ ret void
+}
+
+define void @f2(ptr %a0, i32 %a1, i32 %a2) #0 {
+; CHECK-LABEL: f2:
+; CHECK: // %bb.0:
+; CHECK-NEXT: {
+; CHECK-NEXT: p0 = tstbit(r1,#0)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r6 = memub(r0+#0)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r3 = #1
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r4 = mux(p0,#-1,#0)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r6 = insert(r4,r3:2)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: r1 = and(r6,#255)
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: memb(r0+#0) = r1
+; CHECK-NEXT: }
+; CHECK-NEXT: {
+; CHECK-NEXT: jumpr r31
+; CHECK-NEXT: }
+ %v0 = load <8 x i1>, ptr %a0
+ %v1 = trunc i32 %a1 to i1
+ %v2 = insertelement <8 x i1> %v0, i1 %v1, i32 %a2
+ store <8 x i1> %v2, ptr %a0
+ ret void
+}
+
+attributes #0 = { nounwind "target-features"="-packets" }
More information about the llvm-commits
mailing list