[clang-tools-extra] [lldb] [clang] [llvm] [libc] [compiler-rt] [flang] [GlobalIsel] Combine select of binops (PR #76763)

Thorsten Schütt via cfe-commits cfe-commits at lists.llvm.org
Wed Jan 3 06:57:03 PST 2024


https://github.com/tschuett updated https://github.com/llvm/llvm-project/pull/76763

>From e713bb6e2c36ec16c731217f0c3be19b040a03d0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thorsten=20Sch=C3=BCtt?= <schuett at gmail.com>
Date: Tue, 2 Jan 2024 18:00:45 +0100
Subject: [PATCH 1/4] [GlobalIsel] Combine select of binops

---
 .../llvm/CodeGen/GlobalISel/CombinerHelper.h  |   3 +
 .../CodeGen/GlobalISel/GenericMachineInstrs.h | 103 ++++++++++++++++++
 .../lib/CodeGen/GlobalISel/CombinerHelper.cpp |  91 +++++++++++-----
 .../AArch64/GlobalISel/combine-select.mir     |  74 +++++++++++++
 4 files changed, 243 insertions(+), 28 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index dcc1a4580b14a2..f3b68623596c46 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -910,6 +910,9 @@ class CombinerHelper {
 
   bool tryFoldSelectOfConstants(GSelect *Select, BuildFnTy &MatchInfo);
 
+  /// Try to fold select(cc, binop(), binop()) -> binop(select(), X)
+  bool tryFoldSelectOfBinOps(GSelect *Select, BuildFnTy &MatchInfo);
+
   bool isOneOrOneSplat(Register Src, bool AllowUndefs);
   bool isZeroOrZeroSplat(Register Src, bool AllowUndefs);
   bool isConstantSplatVector(Register Src, int64_t SplatValue,
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
index 6ab1d4550c51ca..21d98d30356c93 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
@@ -558,6 +558,109 @@ class GVecReduce : public GenericMachineInstr {
   }
 };
 
+// Represents a binary operation, i.e, x = y op z.
+class GBinOp : public GenericMachineInstr {
+public:
+  Register getLHSReg() const { return getReg(1); }
+  Register getRHSReg() const { return getReg(2); }
+
+  static bool classof(const MachineInstr *MI) {
+    switch (MI->getOpcode()) {
+    // Integer.
+    case TargetOpcode::G_ADD:
+    case TargetOpcode::G_SUB:
+    case TargetOpcode::G_MUL:
+    case TargetOpcode::G_SDIV:
+    case TargetOpcode::G_UDIV:
+    case TargetOpcode::G_SREM:
+    case TargetOpcode::G_UREM:
+    case TargetOpcode::G_SMIN:
+    case TargetOpcode::G_SMAX:
+    case TargetOpcode::G_UMIN:
+    case TargetOpcode::G_UMAX:
+    // Floating point.
+    case TargetOpcode::G_FMINNUM:
+    case TargetOpcode::G_FMAXNUM:
+    case TargetOpcode::G_FMINNUM_IEEE:
+    case TargetOpcode::G_FMAXNUM_IEEE:
+    case TargetOpcode::G_FMINIMUM:
+    case TargetOpcode::G_FMAXIMUM:
+    case TargetOpcode::G_FADD:
+    case TargetOpcode::G_FSUB:
+    case TargetOpcode::G_FMUL:
+    case TargetOpcode::G_FDIV:
+    case TargetOpcode::G_FPOW:
+    // Logical.
+    case TargetOpcode::G_AND:
+    case TargetOpcode::G_OR:
+    case TargetOpcode::G_XOR:
+      return true;
+    default:
+      return false;
+    }
+  };
+};
+
+// Represents an integer binary operation.
+class GIntBinOp : public GBinOp {
+public:
+  static bool classof(const MachineInstr *MI) {
+    switch (MI->getOpcode()) {
+    case TargetOpcode::G_ADD:
+    case TargetOpcode::G_SUB:
+    case TargetOpcode::G_MUL:
+    case TargetOpcode::G_SDIV:
+    case TargetOpcode::G_UDIV:
+    case TargetOpcode::G_SREM:
+    case TargetOpcode::G_UREM:
+    case TargetOpcode::G_SMIN:
+    case TargetOpcode::G_SMAX:
+    case TargetOpcode::G_UMIN:
+    case TargetOpcode::G_UMAX:
+      return true;
+    default:
+      return false;
+    }
+  };
+};
+
+// Represents a floating point binary operation.
+class GFBinOp : public GBinOp {
+public:
+  static bool classof(const MachineInstr *MI) {
+    switch (MI->getOpcode()) {
+    case TargetOpcode::G_FMINNUM:
+    case TargetOpcode::G_FMAXNUM:
+    case TargetOpcode::G_FMINNUM_IEEE:
+    case TargetOpcode::G_FMAXNUM_IEEE:
+    case TargetOpcode::G_FMINIMUM:
+    case TargetOpcode::G_FMAXIMUM:
+    case TargetOpcode::G_FADD:
+    case TargetOpcode::G_FSUB:
+    case TargetOpcode::G_FMUL:
+    case TargetOpcode::G_FDIV:
+    case TargetOpcode::G_FPOW:
+      return true;
+    default:
+      return false;
+    }
+  };
+};
+
+// Represents a logical binary operation.
+class GLogicalBinOp : public GBinOp {
+public:
+  static bool classof(const MachineInstr *MI) {
+    switch (MI->getOpcode()) {
+    case TargetOpcode::G_AND:
+    case TargetOpcode::G_OR:
+    case TargetOpcode::G_XOR:
+      return true;
+    default:
+      return false;
+    }
+  };
+};
 
 } // namespace llvm
 
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 8b15bdb0aca30b..102b49c48460b1 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -6390,8 +6390,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
   if (TrueValue.isZero() && FalseValue.isOne()) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Inner = MRI.createGenericVirtualRegister(CondTy);
-      B.buildNot(Inner, Cond);
+      auto Inner = B.buildNot(CondTy, Cond);
       B.buildZExtOrTrunc(Dest, Inner);
     };
     return true;
@@ -6401,8 +6400,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
   if (TrueValue.isZero() && FalseValue.isAllOnes()) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Inner = MRI.createGenericVirtualRegister(CondTy);
-      B.buildNot(Inner, Cond);
+      auto Inner = B.buildNot(CondTy, Cond);
       B.buildSExtOrTrunc(Dest, Inner);
     };
     return true;
@@ -6412,8 +6410,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
   if (TrueValue - 1 == FalseValue) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildZExtOrTrunc(Inner, Cond);
+      auto Inner = B.buildZExtOrTrunc(TrueTy, Cond);
       B.buildAdd(Dest, Inner, False);
     };
     return true;
@@ -6423,8 +6420,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
   if (TrueValue + 1 == FalseValue) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildSExtOrTrunc(Inner, Cond);
+      auto Inner = B.buildSExtOrTrunc(TrueTy, Cond);
       B.buildAdd(Dest, Inner, False);
     };
     return true;
@@ -6434,8 +6430,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
   if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildZExtOrTrunc(Inner, Cond);
+      auto Inner = B.buildZExtOrTrunc(TrueTy, Cond);
       // The shift amount must be scalar.
       LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
       auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2());
@@ -6447,8 +6442,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
   if (TrueValue.isAllOnes()) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildSExtOrTrunc(Inner, Cond);
+      auto Inner = B.buildSExtOrTrunc(TrueTy, Cond);
       B.buildOr(Dest, Inner, False, Flags);
     };
     return true;
@@ -6458,10 +6452,8 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
   if (FalseValue.isAllOnes()) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Not = MRI.createGenericVirtualRegister(CondTy);
-      B.buildNot(Not, Cond);
-      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildSExtOrTrunc(Inner, Not);
+      auto Not = B.buildNot(CondTy, Cond);
+      auto Inner = B.buildSExtOrTrunc(TrueTy, Not);
       B.buildOr(Dest, Inner, True, Flags);
     };
     return true;
@@ -6496,8 +6488,7 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
   if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildZExtOrTrunc(Ext, Cond);
+      auto Ext = B.buildZExtOrTrunc(TrueTy, Cond);
       B.buildOr(DstReg, Ext, False, Flags);
     };
     return true;
@@ -6508,8 +6499,7 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
   if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildZExtOrTrunc(Ext, Cond);
+      auto Ext = B.buildZExtOrTrunc(TrueTy, Cond);
       B.buildAnd(DstReg, Ext, True);
     };
     return true;
@@ -6520,11 +6510,9 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
       // First the not.
-      Register Inner = MRI.createGenericVirtualRegister(CondTy);
-      B.buildNot(Inner, Cond);
+      auto Inner = B.buildNot(CondTy, Cond);
       // Then an ext to match the destination register.
-      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildZExtOrTrunc(Ext, Inner);
+      auto Ext = B.buildZExtOrTrunc(TrueTy, Inner);
       B.buildOr(DstReg, Ext, True, Flags);
     };
     return true;
@@ -6535,11 +6523,9 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
       // First the not.
-      Register Inner = MRI.createGenericVirtualRegister(CondTy);
-      B.buildNot(Inner, Cond);
+      auto Inner = B.buildNot(CondTy, Cond);
       // Then an ext to match the destination register.
-      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildZExtOrTrunc(Ext, Inner);
+      auto Ext = B.buildZExtOrTrunc(TrueTy, Inner);
       B.buildAnd(DstReg, Ext, False);
     };
     return true;
@@ -6548,6 +6534,52 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
   return false;
 }
 
+bool CombinerHelper::tryFoldSelectOfBinOps(GSelect *Select,
+                                           BuildFnTy &MatchInfo) {
+  Register DstReg = Select->getReg(0);
+  Register Cond = Select->getCondReg();
+  Register False = Select->getFalseReg();
+  Register True = Select->getTrueReg();
+  LLT DstTy = MRI.getType(DstReg);
+
+  GBinOp *LHS = getOpcodeDef<GBinOp>(True, MRI);
+  GBinOp *RHS = getOpcodeDef<GBinOp>(False, MRI);
+
+  // We need two binops of the same kind on the true/false registers.
+  if (!LHS || !RHS || LHS->getOpcode() != RHS->getOpcode())
+    return false;
+
+  // Note that there are no constraints on CondTy.
+  unsigned Flags = LHS->getFlags() & RHS->getFlags();
+  unsigned Opcode = LHS->getOpcode();
+
+  // Fold select(cond, binop(x, y), binop(z, y))
+  //  --> binop(select(cond, x, z), y)
+  if (LHS->getRHSReg() == RHS->getRHSReg()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      auto Sel = B.buildSelect(DstTy, Cond, LHS->getLHSReg(), RHS->getLHSReg());
+      B.buildInstr(Opcode, {DstReg}, {Sel, LHS->getRHSReg()}, Flags);
+    };
+    return true;
+  }
+
+  // Fold select(cond, binop(x, y), binop(x, z))
+  //  --> binop(x, select(cond, y, z))
+  if (LHS->getLHSReg() == RHS->getLHSReg()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      auto Sel = B.buildSelect(DstTy, Cond, LHS->getRHSReg(), RHS->getRHSReg());
+      B.buildInstr(Opcode, {DstReg}, {LHS->getLHSReg(), Sel}, Flags);
+    };
+    return true;
+  }
+
+  // FIXME: use isCommutable().
+
+  return false;
+}
+
 bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
   GSelect *Select = cast<GSelect>(&MI);
 
@@ -6557,5 +6589,8 @@ bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
   if (tryFoldBoolSelectToLogic(Select, MatchInfo))
     return true;
 
+  if (tryFoldSelectOfBinOps(Select, MatchInfo))
+    return true;
+
   return false;
 }
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir b/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
index be2de620fa456c..7644443f53dc68 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
@@ -544,3 +544,77 @@ body:             |
     %ext:_(s32) = G_ANYEXT %sel
     $w0 = COPY %ext(s32)
 ...
+---
+# select cond, and(x, y), and(z, y) --> and (select, x, z), y
+name:            select_cond_and_x_y_and_z_y_and_select_x_z_y
+body:             |
+  bb.1:
+    liveins: $x0, $x1, $x2
+    ; CHECK-LABEL: name: select_cond_and_x_y_and_z_y_and_select_x_z_y
+    ; CHECK: liveins: $x0, $x1, $x2
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(s64) = COPY $x1
+    ; CHECK-NEXT: [[COPY2:%[0-9]+]]:_(s64) = COPY $x2
+    ; CHECK-NEXT: [[COPY3:%[0-9]+]]:_(s64) = COPY $x3
+    ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: %a:_(s8) = G_TRUNC [[COPY1]](s64)
+    ; CHECK-NEXT: %b:_(s8) = G_TRUNC [[COPY2]](s64)
+    ; CHECK-NEXT: %d:_(s8) = G_TRUNC [[COPY3]](s64)
+    ; CHECK-NEXT: [[SELECT:%[0-9]+]]:_(s8) = G_SELECT %c(s1), %a, %d
+    ; CHECK-NEXT: %sel:_(s8) = G_AND [[SELECT]], %b
+    ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT %sel(s8)
+    ; CHECK-NEXT: $w0 = COPY %ext(s32)
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = COPY $x1
+    %2:_(s64) = COPY $x2
+    %3:_(s64) = COPY $x3
+    %4:_(s64) = COPY $x4
+    %c:_(s1) = G_TRUNC %0
+    %a:_(s8) = G_TRUNC %1
+    %b:_(s8) = G_TRUNC %2
+    %d:_(s8) = G_TRUNC %3
+    %e:_(s8) = G_TRUNC %4
+    %and1:_(s8) = G_AND %a, %b
+    %and2:_(s8) = G_AND %d, %b
+    %sel:_(s8) = G_SELECT %c, %and1, %and2
+    %ext:_(s32) = G_ANYEXT %sel
+    $w0 = COPY %ext(s32)
+...
+---
+# select cond, xor(x, y), xor(x, z) --> xor x, select, x, z)
+name:            select_cond_xor_x_y_xor_x_z_xor_x__select_x_y
+body:             |
+  bb.1:
+    liveins: $x0, $x1, $x2
+    ; CHECK-LABEL: name: select_cond_xor_x_y_xor_x_z_xor_x__select_x_y
+    ; CHECK: liveins: $x0, $x1, $x2
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(s64) = COPY $x1
+    ; CHECK-NEXT: [[COPY2:%[0-9]+]]:_(s64) = COPY $x3
+    ; CHECK-NEXT: [[COPY3:%[0-9]+]]:_(s64) = COPY $x4
+    ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: %a:_(s8) = G_TRUNC [[COPY1]](s64)
+    ; CHECK-NEXT: %d:_(s8) = G_TRUNC [[COPY2]](s64)
+    ; CHECK-NEXT: %e:_(s8) = G_TRUNC [[COPY3]](s64)
+    ; CHECK-NEXT: [[SELECT:%[0-9]+]]:_(s8) = G_SELECT %c(s1), %e, %d
+    ; CHECK-NEXT: %sel:_(s8) = G_XOR %a, [[SELECT]]
+    ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT %sel(s8)
+    ; CHECK-NEXT: $w0 = COPY %ext(s32)
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = COPY $x1
+    %2:_(s64) = COPY $x2
+    %3:_(s64) = COPY $x3
+    %4:_(s64) = COPY $x4
+    %c:_(s1) = G_TRUNC %0
+    %a:_(s8) = G_TRUNC %1
+    %b:_(s8) = G_TRUNC %2
+    %d:_(s8) = G_TRUNC %3
+    %e:_(s8) = G_TRUNC %4
+    %xor1:_(s8) = G_XOR %a, %e
+    %xor2:_(s8) = G_XOR %a, %d
+    %sel:_(s8) = G_SELECT %c, %xor1, %xor2
+    %ext:_(s32) = G_ANYEXT %sel
+    $w0 = COPY %ext(s32)
+...

>From 4eab2545871644fb110933187488e5cee0fcc37b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thorsten=20Sch=C3=BCtt?= <schuett at gmail.com>
Date: Tue, 2 Jan 2024 18:00:45 +0100
Subject: [PATCH 2/4] [GlobalIsel] Combine select of binops

---
 .../llvm/CodeGen/GlobalISel/CombinerHelper.h  |   3 +
 .../CodeGen/GlobalISel/GenericMachineInstrs.h | 103 ++++++++++++++++++
 .../lib/CodeGen/GlobalISel/CombinerHelper.cpp |  91 +++++++++++-----
 .../AArch64/GlobalISel/combine-select.mir     |  74 +++++++++++++
 4 files changed, 243 insertions(+), 28 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index dcc1a4580b14a2..f3b68623596c46 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -910,6 +910,9 @@ class CombinerHelper {
 
   bool tryFoldSelectOfConstants(GSelect *Select, BuildFnTy &MatchInfo);
 
+  /// Try to fold select(cc, binop(), binop()) -> binop(select(), X)
+  bool tryFoldSelectOfBinOps(GSelect *Select, BuildFnTy &MatchInfo);
+
   bool isOneOrOneSplat(Register Src, bool AllowUndefs);
   bool isZeroOrZeroSplat(Register Src, bool AllowUndefs);
   bool isConstantSplatVector(Register Src, int64_t SplatValue,
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
index 6ab1d4550c51ca..21d98d30356c93 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
@@ -558,6 +558,109 @@ class GVecReduce : public GenericMachineInstr {
   }
 };
 
+// Represents a binary operation, i.e, x = y op z.
+class GBinOp : public GenericMachineInstr {
+public:
+  Register getLHSReg() const { return getReg(1); }
+  Register getRHSReg() const { return getReg(2); }
+
+  static bool classof(const MachineInstr *MI) {
+    switch (MI->getOpcode()) {
+    // Integer.
+    case TargetOpcode::G_ADD:
+    case TargetOpcode::G_SUB:
+    case TargetOpcode::G_MUL:
+    case TargetOpcode::G_SDIV:
+    case TargetOpcode::G_UDIV:
+    case TargetOpcode::G_SREM:
+    case TargetOpcode::G_UREM:
+    case TargetOpcode::G_SMIN:
+    case TargetOpcode::G_SMAX:
+    case TargetOpcode::G_UMIN:
+    case TargetOpcode::G_UMAX:
+    // Floating point.
+    case TargetOpcode::G_FMINNUM:
+    case TargetOpcode::G_FMAXNUM:
+    case TargetOpcode::G_FMINNUM_IEEE:
+    case TargetOpcode::G_FMAXNUM_IEEE:
+    case TargetOpcode::G_FMINIMUM:
+    case TargetOpcode::G_FMAXIMUM:
+    case TargetOpcode::G_FADD:
+    case TargetOpcode::G_FSUB:
+    case TargetOpcode::G_FMUL:
+    case TargetOpcode::G_FDIV:
+    case TargetOpcode::G_FPOW:
+    // Logical.
+    case TargetOpcode::G_AND:
+    case TargetOpcode::G_OR:
+    case TargetOpcode::G_XOR:
+      return true;
+    default:
+      return false;
+    }
+  };
+};
+
+// Represents an integer binary operation.
+class GIntBinOp : public GBinOp {
+public:
+  static bool classof(const MachineInstr *MI) {
+    switch (MI->getOpcode()) {
+    case TargetOpcode::G_ADD:
+    case TargetOpcode::G_SUB:
+    case TargetOpcode::G_MUL:
+    case TargetOpcode::G_SDIV:
+    case TargetOpcode::G_UDIV:
+    case TargetOpcode::G_SREM:
+    case TargetOpcode::G_UREM:
+    case TargetOpcode::G_SMIN:
+    case TargetOpcode::G_SMAX:
+    case TargetOpcode::G_UMIN:
+    case TargetOpcode::G_UMAX:
+      return true;
+    default:
+      return false;
+    }
+  };
+};
+
+// Represents a floating point binary operation.
+class GFBinOp : public GBinOp {
+public:
+  static bool classof(const MachineInstr *MI) {
+    switch (MI->getOpcode()) {
+    case TargetOpcode::G_FMINNUM:
+    case TargetOpcode::G_FMAXNUM:
+    case TargetOpcode::G_FMINNUM_IEEE:
+    case TargetOpcode::G_FMAXNUM_IEEE:
+    case TargetOpcode::G_FMINIMUM:
+    case TargetOpcode::G_FMAXIMUM:
+    case TargetOpcode::G_FADD:
+    case TargetOpcode::G_FSUB:
+    case TargetOpcode::G_FMUL:
+    case TargetOpcode::G_FDIV:
+    case TargetOpcode::G_FPOW:
+      return true;
+    default:
+      return false;
+    }
+  };
+};
+
+// Represents a logical binary operation.
+class GLogicalBinOp : public GBinOp {
+public:
+  static bool classof(const MachineInstr *MI) {
+    switch (MI->getOpcode()) {
+    case TargetOpcode::G_AND:
+    case TargetOpcode::G_OR:
+    case TargetOpcode::G_XOR:
+      return true;
+    default:
+      return false;
+    }
+  };
+};
 
 } // namespace llvm
 
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 8b15bdb0aca30b..102b49c48460b1 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -6390,8 +6390,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
   if (TrueValue.isZero() && FalseValue.isOne()) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Inner = MRI.createGenericVirtualRegister(CondTy);
-      B.buildNot(Inner, Cond);
+      auto Inner = B.buildNot(CondTy, Cond);
       B.buildZExtOrTrunc(Dest, Inner);
     };
     return true;
@@ -6401,8 +6400,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
   if (TrueValue.isZero() && FalseValue.isAllOnes()) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Inner = MRI.createGenericVirtualRegister(CondTy);
-      B.buildNot(Inner, Cond);
+      auto Inner = B.buildNot(CondTy, Cond);
       B.buildSExtOrTrunc(Dest, Inner);
     };
     return true;
@@ -6412,8 +6410,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
   if (TrueValue - 1 == FalseValue) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildZExtOrTrunc(Inner, Cond);
+      auto Inner = B.buildZExtOrTrunc(TrueTy, Cond);
       B.buildAdd(Dest, Inner, False);
     };
     return true;
@@ -6423,8 +6420,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
   if (TrueValue + 1 == FalseValue) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildSExtOrTrunc(Inner, Cond);
+      auto Inner = B.buildSExtOrTrunc(TrueTy, Cond);
       B.buildAdd(Dest, Inner, False);
     };
     return true;
@@ -6434,8 +6430,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
   if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildZExtOrTrunc(Inner, Cond);
+      auto Inner = B.buildZExtOrTrunc(TrueTy, Cond);
       // The shift amount must be scalar.
       LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
       auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2());
@@ -6447,8 +6442,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
   if (TrueValue.isAllOnes()) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildSExtOrTrunc(Inner, Cond);
+      auto Inner = B.buildSExtOrTrunc(TrueTy, Cond);
       B.buildOr(Dest, Inner, False, Flags);
     };
     return true;
@@ -6458,10 +6452,8 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
   if (FalseValue.isAllOnes()) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Not = MRI.createGenericVirtualRegister(CondTy);
-      B.buildNot(Not, Cond);
-      Register Inner = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildSExtOrTrunc(Inner, Not);
+      auto Not = B.buildNot(CondTy, Cond);
+      auto Inner = B.buildSExtOrTrunc(TrueTy, Not);
       B.buildOr(Dest, Inner, True, Flags);
     };
     return true;
@@ -6496,8 +6488,7 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
   if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildZExtOrTrunc(Ext, Cond);
+      auto Ext = B.buildZExtOrTrunc(TrueTy, Cond);
       B.buildOr(DstReg, Ext, False, Flags);
     };
     return true;
@@ -6508,8 +6499,7 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
   if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildZExtOrTrunc(Ext, Cond);
+      auto Ext = B.buildZExtOrTrunc(TrueTy, Cond);
       B.buildAnd(DstReg, Ext, True);
     };
     return true;
@@ -6520,11 +6510,9 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
       // First the not.
-      Register Inner = MRI.createGenericVirtualRegister(CondTy);
-      B.buildNot(Inner, Cond);
+      auto Inner = B.buildNot(CondTy, Cond);
       // Then an ext to match the destination register.
-      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildZExtOrTrunc(Ext, Inner);
+      auto Ext = B.buildZExtOrTrunc(TrueTy, Inner);
       B.buildOr(DstReg, Ext, True, Flags);
     };
     return true;
@@ -6535,11 +6523,9 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
       // First the not.
-      Register Inner = MRI.createGenericVirtualRegister(CondTy);
-      B.buildNot(Inner, Cond);
+      auto Inner = B.buildNot(CondTy, Cond);
       // Then an ext to match the destination register.
-      Register Ext = MRI.createGenericVirtualRegister(TrueTy);
-      B.buildZExtOrTrunc(Ext, Inner);
+      auto Ext = B.buildZExtOrTrunc(TrueTy, Inner);
       B.buildAnd(DstReg, Ext, False);
     };
     return true;
@@ -6548,6 +6534,52 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
   return false;
 }
 
+bool CombinerHelper::tryFoldSelectOfBinOps(GSelect *Select,
+                                           BuildFnTy &MatchInfo) {
+  Register DstReg = Select->getReg(0);
+  Register Cond = Select->getCondReg();
+  Register False = Select->getFalseReg();
+  Register True = Select->getTrueReg();
+  LLT DstTy = MRI.getType(DstReg);
+
+  GBinOp *LHS = getOpcodeDef<GBinOp>(True, MRI);
+  GBinOp *RHS = getOpcodeDef<GBinOp>(False, MRI);
+
+  // We need two binops of the same kind on the true/false registers.
+  if (!LHS || !RHS || LHS->getOpcode() != RHS->getOpcode())
+    return false;
+
+  // Note that there are no constraints on CondTy.
+  unsigned Flags = LHS->getFlags() & RHS->getFlags();
+  unsigned Opcode = LHS->getOpcode();
+
+  // Fold select(cond, binop(x, y), binop(z, y))
+  //  --> binop(select(cond, x, z), y)
+  if (LHS->getRHSReg() == RHS->getRHSReg()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      auto Sel = B.buildSelect(DstTy, Cond, LHS->getLHSReg(), RHS->getLHSReg());
+      B.buildInstr(Opcode, {DstReg}, {Sel, LHS->getRHSReg()}, Flags);
+    };
+    return true;
+  }
+
+  // Fold select(cond, binop(x, y), binop(x, z))
+  //  --> binop(x, select(cond, y, z))
+  if (LHS->getLHSReg() == RHS->getLHSReg()) {
+    MatchInfo = [=](MachineIRBuilder &B) {
+      B.setInstrAndDebugLoc(*Select);
+      auto Sel = B.buildSelect(DstTy, Cond, LHS->getRHSReg(), RHS->getRHSReg());
+      B.buildInstr(Opcode, {DstReg}, {LHS->getLHSReg(), Sel}, Flags);
+    };
+    return true;
+  }
+
+  // FIXME: use isCommutable().
+
+  return false;
+}
+
 bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
   GSelect *Select = cast<GSelect>(&MI);
 
@@ -6557,5 +6589,8 @@ bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
   if (tryFoldBoolSelectToLogic(Select, MatchInfo))
     return true;
 
+  if (tryFoldSelectOfBinOps(Select, MatchInfo))
+    return true;
+
   return false;
 }
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir b/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
index be2de620fa456c..7644443f53dc68 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
@@ -544,3 +544,77 @@ body:             |
     %ext:_(s32) = G_ANYEXT %sel
     $w0 = COPY %ext(s32)
 ...
+---
+# select cond, and(x, y), and(z, y) --> and (select, x, z), y
+name:            select_cond_and_x_y_and_z_y_and_select_x_z_y
+body:             |
+  bb.1:
+    liveins: $x0, $x1, $x2
+    ; CHECK-LABEL: name: select_cond_and_x_y_and_z_y_and_select_x_z_y
+    ; CHECK: liveins: $x0, $x1, $x2
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(s64) = COPY $x1
+    ; CHECK-NEXT: [[COPY2:%[0-9]+]]:_(s64) = COPY $x2
+    ; CHECK-NEXT: [[COPY3:%[0-9]+]]:_(s64) = COPY $x3
+    ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: %a:_(s8) = G_TRUNC [[COPY1]](s64)
+    ; CHECK-NEXT: %b:_(s8) = G_TRUNC [[COPY2]](s64)
+    ; CHECK-NEXT: %d:_(s8) = G_TRUNC [[COPY3]](s64)
+    ; CHECK-NEXT: [[SELECT:%[0-9]+]]:_(s8) = G_SELECT %c(s1), %a, %d
+    ; CHECK-NEXT: %sel:_(s8) = G_AND [[SELECT]], %b
+    ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT %sel(s8)
+    ; CHECK-NEXT: $w0 = COPY %ext(s32)
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = COPY $x1
+    %2:_(s64) = COPY $x2
+    %3:_(s64) = COPY $x3
+    %4:_(s64) = COPY $x4
+    %c:_(s1) = G_TRUNC %0
+    %a:_(s8) = G_TRUNC %1
+    %b:_(s8) = G_TRUNC %2
+    %d:_(s8) = G_TRUNC %3
+    %e:_(s8) = G_TRUNC %4
+    %and1:_(s8) = G_AND %a, %b
+    %and2:_(s8) = G_AND %d, %b
+    %sel:_(s8) = G_SELECT %c, %and1, %and2
+    %ext:_(s32) = G_ANYEXT %sel
+    $w0 = COPY %ext(s32)
+...
+---
+# select cond, xor(x, y), xor(x, z) --> xor x, select, x, z)
+name:            select_cond_xor_x_y_xor_x_z_xor_x__select_x_y
+body:             |
+  bb.1:
+    liveins: $x0, $x1, $x2
+    ; CHECK-LABEL: name: select_cond_xor_x_y_xor_x_z_xor_x__select_x_y
+    ; CHECK: liveins: $x0, $x1, $x2
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(s64) = COPY $x1
+    ; CHECK-NEXT: [[COPY2:%[0-9]+]]:_(s64) = COPY $x3
+    ; CHECK-NEXT: [[COPY3:%[0-9]+]]:_(s64) = COPY $x4
+    ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: %a:_(s8) = G_TRUNC [[COPY1]](s64)
+    ; CHECK-NEXT: %d:_(s8) = G_TRUNC [[COPY2]](s64)
+    ; CHECK-NEXT: %e:_(s8) = G_TRUNC [[COPY3]](s64)
+    ; CHECK-NEXT: [[SELECT:%[0-9]+]]:_(s8) = G_SELECT %c(s1), %e, %d
+    ; CHECK-NEXT: %sel:_(s8) = G_XOR %a, [[SELECT]]
+    ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT %sel(s8)
+    ; CHECK-NEXT: $w0 = COPY %ext(s32)
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = COPY $x1
+    %2:_(s64) = COPY $x2
+    %3:_(s64) = COPY $x3
+    %4:_(s64) = COPY $x4
+    %c:_(s1) = G_TRUNC %0
+    %a:_(s8) = G_TRUNC %1
+    %b:_(s8) = G_TRUNC %2
+    %d:_(s8) = G_TRUNC %3
+    %e:_(s8) = G_TRUNC %4
+    %xor1:_(s8) = G_XOR %a, %e
+    %xor2:_(s8) = G_XOR %a, %d
+    %sel:_(s8) = G_SELECT %c, %xor1, %xor2
+    %ext:_(s32) = G_ANYEXT %sel
+    $w0 = COPY %ext(s32)
+...

>From 85ef024d0c6d28872516d5f8c7790ae2e86f8a84 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thorsten=20Sch=C3=BCtt?= <schuett at gmail.com>
Date: Wed, 3 Jan 2024 10:10:07 +0100
Subject: [PATCH 3/4] address comments

---
 .../lib/CodeGen/GlobalISel/CombinerHelper.cpp |  8 +-
 .../AArch64/GlobalISel/combine-select.mir     | 79 ++++++++++++++++++-
 2 files changed, 83 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 102b49c48460b1..5d8def4cca6668 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -6550,7 +6550,7 @@ bool CombinerHelper::tryFoldSelectOfBinOps(GSelect *Select,
     return false;
 
   // Note that there are no constraints on CondTy.
-  unsigned Flags = LHS->getFlags() & RHS->getFlags();
+  unsigned Flags = (LHS->getFlags() & RHS->getFlags()) | Select->getFlags();
   unsigned Opcode = LHS->getOpcode();
 
   // Fold select(cond, binop(x, y), binop(z, y))
@@ -6558,7 +6558,8 @@ bool CombinerHelper::tryFoldSelectOfBinOps(GSelect *Select,
   if (LHS->getRHSReg() == RHS->getRHSReg()) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      auto Sel = B.buildSelect(DstTy, Cond, LHS->getLHSReg(), RHS->getLHSReg());
+      auto Sel = B.buildSelect(DstTy, Cond, LHS->getLHSReg(), RHS->getLHSReg(),
+                               Select->getFlags());
       B.buildInstr(Opcode, {DstReg}, {Sel, LHS->getRHSReg()}, Flags);
     };
     return true;
@@ -6569,7 +6570,8 @@ bool CombinerHelper::tryFoldSelectOfBinOps(GSelect *Select,
   if (LHS->getLHSReg() == RHS->getLHSReg()) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      auto Sel = B.buildSelect(DstTy, Cond, LHS->getRHSReg(), RHS->getRHSReg());
+      auto Sel = B.buildSelect(DstTy, Cond, LHS->getRHSReg(), RHS->getRHSReg(),
+                               Select->getFlags());
       B.buildInstr(Opcode, {DstReg}, {LHS->getLHSReg(), Sel}, Flags);
     };
     return true;
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir b/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
index 7644443f53dc68..c5a3490221b661 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/combine-select.mir
@@ -545,7 +545,7 @@ body:             |
     $w0 = COPY %ext(s32)
 ...
 ---
-# select cond, and(x, y), and(z, y) --> and (select, x, z), y
+# select cond, and(x, y), and(z, y) --> and (select cond, x, z), y
 name:            select_cond_and_x_y_and_z_y_and_select_x_z_y
 body:             |
   bb.1:
@@ -618,3 +618,80 @@ body:             |
     %ext:_(s32) = G_ANYEXT %sel
     $w0 = COPY %ext(s32)
 ...
+---
+# negative test select cond, and(x, y), or(z, a) --> failed
+name:            select_cond_and_x_y_or_z_a_failed
+body:             |
+  bb.1:
+    liveins: $x0, $x1, $x2
+    ; CHECK-LABEL: name: select_cond_and_x_y_or_z_a_failed
+    ; CHECK: liveins: $x0, $x1, $x2
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(s64) = COPY $x1
+    ; CHECK-NEXT: [[COPY2:%[0-9]+]]:_(s64) = COPY $x2
+    ; CHECK-NEXT: [[COPY3:%[0-9]+]]:_(s64) = COPY $x3
+    ; CHECK-NEXT: [[COPY4:%[0-9]+]]:_(s64) = COPY $x4
+    ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: %a:_(s8) = G_TRUNC [[COPY1]](s64)
+    ; CHECK-NEXT: %b:_(s8) = G_TRUNC [[COPY2]](s64)
+    ; CHECK-NEXT: %d:_(s8) = G_TRUNC [[COPY3]](s64)
+    ; CHECK-NEXT: %e:_(s8) = G_TRUNC [[COPY4]](s64)
+    ; CHECK-NEXT: %and1:_(s8) = G_AND %a, %b
+    ; CHECK-NEXT: %or2:_(s8) = G_OR %e, %d
+    ; CHECK-NEXT: %sel:_(s8) = G_SELECT %c(s1), %and1, %or2
+    ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT %sel(s8)
+    ; CHECK-NEXT: $w0 = COPY %ext(s32)
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = COPY $x1
+    %2:_(s64) = COPY $x2
+    %3:_(s64) = COPY $x3
+    %4:_(s64) = COPY $x4
+    %c:_(s1) = G_TRUNC %0
+    %a:_(s8) = G_TRUNC %1
+    %b:_(s8) = G_TRUNC %2
+    %d:_(s8) = G_TRUNC %3
+    %e:_(s8) = G_TRUNC %4
+    %and1:_(s8) = G_AND %a, %b
+    %or2:_(s8) = G_OR %e, %d
+    %sel:_(s8) = G_SELECT %c, %and1, %or2
+    %ext:_(s32) = G_ANYEXT %sel
+    $w0 = COPY %ext(s32)
+...
+---
+# flags test select cond, xor(x, y), xor(x, z) --> xor x, select, cond, x, z)
+name:            flags_select_cond_xor_x_y_xor_x_z_xor_x__select_cond_x_y
+body:             |
+  bb.1:
+    liveins: $x0, $x1, $x2
+    ; CHECK-LABEL: name: flags_select_cond_xor_x_y_xor_x_z_xor_x__select_cond_x_y
+    ; CHECK: liveins: $x0, $x1, $x2
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $x0
+    ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(s64) = COPY $x1
+    ; CHECK-NEXT: [[COPY2:%[0-9]+]]:_(s64) = COPY $x3
+    ; CHECK-NEXT: [[COPY3:%[0-9]+]]:_(s64) = COPY $x4
+    ; CHECK-NEXT: %c:_(s1) = G_TRUNC [[COPY]](s64)
+    ; CHECK-NEXT: %a:_(s8) = G_TRUNC [[COPY1]](s64)
+    ; CHECK-NEXT: %d:_(s8) = G_TRUNC [[COPY2]](s64)
+    ; CHECK-NEXT: %e:_(s8) = G_TRUNC [[COPY3]](s64)
+    ; CHECK-NEXT: [[SELECT:%[0-9]+]]:_(s8) = ninf exact G_SELECT %c(s1), %e, %d
+    ; CHECK-NEXT: %sel:_(s8) = ninf arcp exact G_XOR %a, [[SELECT]]
+    ; CHECK-NEXT: %ext:_(s32) = G_ANYEXT %sel(s8)
+    ; CHECK-NEXT: $w0 = COPY %ext(s32)
+    %0:_(s64) = COPY $x0
+    %1:_(s64) = COPY $x1
+    %2:_(s64) = COPY $x2
+    %3:_(s64) = COPY $x3
+    %4:_(s64) = COPY $x4
+    %c:_(s1) = G_TRUNC %0
+    %a:_(s8) = G_TRUNC %1
+    %b:_(s8) = G_TRUNC %2
+    %d:_(s8) = G_TRUNC %3
+    %e:_(s8) = G_TRUNC %4
+    %xor1:_(s8) = nsz arcp nsw G_XOR %a, %e
+    %xor2:_(s8) = nnan arcp nuw G_XOR %a, %d
+    %sel:_(s8) = ninf exact G_SELECT %c, %xor1, %xor2
+    %ext:_(s32) = G_ANYEXT %sel
+    $w0 = COPY %ext(s32)
+...

>From b6ee5a53dd452f3fdb9624f1ddffdad59c8ffeac Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thorsten=20Sch=C3=BCtt?= <schuett at gmail.com>
Date: Wed, 3 Jan 2024 15:56:04 +0100
Subject: [PATCH 4/4] fix select flags

---
 llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 25c47bbc8df0dd..5d8def4cca6668 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -6558,7 +6558,8 @@ bool CombinerHelper::tryFoldSelectOfBinOps(GSelect *Select,
   if (LHS->getRHSReg() == RHS->getRHSReg()) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      auto Sel = B.buildSelect(DstTy, Cond, LHS->getLHSReg(), RHS->getLHSReg());
+      auto Sel = B.buildSelect(DstTy, Cond, LHS->getLHSReg(), RHS->getLHSReg(),
+                               Select->getFlags());
       B.buildInstr(Opcode, {DstReg}, {Sel, LHS->getRHSReg()}, Flags);
     };
     return true;
@@ -6569,7 +6570,8 @@ bool CombinerHelper::tryFoldSelectOfBinOps(GSelect *Select,
   if (LHS->getLHSReg() == RHS->getLHSReg()) {
     MatchInfo = [=](MachineIRBuilder &B) {
       B.setInstrAndDebugLoc(*Select);
-      auto Sel = B.buildSelect(DstTy, Cond, LHS->getRHSReg(), RHS->getRHSReg());
+      auto Sel = B.buildSelect(DstTy, Cond, LHS->getRHSReg(), RHS->getRHSReg(),
+                               Select->getFlags());
       B.buildInstr(Opcode, {DstReg}, {LHS->getLHSReg(), Sel}, Flags);
     };
     return true;



More information about the cfe-commits mailing list