[llvm] [DAG] Fold (mul (sext (add_nsw x, c1)), c2) -> (add (mul (sext x), c2), c1*c2) (PR #69667)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 19 19:14:23 PDT 2023
https://github.com/LiqinWeng updated https://github.com/llvm/llvm-project/pull/69667
>From 3ad9bffa0eca673459579e006b7aae44736cb61e Mon Sep 17 00:00:00 2001
From: LiqinWeng <liqin.weng at spacemit.com>
Date: Fri, 20 Oct 2023 10:13:43 +0800
Subject: [PATCH] [DAG] Fold (mul (sext (add_nsw x, c1)), c2) -> (add (mul
(sext x), c2), c1 * c2)
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 25 +++++++++++-
llvm/test/CodeGen/RISCV/mul-and.ll | 38 +++++++++++++++++++
llvm/test/CodeGen/X86/addr-mode-matcher-2.ll | 4 +-
3 files changed, 64 insertions(+), 3 deletions(-)
create mode 100644 llvm/test/CodeGen/RISCV/mul-and.ll
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2dfdddad3cc389f..3219cbc4e510e27 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4517,6 +4517,27 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
DAG.getNode(ISD::MUL, SDLoc(N0), VT, N0.getOperand(0), N1),
DAG.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1));
+ // fold (mul (sext (add_nsw x, c1)), c2) -> (add (mul (sext x), c2), c1*c2)
+ if (N0.getOpcode() == ISD::SIGN_EXTEND &&
+ N0.getOperand(0).getOpcode() == ISD::ADD &&
+ N0.getOperand(0)->getFlags().hasNoSignedWrap() &&
+ DAG.isConstantIntBuildVectorOrConstantInt(N1) &&
+ DAG.isConstantIntBuildVectorOrConstantInt(
+ N0.getOperand(0).getOperand(1)) &&
+ isMulAddWithConstProfitable(N, N0.getOperand(0), N1)) {
+ SDValue Add = N0.getOperand(0);
+ SDLoc DL(N0);
+ if (SDValue ExtC = DAG.FoldConstantArithmetic(N0.getOpcode(), DL, VT,
+ {Add.getOperand(1)})) {
+ if (SDValue MulC =
+ DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {ExtC, N1})) {
+ SDValue ExtX = DAG.getNode(N0.getOpcode(), DL, VT, Add.getOperand(0));
+ SDValue MulX = DAG.getNode(ISD::MUL, DL, VT, ExtX, N1);
+ return DAG.getNode(ISD::ADD, DL, VT, MulX, MulC);
+ }
+ }
+ }
+
// Fold (mul (vscale * C0), C1) to (vscale * (C0 * C1)).
ConstantSDNode *NC1 = isConstOrConstSplat(N1);
if (N0.getOpcode() == ISD::VSCALE && NC1) {
@@ -19682,7 +19703,9 @@ bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, SDValue AddNode,
if (Use->getOpcode() == ISD::MUL) { // We have another multiply use.
SDNode *OtherOp;
- SDNode *MulVar = AddNode.getOperand(0).getNode();
+ SDNode *MulVar = AddNode.getOperand(0).getOpcode() == ISD::TRUNCATE
+ ? AddNode.getOperand(0).getOperand(0).getNode()
+ : AddNode.getOperand(0).getNode();
// OtherOp is what we're multiplying against the constant.
if (Use->getOperand(0) == ConstNode)
diff --git a/llvm/test/CodeGen/RISCV/mul-and.ll b/llvm/test/CodeGen/RISCV/mul-and.ll
new file mode 100644
index 000000000000000..0c8418def4a1827
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/mul-and.ll
@@ -0,0 +1,38 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
+; RUN: llc -mtriple=riscv64 -verify-machineinstrs -mattr=+m < %s \
+; RUN: | FileCheck -check-prefix=RV64 %s
+; ModuleID = '/home/wengliqin/test/shadd.c'
+
+define void @test(ptr nocapture noundef %array2, i32 noundef signext %a, i32 noundef signext %b) #0 {
+; RV64-LABEL: test:
+; RV64: # %bb.0: # %entry
+; RV64-NEXT: slli a2, a1, 2
+; RV64-NEXT: li a3, 200
+; RV64-NEXT: mul a3, a1, a3
+; RV64-NEXT: add a0, a0, a3
+; RV64-NEXT: add a2, a0, a2
+; RV64-NEXT: lw a3, 1016(a2)
+; RV64-NEXT: addiw a1, a1, 5
+; RV64-NEXT: addi a3, a3, 1
+; RV64-NEXT: sw a3, 1016(a2)
+; RV64-NEXT: slli a2, a1, 2
+; RV64-NEXT: lui a3, 1
+; RV64-NEXT: add a2, a2, a3
+; RV64-NEXT: add a0, a0, a2
+; RV64-NEXT: sw a1, 904(a0)
+; RV64-NEXT: ret
+entry:
+ %add = add nsw i32 %a, 5
+ %idxprom = sext i32 %add to i64
+ %sub = add nsw i32 %a, 4
+ %idxprom1 = sext i32 %sub to i64
+ %arrayidx2 = getelementptr inbounds [50 x i32], ptr %array2, i64 %idxprom, i64 %idxprom1
+ %0 = load i32, ptr %arrayidx2, align 4
+ %add3 = add nsw i32 %0, 1
+ store i32 %add3, ptr %arrayidx2, align 4
+ %1 = sext i32 %a to i64
+ %2 = getelementptr [50 x i32], ptr %array2, i64 %1
+ %arrayidx8 = getelementptr [50 x i32], ptr %2, i64 25, i64 %idxprom
+ store i32 %add, ptr %arrayidx8, align 4
+ ret void
+}
diff --git a/llvm/test/CodeGen/X86/addr-mode-matcher-2.ll b/llvm/test/CodeGen/X86/addr-mode-matcher-2.ll
index daba729bf040f26..6ab3bd9923c8a1f 100644
--- a/llvm/test/CodeGen/X86/addr-mode-matcher-2.ll
+++ b/llvm/test/CodeGen/X86/addr-mode-matcher-2.ll
@@ -52,8 +52,8 @@ define void @foo_sext_nsw(i1 zeroext, i32) nounwind {
; X64-NEXT: .p2align 4, 0x90
; X64-NEXT: .LBB0_2: # =>This Inner Loop Header: Depth=1
; X64-NEXT: cltq
-; X64-NEXT: shlq $2, %rax
-; X64-NEXT: leaq 20(%rax,%rax,4), %rdi
+; X64-NEXT: leaq (%rax,%rax,4), %rax
+; X64-NEXT: leaq 20(,%rax,4), %rdi
; X64-NEXT: callq bar at PLT
; X64-NEXT: jmp .LBB0_2
br i1 %0, label %9, label %3
More information about the llvm-commits
mailing list