[llvm] [DAGCombiner] Freeze maybe poison operands when folding select to logic (PR #84924)

Björn Pettersson via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 12 07:42:29 PDT 2024


https://github.com/bjope created https://github.com/llvm/llvm-project/pull/84924

Work-in-progress, to fix https://github.com/llvm/llvm-project/issues/84653

>From 4f7c20f55b694b6f84be80d4c349347377eaffa1 Mon Sep 17 00:00:00 2001
From: Bjorn Pettersson <bjorn.a.pettersson at ericsson.com>
Date: Tue, 12 Mar 2024 15:24:33 +0100
Subject: [PATCH] [DAGCombiner] Freeze maybe poison operands when folding
 select to logic

Work-in-progress, to fix https://github.com/llvm/llvm-project/issues/84653
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 30 +++++++++++++------
 1 file changed, 21 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 5476ef87971436..46675f94642cc9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -11344,28 +11344,34 @@ static SDValue foldBoolSelectToLogic(SDNode *N, SelectionDAG &DAG) {
   if (VT != Cond.getValueType() || VT.getScalarSizeInBits() != 1)
     return SDValue();
 
-  // select Cond, Cond, F --> or Cond, F
-  // select Cond, 1, F    --> or Cond, F
+  auto FreezeIfNeeded = [&](SDValue V) {
+    if (!DAG.isGuaranteedNotToBePoison(V))
+      return DAG.getFreeze(V);
+    return V;
+  };
+
+  // select Cond, Cond, F --> or Cond, freeze(F)
+  // select Cond, 1, F    --> or Cond, freeze(F)
   if (Cond == T || isOneOrOneSplat(T, /* AllowUndefs */ true))
-    return matcher.getNode(ISD::OR, SDLoc(N), VT, Cond, F);
+    return matcher.getNode(ISD::OR, SDLoc(N), VT, Cond, FreezeIfNeeded(F));
 
   // select Cond, T, Cond --> and Cond, T
   // select Cond, T, 0    --> and Cond, T
   if (Cond == F || isNullOrNullSplat(F, /* AllowUndefs */ true))
-    return matcher.getNode(ISD::AND, SDLoc(N), VT, Cond, T);
+    return matcher.getNode(ISD::AND, SDLoc(N), VT, Cond, FreezeIfNeeded(T));
 
   // select Cond, T, 1 --> or (not Cond), T
   if (isOneOrOneSplat(F, /* AllowUndefs */ true)) {
     SDValue NotCond = matcher.getNode(ISD::XOR, SDLoc(N), VT, Cond,
                                       DAG.getAllOnesConstant(SDLoc(N), VT));
-    return matcher.getNode(ISD::OR, SDLoc(N), VT, NotCond, T);
+    return matcher.getNode(ISD::OR, SDLoc(N), VT, NotCond, FreezeIfNeeded(T));
   }
 
   // select Cond, 0, F --> and (not Cond), F
   if (isNullOrNullSplat(T, /* AllowUndefs */ true)) {
     SDValue NotCond = matcher.getNode(ISD::XOR, SDLoc(N), VT, Cond,
                                       DAG.getAllOnesConstant(SDLoc(N), VT));
-    return matcher.getNode(ISD::AND, SDLoc(N), VT, NotCond, F);
+    return matcher.getNode(ISD::AND, SDLoc(N), VT, NotCond, FreezeIfNeeded(F));
   }
 
   return SDValue();
@@ -11394,12 +11400,18 @@ static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
   else
     return SDValue();
 
+  auto FreezeIfNeeded = [&](SDValue V) {
+    if (!DAG.isGuaranteedNotToBePoison(V))
+      return DAG.getFreeze(V);
+    return V;
+  };
+
   // (Cond0 s< 0) ? N1 : 0 --> (Cond0 s>> BW-1) & N1
   if (isNullOrNullSplat(N2)) {
     SDLoc DL(N);
     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
-    return DAG.getNode(ISD::AND, DL, VT, Sra, N1);
+    return DAG.getNode(ISD::AND, DL, VT, Sra, FreezeIfNeeded(N1));
   }
 
   // (Cond0 s< 0) ? -1 : N2 --> (Cond0 s>> BW-1) | N2
@@ -11407,7 +11419,7 @@ static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
     SDLoc DL(N);
     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
-    return DAG.getNode(ISD::OR, DL, VT, Sra, N2);
+    return DAG.getNode(ISD::OR, DL, VT, Sra, FreezeIfNeeded(N2));
   }
 
   // If we have to invert the sign bit mask, only do that transform if the
@@ -11419,7 +11431,7 @@ static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
     SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT);
     SDValue Sra = DAG.getNode(ISD::SRA, DL, VT, Cond0, ShiftAmt);
     SDValue Not = DAG.getNOT(DL, Sra, VT);
-    return DAG.getNode(ISD::AND, DL, VT, Not, N2);
+    return DAG.getNode(ISD::AND, DL, VT, Not, FreezeIfNeeded(N2));
   }
 
   // TODO: There's another pattern in this family, but it may require



More information about the llvm-commits mailing list