[llvm] [GlobalIsel] combine extract vector element (PR #91922)
via llvm-commits
llvm-commits at lists.llvm.org
Sun May 12 23:59:11 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-globalisel
Author: Thorsten Schütt (tschuett)
<details>
<summary>Changes</summary>
scalarize compares
extelt (cmp X, Y), Index --> cmp (extelt X, Index),
(extelt Y, Index)
---
Full diff: https://github.com/llvm/llvm-project/pull/91922.diff
4 Files Affected:
- (modified) llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h (+23)
- (modified) llvm/include/llvm/Target/GlobalISel/Combine.td (+18-1)
- (modified) llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp (+159)
- (modified) llvm/test/CodeGen/AArch64/extract-vector-elt.ll (+128)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index ecaece8b68342..6edb3f9cd2e89 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -1,3 +1,4 @@
+
//===-- llvm/CodeGen/GlobalISel/CombinerHelper.h --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -866,6 +867,16 @@ class CombinerHelper {
/// Combine insert vector element OOB.
bool matchInsertVectorElementOOB(MachineInstr &MI, BuildFnTy &MatchInfo);
+ /// Combine extract vector element with a compare on the vector
+ /// register.
+ bool matchExtractVectorElementWithICmp(const MachineOperand &MO,
+ BuildFnTy &MatchInfo);
+
+ /// Combine extract vector element with a compare on the vector
+ /// register.
+ bool matchExtractVectorElementWithFCmp(const MachineOperand &MO,
+ BuildFnTy &MatchInfo);
+
private:
/// Checks for legality of an indexed variant of \p LdSt.
bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
@@ -981,6 +992,18 @@ class CombinerHelper {
// Simplify (cmp cc0 x, y) (&& or ||) (cmp cc1 x, y) -> cmp cc2 x, y.
bool tryFoldLogicOfFCmps(GLogicalBinOp *Logic, BuildFnTy &MatchInfo);
+
+ /// Return true if the register \p Src is cheaper to scalarize than it is to
+ /// leave as a vector operation. If the extract index \p Index is a constant
+ /// integer then some operations may be cheap to scalarize. The depth \p Depth
+ /// prevents arbitrary recursion.
+ bool isCheapToScalarize(Register Src, const std::optional<APInt> &Index,
+ unsigned Depth = 0);
+
+ /// Return true if \p Src is def'd by a operation of type vector that is
+ /// constant at offset \p Index. \p Depth limits arbitrary recursion into look
+ /// through vector operations.
+ bool isConstantAtOffset(Register Src, const APInt &Index, unsigned Depth = 0);
};
} // namespace llvm
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 98d266c8c0b4f..3c71c2a25b2d9 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -1591,6 +1591,20 @@ def insert_vector_elt_oob : GICombineRule<
[{ return Helper.matchInsertVectorElementOOB(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
+def extract_vector_element_icmp : GICombineRule<
+ (defs root:$root, build_fn_matchinfo:$matchinfo),
+ (match (G_ICMP $src, $pred, $lhs, $rhs),
+ (G_EXTRACT_VECTOR_ELT $root, $src, $idx),
+ [{ return Helper.matchExtractVectorElementWithICmp(${root}, ${matchinfo}); }]),
+ (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;
+
+def extract_vector_element_fcmp : GICombineRule<
+ (defs root:$root, build_fn_matchinfo:$matchinfo),
+ (match (G_FCMP $fsrc, $fpred, $flhs, $frhs),
+ (G_EXTRACT_VECTOR_ELT $root, $fsrc, $fidx),
+ [{ return Helper.matchExtractVectorElementWithFCmp(${root}, ${matchinfo}); }]),
+ (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;
+
// match_extract_of_element and insert_vector_elt_oob must be the first!
def vector_ops_combines: GICombineGroup<[
match_extract_of_element_undef_vector,
@@ -1624,6 +1638,8 @@ extract_vector_element_build_vector_trunc7,
extract_vector_element_build_vector_trunc8,
extract_vector_element_freeze,
extract_vector_element_shuffle_vector,
+extract_vector_element_icmp,
+extract_vector_element_fcmp,
insert_vector_element_extract_vector_element
]>;
@@ -1706,7 +1722,8 @@ def all_combines : GICombineGroup<[trivial_combines, vector_ops_combines,
sub_add_reg, select_to_minmax, redundant_binop_in_equality,
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
combine_concat_vector, double_icmp_zero_and_or_combine, match_addos,
- sext_trunc, zext_trunc, combine_shuffle_concat]>;
+ sext_trunc, zext_trunc, combine_shuffle_concat
+]>;
// A combine group used to for prelegalizer combiners at -O0. The combines in
// this group have been selected based on experiments to balance code size and
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
index 21b1eb2628174..64b39e3f82e65 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
@@ -453,3 +453,162 @@ bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI,
return false;
}
+
+bool CombinerHelper::isConstantAtOffset(Register Src, const APInt &Index,
+ unsigned Depth) {
+ assert(MRI.getType(Src).isVector() && "expected a vector as input");
+ if (Depth == 2)
+ return false;
+
+ // We use the look through variant for higher hit rate and to increase the
+ // likelyhood of constant folding. The actual value is ignored. We only test
+ // *whether* there is a constant.
+
+ MachineInstr *SrcMI = getDefIgnoringCopies(Src, MRI);
+
+ // If Src is def'd by build vector, then we check the constness at the offset.
+ if (auto *Build = dyn_cast<GBuildVector>(SrcMI))
+ return getAnyConstantVRegValWithLookThrough(
+ Build->getSourceReg(Index.getZExtValue()), MRI)
+ .has_value();
+
+ // For concat and shuffle vectors, we could recurse.
+ // FIXME concat vectors
+ // FIXME shuffle vectors
+ // FIXME unary ops
+ // FIXME insert vector element
+ // FIXME subvector
+
+ return false;
+}
+
+bool CombinerHelper::isCheapToScalarize(Register Src,
+ const std::optional<APInt> &Index,
+ unsigned Depth) {
+ assert(MRI.getType(Src).isVector() && "expected a vector as input");
+
+ if (Depth >= 2)
+ return false;
+
+ MachineInstr *SrcMI = getDefIgnoringCopies(Src, MRI);
+
+ // If Src is def'd by a binary operator,
+ // then scalarizing the op is cheap when one of its operands is cheap to
+ // scalarize.
+ if (auto *BinOp = dyn_cast<GBinOp>(SrcMI))
+ if (MRI.hasOneNonDBGUse(BinOp->getReg(0)))
+ if (isCheapToScalarize(BinOp->getLHSReg(), Index, Depth + 1) ||
+ isCheapToScalarize(BinOp->getRHSReg(), Index, Depth + 1))
+ return true;
+
+ // If Src is def'd by a compare,
+ // then scalarizing the cmp is cheap when one of its operands is cheap to
+ // scalarize.
+ if (auto *Cmp = dyn_cast<GAnyCmp>(SrcMI))
+ if (MRI.hasOneNonDBGUse(Cmp->getReg(0)))
+ if (isCheapToScalarize(Cmp->getLHSReg(), Index, Depth + 1) ||
+ isCheapToScalarize(Cmp->getRHSReg(), Index, Depth + 1))
+ return true;
+
+ // FIXME: unary operator
+ // FIXME: casts
+ // FIXME: loads
+ // FIXME: subvector
+
+ if (Index)
+ // If Index is constant, then Src is cheap to scalarize when it is constant
+ // at offset Index.
+ return isConstantAtOffset(Src, *Index, Depth);
+
+ return false;
+}
+
+bool CombinerHelper::matchExtractVectorElementWithICmp(const MachineOperand &MO,
+ BuildFnTy &MatchInfo) {
+ GExtractVectorElement *Extract =
+ cast<GExtractVectorElement>(MRI.getVRegDef(MO.getReg()));
+
+ Register Vector = Extract->getVectorReg();
+
+ GICmp *Cmp = cast<GICmp>(MRI.getVRegDef(Vector));
+
+ std::optional<ValueAndVReg> MaybeIndex =
+ getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI);
+ std::optional<APInt> IndexC = std::nullopt;
+
+ if (MaybeIndex)
+ IndexC = MaybeIndex->Value;
+
+ if (!isCheapToScalarize(Vector, IndexC))
+ return false;
+
+ if (!MRI.hasOneNonDBGUse(Cmp->getReg(0)))
+ return false;
+
+ Register Dst = Extract->getReg(0);
+ LLT DstTy = MRI.getType(Dst);
+ LLT IdxTy = MRI.getType(Extract->getIndexReg());
+ LLT VectorTy = MRI.getType(Cmp->getLHSReg());
+ LLT ExtractDstTy = VectorTy.getScalarType();
+
+ if (!isLegalOrBeforeLegalizer(
+ {TargetOpcode::G_ICMP, {DstTy, ExtractDstTy}}) ||
+ !isLegalOrBeforeLegalizer({TargetOpcode::G_EXTRACT_VECTOR_ELT,
+ {ExtractDstTy, VectorTy, IdxTy}}))
+ return false;
+
+ MatchInfo = [=](MachineIRBuilder &B) {
+ auto LHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getLHSReg(),
+ Extract->getIndexReg());
+ auto RHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getRHSReg(),
+ Extract->getIndexReg());
+ B.buildICmp(Cmp->getCond(), Dst, LHS, RHS);
+ };
+
+ return true;
+}
+
+bool CombinerHelper::matchExtractVectorElementWithFCmp(const MachineOperand &MO,
+ BuildFnTy &MatchInfo) {
+ GExtractVectorElement *Extract =
+ cast<GExtractVectorElement>(MRI.getVRegDef(MO.getReg()));
+
+ Register Vector = Extract->getVectorReg();
+
+ GFCmp *Cmp = cast<GFCmp>(MRI.getVRegDef(Vector));
+
+ std::optional<ValueAndVReg> MaybeIndex =
+ getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI);
+ std::optional<APInt> IndexC = std::nullopt;
+
+ if (MaybeIndex)
+ IndexC = MaybeIndex->Value;
+
+ if (!isCheapToScalarize(Vector, IndexC))
+ return false;
+
+ if (!MRI.hasOneNonDBGUse(Cmp->getReg(0)))
+ return false;
+
+ Register Dst = Extract->getReg(0);
+ LLT DstTy = MRI.getType(Dst);
+ LLT IdxTy = MRI.getType(Extract->getIndexReg());
+ LLT VectorTy = MRI.getType(Cmp->getLHSReg());
+ LLT ExtractDstTy = VectorTy.getScalarType();
+
+ if (!isLegalOrBeforeLegalizer(
+ {TargetOpcode::G_FCMP, {DstTy, ExtractDstTy}}) ||
+ !isLegalOrBeforeLegalizer({TargetOpcode::G_EXTRACT_VECTOR_ELT,
+ {ExtractDstTy, VectorTy, IdxTy}}))
+ return false;
+
+ MatchInfo = [=](MachineIRBuilder &B) {
+ auto LHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getLHSReg(),
+ Extract->getIndexReg());
+ auto RHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getRHSReg(),
+ Extract->getIndexReg());
+ B.buildFCmp(Cmp->getCond(), Dst, LHS, RHS, Cmp->getFlags());
+ };
+
+ return true;
+}
diff --git a/llvm/test/CodeGen/AArch64/extract-vector-elt.ll b/llvm/test/CodeGen/AArch64/extract-vector-elt.ll
index 0481d997d24fa..42fe5e82cb7de 100644
--- a/llvm/test/CodeGen/AArch64/extract-vector-elt.ll
+++ b/llvm/test/CodeGen/AArch64/extract-vector-elt.ll
@@ -1100,4 +1100,132 @@ ret:
ret i32 %3
}
+define i32 @extract_v4float_fcmp_const_no_zext(<4 x float> %a, <4 x float> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4float_fcmp_const_no_zext:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: fcmeq v0.4s, v0.4s, v0.4s
+; CHECK-SD-NEXT: mvn v0.16b, v0.16b
+; CHECK-SD-NEXT: xtn v0.4h, v0.4s
+; CHECK-SD-NEXT: umov w8, v0.h[1]
+; CHECK-SD-NEXT: and w0, w8, #0x1
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: extract_v4float_fcmp_const_no_zext:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: mov s0, v0.s[1]
+; CHECK-GI-NEXT: fmov s1, #1.00000000
+; CHECK-GI-NEXT: fcmp s0, s1
+; CHECK-GI-NEXT: cset w0, vs
+; CHECK-GI-NEXT: ret
+entry:
+ %vector = fcmp uno <4 x float> %a, <float 1.0, float 1.0, float 1.0, float 1.0>
+ %d = extractelement <4 x i1> %vector, i32 1
+ %z = zext i1 %d to i32
+ ret i32 %z
+}
+define i32 @extract_v4i32_icmp_const_no_zext(<4 x i32> %a, <4 x i32> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4i32_icmp_const_no_zext:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: adrp x8, .LCPI43_0
+; CHECK-SD-NEXT: ldr q1, [x8, :lo12:.LCPI43_0]
+; CHECK-SD-NEXT: cmge v0.4s, v1.4s, v0.4s
+; CHECK-SD-NEXT: xtn v0.4h, v0.4s
+; CHECK-SD-NEXT: umov w8, v0.h[1]
+; CHECK-SD-NEXT: and w0, w8, #0x1
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: extract_v4i32_icmp_const_no_zext:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: mov s0, v0.s[1]
+; CHECK-GI-NEXT: fmov w8, s0
+; CHECK-GI-NEXT: cmp w8, #8
+; CHECK-GI-NEXT: cset w0, le
+; CHECK-GI-NEXT: ret
+entry:
+ %vector = icmp sle <4 x i32> %a, <i32 7, i32 8, i32 7, i32 9>
+ %d = extractelement <4 x i1> %vector, i32 1
+ %z = zext i1 %d to i32
+ ret i32 %z
+}
+
+define i32 @extract_v4float_fcmp_const_no_zext_fail(<4 x float> %a, <4 x float> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4float_fcmp_const_no_zext_fail:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: sub sp, sp, #16
+; CHECK-SD-NEXT: .cfi_def_cfa_offset 16
+; CHECK-SD-NEXT: fcmeq v0.4s, v0.4s, v0.4s
+; CHECK-SD-NEXT: add x8, sp, #8
+; CHECK-SD-NEXT: // kill: def $w0 killed $w0 def $x0
+; CHECK-SD-NEXT: bfi x8, x0, #1, #2
+; CHECK-SD-NEXT: mvn v0.16b, v0.16b
+; CHECK-SD-NEXT: xtn v0.4h, v0.4s
+; CHECK-SD-NEXT: str d0, [sp, #8]
+; CHECK-SD-NEXT: ldrh w8, [x8]
+; CHECK-SD-NEXT: and w0, w8, #0x1
+; CHECK-SD-NEXT: add sp, sp, #16
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: extract_v4float_fcmp_const_no_zext_fail:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: sub sp, sp, #16
+; CHECK-GI-NEXT: .cfi_def_cfa_offset 16
+; CHECK-GI-NEXT: fmov v1.4s, #1.00000000
+; CHECK-GI-NEXT: mov w8, w0
+; CHECK-GI-NEXT: mov x9, sp
+; CHECK-GI-NEXT: and x8, x8, #0x3
+; CHECK-GI-NEXT: fcmge v2.4s, v0.4s, v1.4s
+; CHECK-GI-NEXT: fcmgt v0.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT: orr v0.16b, v0.16b, v2.16b
+; CHECK-GI-NEXT: mvn v0.16b, v0.16b
+; CHECK-GI-NEXT: str q0, [sp]
+; CHECK-GI-NEXT: ldr w8, [x9, x8, lsl #2]
+; CHECK-GI-NEXT: and w0, w8, #0x1
+; CHECK-GI-NEXT: add sp, sp, #16
+; CHECK-GI-NEXT: ret
+entry:
+ %vector = fcmp uno <4 x float> %a, <float 1.0, float 1.0, float 1.0, float 1.0>
+ %d = extractelement <4 x i1> %vector, i32 %c
+ %z = zext i1 %d to i32
+ ret i32 %z
+}
+
+define i32 @extract_v4i32_icmp_const_no_zext_fail(<4 x i32> %a, <4 x i32> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4i32_icmp_const_no_zext_fail:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: sub sp, sp, #16
+; CHECK-SD-NEXT: .cfi_def_cfa_offset 16
+; CHECK-SD-NEXT: adrp x8, .LCPI45_0
+; CHECK-SD-NEXT: // kill: def $w0 killed $w0 def $x0
+; CHECK-SD-NEXT: ldr q1, [x8, :lo12:.LCPI45_0]
+; CHECK-SD-NEXT: add x8, sp, #8
+; CHECK-SD-NEXT: bfi x8, x0, #1, #2
+; CHECK-SD-NEXT: cmge v0.4s, v1.4s, v0.4s
+; CHECK-SD-NEXT: xtn v0.4h, v0.4s
+; CHECK-SD-NEXT: str d0, [sp, #8]
+; CHECK-SD-NEXT: ldrh w8, [x8]
+; CHECK-SD-NEXT: and w0, w8, #0x1
+; CHECK-SD-NEXT: add sp, sp, #16
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: extract_v4i32_icmp_const_no_zext_fail:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: sub sp, sp, #16
+; CHECK-GI-NEXT: .cfi_def_cfa_offset 16
+; CHECK-GI-NEXT: adrp x8, .LCPI45_0
+; CHECK-GI-NEXT: mov x9, sp
+; CHECK-GI-NEXT: ldr q1, [x8, :lo12:.LCPI45_0]
+; CHECK-GI-NEXT: mov w8, w0
+; CHECK-GI-NEXT: and x8, x8, #0x3
+; CHECK-GI-NEXT: cmge v0.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT: str q0, [sp]
+; CHECK-GI-NEXT: ldr w8, [x9, x8, lsl #2]
+; CHECK-GI-NEXT: and w0, w8, #0x1
+; CHECK-GI-NEXT: add sp, sp, #16
+; CHECK-GI-NEXT: ret
+entry:
+ %vector = icmp sle <4 x i32> %a, <i32 7, i32 8, i32 7, i32 9>
+ %d = extractelement <4 x i1> %vector, i32 %c
+ %z = zext i1 %d to i32
+ ret i32 %z
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/91922
More information about the llvm-commits
mailing list