[flang-commits] [clang-tools-extra] [lldb] [clang] [llvm] [libc] [compiler-rt] [flang] [GlobalIsel] Combine select of binops (PR #76763)
Thorsten Schütt via flang-commits
flang-commits at lists.llvm.org
Wed Jan 3 06:42:30 PST 2024
https://github.com/tschuett updated https://github.com/llvm/llvm-project/pull/76763
>From e713bb6e2c36ec16c731217f0c3be19b040a03d0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thorsten=20Sch=C3=BCtt?= <schuett at gmail.com>
Date: Tue, 2 Jan 2024 18:00:45 +0100
Subject: [PATCH] [GlobalIsel] Combine select of binops
---
.../llvm/CodeGen/GlobalISel/CombinerHelper.h | 3 +
.../CodeGen/GlobalISel/GenericMachineInstrs.h | 103 ++++++++++++++++++
.../lib/CodeGen/GlobalISel/CombinerHelper.cpp | 91 +++++++++++-----
.../AArch64/GlobalISel/combine-select.mir | 74 +++++++++++++
4 files changed, 243 insertions(+), 28 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index dcc1a4580b14a2..f3b68623596c46 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -910,6 +910,9 @@ class CombinerHelper {
bool tryFoldSelectOfConstants(GSelect *Select, BuildFnTy &MatchInfo);
+ /// Try to fold select(cc, binop(), binop()) -> binop(select(), X)
+ bool tryFoldSelectOfBinOps(GSelect *Select, BuildFnTy &MatchInfo);
+
bool isOneOrOneSplat(Register Src, bool AllowUndefs);
bool isZeroOrZeroSplat(Register Src, bool AllowUndefs);
bool isConstantSplatVector(Register Src, int64_t SplatValue,
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
index 6ab1d4550c51ca..21d98d30356c93 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
@@ -558,6 +558,109 @@ class GVecReduce : public GenericMachineInstr {
}
};
+// Represents a binary operation, i.e, x = y op z.
+class GBinOp : public GenericMachineInstr {
+public:
+ Register getLHSReg() const { return getReg(1); }
+ Register getRHSReg() const { return getReg(2); }
+
+ static bool classof(const MachineInstr *MI) {
+ switch (MI->getOpcode()) {
+ // Integer.
+ case TargetOpcode::G_ADD:
+ case TargetOpcode::G_SUB:
+ case TargetOpcode::G_MUL:
+ case TargetOpcode::G_SDIV:
+ case TargetOpcode::G_UDIV:
+ case TargetOpcode::G_SREM:
+ case TargetOpcode::G_UREM:
+ case TargetOpcode::G_SMIN:
+ case TargetOpcode::G_SMAX:
+ case TargetOpcode::G_UMIN:
+ case TargetOpcode::G_UMAX:
+ // Floating point.
+ case TargetOpcode::G_FMINNUM:
+ case TargetOpcode::G_FMAXNUM:
+ case TargetOpcode::G_FMINNUM_IEEE:
+ case TargetOpcode::G_FMAXNUM_IEEE:
+ case TargetOpcode::G_FMINIMUM:
+ case TargetOpcode::G_FMAXIMUM:
+ case TargetOpcode::G_FADD:
+ case TargetOpcode::G_FSUB:
+ case TargetOpcode::G_FMUL:
+ case TargetOpcode::G_FDIV:
+ case TargetOpcode::G_FPOW:
+ // Logical.
+ case TargetOpcode::G_AND:
+ case TargetOpcode::G_OR:
+ case TargetOpcode::G_XOR:
+ return true;
+ default:
+ return false;
+ }
+ };
+};
+
+// Represents an integer binary operation.
+class GIntBinOp : public GBinOp {
+public:
+ static bool classof(const MachineInstr *MI) {
+ switch (MI->getOpcode()) {
+ case TargetOpcode::G_ADD:
+ case TargetOpcode::G_SUB:
+ case TargetOpcode::G_MUL:
+ case TargetOpcode::G_SDIV:
+ case TargetOpcode::G_UDIV:
+ case TargetOpcode::G_SREM:
+ case TargetOpcode::G_UREM:
+ case TargetOpcode::G_SMIN:
+ case TargetOpcode::G_SMAX:
+ case TargetOpcode::G_UMIN:
+ case TargetOpcode::G_UMAX:
+ return true;
+ default:
+ return false;
+ }
+ };
+};
+
+// Represents a floating point binary operation.
+class GFBinOp : public GBinOp {
+public:
+ static bool classof(const MachineInstr *MI) {
+ switch (MI->getOpcode()) {
+ case TargetOpcode::G_FMINNUM:
+ case TargetOpcode::G_FMAXNUM:
+ case TargetOpcode::G_FMINNUM_IEEE:
+ case TargetOpcode::G_FMAXNUM_IEEE:
+ case TargetOpcode::G_FMINIMUM:
+ case TargetOpcode::G_FMAXIMUM:
+ case TargetOpcode::G_FADD:
+ case TargetOpcode::G_FSUB:
+ case TargetOpcode::G_FMUL:
+ case TargetOpcode::G_FDIV:
+ case TargetOpcode::G_FPOW:
+ return true;
+ default:
+ return false;
+ }
+ };
+};
+
+// Represents a logical binary operation.
+class GLogicalBinOp : public GBinOp {
+public:
+ static bool classof(const MachineInstr *MI) {
+ switch (MI->getOpcode()) {
+ case TargetOpcode::G_AND:
+ case TargetOpcode::G_OR:
+ case TargetOpcode::G_XOR:
+ return true;
+ default:
+ return false;
+ }
+ };
+};
} // namespace llvm
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 8b15bdb0aca30b..102b49c48460b1 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -6390,8 +6390,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isZero() && FalseValue.isOne()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
- Register Inner = MRI.createGenericVirtualRegister(CondTy);
- B.buildNot(Inner, Cond);
+ auto Inner = B.buildNot(CondTy, Cond);
B.buildZExtOrTrunc(Dest, Inner);
};
return true;
@@ -6401,8 +6400,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isZero() && FalseValue.isAllOnes()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
- Register Inner = MRI.createGenericVirtualRegister(CondTy);
- B.buildNot(Inner, Cond);
+ auto Inner = B.buildNot(CondTy, Cond);
B.buildSExtOrTrunc(Dest, Inner);
};
return true;
@@ -6412,8 +6410,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue - 1 == FalseValue) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
- Register Inner = MRI.createGenericVirtualRegister(TrueTy);
- B.buildZExtOrTrunc(Inner, Cond);
+ auto Inner = B.buildZExtOrTrunc(TrueTy, Cond);
B.buildAdd(Dest, Inner, False);
};
return true;
@@ -6423,8 +6420,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue + 1 == FalseValue) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
- Register Inner = MRI.createGenericVirtualRegister(TrueTy);
- B.buildSExtOrTrunc(Inner, Cond);
+ auto Inner = B.buildSExtOrTrunc(TrueTy, Cond);
B.buildAdd(Dest, Inner, False);
};
return true;
@@ -6434,8 +6430,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
- Register Inner = MRI.createGenericVirtualRegister(TrueTy);
- B.buildZExtOrTrunc(Inner, Cond);
+ auto Inner = B.buildZExtOrTrunc(TrueTy, Cond);
// The shift amount must be scalar.
LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2());
@@ -6447,8 +6442,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isAllOnes()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
- Register Inner = MRI.createGenericVirtualRegister(TrueTy);
- B.buildSExtOrTrunc(Inner, Cond);
+ auto Inner = B.buildSExtOrTrunc(TrueTy, Cond);
B.buildOr(Dest, Inner, False, Flags);
};
return true;
@@ -6458,10 +6452,8 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (FalseValue.isAllOnes()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
- Register Not = MRI.createGenericVirtualRegister(CondTy);
- B.buildNot(Not, Cond);
- Register Inner = MRI.createGenericVirtualRegister(TrueTy);
- B.buildSExtOrTrunc(Inner, Not);
+ auto Not = B.buildNot(CondTy, Cond);
+ auto Inner = B.buildSExtOrTrunc(TrueTy, Not);
B.buildOr(Dest, Inner, True, Flags);
};
return true;
@@ -6496,8 +6488,7 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
- Register Ext = MRI.createGenericVirtualRegister(TrueTy);
- B.buildZExtOrTrunc(Ext, Cond);
+ auto Ext = B.buildZExtOrTrunc(TrueTy, Cond);
B.buildOr(DstReg, Ext, False, Flags);
};
return true;
@@ -6508,8 +6499,7 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
- Register Ext = MRI.createGenericVirtualRegister(TrueTy);
- B.buildZExtOrTrunc(Ext, Cond);
+ auto Ext = B.buildZExtOrTrunc(TrueTy, Cond);
B.buildAnd(DstReg, Ext, True);
};
return true;
@@ -6520,11 +6510,9 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
// First the not.
- Register Inner = MRI.createGenericVirtualRegister(CondTy);
- B.buildNot(Inner, Cond);
+ auto Inner = B.buildNot(CondTy, Cond);
// Then an ext to match the destination register.
- Register Ext = MRI.createGenericVirtualRegister(TrueTy);
- B.buildZExtOrTrunc(Ext, Inner);
+ auto Ext = B.buildZExtOrTrunc(TrueTy, Inner);
B.buildOr(DstReg, Ext, True, Flags);
};
return true;
@@ -6535,11 +6523,9 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
// First the not.
- Register Inner = MRI.createGenericVirtualRegister(CondTy);
- B.buildNot(Inner, Cond);
+ auto Inner = B.buildNot(CondTy, Cond);
// Then an ext to match the destination register.
- Register Ext = MRI.createGenericVirtualRegister(TrueTy);
- B.buildZExtOrTrunc(Ext, Inner);
+ auto Ext = B.buildZExtOrTrunc(TrueTy, Inner);
B.buildAnd(DstReg, Ext, False);
};
return true;
@@ -6548,6 +6534,52 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
return false;
}
+bool CombinerHelper::tryFoldSelectOfBinOps(GSelect *Select,
+ BuildFnTy &MatchInfo) {
+ Register DstReg = Select->getReg(0);
+ Register Cond = Select->getCondReg();
+ Register False = Select->getFalseReg();
+ Register True = Select->getTrueReg();
+ LLT DstTy = MRI.getType(DstReg);
+
+ GBinOp *LHS = getOpcodeDef<GBinOp>(True, MRI);
+ GBinOp *RHS = getOpcodeDef<GBinOp>(False, MRI);
+
+ // We need two binops of the same kind on the true/false registers.
+ if (!LHS || !RHS || LHS->getOpcode() != RHS->getOpcode())
+ return false;
+
+ // Note that there are no constraints on CondTy.
+ unsigned Flags = LHS->getFlags() & RHS->getFlags();
+ unsigned Opcode = LHS->getOpcode();
+
+ // Fold select(cond, binop(x, y), binop(z, y))
+ // --> binop(select(cond, x, z), y)
+ if (LHS->getRHSReg() == RHS->getRHSReg()) {
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.setInstrAndDebugLoc(*Select);
+ auto Sel = B.buildSelect(DstTy, Cond, LHS->getLHSReg(), RHS->getLHSReg());
+ B.buildInstr(Opcode, {DstReg}, {Sel, LHS->getRHSReg()}, Flags);
+ };
+ return true;
+ }
+
+ // Fold select(cond, binop(x, y), binop(x, z))
+ // --> binop(x, select(cond, y, z))
+ if (LHS->getLHSReg() == RHS->getLHSReg()) {
+ MatchInfo = [=](MachineIRBuilder &B) {
+ B.setInstrAndDebugLoc(*Select);
+ auto Sel = B.buildSelect(DstTy, Cond, LHS->getRHSReg(), RHS->getRHSReg());
+ B.buildInstr(Opcode, {DstReg}, {LHS->getLHSReg(), Sel}, Flags);
+ };
+ return true;
+ }
+
+ // FIXME: use isCommutable().
+
+ return false;
+}
+
bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
GSelect *Select = cast<GSelect>(&MI);
@@ -6557,5 +6589,8 @@ bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
if (tryFoldBoolSelectToLogic(Select, MatchInfo))
return true;
+ if (tryFoldSelectOfBinOps(Select, MatchInfo))
+ return true;
+
return false;
}
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir b/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
index be2de620fa456c..7644443f53dc68 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
@@ -544,3 +544,77 @@ body: |
%ext:_(s32) = G_ANYEXT %sel
$w0 = COPY %ext(s32)
...
+---
+# select cond, and(x, y), and(z, y) --> and (select, x, z), y
+name: select_cond_and_x_y_and_z_y_and_select_x_z_y
+body: |
+ bb.1:
+ liveins: $x0, $x1, $x2
+ ; CHECK-LABEL: name: select_cond_and_x_y_and_z_y_and_select_x_z_y
+ ; CHECK: liveins: $x0, $x1, $x2
+ ; CHECK-NEXT: {{ $}}
+ ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+ ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(s64) = COPY $x1
+ ; CHECK-NEXT: [[COPY2:%[0-9]+]]:_(s64) = COPY $x2
+ ; CHECK-NEXT: [[COPY3:%[0-9]+]]:_(s64) = COPY $x3
+ ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+ ; CHECK-NEXT: %a:_(s8) = G_TRUNC [[COPY1]](s64)
+ ; CHECK-NEXT: %b:_(s8) = G_TRUNC [[COPY2]](s64)
+ ; CHECK-NEXT: %d:_(s8) = G_TRUNC [[COPY3]](s64)
+ ; CHECK-NEXT: [[SELECT:%[0-9]+]]:_(s8) = G_SELECT %c(s1), %a, %d
+ ; CHECK-NEXT: %sel:_(s8) = G_AND [[SELECT]], %b
+ ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT %sel(s8)
+ ; CHECK-NEXT: $w0 = COPY %ext(s32)
+ %0:_(s64) = COPY $x0
+ %1:_(s64) = COPY $x1
+ %2:_(s64) = COPY $x2
+ %3:_(s64) = COPY $x3
+ %4:_(s64) = COPY $x4
+ %c:_(s1) = G_TRUNC %0
+ %a:_(s8) = G_TRUNC %1
+ %b:_(s8) = G_TRUNC %2
+ %d:_(s8) = G_TRUNC %3
+ %e:_(s8) = G_TRUNC %4
+ %and1:_(s8) = G_AND %a, %b
+ %and2:_(s8) = G_AND %d, %b
+ %sel:_(s8) = G_SELECT %c, %and1, %and2
+ %ext:_(s32) = G_ANYEXT %sel
+ $w0 = COPY %ext(s32)
+...
+---
+# select cond, xor(x, y), xor(x, z) --> xor x, select, x, z)
+name: select_cond_xor_x_y_xor_x_z_xor_x__select_x_y
+body: |
+ bb.1:
+ liveins: $x0, $x1, $x2
+ ; CHECK-LABEL: name: select_cond_xor_x_y_xor_x_z_xor_x__select_x_y
+ ; CHECK: liveins: $x0, $x1, $x2
+ ; CHECK-NEXT: {{ $}}
+ ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+ ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(s64) = COPY $x1
+ ; CHECK-NEXT: [[COPY2:%[0-9]+]]:_(s64) = COPY $x3
+ ; CHECK-NEXT: [[COPY3:%[0-9]+]]:_(s64) = COPY $x4
+ ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+ ; CHECK-NEXT: %a:_(s8) = G_TRUNC [[COPY1]](s64)
+ ; CHECK-NEXT: %d:_(s8) = G_TRUNC [[COPY2]](s64)
+ ; CHECK-NEXT: %e:_(s8) = G_TRUNC [[COPY3]](s64)
+ ; CHECK-NEXT: [[SELECT:%[0-9]+]]:_(s8) = G_SELECT %c(s1), %e, %d
+ ; CHECK-NEXT: %sel:_(s8) = G_XOR %a, [[SELECT]]
+ ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT %sel(s8)
+ ; CHECK-NEXT: $w0 = COPY %ext(s32)
+ %0:_(s64) = COPY $x0
+ %1:_(s64) = COPY $x1
+ %2:_(s64) = COPY $x2
+ %3:_(s64) = COPY $x3
+ %4:_(s64) = COPY $x4
+ %c:_(s1) = G_TRUNC %0
+ %a:_(s8) = G_TRUNC %1
+ %b:_(s8) = G_TRUNC %2
+ %d:_(s8) = G_TRUNC %3
+ %e:_(s8) = G_TRUNC %4
+ %xor1:_(s8) = G_XOR %a, %e
+ %xor2:_(s8) = G_XOR %a, %d
+ %sel:_(s8) = G_SELECT %c, %xor1, %xor2
+ %ext:_(s32) = G_ANYEXT %sel
+ $w0 = COPY %ext(s32)
+...
More information about the flang-commits
mailing list