[llvm] [DAG] Remove OneUse restriction when folding (shl (add x, c1), c2) (PR #101294)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Sun Aug 4 23:16:35 PDT 2024


================
@@ -10070,17 +10070,33 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
   // Variant of version done on multiply, except mul by a power of 2 is turned
   // into a shift.
   if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
-      N0->hasOneUse() && TLI.isDesirableToCommuteWithShift(N, Level)) {
-    SDValue N01 = N0.getOperand(1);
-    if (SDValue Shl1 =
-            DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1})) {
-      SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
-      AddToWorklist(Shl0.getNode());
-      SDNodeFlags Flags;
-      // Preserve the disjoint flag for Or.
-      if (N0.getOpcode() == ISD::OR && N0->getFlags().hasDisjoint())
-        Flags.setDisjoint(true);
-      return DAG.getNode(N0.getOpcode(), DL, VT, Shl0, Shl1, Flags);
+      TLI.isDesirableToCommuteWithShift(N, Level)) {
+    // LD/ST will optimize constant Offset extraction, so when AddNode
+    // is used by LD/ST, it can still complete the folding optimization
+    // operation performed above.
+    bool canOptAwlays = false;
+    if (!N0.hasOneUse() && N0.getOpcode() == ISD::ADD) {
+      for (SDNode *Use : N0->uses()) {
+        if (!isa<StoreSDNode>(Use) && !isa<LoadSDNode>(Use) && Use != N) {
+          canOptAwlays = false;
+          break;
+        }
+        canOptAwlays = true;
+      }
+    }
+    if (N0.hasOneUse() || canOptAwlays) {
+      SDValue N01 = N0.getOperand(1);
+      if (SDValue Shl1 =
+              DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1})) {
+        SDValue Shl0 =
+            DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1);
+        AddToWorklist(Shl0.getNode());
+        SDNodeFlags Flags;
+        // Preserve the disjoint flag for Or.
+        if (N0.getOpcode() == ISD::OR && N0->getFlags().hasDisjoint())
+          Flags.setDisjoint(true);
----------------
arsenm wrote:

Just pass through the flags from N0->getFlags then? 

https://github.com/llvm/llvm-project/pull/101294


More information about the llvm-commits mailing list