[llvm] [GlobalIsel] Combine G_VSCALE (PR #94096)
Thorsten Schütt via llvm-commits
llvm-commits at lists.llvm.org
Fri May 31 23:59:41 PDT 2024
https://github.com/tschuett updated https://github.com/llvm/llvm-project/pull/94096
>From 26ca8a610b8d3f4b320dd31448364d948f626f8d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thorsten=20Sch=C3=BCtt?= <schuett at gmail.com>
Date: Fri, 31 May 2024 19:52:06 +0200
Subject: [PATCH 1/2] [GlobalIsel] Combine G_VSCALE
We need them for scalable address calculation and
legal scalable addressing modes.
---
.../llvm/CodeGen/GlobalISel/CombinerHelper.h | 8 ++
.../CodeGen/GlobalISel/GenericMachineInstrs.h | 41 ++++++-
.../include/llvm/Target/GlobalISel/Combine.td | 37 +++++-
.../GlobalISel/CombinerHelperVectorOps.cpp | 86 ++++++++++++-
.../AArch64/GlobalISel/combine-vscale.mir | 113 ++++++++++++++++++
5 files changed, 282 insertions(+), 3 deletions(-)
create mode 100644 llvm/test/CodeGen/AArch64/GlobalISel/combine-vscale.mir
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index 2ddf20ebe7af7..5e476b9f7bf31 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -868,6 +868,14 @@ class CombinerHelper {
bool matchFreezeOfSingleMaybePoisonOperand(MachineInstr &MI,
BuildFnTy &MatchInfo);
+ bool matchAddOfVScale(const MachineOperand &MO, BuildFnTy &MatchInfo);
+
+ bool matchMulOfVScale(const MachineOperand &MO, BuildFnTy &MatchInfo);
+
+ bool matchSubOfVScale(const MachineOperand &MO, BuildFnTy &MatchInfo);
+
+ bool matchShlOfVScale(const MachineOperand &MO, BuildFnTy &MatchInfo);
+
private:
/// Checks for legality of an indexed variant of \p LdSt.
bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
index 2b3efc3b609f0..36ae9beed8aa9 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
@@ -14,10 +14,12 @@
#ifndef LLVM_CODEGEN_GLOBALISEL_GENERICMACHINEINSTRS_H
#define LLVM_CODEGEN_GLOBALISEL_GENERICMACHINEINSTRS_H
-#include "llvm/IR/Instructions.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineMemOperand.h"
#include "llvm/CodeGen/TargetOpcodes.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/Support/Casting.h"
namespace llvm {
@@ -856,6 +858,43 @@ class GTrunc : public GCastOp {
};
};
+/// Represents a vscale.
+class GVScale : public GenericMachineInstr {
+public:
+ APInt getSrc() const { return getOperand(1).getCImm()->getValue(); }
+
+ static bool classof(const MachineInstr *MI) {
+ return MI->getOpcode() == TargetOpcode::G_VSCALE;
+ };
+};
+
+/// Represents an integer subtraction.
+class GSub : public GIntBinOp {
+public:
+ static bool classof(const MachineInstr *MI) {
+ return MI->getOpcode() == TargetOpcode::G_SUB;
+ };
+};
+
+/// Represents an integer multiplication.
+class GMul : public GIntBinOp {
+public:
+ static bool classof(const MachineInstr *MI) {
+ return MI->getOpcode() == TargetOpcode::G_MUL;
+ };
+};
+
+/// Represents a shift left.
+class GSHL : public GenericMachineInstr {
+public:
+ Register getSrcReg() const { return getOperand(1).getReg(); }
+ Register getShiftReg() const { return getOperand(2).getReg(); }
+
+ static bool classof(const MachineInstr *MI) {
+ return MI->getOpcode() == TargetOpcode::G_SHL;
+ };
+};
+
} // namespace llvm
#endif // LLVM_CODEGEN_GLOBALISEL_GENERICMACHINEINSTRS_H
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 383589add7755..94abf10a033ce 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -1598,6 +1598,37 @@ def insert_vector_elt_oob : GICombineRule<
[{ return Helper.matchInsertVectorElementOOB(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
+def add_of_vscale : GICombineRule<
+ (defs root:$root, build_fn_matchinfo:$matchinfo),
+ (match (G_VSCALE $left, $imm1),
+ (G_VSCALE $right, $imm2),
+ (G_ADD $root, $left, $right, (MIFlags NoSWrap)),
+ [{ return Helper.matchAddOfVScale(${root}, ${matchinfo}); }]),
+ (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;
+
+def mul_of_vscale : GICombineRule<
+ (defs root:$root, build_fn_matchinfo:$matchinfo),
+ (match (G_VSCALE $left, $scale),
+ (G_CONSTANT $x, $imm1),
+ (G_MUL $root, $left, $x, (MIFlags NoSWrap)),
+ [{ return Helper.matchMulOfVScale(${root}, ${matchinfo}); }]),
+ (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;
+
+def shl_of_vscale : GICombineRule<
+ (defs root:$root, build_fn_matchinfo:$matchinfo),
+ (match (G_VSCALE $left, $imm),
+ (G_CONSTANT $x, $imm1),
+ (G_SHL $root, $left, $x, (MIFlags NoSWrap)),
+ [{ return Helper.matchShlOfVScale(${root}, ${matchinfo}); }]),
+ (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;
+
+def sub_of_vscale : GICombineRule<
+ (defs root:$root, build_fn_matchinfo:$matchinfo),
+ (match (G_VSCALE $right, $imm),
+ (G_SUB $root, $x, $right, (MIFlags NoSWrap)),
+ [{ return Helper.matchSubOfVScale(${root}, ${matchinfo}); }]),
+ (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;
+
// match_extract_of_element and insert_vector_elt_oob must be the first!
def vector_ops_combines: GICombineGroup<[
match_extract_of_element_undef_vector,
@@ -1630,7 +1661,11 @@ extract_vector_element_build_vector_trunc6,
extract_vector_element_build_vector_trunc7,
extract_vector_element_build_vector_trunc8,
extract_vector_element_shuffle_vector,
-insert_vector_element_extract_vector_element
+insert_vector_element_extract_vector_element,
+add_of_vscale,
+mul_of_vscale,
+shl_of_vscale,
+sub_of_vscale,
]>;
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
index b4765fb280f9d..62ee80f49b7b6 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
@@ -6,7 +6,8 @@
//
//===----------------------------------------------------------------------===//
//
-// This file implements CombinerHelper for G_EXTRACT_VECTOR_ELT.
+// This file implements CombinerHelper for G_EXTRACT_VECTOR_ELT,
+// G_INSERT_VECTOR_ELT, and G_VSCALE
//
//===----------------------------------------------------------------------===//
#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
@@ -400,3 +401,86 @@ bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI,
return false;
}
+
+bool CombinerHelper::matchAddOfVScale(const MachineOperand &MO,
+ BuildFnTy &MatchInfo) {
+ GAdd *Add = cast<GAdd>(MRI.getVRegDef(MO.getReg()));
+ GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Add->getLHSReg()));
+ GVScale *RHSVScale = cast<GVScale>(MRI.getVRegDef(Add->getRHSReg()));
+
+ Register Dst = Add->getReg(0);
+
+ if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)) ||
+ !MRI.hasOneNonDBGUse(RHSVScale->getReg(0)))
+ return false;
+
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.buildVScale(Dst, LHSVScale->getSrc() + RHSVScale->getSrc());
+ };
+
+ return true;
+}
+
+bool CombinerHelper::matchMulOfVScale(const MachineOperand &MO,
+ BuildFnTy &MatchInfo) {
+ GMul *Mul = cast<GMul>(MRI.getVRegDef(MO.getReg()));
+ GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Mul->getLHSReg()));
+
+ std::optional<APInt> MaybeRHS = getIConstantVRegVal(Mul->getRHSReg(), MRI);
+ if (!MaybeRHS)
+ return false;
+
+ Register Dst = MO.getReg();
+
+ if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)))
+ return false;
+
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.buildVScale(Dst, LHSVScale->getSrc() * *MaybeRHS);
+ };
+
+ return true;
+}
+
+bool CombinerHelper::matchSubOfVScale(const MachineOperand &MO,
+ BuildFnTy &MatchInfo) {
+ GSub *Sub = cast<GSub>(MRI.getVRegDef(MO.getReg()));
+ GVScale *RHSVScale = cast<GVScale>(MRI.getVRegDef(Sub->getRHSReg()));
+
+ Register Dst = MO.getReg();
+ LLT DstTy = MRI.getType(Dst);
+
+ if (!MRI.hasOneNonDBGUse(RHSVScale->getReg(0)) ||
+ !isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, DstTy}))
+ return false;
+
+ MatchInfo = [=](MachineIRBuilder &B) {
+ auto VScale = B.buildVScale(DstTy, -RHSVScale->getSrc());
+ B.buildAdd(Dst, Sub->getLHSReg(), VScale, Sub->getFlags());
+ };
+
+ return true;
+}
+
+bool CombinerHelper::matchShlOfVScale(const MachineOperand &MO,
+ BuildFnTy &MatchInfo) {
+ GSHL *Shl = cast<GSHL>(MRI.getVRegDef(MO.getReg()));
+ GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Shl->getSrcReg()));
+
+ std::optional<APInt> MaybeRHS = getIConstantVRegVal(Shl->getShiftReg(), MRI);
+ if (!MaybeRHS)
+ return false;
+
+ Register Dst = MO.getReg();
+ LLT DstTy = MRI.getType(Dst);
+
+ if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)) ||
+ !isLegalOrBeforeLegalizer({TargetOpcode::G_VSCALE, DstTy}))
+ return false;
+
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.buildVScale(Dst, LHSVScale->getSrc().shl(*MaybeRHS));
+ };
+
+ return true;
+}
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/combine-vscale.mir b/llvm/test/CodeGen/AArch64/GlobalISel/combine-vscale.mir
new file mode 100644
index 0000000000000..9b7a44954afdb
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/combine-vscale.mir
@@ -0,0 +1,113 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py
+# RUN: llc -o - -mtriple=aarch64 -run-pass=aarch64-prelegalizer-combiner -verify-machineinstrs %s | FileCheck %s
+
+...
+---
+name: sum_of_vscale
+body: |
+ bb.1:
+ liveins: $x0, $x1
+ ; CHECK-LABEL: name: sum_of_vscale
+ ; CHECK: liveins: $x0, $x1
+ ; CHECK-NEXT: {{ $}}
+ ; CHECK-NEXT: %sum:_(s64) = G_VSCALE i64 20
+ ; CHECK-NEXT: $x0 = COPY %sum(s64)
+ ; CHECK-NEXT: RET_ReallyLR implicit $x0
+ %rhs:_(s64) = G_VSCALE i64 11
+ %lhs:_(s64) = G_VSCALE i64 9
+ %sum:_(s64) = nsw G_ADD %lhs(s64), %rhs(s64)
+ $x0 = COPY %sum(s64)
+ RET_ReallyLR implicit $x0
+...
+---
+name: sum_of_vscale_multi_use
+body: |
+ bb.1:
+ liveins: $x0, $x1
+ ; CHECK-LABEL: name: sum_of_vscale_multi_use
+ ; CHECK: liveins: $x0, $x1
+ ; CHECK-NEXT: {{ $}}
+ ; CHECK-NEXT: %rhs:_(s64) = G_VSCALE i64 11
+ ; CHECK-NEXT: %lhs:_(s64) = G_VSCALE i64 9
+ ; CHECK-NEXT: %sum:_(s64) = nsw G_ADD %lhs, %rhs
+ ; CHECK-NEXT: $x0 = COPY %sum(s64)
+ ; CHECK-NEXT: $x1 = COPY %rhs(s64)
+ ; CHECK-NEXT: RET_ReallyLR implicit $x0
+ %rhs:_(s64) = G_VSCALE i64 11
+ %lhs:_(s64) = G_VSCALE i64 9
+ %sum:_(s64) = nsw G_ADD %lhs(s64), %rhs(s64)
+ $x0 = COPY %sum(s64)
+ $x1 = COPY %rhs(s64)
+ RET_ReallyLR implicit $x0
+...
+---
+name: mul_of_vscale
+body: |
+ bb.1:
+ liveins: $x0, $x1
+ ; CHECK-LABEL: name: mul_of_vscale
+ ; CHECK: liveins: $x0, $x1
+ ; CHECK-NEXT: {{ $}}
+ ; CHECK-NEXT: %mul:_(s64) = G_VSCALE i64 99
+ ; CHECK-NEXT: $x0 = COPY %mul(s64)
+ ; CHECK-NEXT: RET_ReallyLR implicit $x0
+ %rhs:_(s64) = G_CONSTANT i64 11
+ %lhs:_(s64) = G_VSCALE i64 9
+ %mul:_(s64) = nsw G_MUL %lhs(s64), %rhs(s64)
+ $x0 = COPY %mul(s64)
+ RET_ReallyLR implicit $x0
+...
+---
+name: sub_of_vscale
+body: |
+ bb.1:
+ liveins: $x0, $x1
+ ; CHECK-LABEL: name: sub_of_vscale
+ ; CHECK: liveins: $x0, $x1
+ ; CHECK-NEXT: {{ $}}
+ ; CHECK-NEXT: %x:_(s64) = COPY $x0
+ ; CHECK-NEXT: [[VSCALE:%[0-9]+]]:_(s64) = G_VSCALE i64 -9
+ ; CHECK-NEXT: %sub:_(s64) = nsw G_ADD %x, [[VSCALE]]
+ ; CHECK-NEXT: $x0 = COPY %sub(s64)
+ ; CHECK-NEXT: RET_ReallyLR implicit $x0
+ %x:_(s64) = COPY $x0
+ %rhs:_(s64) = G_VSCALE i64 9
+ %sub:_(s64) = nsw G_SUB %x(s64), %rhs(s64)
+ $x0 = COPY %sub(s64)
+ RET_ReallyLR implicit $x0
+...
+---
+name: shl_of_vscale
+body: |
+ bb.1:
+ liveins: $x0, $x1
+ ; CHECK-LABEL: name: shl_of_vscale
+ ; CHECK: liveins: $x0, $x1
+ ; CHECK-NEXT: {{ $}}
+ ; CHECK-NEXT: %shl:_(s64) = G_VSCALE i64 44
+ ; CHECK-NEXT: $x0 = COPY %shl(s64)
+ ; CHECK-NEXT: RET_ReallyLR implicit $x0
+ %rhs:_(s64) = G_CONSTANT i64 2
+ %lhs:_(s64) = G_VSCALE i64 11
+ %shl:_(s64) = nsw G_SHL %lhs(s64), %rhs(s64)
+ $x0 = COPY %shl(s64)
+ RET_ReallyLR implicit $x0
+...
+---
+name: shl_of_vscale_wrong_flag
+body: |
+ bb.1:
+ liveins: $x0, $x1
+ ; CHECK-LABEL: name: shl_of_vscale_wrong_flag
+ ; CHECK: liveins: $x0, $x1
+ ; CHECK-NEXT: {{ $}}
+ ; CHECK-NEXT: %rhs:_(s64) = G_CONSTANT i64 2
+ ; CHECK-NEXT: %lhs:_(s64) = G_VSCALE i64 11
+ ; CHECK-NEXT: %shl:_(s64) = nuw G_SHL %lhs, %rhs(s64)
+ ; CHECK-NEXT: $x0 = COPY %shl(s64)
+ ; CHECK-NEXT: RET_ReallyLR implicit $x0
+ %rhs:_(s64) = G_CONSTANT i64 2
+ %lhs:_(s64) = G_VSCALE i64 11
+ %shl:_(s64) = nuw G_SHL %lhs(s64), %rhs(s64)
+ $x0 = COPY %shl(s64)
+ RET_ReallyLR implicit $x0
>From b73b3fbb663de295304c563cb4433a86baba003b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thorsten=20Sch=C3=BCtt?= <schuett at gmail.com>
Date: Sat, 1 Jun 2024 08:58:56 +0200
Subject: [PATCH 2/2] address review comments
---
llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h | 2 +-
llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
index 36ae9beed8aa9..2273725637713 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
@@ -885,7 +885,7 @@ class GMul : public GIntBinOp {
};
/// Represents a shift left.
-class GSHL : public GenericMachineInstr {
+class GShl : public GenericMachineInstr {
public:
Register getSrcReg() const { return getOperand(1).getReg(); }
Register getShiftReg() const { return getOperand(2).getReg(); }
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
index 62ee80f49b7b6..66b1c5f8ca82c 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp
@@ -464,7 +464,7 @@ bool CombinerHelper::matchSubOfVScale(const MachineOperand &MO,
bool CombinerHelper::matchShlOfVScale(const MachineOperand &MO,
BuildFnTy &MatchInfo) {
- GSHL *Shl = cast<GSHL>(MRI.getVRegDef(MO.getReg()));
+ GShl *Shl = cast<GShl>(MRI.getVRegDef(MO.getReg()));
GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Shl->getSrcReg()));
std::optional<APInt> MaybeRHS = getIConstantVRegVal(Shl->getShiftReg(), MRI);
More information about the llvm-commits
mailing list