[llvm] [GlobalIsel] combine extract vector element (PR #91922)

Thorsten Schütt via llvm-commits llvm-commits at lists.llvm.org
Mon May 13 00:03:22 PDT 2024


https://github.com/tschuett updated https://github.com/llvm/llvm-project/pull/91922

>From 1c594a400929363677fb7850f1055929ed7951c1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thorsten=20Sch=C3=BCtt?= <schuett at gmail.com>
Date: Sun, 12 May 2024 23:08:25 +0200
Subject: [PATCH 1/2] [GlobalIsel] combine extract vector element

scalarize compares

extelt (cmp X, Y), Index --> cmp (extelt X, Index),
                                 (extelt Y, Index)
---
 .../llvm/CodeGen/GlobalISel/CombinerHelper.h  |  23 +++
 .../include/llvm/Target/GlobalISel/Combine.td |  19 ++-
 .../GlobalISel/CombinerHelperVectorOps.cpp    | 159 ++++++++++++++++++
 .../CodeGen/AArch64/extract-vector-elt.ll     | 128 ++++++++++++++
 4 files changed, 328 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index ecaece8b68342..6edb3f9cd2e89 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -1,3 +1,4 @@
+
 //===-- llvm/CodeGen/GlobalISel/CombinerHelper.h --------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -866,6 +867,16 @@ class CombinerHelper {
   /// Combine insert vector element OOB.
   bool matchInsertVectorElementOOB(MachineInstr &MI, BuildFnTy &MatchInfo);
 
+  /// Combine extract vector element with a compare on the vector
+  /// register.
+  bool matchExtractVectorElementWithICmp(const MachineOperand &MO,
+                                         BuildFnTy &MatchInfo);
+
+  /// Combine extract vector element with a compare on the vector
+  /// register.
+  bool matchExtractVectorElementWithFCmp(const MachineOperand &MO,
+                                         BuildFnTy &MatchInfo);
+
 private:
   /// Checks for legality of an indexed variant of \p LdSt.
   bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
@@ -981,6 +992,18 @@ class CombinerHelper {
 
   // Simplify (cmp cc0 x, y) (&& or ||) (cmp cc1 x, y) -> cmp cc2 x, y.
   bool tryFoldLogicOfFCmps(GLogicalBinOp *Logic, BuildFnTy &MatchInfo);
+
+  /// Return true if the register \p Src is cheaper to scalarize than it is to
+  /// leave as a vector operation. If the extract index \p Index is a constant
+  /// integer then some operations may be cheap to scalarize. The depth \p Depth
+  /// prevents arbitrary recursion.
+  bool isCheapToScalarize(Register Src, const std::optional<APInt> &Index,
+                          unsigned Depth = 0);
+
+  /// Return true if \p Src is def'd by a operation of type vector that is
+  /// constant at offset \p Index. \p Depth limits arbitrary recursion into look
+  /// through vector operations.
+  bool isConstantAtOffset(Register Src, const APInt &Index, unsigned Depth = 0);
 };
 } // namespace llvm
 
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 98d266c8c0b4f..3c71c2a25b2d9 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -1591,6 +1591,20 @@ def insert_vector_elt_oob : GICombineRule<
          [{ return Helper.matchInsertVectorElementOOB(*${root}, ${matchinfo}); }]),
   (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
 
+def extract_vector_element_icmp : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+   (match (G_ICMP $src, $pred, $lhs, $rhs),
+          (G_EXTRACT_VECTOR_ELT $root, $src, $idx),
+   [{ return Helper.matchExtractVectorElementWithICmp(${root}, ${matchinfo}); }]),
+   (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;
+
+def extract_vector_element_fcmp : GICombineRule<
+  (defs root:$root, build_fn_matchinfo:$matchinfo),
+  (match (G_FCMP $fsrc, $fpred, $flhs, $frhs),
+         (G_EXTRACT_VECTOR_ELT $root, $fsrc, $fidx),
+  [{ return Helper.matchExtractVectorElementWithFCmp(${root}, ${matchinfo}); }]),
+  (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;
+
 // match_extract_of_element and insert_vector_elt_oob must be the first!
 def vector_ops_combines: GICombineGroup<[
 match_extract_of_element_undef_vector,
@@ -1624,6 +1638,8 @@ extract_vector_element_build_vector_trunc7,
 extract_vector_element_build_vector_trunc8,
 extract_vector_element_freeze,
 extract_vector_element_shuffle_vector,
+extract_vector_element_icmp,
+extract_vector_element_fcmp,
 insert_vector_element_extract_vector_element
 ]>;
 
@@ -1706,7 +1722,8 @@ def all_combines : GICombineGroup<[trivial_combines, vector_ops_combines,
     sub_add_reg, select_to_minmax, redundant_binop_in_equality,
     fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
     combine_concat_vector, double_icmp_zero_and_or_combine, match_addos,
-    sext_trunc, zext_trunc, combine_shuffle_concat]>;
+    sext_trunc, zext_trunc, combine_shuffle_concat
+]>;
 
 // A combine group used to for prelegalizer combiners at -O0. The combines in
 // this group have been selected based on experiments to balance code size and
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
index 21b1eb2628174..64b39e3f82e65 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
@@ -453,3 +453,162 @@ bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI,
 
   return false;
 }
+
+bool CombinerHelper::isConstantAtOffset(Register Src, const APInt &Index,
+                                        unsigned Depth) {
+  assert(MRI.getType(Src).isVector() && "expected a vector as input");
+  if (Depth == 2)
+    return false;
+
+  // We use the look through variant for higher hit rate and to increase the
+  // likelyhood of constant folding. The actual value is ignored. We only test
+  // *whether* there is a constant.
+
+  MachineInstr *SrcMI = getDefIgnoringCopies(Src, MRI);
+
+  // If Src is def'd by build vector, then we check the constness at the offset.
+  if (auto *Build = dyn_cast<GBuildVector>(SrcMI))
+    return getAnyConstantVRegValWithLookThrough(
+               Build->getSourceReg(Index.getZExtValue()), MRI)
+        .has_value();
+
+  // For concat and shuffle vectors, we could recurse.
+  // FIXME concat vectors
+  // FIXME shuffle vectors
+  // FIXME unary ops
+  // FIXME insert vector element
+  // FIXME subvector
+
+  return false;
+}
+
+bool CombinerHelper::isCheapToScalarize(Register Src,
+                                        const std::optional<APInt> &Index,
+                                        unsigned Depth) {
+  assert(MRI.getType(Src).isVector() && "expected a vector as input");
+
+  if (Depth >= 2)
+    return false;
+
+  MachineInstr *SrcMI = getDefIgnoringCopies(Src, MRI);
+
+  // If Src is def'd by a binary operator,
+  // then scalarizing the op is cheap when one of its operands is cheap to
+  // scalarize.
+  if (auto *BinOp = dyn_cast<GBinOp>(SrcMI))
+    if (MRI.hasOneNonDBGUse(BinOp->getReg(0)))
+      if (isCheapToScalarize(BinOp->getLHSReg(), Index, Depth + 1) ||
+          isCheapToScalarize(BinOp->getRHSReg(), Index, Depth + 1))
+        return true;
+
+  // If Src is def'd by a compare,
+  // then scalarizing the cmp is cheap when one of its operands is cheap to
+  // scalarize.
+  if (auto *Cmp = dyn_cast<GAnyCmp>(SrcMI))
+    if (MRI.hasOneNonDBGUse(Cmp->getReg(0)))
+      if (isCheapToScalarize(Cmp->getLHSReg(), Index, Depth + 1) ||
+          isCheapToScalarize(Cmp->getRHSReg(), Index, Depth + 1))
+        return true;
+
+  // FIXME: unary operator
+  // FIXME: casts
+  // FIXME: loads
+  // FIXME: subvector
+
+  if (Index)
+    // If Index is constant, then Src is cheap to scalarize when it is constant
+    // at offset Index.
+    return isConstantAtOffset(Src, *Index, Depth);
+
+  return false;
+}
+
+bool CombinerHelper::matchExtractVectorElementWithICmp(const MachineOperand &MO,
+                                                       BuildFnTy &MatchInfo) {
+  GExtractVectorElement *Extract =
+      cast<GExtractVectorElement>(MRI.getVRegDef(MO.getReg()));
+
+  Register Vector = Extract->getVectorReg();
+
+  GICmp *Cmp = cast<GICmp>(MRI.getVRegDef(Vector));
+
+  std::optional<ValueAndVReg> MaybeIndex =
+      getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI);
+  std::optional<APInt> IndexC = std::nullopt;
+
+  if (MaybeIndex)
+    IndexC = MaybeIndex->Value;
+
+  if (!isCheapToScalarize(Vector, IndexC))
+    return false;
+
+  if (!MRI.hasOneNonDBGUse(Cmp->getReg(0)))
+    return false;
+
+  Register Dst = Extract->getReg(0);
+  LLT DstTy = MRI.getType(Dst);
+  LLT IdxTy = MRI.getType(Extract->getIndexReg());
+  LLT VectorTy = MRI.getType(Cmp->getLHSReg());
+  LLT ExtractDstTy = VectorTy.getScalarType();
+
+  if (!isLegalOrBeforeLegalizer(
+          {TargetOpcode::G_ICMP, {DstTy, ExtractDstTy}}) ||
+      !isLegalOrBeforeLegalizer({TargetOpcode::G_EXTRACT_VECTOR_ELT,
+                                 {ExtractDstTy, VectorTy, IdxTy}}))
+    return false;
+
+  MatchInfo = [=](MachineIRBuilder &B) {
+    auto LHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getLHSReg(),
+                                           Extract->getIndexReg());
+    auto RHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getRHSReg(),
+                                           Extract->getIndexReg());
+    B.buildICmp(Cmp->getCond(), Dst, LHS, RHS);
+  };
+
+  return true;
+}
+
+bool CombinerHelper::matchExtractVectorElementWithFCmp(const MachineOperand &MO,
+                                                       BuildFnTy &MatchInfo) {
+  GExtractVectorElement *Extract =
+      cast<GExtractVectorElement>(MRI.getVRegDef(MO.getReg()));
+
+  Register Vector = Extract->getVectorReg();
+
+  GFCmp *Cmp = cast<GFCmp>(MRI.getVRegDef(Vector));
+
+  std::optional<ValueAndVReg> MaybeIndex =
+      getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI);
+  std::optional<APInt> IndexC = std::nullopt;
+
+  if (MaybeIndex)
+    IndexC = MaybeIndex->Value;
+
+  if (!isCheapToScalarize(Vector, IndexC))
+    return false;
+
+  if (!MRI.hasOneNonDBGUse(Cmp->getReg(0)))
+    return false;
+
+  Register Dst = Extract->getReg(0);
+  LLT DstTy = MRI.getType(Dst);
+  LLT IdxTy = MRI.getType(Extract->getIndexReg());
+  LLT VectorTy = MRI.getType(Cmp->getLHSReg());
+  LLT ExtractDstTy = VectorTy.getScalarType();
+
+  if (!isLegalOrBeforeLegalizer(
+          {TargetOpcode::G_FCMP, {DstTy, ExtractDstTy}}) ||
+      !isLegalOrBeforeLegalizer({TargetOpcode::G_EXTRACT_VECTOR_ELT,
+                                 {ExtractDstTy, VectorTy, IdxTy}}))
+    return false;
+
+  MatchInfo = [=](MachineIRBuilder &B) {
+    auto LHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getLHSReg(),
+                                           Extract->getIndexReg());
+    auto RHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getRHSReg(),
+                                           Extract->getIndexReg());
+    B.buildFCmp(Cmp->getCond(), Dst, LHS, RHS, Cmp->getFlags());
+  };
+
+  return true;
+}
diff --git a/llvm/test/CodeGen/AArch64/extract-vector-elt.ll b/llvm/test/CodeGen/AArch64/extract-vector-elt.ll
index 0481d997d24fa..42fe5e82cb7de 100644
--- a/llvm/test/CodeGen/AArch64/extract-vector-elt.ll
+++ b/llvm/test/CodeGen/AArch64/extract-vector-elt.ll
@@ -1100,4 +1100,132 @@ ret:
   ret i32 %3
 }
 
+define i32 @extract_v4float_fcmp_const_no_zext(<4 x float> %a, <4 x float> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4float_fcmp_const_no_zext:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    fcmeq v0.4s, v0.4s, v0.4s
+; CHECK-SD-NEXT:    mvn v0.16b, v0.16b
+; CHECK-SD-NEXT:    xtn v0.4h, v0.4s
+; CHECK-SD-NEXT:    umov w8, v0.h[1]
+; CHECK-SD-NEXT:    and w0, w8, #0x1
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: extract_v4float_fcmp_const_no_zext:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    mov s0, v0.s[1]
+; CHECK-GI-NEXT:    fmov s1, #1.00000000
+; CHECK-GI-NEXT:    fcmp s0, s1
+; CHECK-GI-NEXT:    cset w0, vs
+; CHECK-GI-NEXT:    ret
+entry:
+  %vector = fcmp uno <4 x float> %a, <float 1.0, float 1.0, float 1.0, float 1.0>
+  %d = extractelement <4 x i1> %vector, i32 1
+  %z = zext i1 %d to i32
+  ret i32 %z
+}
 
+define i32 @extract_v4i32_icmp_const_no_zext(<4 x i32> %a, <4 x i32> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4i32_icmp_const_no_zext:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    adrp x8, .LCPI43_0
+; CHECK-SD-NEXT:    ldr q1, [x8, :lo12:.LCPI43_0]
+; CHECK-SD-NEXT:    cmge v0.4s, v1.4s, v0.4s
+; CHECK-SD-NEXT:    xtn v0.4h, v0.4s
+; CHECK-SD-NEXT:    umov w8, v0.h[1]
+; CHECK-SD-NEXT:    and w0, w8, #0x1
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: extract_v4i32_icmp_const_no_zext:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    mov s0, v0.s[1]
+; CHECK-GI-NEXT:    fmov w8, s0
+; CHECK-GI-NEXT:    cmp w8, #8
+; CHECK-GI-NEXT:    cset w0, le
+; CHECK-GI-NEXT:    ret
+entry:
+  %vector = icmp sle <4 x i32> %a, <i32 7, i32 8, i32 7, i32 9>
+  %d = extractelement <4 x i1> %vector, i32 1
+  %z = zext i1 %d to i32
+  ret i32 %z
+}
+
+define i32 @extract_v4float_fcmp_const_no_zext_fail(<4 x float> %a, <4 x float> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4float_fcmp_const_no_zext_fail:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    sub sp, sp, #16
+; CHECK-SD-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-SD-NEXT:    fcmeq v0.4s, v0.4s, v0.4s
+; CHECK-SD-NEXT:    add x8, sp, #8
+; CHECK-SD-NEXT:    // kill: def $w0 killed $w0 def $x0
+; CHECK-SD-NEXT:    bfi x8, x0, #1, #2
+; CHECK-SD-NEXT:    mvn v0.16b, v0.16b
+; CHECK-SD-NEXT:    xtn v0.4h, v0.4s
+; CHECK-SD-NEXT:    str d0, [sp, #8]
+; CHECK-SD-NEXT:    ldrh w8, [x8]
+; CHECK-SD-NEXT:    and w0, w8, #0x1
+; CHECK-SD-NEXT:    add sp, sp, #16
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: extract_v4float_fcmp_const_no_zext_fail:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    sub sp, sp, #16
+; CHECK-GI-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-GI-NEXT:    fmov v1.4s, #1.00000000
+; CHECK-GI-NEXT:    mov w8, w0
+; CHECK-GI-NEXT:    mov x9, sp
+; CHECK-GI-NEXT:    and x8, x8, #0x3
+; CHECK-GI-NEXT:    fcmge v2.4s, v0.4s, v1.4s
+; CHECK-GI-NEXT:    fcmgt v0.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT:    orr v0.16b, v0.16b, v2.16b
+; CHECK-GI-NEXT:    mvn v0.16b, v0.16b
+; CHECK-GI-NEXT:    str q0, [sp]
+; CHECK-GI-NEXT:    ldr w8, [x9, x8, lsl #2]
+; CHECK-GI-NEXT:    and w0, w8, #0x1
+; CHECK-GI-NEXT:    add sp, sp, #16
+; CHECK-GI-NEXT:    ret
+entry:
+  %vector = fcmp uno <4 x float> %a, <float 1.0, float 1.0, float 1.0, float 1.0>
+  %d = extractelement <4 x i1> %vector, i32 %c
+  %z = zext i1 %d to i32
+  ret i32 %z
+}
+
+define i32 @extract_v4i32_icmp_const_no_zext_fail(<4 x i32> %a, <4 x i32> %b, i32 %c) {
+; CHECK-SD-LABEL: extract_v4i32_icmp_const_no_zext_fail:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    sub sp, sp, #16
+; CHECK-SD-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-SD-NEXT:    adrp x8, .LCPI45_0
+; CHECK-SD-NEXT:    // kill: def $w0 killed $w0 def $x0
+; CHECK-SD-NEXT:    ldr q1, [x8, :lo12:.LCPI45_0]
+; CHECK-SD-NEXT:    add x8, sp, #8
+; CHECK-SD-NEXT:    bfi x8, x0, #1, #2
+; CHECK-SD-NEXT:    cmge v0.4s, v1.4s, v0.4s
+; CHECK-SD-NEXT:    xtn v0.4h, v0.4s
+; CHECK-SD-NEXT:    str d0, [sp, #8]
+; CHECK-SD-NEXT:    ldrh w8, [x8]
+; CHECK-SD-NEXT:    and w0, w8, #0x1
+; CHECK-SD-NEXT:    add sp, sp, #16
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: extract_v4i32_icmp_const_no_zext_fail:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    sub sp, sp, #16
+; CHECK-GI-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-GI-NEXT:    adrp x8, .LCPI45_0
+; CHECK-GI-NEXT:    mov x9, sp
+; CHECK-GI-NEXT:    ldr q1, [x8, :lo12:.LCPI45_0]
+; CHECK-GI-NEXT:    mov w8, w0
+; CHECK-GI-NEXT:    and x8, x8, #0x3
+; CHECK-GI-NEXT:    cmge v0.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT:    str q0, [sp]
+; CHECK-GI-NEXT:    ldr w8, [x9, x8, lsl #2]
+; CHECK-GI-NEXT:    and w0, w8, #0x1
+; CHECK-GI-NEXT:    add sp, sp, #16
+; CHECK-GI-NEXT:    ret
+entry:
+  %vector = icmp sle <4 x i32> %a, <i32 7, i32 8, i32 7, i32 9>
+  %d = extractelement <4 x i1> %vector, i32 %c
+  %z = zext i1 %d to i32
+  ret i32 %z
+}

>From 27189e9beb57e8a55db32f11511f0ad2670b1dc0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thorsten=20Sch=C3=BCtt?= <schuett at gmail.com>
Date: Mon, 13 May 2024 08:59:18 +0200
Subject: [PATCH 2/2] fixups

---
 llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h | 1 -
 llvm/include/llvm/Target/GlobalISel/Combine.td        | 3 +--
 2 files changed, 1 insertion(+), 3 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index 6edb3f9cd2e89..c4af944d814f4 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -1,4 +1,3 @@
-
 //===-- llvm/CodeGen/GlobalISel/CombinerHelper.h --------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 3c71c2a25b2d9..d7aa0267fb449 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -1722,8 +1722,7 @@ def all_combines : GICombineGroup<[trivial_combines, vector_ops_combines,
     sub_add_reg, select_to_minmax, redundant_binop_in_equality,
     fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
     combine_concat_vector, double_icmp_zero_and_or_combine, match_addos,
-    sext_trunc, zext_trunc, combine_shuffle_concat
-]>;
+    sext_trunc, zext_trunc, combine_shuffle_concat]>;
 
 // A combine group used to for prelegalizer combiners at -O0. The combines in
 // this group have been selected based on experiments to balance code size and



More information about the llvm-commits mailing list