[llvm] [LLVM][ConstantFolding] Extend constantFoldVectorReduce to include scalable vectors. (PR #165437)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 29 04:21:46 PDT 2025


https://github.com/paulwalker-arm updated https://github.com/llvm/llvm-project/pull/165437

>From 3538605ef478ad8aa81e8f37b22140ca69065582 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Mon, 27 Oct 2025 17:04:21 +0000
Subject: [PATCH 1/2] [LLVM][ConstantFolding] Extend constantFoldVectorReduce
 to include scalable vectors.

---
 llvm/lib/Analysis/ConstantFolding.cpp         | 39 +++++++--
 .../InstSimplify/ConstProp/vecreduce.ll       | 84 ++++++++-----------
 2 files changed, 64 insertions(+), 59 deletions(-)

diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index e9e2e7d0316c7..bd84bd1a14113 100755
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -2163,18 +2163,39 @@ Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double),
 }
 
 Constant *constantFoldVectorReduce(Intrinsic::ID IID, Constant *Op) {
-  FixedVectorType *VT = dyn_cast<FixedVectorType>(Op->getType());
-  if (!VT)
-    return nullptr;
-
-  // This isn't strictly necessary, but handle the special/common case of zero:
-  // all integer reductions of a zero input produce zero.
-  if (isa<ConstantAggregateZero>(Op))
-    return ConstantInt::get(VT->getElementType(), 0);
+  auto *OpVT = cast<VectorType>(Op->getType());
 
   // This is the same as the underlying binops - poison propagates.
   if (isa<PoisonValue>(Op) || Op->containsPoisonElement())
-    return PoisonValue::get(VT->getElementType());
+    return PoisonValue::get(OpVT->getElementType());
+
+  // Shortcut non-accumulating reductions.
+  if (Constant *SplatVal = Op->getSplatValue()) {
+    switch (IID) {
+    case Intrinsic::vector_reduce_and:
+    case Intrinsic::vector_reduce_or:
+    case Intrinsic::vector_reduce_smin:
+    case Intrinsic::vector_reduce_smax:
+    case Intrinsic::vector_reduce_umin:
+    case Intrinsic::vector_reduce_umax:
+      return SplatVal;
+    case Intrinsic::vector_reduce_add:
+    case Intrinsic::vector_reduce_mul:
+      if (SplatVal->isZeroValue())
+        return SplatVal;
+      break;
+    case Intrinsic::vector_reduce_xor:
+      if (SplatVal->isZeroValue())
+        return SplatVal;
+      if (OpVT->getElementCount().isKnownMultipleOf(2))
+        return Constant::getNullValue(OpVT->getElementType());
+      break;
+    }
+  }
+
+  FixedVectorType *VT = dyn_cast<FixedVectorType>(OpVT);
+  if (!VT)
+    return nullptr;
 
   // TODO: Handle undef.
   auto *EltC = dyn_cast_or_null<ConstantInt>(Op->getAggregateElement(0U));
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/vecreduce.ll b/llvm/test/Transforms/InstSimplify/ConstProp/vecreduce.ll
index 77a7f0d4e4acf..1ba8b0eff1d1a 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/vecreduce.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/vecreduce.ll
@@ -12,8 +12,7 @@ define i32 @add_0() {
 
 define i32 @add_0_scalable_vector() {
 ; CHECK-LABEL: @add_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -89,8 +88,7 @@ define i32 @add_poison() {
 
 define i32 @add_poison_scalable_vector() {
 ; CHECK-LABEL: @add_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -123,8 +121,7 @@ define i32 @mul_0() {
 
 define i32 @mul_0_scalable_vector() {
 ; CHECK-LABEL: @mul_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -200,8 +197,7 @@ define i32 @mul_poison() {
 
 define i32 @mul_poison_scalable_vector() {
 ; CHECK-LABEL: @mul_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -225,8 +221,7 @@ define i32 @and_0() {
 
 define i32 @and_0_scalable_vector() {
 ; CHECK-LABEL: @and_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -242,8 +237,7 @@ define i32 @and_1() {
 
 define i32 @and_1_scalable_vector() {
 ; CHECK-LABEL: @and_1_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> splat (i32 1))
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 1
 ;
   %x = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> splat (i32 1))
   ret i32 %x
@@ -302,8 +296,7 @@ define i32 @and_poison() {
 
 define i32 @and_poison_scalable_vector() {
 ; CHECK-LABEL: @and_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.and.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -327,8 +320,7 @@ define i32 @or_0() {
 
 define i32 @or_0_scalable_vector() {
 ; CHECK-LABEL: @or_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -344,8 +336,7 @@ define i32 @or_1() {
 
 define i32 @or_1_scalable_vector() {
 ; CHECK-LABEL: @or_1_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> splat (i32 1))
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 1
 ;
   %x = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> splat (i32 1))
   ret i32 %x
@@ -404,8 +395,7 @@ define i32 @or_poison() {
 
 define i32 @or_poison_scalable_vector() {
 ; CHECK-LABEL: @or_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.or.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -429,8 +419,7 @@ define i32 @xor_0() {
 
 define i32 @xor_0_scalable_vector() {
 ; CHECK-LABEL: @xor_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -446,13 +435,21 @@ define i32 @xor_1() {
 
 define i32 @xor_1_scalable_vector() {
 ; CHECK-LABEL: @xor_1_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> splat (i32 1))
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> splat(i32 1))
   ret i32 %x
 }
 
+define i32 @xor_1_scalable_vector_lane_count_not_known_even() {
+; CHECK-LABEL: @xor_1_scalable_vector_lane_count_not_known_even(
+; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv1i32(<vscale x 1 x i32> splat (i32 1))
+; CHECK-NEXT:    ret i32 [[X]]
+;
+  %x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 1 x i32> splat(i32 1))
+  ret i32 %x
+}
+
 define i32 @xor_inc() {
 ; CHECK-LABEL: @xor_inc(
 ; CHECK-NEXT:    ret i32 10
@@ -506,8 +503,7 @@ define i32 @xor_poison() {
 
 define i32 @xor_poison_scalable_vector() {
 ; CHECK-LABEL: @xor_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.xor.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -531,8 +527,7 @@ define i32 @smin_0() {
 
 define i32 @smin_0_scalable_vector() {
 ; CHECK-LABEL: @smin_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -548,8 +543,7 @@ define i32 @smin_1() {
 
 define i32 @smin_1_scalable_vector() {
 ; CHECK-LABEL: @smin_1_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> splat (i32 1))
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 1
 ;
   %x = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> splat(i32 1))
   ret i32 %x
@@ -608,8 +602,7 @@ define i32 @smin_poison() {
 
 define i32 @smin_poison_scalable_vector() {
 ; CHECK-LABEL: @smin_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.smin.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -633,8 +626,7 @@ define i32 @smax_0() {
 
 define i32 @smax_0_scalable_vector() {
 ; CHECK-LABEL: @smax_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -650,8 +642,7 @@ define i32 @smax_1() {
 
 define i32 @smax_1_scalable_vector() {
 ; CHECK-LABEL: @smax_1_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> splat (i32 1))
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 1
 ;
   %x = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> splat(i32 1))
   ret i32 %x
@@ -710,8 +701,7 @@ define i32 @smax_poison() {
 
 define i32 @smax_poison_scalable_vector() {
 ; CHECK-LABEL: @smax_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.smax.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -735,8 +725,7 @@ define i32 @umin_0() {
 
 define i32 @umin_0_scalable_vector() {
 ; CHECK-LABEL: @umin_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -752,8 +741,7 @@ define i32 @umin_1() {
 
 define i32 @umin_1_scalable_vector() {
 ; CHECK-LABEL: @umin_1_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> splat (i32 1))
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 1
 ;
   %x = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> splat (i32 1))
   ret i32 %x
@@ -812,8 +800,7 @@ define i32 @umin_poison() {
 
 define i32 @umin_poison_scalable_vector() {
 ; CHECK-LABEL: @umin_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.umin.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x
@@ -837,8 +824,7 @@ define i32 @umax_0() {
 
 define i32 @umax_0_scalable_vector() {
 ; CHECK-LABEL: @umax_0_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 0
 ;
   %x = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> zeroinitializer)
   ret i32 %x
@@ -854,8 +840,7 @@ define i32 @umax_1() {
 
 define i32 @umax_1_scalable_vector() {
 ; CHECK-LABEL: @umax_1_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> splat (i32 1))
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 1
 ;
   %x = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> splat(i32 1))
   ret i32 %x
@@ -914,8 +899,7 @@ define i32 @umax_poison() {
 
 define i32 @umax_poison_scalable_vector() {
 ; CHECK-LABEL: @umax_poison_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> poison)
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 poison
 ;
   %x = call i32 @llvm.vector.reduce.umax.nxv8i32(<vscale x 8 x i32> poison)
   ret i32 %x

>From 4da9437ac61bc762b0eb842a720366a041adb1a8 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Wed, 29 Oct 2025 11:09:46 +0000
Subject: [PATCH 2/2] Remove unnecessary isa<Posion> check. Add support for
 vec.reduce.mul.nxv##(1).

Patch also uses isNullValue() rather than isZeroValue(), which is
less relevant for integer types.
---
 llvm/lib/Analysis/ConstantFolding.cpp         |  9 ++++++---
 .../InstSimplify/ConstProp/vecreduce.ll       | 20 +++++++++++++++++--
 2 files changed, 24 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index bd84bd1a14113..da32542cf7870 100755
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -2166,7 +2166,7 @@ Constant *constantFoldVectorReduce(Intrinsic::ID IID, Constant *Op) {
   auto *OpVT = cast<VectorType>(Op->getType());
 
   // This is the same as the underlying binops - poison propagates.
-  if (isa<PoisonValue>(Op) || Op->containsPoisonElement())
+  if (Op->containsPoisonElement())
     return PoisonValue::get(OpVT->getElementType());
 
   // Shortcut non-accumulating reductions.
@@ -2180,12 +2180,15 @@ Constant *constantFoldVectorReduce(Intrinsic::ID IID, Constant *Op) {
     case Intrinsic::vector_reduce_umax:
       return SplatVal;
     case Intrinsic::vector_reduce_add:
+      if (SplatVal->isNullValue())
+        return SplatVal;
+      break;
     case Intrinsic::vector_reduce_mul:
-      if (SplatVal->isZeroValue())
+      if (SplatVal->isNullValue() || SplatVal->isOneValue())
         return SplatVal;
       break;
     case Intrinsic::vector_reduce_xor:
-      if (SplatVal->isZeroValue())
+      if (SplatVal->isNullValue())
         return SplatVal;
       if (OpVT->getElementCount().isKnownMultipleOf(2))
         return Constant::getNullValue(OpVT->getElementType());
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/vecreduce.ll b/llvm/test/Transforms/InstSimplify/ConstProp/vecreduce.ll
index 1ba8b0eff1d1a..479b3f8ea4128 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/vecreduce.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/vecreduce.ll
@@ -137,13 +137,29 @@ define i32 @mul_1() {
 
 define i32 @mul_1_scalable_vector() {
 ; CHECK-LABEL: @mul_1_scalable_vector(
-; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> splat (i32 1))
-; CHECK-NEXT:    ret i32 [[X]]
+; CHECK-NEXT:    ret i32 1
 ;
   %x = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> splat (i32 1))
   ret i32 %x
 }
 
+define i32 @mul_2() {
+; CHECK-LABEL: @mul_2(
+; CHECK-NEXT:    ret i32 256
+;
+  %x = call i32 @llvm.vector.reduce.mul.v8i32(<8 x i32> <i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2>)
+  ret i32 %x
+}
+
+define i32 @mul_2_scalable_vector() {
+; CHECK-LABEL: @mul_2_scalable_vector(
+; CHECK-NEXT:    [[X:%.*]] = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> splat (i32 2))
+; CHECK-NEXT:    ret i32 [[X]]
+;
+  %x = call i32 @llvm.vector.reduce.mul.nxv8i32(<vscale x 8 x i32> splat (i32 2))
+  ret i32 %x
+}
+
 define i32 @mul_inc() {
 ; CHECK-LABEL: @mul_inc(
 ; CHECK-NEXT:    ret i32 40320



More information about the llvm-commits mailing list