[llvm] Extend vector.reduce.add and splat transform to scalable vectors (PR #161101)

Gábor Spaits via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 29 01:53:32 PDT 2025


https://github.com/spaits updated https://github.com/llvm/llvm-project/pull/161101

>From 7d90e64e5b5ec7ba07020471057a6a4932f6dc82 Mon Sep 17 00:00:00 2001
From: Gabor Spaits <gaborspaits1 at gmail.com>
Date: Sun, 28 Sep 2025 23:58:13 +0200
Subject: [PATCH 1/3] Extend vector.reduce.add and splat transform to scalable
 vectors

---
 .../Transforms/InstCombine/InstCombineCalls.cpp    | 14 ++++++++++----
 .../Transforms/InstCombine/vector-reductions.ll    |  7 ++++---
 2 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index cf6d0ecab4f69..02b46b0161ad8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3785,13 +3785,19 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
 
       // vector.reduce.add.vNiM(splat(%x)) -> mul(%x, N)
       if (Value *Splat = getSplatValue(Arg)) {
-        ElementCount VecToReduceCount =
-            cast<VectorType>(Arg->getType())->getElementCount();
+        VectorType *VecToReduceTy = cast<VectorType>(Arg->getType());
+        ElementCount VecToReduceCount = VecToReduceTy->getElementCount();
+        Value *RHS;
         if (VecToReduceCount.isFixed()) {
           unsigned VectorSize = VecToReduceCount.getFixedValue();
-          return BinaryOperator::CreateMul(
-              Splat, ConstantInt::get(Splat->getType(), VectorSize));
+          RHS = ConstantInt::get(Splat->getType(), VectorSize);
         }
+
+        RHS = Builder.CreateElementCount(Type::getInt64Ty(II->getContext()),
+                                         VecToReduceCount);
+        if (Splat->getType() != RHS->getType())
+          RHS = Builder.CreateZExtOrTrunc(RHS, Splat->getType());
+        return BinaryOperator::CreateMul(Splat, RHS);
       }
     }
     [[fallthrough]];
diff --git a/llvm/test/Transforms/InstCombine/vector-reductions.ll b/llvm/test/Transforms/InstCombine/vector-reductions.ll
index f1e0dd9bd06d7..34f0570c2698d 100644
--- a/llvm/test/Transforms/InstCombine/vector-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/vector-reductions.ll
@@ -469,9 +469,10 @@ define i2 @constant_multiplied_7xi2(i2 %0) {
 
 define i32 @negative_scalable_vector(i32 %0) {
 ; CHECK-LABEL: @negative_scalable_vector(
-; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[TMP0:%.*]], i64 0
-; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <vscale x 4 x i32> [[TMP2]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP4:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP3]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[DOTTR:%.*]] = trunc i64 [[TMP2]] to i32
+; CHECK-NEXT:    [[TMP3:%.*]] = shl i32 [[DOTTR]], 2
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i32 [[TMP0:%.*]], [[TMP3]]
 ; CHECK-NEXT:    ret i32 [[TMP4]]
 ;
   %2 = insertelement <vscale x 4 x i32> poison, i32 %0, i64 0

>From f5189ce8500389e24894879b19c1f97208f0a36f Mon Sep 17 00:00:00 2001
From: Gabor Spaits <gaborspaits1 at gmail.com>
Date: Mon, 29 Sep 2025 10:30:35 +0200
Subject: [PATCH 2/3] Throw out redundant baranch and redundant check

---
 .../lib/Transforms/InstCombine/InstCombineCalls.cpp | 13 +++----------
 1 file changed, 3 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 02b46b0161ad8..3eb472f53936e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3787,16 +3787,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       if (Value *Splat = getSplatValue(Arg)) {
         VectorType *VecToReduceTy = cast<VectorType>(Arg->getType());
         ElementCount VecToReduceCount = VecToReduceTy->getElementCount();
-        Value *RHS;
-        if (VecToReduceCount.isFixed()) {
-          unsigned VectorSize = VecToReduceCount.getFixedValue();
-          RHS = ConstantInt::get(Splat->getType(), VectorSize);
-        }
-
-        RHS = Builder.CreateElementCount(Type::getInt64Ty(II->getContext()),
-                                         VecToReduceCount);
-        if (Splat->getType() != RHS->getType())
-          RHS = Builder.CreateZExtOrTrunc(RHS, Splat->getType());
+        Value *RHS = Builder.CreateElementCount(
+            Type::getInt64Ty(II->getContext()), VecToReduceCount);
+        RHS = Builder.CreateZExtOrTrunc(RHS, Splat->getType());
         return BinaryOperator::CreateMul(Splat, RHS);
       }
     }

>From 556e00455de2ba25411169bec7f0252bfc7d433b Mon Sep 17 00:00:00 2001
From: Gabor Spaits <gaborspaits1 at gmail.com>
Date: Mon, 29 Sep 2025 10:46:02 +0200
Subject: [PATCH 3/3] Add more tests

---
 .../InstCombine/vector-reductions.ll          | 78 ++++++++++++++++++-
 1 file changed, 76 insertions(+), 2 deletions(-)

diff --git a/llvm/test/Transforms/InstCombine/vector-reductions.ll b/llvm/test/Transforms/InstCombine/vector-reductions.ll
index 34f0570c2698d..56b3e5726d460 100644
--- a/llvm/test/Transforms/InstCombine/vector-reductions.ll
+++ b/llvm/test/Transforms/InstCombine/vector-reductions.ll
@@ -467,8 +467,8 @@ define i2 @constant_multiplied_7xi2(i2 %0) {
   ret i2 %4
 }
 
-define i32 @negative_scalable_vector(i32 %0) {
-; CHECK-LABEL: @negative_scalable_vector(
+define i32 @reduce_add_splat_to_mul_vscale_4xi32(i32 %0) {
+; CHECK-LABEL: @reduce_add_splat_to_mul_vscale_4xi32(
 ; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
 ; CHECK-NEXT:    [[DOTTR:%.*]] = trunc i64 [[TMP2]] to i32
 ; CHECK-NEXT:    [[TMP3:%.*]] = shl i32 [[DOTTR]], 2
@@ -480,3 +480,77 @@ define i32 @negative_scalable_vector(i32 %0) {
   %4 = tail call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %3)
   ret i32 %4
 }
+
+define i64 @reduce_add_splat_to_mul_vscale_4xi64(i64 %0) {
+; CHECK-LABEL: @reduce_add_splat_to_mul_vscale_4xi64(
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP3:%.*]] = shl nuw i64 [[TMP2]], 2
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP0:%.*]], [[TMP3]]
+; CHECK-NEXT:    ret i64 [[TMP4]]
+;
+  %2 = insertelement <vscale x 4 x i64> poison, i64 %0, i64 0
+  %3 = shufflevector <vscale x 4 x i64> %2, <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer
+  %4 = tail call i64 @llvm.vector.reduce.add.nxv4i64(<vscale x 4 x i64> %3)
+  ret i64 %4
+}
+
+define i2 @reduce_add_splat_to_mul_vscale_4xi2(i2 %0) {
+; CHECK-LABEL: @reduce_add_splat_to_mul_vscale_4xi2(
+; CHECK-NEXT:    ret i2 0
+;
+  %2 = insertelement <vscale x 4 x i2> poison, i2 %0, i64 0
+  %3 = shufflevector <vscale x 4 x i2> %2, <vscale x 4 x i2> poison, <vscale x 4 x i32> zeroinitializer
+  %4 = tail call i2 @llvm.vector.reduce.add.nxv4i2(<vscale x 4 x i2> %3)
+  ret i2 %4
+}
+
+define i1 @reduce_add_splat_to_mul_vscale_8xi1(i1 %0) {
+; CHECK-LABEL: @reduce_add_splat_to_mul_vscale_8xi1(
+; CHECK-NEXT:    ret i1 false
+;
+  %2 = insertelement <vscale x 4 x i1> poison, i1 %0, i64 0
+  %3 = shufflevector <vscale x 4 x i1> %2, <vscale x 4 x i1> poison, <vscale x 8 x i32> zeroinitializer
+  %4 = tail call i1 @llvm.vector.reduce.add.nxv8i1(<vscale x 8 x i1> %3)
+  ret i1 %4
+}
+
+define i2 @reduce_add_splat_to_mul_vscale_5xi2(i2 %0) {
+; CHECK-LABEL: @reduce_add_splat_to_mul_vscale_5xi2(
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP3:%.*]] = trunc i64 [[TMP2]] to i2
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i2 [[TMP0:%.*]], [[TMP3]]
+; CHECK-NEXT:    ret i2 [[TMP4]]
+;
+  %2 = insertelement <vscale x 4 x i2> poison, i2 %0, i64 0
+  %3 = shufflevector <vscale x 4 x i2> %2, <vscale x 4 x i2> poison, <vscale x 5 x i32> zeroinitializer
+  %4 = tail call i2 @llvm.vector.reduce.add.nxv5i2(<vscale x 5 x i2> %3)
+  ret i2 %4
+}
+
+define i2 @reduce_add_splat_to_mul_vscale_6xi2(i2 %0) {
+; CHECK-LABEL: @reduce_add_splat_to_mul_vscale_6xi2(
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[DOTTR:%.*]] = trunc i64 [[TMP2]] to i2
+; CHECK-NEXT:    [[TMP3:%.*]] = shl i2 [[DOTTR]], 1
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i2 [[TMP0:%.*]], [[TMP3]]
+; CHECK-NEXT:    ret i2 [[TMP4]]
+;
+  %2 = insertelement <vscale x 4 x i2> poison, i2 %0, i64 0
+  %3 = shufflevector <vscale x 4 x i2> %2, <vscale x 4 x i2> poison, <vscale x 6 x i32> zeroinitializer
+  %4 = tail call i2 @llvm.vector.reduce.add.nxv6i2(<vscale x 6 x i2> %3)
+  ret i2 %4
+}
+
+define i2 @reduce_add_splat_to_mul_vscale_7xi2(i2 %0) {
+; CHECK-LABEL: @reduce_add_splat_to_mul_vscale_7xi2(
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP3:%.*]] = trunc i64 [[TMP2]] to i2
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i2 [[TMP0:%.*]], [[TMP3]]
+; CHECK-NEXT:    [[TMP5:%.*]] = sub i2 0, [[TMP4]]
+; CHECK-NEXT:    ret i2 [[TMP5]]
+;
+  %2 = insertelement <vscale x 4 x i2> poison, i2 %0, i64 0
+  %3 = shufflevector <vscale x 4 x i2> %2, <vscale x 4 x i2> poison, <vscale x 7 x i32> zeroinitializer
+  %4 = tail call i2 @llvm.vector.reduce.add.nxv7i2(<vscale x 7 x i2> %3)
+  ret i2 %4
+}



More information about the llvm-commits mailing list