[llvm] [DAG] Fold (mul (sext (add_nsw x, c1)), c2) -> (add (mul (sext x), c2), c1*c2) (PR #69667)

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 19 19:14:23 PDT 2023


https://github.com/LiqinWeng updated https://github.com/llvm/llvm-project/pull/69667

>From 3ad9bffa0eca673459579e006b7aae44736cb61e Mon Sep 17 00:00:00 2001
From: LiqinWeng <liqin.weng at spacemit.com>
Date: Fri, 20 Oct 2023 10:13:43 +0800
Subject: [PATCH] [DAG] Fold (mul (sext (add_nsw x, c1)), c2) -> (add (mul
 (sext x), c2), c1 * c2)

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 25 +++++++++++-
 llvm/test/CodeGen/RISCV/mul-and.ll            | 38 +++++++++++++++++++
 llvm/test/CodeGen/X86/addr-mode-matcher-2.ll  |  4 +-
 3 files changed, 64 insertions(+), 3 deletions(-)
 create mode 100644 llvm/test/CodeGen/RISCV/mul-and.ll

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2dfdddad3cc389f..3219cbc4e510e27 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4517,6 +4517,27 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
         DAG.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
         DAG.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
 
+  // fold (mul (sext (add_nsw x, c1)), c2) -> (add (mul (sext x), c2), c1*c2)
+  if (N0.getOpcode() == ISD::SIGN_EXTEND &&
+      N0.getOperand(0).getOpcode() == ISD::ADD &&
+      N0.getOperand(0)->getFlags().hasNoSignedWrap() &&
+      DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
+      DAG.isConstantIntBuildVectorOrConstantInt(
+          N0.getOperand(0).getOperand(1)) &&
+      isMulAddWithConstProfitable(N, N0.getOperand(0), N1)) {
+    SDValue Add = N0.getOperand(0);
+    SDLoc DL(N0);
+    if (SDValue ExtC = DAG.FoldConstantArithmetic(N0.getOpcode(), DL, VT,
+                                                  {Add.getOperand(1)})) {
+      if (SDValue MulC =
+              DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {ExtC, N1})) {
+        SDValue ExtX = DAG.getNode(N0.getOpcode(), DL, VT, Add.getOperand(0));
+        SDValue MulX = DAG.getNode(ISD::MUL, DL, VT, ExtX, N1);
+        return DAG.getNode(ISD::ADD, DL, VT, MulX, MulC);
+      }
+    }
+  }
+
   // Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
   ConstantSDNode *NC1 = isConstOrConstSplat(N1);
   if (N0.getOpcode() == ISD::VSCALE && NC1) {
@@ -19682,7 +19703,9 @@ bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
 
     if (Use->getOpcode() == ISD::MUL) { // We have another multiply use.
       SDNode *OtherOp;
-      SDNode *MulVar = AddNode.getOperand(0).getNode();
+      SDNode *MulVar = AddNode.getOperand(0).getOpcode() == ISD::TRUNCATE
+                           ? AddNode.getOperand(0).getOperand(0).getNode()
+                           : AddNode.getOperand(0).getNode();
 
       // OtherOp is what we're multiplying against the constant.
       if (Use->getOperand(0) == ConstNode)
diff --git a/llvm/test/CodeGen/RISCV/mul-and.ll b/llvm/test/CodeGen/RISCV/mul-and.ll
new file mode 100644
index 000000000000000..0c8418def4a1827
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/mul-and.ll
@@ -0,0 +1,38 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
+; RUN: llc -mtriple=riscv64 -verify-machineinstrs -mattr=+m < %s \
+; RUN:   | FileCheck -check-prefix=RV64 %s
+; ModuleID = '/home/wengliqin/test/shadd.c'
+
+define void @test(ptr nocapture noundef %array2, i32 noundef signext %a, i32 noundef signext %b) #0 {
+; RV64-LABEL: test:
+; RV64:       # %bb.0: # %entry
+; RV64-NEXT:    slli a2, a1, 2
+; RV64-NEXT:    li a3, 200
+; RV64-NEXT:    mul a3, a1, a3
+; RV64-NEXT:    add a0, a0, a3
+; RV64-NEXT:    add a2, a0, a2
+; RV64-NEXT:    lw a3, 1016(a2)
+; RV64-NEXT:    addiw a1, a1, 5
+; RV64-NEXT:    addi a3, a3, 1
+; RV64-NEXT:    sw a3, 1016(a2)
+; RV64-NEXT:    slli a2, a1, 2
+; RV64-NEXT:    lui a3, 1
+; RV64-NEXT:    add a2, a2, a3
+; RV64-NEXT:    add a0, a0, a2
+; RV64-NEXT:    sw a1, 904(a0)
+; RV64-NEXT:    ret
+entry:
+  %add = add nsw i32 %a, 5
+  %idxprom = sext i32 %add to i64
+  %sub = add nsw i32 %a, 4
+  %idxprom1 = sext i32 %sub to i64
+  %arrayidx2 = getelementptr inbounds [50 x i32], ptr %array2, i64 %idxprom, i64 %idxprom1
+  %0 = load i32, ptr %arrayidx2, align 4
+  %add3 = add nsw i32 %0, 1
+  store i32 %add3, ptr %arrayidx2, align 4
+  %1 = sext i32 %a to i64
+  %2 = getelementptr [50 x i32], ptr %array2, i64 %1
+  %arrayidx8 = getelementptr [50 x i32], ptr %2, i64 25, i64 %idxprom
+  store i32 %add, ptr %arrayidx8, align 4
+  ret void
+}
diff --git a/llvm/test/CodeGen/X86/addr-mode-matcher-2.ll b/llvm/test/CodeGen/X86/addr-mode-matcher-2.ll
index daba729bf040f26..6ab3bd9923c8a1f 100644
--- a/llvm/test/CodeGen/X86/addr-mode-matcher-2.ll
+++ b/llvm/test/CodeGen/X86/addr-mode-matcher-2.ll
@@ -52,8 +52,8 @@ define void @foo_sext_nsw(i1 zeroext, i32) nounwind {
 ; X64-NEXT:    .p2align 4, 0x90
 ; X64-NEXT:  .LBB0_2: # =>This Inner Loop Header: Depth=1
 ; X64-NEXT:    cltq
-; X64-NEXT:    shlq $2, %rax
-; X64-NEXT:    leaq 20(%rax,%rax,4), %rdi
+; X64-NEXT:    leaq (%rax,%rax,4), %rax
+; X64-NEXT:    leaq 20(,%rax,4), %rdi
 ; X64-NEXT:    callq bar at PLT
 ; X64-NEXT:    jmp .LBB0_2
   br i1 %0, label %9, label %3



More information about the llvm-commits mailing list