[llvm] [LLVM][DAGCombiner] fold (shl (X * vscale(C0)), C1) -> (X * vscale(C0 << C1)). (PR #150651)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 25 09:54:07 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: Paul Walker (paulwalker-arm)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/150651.diff


3 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+9) 
- (modified) llvm/test/CodeGen/AArch64/sve-vscale-combine.ll (+68-37) 
- (modified) llvm/test/Transforms/LoopStrengthReduce/AArch64/vscale-fixups.ll (+4-4) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d3df43473013e..8dc0ab74d25c6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10615,6 +10615,15 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
     return DAG.getVScale(DL, VT, C0 << C1);
   }
 
+  SDValue X;
+  APInt VS0;
+
+  // fold (shl (X * vscale(VS0)), C1) -> (X * vscale(VS0 << C1))
+  if (N1C && sd_match(N0, m_Mul(m_Value(X), m_VScale(m_ConstInt(VS0))))) {
+    SDValue VScale = DAG.getVScale(DL, VT, VS0 << N1C->getAPIntValue());
+    return DAG.getNode(ISD::MUL, DL, VT, X, VScale);
+  }
+
   // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
   APInt ShlVal;
   if (N0.getOpcode() == ISD::STEP_VECTOR &&
diff --git a/llvm/test/CodeGen/AArch64/sve-vscale-combine.ll b/llvm/test/CodeGen/AArch64/sve-vscale-combine.ll
index 9306c208c3b55..7dcd56c3e7dd7 100644
--- a/llvm/test/CodeGen/AArch64/sve-vscale-combine.ll
+++ b/llvm/test/CodeGen/AArch64/sve-vscale-combine.ll
@@ -1,14 +1,14 @@
-; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve --asm-verbose=false < %s |FileCheck %s
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve < %s | FileCheck %s
 
-declare i32 @llvm.vscale.i32()
-declare i64 @llvm.vscale.i64()
+target triple = "aarch64-unknown-linux-gnu"
 
 ; Fold (add (vscale * C0), (vscale * C1)) to (vscale * (C0 + C1)).
 define i64 @combine_add_vscale_i64() nounwind {
 ; CHECK-LABEL: combine_add_vscale_i64:
-; CHECK-NOT:   add
-; CHECK-NEXT:  cntd  x0
-; CHECK-NEXT:  ret
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cntd x0
+; CHECK-NEXT:    ret
  %vscale = call i64 @llvm.vscale.i64()
  %add = add i64 %vscale, %vscale
  ret i64 %add
@@ -16,9 +16,10 @@ define i64 @combine_add_vscale_i64() nounwind {
 
 define i32 @combine_add_vscale_i32() nounwind {
 ; CHECK-LABEL: combine_add_vscale_i32:
-; CHECK-NOT:   add
-; CHECK-NEXT:  cntd  x0
-; CHECK-NEXT:  ret
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cntd x0
+; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    ret
  %vscale = call i32 @llvm.vscale.i32()
  %add = add i32 %vscale, %vscale
  ret i32 %add
@@ -28,9 +29,9 @@ define i32 @combine_add_vscale_i32() nounwind {
 ; In this test, C0 = 1, C1 = 32.
 define i64 @combine_mul_vscale_i64() nounwind {
 ; CHECK-LABEL: combine_mul_vscale_i64:
-; CHECK-NOT:   mul
-; CHECK-NEXT:  rdvl  x0, #2
-; CHECK-NEXT:  ret
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdvl x0, #2
+; CHECK-NEXT:    ret
  %vscale = call i64 @llvm.vscale.i64()
  %mul = mul i64 %vscale, 32
  ret i64 %mul
@@ -38,9 +39,10 @@ define i64 @combine_mul_vscale_i64() nounwind {
 
 define i32 @combine_mul_vscale_i32() nounwind {
 ; CHECK-LABEL: combine_mul_vscale_i32:
-; CHECK-NOT:   mul
-; CHECK-NEXT:  rdvl  x0, #3
-; CHECK-NEXT:  ret
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdvl x0, #3
+; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    ret
  %vscale = call i32 @llvm.vscale.i32()
  %mul = mul i32 %vscale, 48
  ret i32 %mul
@@ -49,11 +51,11 @@ define i32 @combine_mul_vscale_i32() nounwind {
 ; Canonicalize (sub X, (vscale * C)) to (add X,  (vscale * -C))
 define i64 @combine_sub_vscale_i64(i64 %in) nounwind {
 ; CHECK-LABEL: combine_sub_vscale_i64:
-; CHECK-NOT:   sub
-; CHECK-NEXT:  rdvl  x8, #-1
-; CHECK-NEXT:  asr   x8, x8, #4
-; CHECK-NEXT:  add   x0, x0, x8
-; CHECK-NEXT:  ret
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdvl x8, #-1
+; CHECK-NEXT:    asr x8, x8, #4
+; CHECK-NEXT:    add x0, x0, x8
+; CHECK-NEXT:    ret
  %vscale = call i64 @llvm.vscale.i64()
  %sub = sub i64 %in,  %vscale
  ret i64 %sub
@@ -61,11 +63,11 @@ define i64 @combine_sub_vscale_i64(i64 %in) nounwind {
 
 define i32 @combine_sub_vscale_i32(i32 %in) nounwind {
 ; CHECK-LABEL: combine_sub_vscale_i32:
-; CHECK-NOT:   sub
-; CHECK-NEXT:  rdvl  x8, #-1
-; CHECK-NEXT:  asr   x8, x8, #4
-; CHECK-NEXT:  add   w0, w0, w8
-; CHECK-NEXT:  ret
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdvl x8, #-1
+; CHECK-NEXT:    asr x8, x8, #4
+; CHECK-NEXT:    add w0, w0, w8
+; CHECK-NEXT:    ret
  %vscale = call i32 @llvm.vscale.i32()
  %sub = sub i32 %in, %vscale
  ret i32 %sub
@@ -75,12 +77,13 @@ define i32 @combine_sub_vscale_i32(i32 %in) nounwind {
 ; (sub X, (vscale * C)) to (add X,  (vscale * -C))
 define i64 @multiple_uses_sub_vscale_i64(i64 %x, i64 %y) nounwind {
 ; CHECK-LABEL: multiple_uses_sub_vscale_i64:
-; CHECK-NEXT:  rdvl	x8, #1
-; CHECK-NEXT:  lsr	x8, x8, #4
-; CHECK-NEXT:  sub	x9, x0, x8
-; CHECK-NEXT:  add	x8, x1, x8
-; CHECK-NEXT:  mul	x0, x9, x8
-; CHECK-NEXT:  ret
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdvl x8, #1
+; CHECK-NEXT:    lsr x8, x8, #4
+; CHECK-NEXT:    sub x9, x0, x8
+; CHECK-NEXT:    add x8, x1, x8
+; CHECK-NEXT:    mul x0, x9, x8
+; CHECK-NEXT:    ret
  %vscale = call i64 @llvm.vscale.i64()
  %sub = sub i64 %x, %vscale
  %add = add i64 %y, %vscale
@@ -95,9 +98,9 @@ define i64 @multiple_uses_sub_vscale_i64(i64 %x, i64 %y) nounwind {
 ; Hence, the immediate for RDVL is #1.
 define i64 @combine_shl_vscale_i64() nounwind {
 ; CHECK-LABEL: combine_shl_vscale_i64:
-; CHECK-NOT:   shl
-; CHECK-NEXT:  rdvl  x0, #1
-; CHECK-NEXT:  ret
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdvl x0, #1
+; CHECK-NEXT:    ret
  %vscale = call i64 @llvm.vscale.i64()
  %shl = shl i64 %vscale, 4
  ret i64 %shl
@@ -105,10 +108,38 @@ define i64 @combine_shl_vscale_i64() nounwind {
 
 define i32 @combine_shl_vscale_i32() nounwind {
 ; CHECK-LABEL: combine_shl_vscale_i32:
-; CHECK-NOT:   shl
-; CHECK-NEXT:  rdvl  x0, #1
-; CHECK-NEXT:  ret
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdvl x0, #1
+; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    ret
  %vscale = call i32 @llvm.vscale.i32()
  %shl = shl i32 %vscale, 4
  ret i32 %shl
 }
+
+define i64 @combine_shl_mul_vscale(i64 %a) nounwind {
+; CHECK-LABEL: combine_shl_mul_vscale:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cnth x8
+; CHECK-NEXT:    mul x0, x0, x8
+; CHECK-NEXT:    ret
+  %vscale = tail call i64 @llvm.vscale.i64()
+  %mul = mul i64 %a, %vscale
+  %shl = shl i64 %mul, 3
+  ret i64 %shl
+}
+
+define i64 @combine_shl_mul_vscale_commuted(i64 %a) nounwind {
+; CHECK-LABEL: combine_shl_mul_vscale_commuted:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cnth x8
+; CHECK-NEXT:    mul x0, x0, x8
+; CHECK-NEXT:    ret
+  %vscale = tail call i64 @llvm.vscale.i64()
+  %mul = mul i64 %vscale, %a
+  %shl = shl i64 %mul, 3
+  ret i64 %shl
+}
+
+declare i32 @llvm.vscale.i32()
+declare i64 @llvm.vscale.i64()
diff --git a/llvm/test/Transforms/LoopStrengthReduce/AArch64/vscale-fixups.ll b/llvm/test/Transforms/LoopStrengthReduce/AArch64/vscale-fixups.ll
index aa954aeb0ad07..9003072f5fcdf 100644
--- a/llvm/test/Transforms/LoopStrengthReduce/AArch64/vscale-fixups.ll
+++ b/llvm/test/Transforms/LoopStrengthReduce/AArch64/vscale-fixups.ll
@@ -383,14 +383,14 @@ define void @vscale_squared_offset(ptr %alloc) #0 {
 ; COMMON-LABEL: vscale_squared_offset:
 ; COMMON:       // %bb.0: // %entry
 ; COMMON-NEXT:    rdvl x9, #1
+; COMMON-NEXT:    rdvl x10, #4
 ; COMMON-NEXT:    fmov z0.s, #4.00000000
-; COMMON-NEXT:    mov x8, xzr
 ; COMMON-NEXT:    lsr x9, x9, #4
 ; COMMON-NEXT:    fmov z1.s, #8.00000000
-; COMMON-NEXT:    cntw x10
+; COMMON-NEXT:    mov x8, xzr
 ; COMMON-NEXT:    ptrue p0.s, vl1
-; COMMON-NEXT:    umull x9, w9, w9
-; COMMON-NEXT:    lsl x9, x9, #6
+; COMMON-NEXT:    umull x9, w9, w10
+; COMMON-NEXT:    cntw x10
 ; COMMON-NEXT:    cmp x8, x10
 ; COMMON-NEXT:    b.ge .LBB6_2
 ; COMMON-NEXT:  .LBB6_1: // %for.body

``````````

</details>


https://github.com/llvm/llvm-project/pull/150651


More information about the llvm-commits mailing list