[llvm] [InstCombine] Transform `vector.reduce.add (splat %0, 4)` into `shl i32 %0, 2` (PR #161020)

Gábor Spaits via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 27 13:30:09 PDT 2025


https://github.com/spaits created https://github.com/llvm/llvm-project/pull/161020

Fixes #160066

Whenever we have a vector with all the same elemnts, created with `insertelement` and `shufflevector` and the result type's element number is a power of two and we sum the vector, we have a multiplication by a power of two, which can be replaced with a left shift.

>From fa77c2c10596acec00ee517297dc92d2bee09360 Mon Sep 17 00:00:00 2001
From: Gabor Spaits <gaborspaits1 at gmail.com>
Date: Sat, 27 Sep 2025 22:24:16 +0200
Subject: [PATCH] [InstCombine] Transform `vector.reduce.add (splat %0, 4)`
 into `shl i32 %0, 2`

Fixes #160066

Whenever we have a vector with all the same elemnts, created with
`insertelement` and `shufflevector` and the result type's element number is
a power of two and we sum the vector, we have a multiplication by a power of
two, which can be replaced with a left shift.
---
 .../InstCombine/InstCombineCalls.cpp          | 33 +++++++++
 .../InstCombine/vector-reductions.ll          | 70 +++++++++++++++++++
 2 files changed, 103 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 6ad493772d170..49f6b86fa8f30 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3761,6 +3761,39 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
             return replaceInstUsesWith(CI, Res);
           }
       }
+
+      // Handle the case where a value is multiplied by a power of two.
+      // For example:
+      // %2 = insertelement <4 x i32> poison, i32 %0, i64 0
+      // %3 = shufflevector <4 x i32> %2, poison, <4 x i32> zeroinitializer
+      // %4 = tail call i32 @llvm.vector.reduce.add.v4i32(%3)
+      // =>
+      // %2 = shl i32 %0, 2
+      Value *InputValue;
+      ArrayRef<int> Mask;
+      ConstantInt *InsertionIdx;
+      assert(Arg->getType()->isVectorTy() &&
+             "The vector.reduce.add intrinsic's argument must be a vector!");
+
+      if (match(Arg, m_Shuffle(m_InsertElt(m_Poison(), m_Value(InputValue),
+                                           m_ConstantInt(InsertionIdx)),
+                               m_Poison(), m_Mask(Mask)))) {
+        // It is only a multiplication if we add the same element over and over.
+        bool AllElementsAreTheSameInMask =
+            std::all_of(Mask.begin(), Mask.end(),
+                        [&Mask](int MaskElt) { return MaskElt == Mask[0]; });
+        unsigned ReducedVectorLength = Mask.size();
+
+        if (AllElementsAreTheSameInMask &&
+            InsertionIdx->getSExtValue() == Mask[0] &&
+            isPowerOf2_32(ReducedVectorLength)) {
+          unsigned Pow2 = Log2_32(ReducedVectorLength);
+          Value *Res = Builder.CreateShl(
+              InputValue, Constant::getIntegerValue(InputValue->getType(),
+                                                    APInt(32, Pow2)));
+          return replaceInstUsesWith(CI, Res);
+        }
+      }
     }
     [[fallthrough]];
   }
diff --git a/llvm/test/Transforms/InstCombine/vector-reductions.ll b/llvm/test/Transforms/InstCombine/vector-reductions.ll
index 10f4aca72dbc7..2547403386106 100644
--- a/llvm/test/Transforms/InstCombine/vector-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/vector-reductions.ll
@@ -308,3 +308,73 @@ define i32 @diff_of_sums_type_mismatch2(<8 x i32> %v0, <4 x i32> %v1) {
   %r = sub i32 %r0, %r1
   ret i32 %r
 }
+
+define i32 @constant_multiplied_at_0(i32 %0) {
+; CHECK-LABEL: @constant_multiplied_at_0(
+; CHECK-NEXT:    [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 2
+; CHECK-NEXT:    ret i32 [[TMP2]]
+;
+  %2 = insertelement <4 x i32> poison, i32 %0, i64 0
+  %3 = shufflevector <4 x i32> %2, <4 x i32> poison, <4 x i32> zeroinitializer
+  %4 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %3)
+  ret i32 %4
+}
+
+define i32 @constant_multiplied_at_0_two_pow8(i32 %0) {
+; CHECK-LABEL: @constant_multiplied_at_0_two_pow8(
+; CHECK-NEXT:    [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 3
+; CHECK-NEXT:    ret i32 [[TMP2]]
+;
+  %2 = insertelement <4 x i32> poison, i32 %0, i64 0
+  %3 = shufflevector <4 x i32> %2, <4 x i32> poison, <8 x i32> zeroinitializer
+  %4 = tail call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %3)
+  ret i32 %4
+}
+
+
+define i32 @constant_multiplied_at_0_two_pow16(i32 %0) {
+; CHECK-LABEL: @constant_multiplied_at_0_two_pow16(
+; CHECK-NEXT:    [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 4
+; CHECK-NEXT:    ret i32 [[TMP2]]
+;
+  %2 = insertelement <4 x i32> poison, i32 %0, i64 0
+  %3 = shufflevector <4 x i32> %2, <4 x i32> poison, <16 x i32> zeroinitializer
+  %4 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %3)
+  ret i32 %4
+}
+
+
+define i32 @constant_multiplied_at_1(i32 %0) {
+; CHECK-LABEL: @constant_multiplied_at_1(
+; CHECK-NEXT:    [[TMP2:%.*]] = shl i32 [[TMP0:%.*]], 2
+; CHECK-NEXT:    ret i32 [[TMP2]]
+;
+  %2 = insertelement <4 x i32> poison, i32 %0, i64 1
+  %3 = shufflevector <4 x i32> %2, <4 x i32> poison,
+  <4 x i32> <i32 1, i32 1, i32 1, i32 1>
+  %4 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %3)
+  ret i32 %4
+}
+
+define i32 @negative_constant_multiplied_at_1(i32 %0) {
+; CHECK-LABEL: @negative_constant_multiplied_at_1(
+; CHECK-NEXT:    ret i32 poison
+;
+  %2 = insertelement <4 x i32> poison, i32 %0, i64 1
+  %3 = shufflevector <4 x i32> %2, <4 x i32> poison, <4 x i32> zeroinitializer
+  %4 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %3)
+  ret i32 %4
+}
+
+define i32 @negative_constant_multiplied_non_power_of_2(i32 %0) {
+; CHECK-LABEL: @negative_constant_multiplied_non_power_of_2(
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <4 x i32> poison, i32 [[TMP0:%.*]], i64 0
+; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <4 x i32> [[TMP2]], <4 x i32> poison, <6 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP4:%.*]] = tail call i32 @llvm.vector.reduce.add.v6i32(<6 x i32> [[TMP3]])
+; CHECK-NEXT:    ret i32 [[TMP4]]
+;
+  %2 = insertelement <4 x i32> poison, i32 %0, i64 0
+  %3 = shufflevector <4 x i32> %2, <4 x i32> poison, <6 x i32> zeroinitializer
+  %4 = tail call i32 @llvm.vector.reduce.add.v6i32(<6 x i32> %3)
+  ret i32 %4
+}



More information about the llvm-commits mailing list