[llvm] 1159266 - [SLP] Add support for fmaximum/fminimum reduction
Anna Thomas via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 12 12:22:49 PDT 2023
Author: Anna Thomas
Date: 2023-07-12T15:22:38-04:00
New Revision: 11592667344f1f4e12da599e6669a359b99bd43b
URL: https://github.com/llvm/llvm-project/commit/11592667344f1f4e12da599e6669a359b99bd43b
DIFF: https://github.com/llvm/llvm-project/commit/11592667344f1f4e12da599e6669a359b99bd43b.diff
LOG: [SLP] Add support for fmaximum/fminimum reduction
This patch adds support for vectorized reduction of maximum/minimum
intrinsics which are under the appropriate reduction kind.
Differential Revision: https://reviews.llvm.org/D154463
Added:
Modified:
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
llvm/test/Transforms/SLPVectorizer/X86/fmaximum-fminimum.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index ebf9e7be260641..97b8737df83e4a 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -12701,6 +12701,9 @@ class HorizontalReduction {
return I->getFastMathFlags().noNaNs();
}
+ if (Kind == RecurKind::FMaximum || Kind == RecurKind::FMinimum)
+ return true;
+
return I->isAssociative();
}
@@ -12751,6 +12754,18 @@ class HorizontalReduction {
minnum(cast<ConstantFP>(LHS)->getValueAPF(),
cast<ConstantFP>(RHS)->getValueAPF()));
return Builder.CreateBinaryIntrinsic(Intrinsic::minnum, LHS, RHS);
+ case RecurKind::FMaximum:
+ if (IsConstant)
+ return ConstantFP::get(LHS->getType(),
+ maximum(cast<ConstantFP>(LHS)->getValueAPF(),
+ cast<ConstantFP>(RHS)->getValueAPF()));
+ return Builder.CreateBinaryIntrinsic(Intrinsic::maximum, LHS, RHS);
+ case RecurKind::FMinimum:
+ if (IsConstant)
+ return ConstantFP::get(LHS->getType(),
+ minimum(cast<ConstantFP>(LHS)->getValueAPF(),
+ cast<ConstantFP>(RHS)->getValueAPF()));
+ return Builder.CreateBinaryIntrinsic(Intrinsic::minimum, LHS, RHS);
case RecurKind::SMax:
if (IsConstant || UseSelect) {
Value *Cmp = Builder.CreateICmpSGT(LHS, RHS, Name);
@@ -12833,6 +12848,10 @@ class HorizontalReduction {
if (match(I, m_Intrinsic<Intrinsic::minnum>(m_Value(), m_Value())))
return RecurKind::FMin;
+ if (match(I, m_Intrinsic<Intrinsic::maximum>(m_Value(), m_Value())))
+ return RecurKind::FMaximum;
+ if (match(I, m_Intrinsic<Intrinsic::minimum>(m_Value(), m_Value())))
+ return RecurKind::FMinimum;
// This matches either cmp+select or intrinsics. SLP is expected to handle
// either form.
// TODO: If we are canonicalizing to intrinsics, we can remove several
@@ -13800,6 +13819,8 @@ class HorizontalReduction {
}
case RecurKind::FMax:
case RecurKind::FMin:
+ case RecurKind::FMaximum:
+ case RecurKind::FMinimum:
case RecurKind::SMax:
case RecurKind::SMin:
case RecurKind::UMax:
@@ -14131,6 +14152,10 @@ static bool matchRdxBop(Instruction *I, Value *&V0, Value *&V1) {
return true;
if (match(I, m_Intrinsic<Intrinsic::minnum>(m_Value(V0), m_Value(V1))))
return true;
+ if (match(I, m_Intrinsic<Intrinsic::maximum>(m_Value(V0), m_Value(V1))))
+ return true;
+ if (match(I, m_Intrinsic<Intrinsic::minimum>(m_Value(V0), m_Value(V1))))
+ return true;
if (match(I, m_Intrinsic<Intrinsic::smax>(m_Value(V0), m_Value(V1))))
return true;
if (match(I, m_Intrinsic<Intrinsic::smin>(m_Value(V0), m_Value(V1))))
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/fmaximum-fminimum.ll b/llvm/test/Transforms/SLPVectorizer/X86/fmaximum-fminimum.ll
index 9e565ea27023be..eb360948efd8e6 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/fmaximum-fminimum.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/fmaximum-fminimum.ll
@@ -175,31 +175,15 @@ define double @reduction_v2f64(ptr %p) {
define float @reduction_v4f32(ptr %p) {
; SSE-LABEL: define float @reduction_v4f32
; SSE-SAME: (ptr [[P:%.*]]) {
-; SSE-NEXT: [[G1:%.*]] = getelementptr inbounds float, ptr [[P]], i64 1
-; SSE-NEXT: [[G2:%.*]] = getelementptr inbounds float, ptr [[P]], i64 2
-; SSE-NEXT: [[G3:%.*]] = getelementptr inbounds float, ptr [[P]], i64 3
-; SSE-NEXT: [[T0:%.*]] = load float, ptr [[P]], align 4
-; SSE-NEXT: [[T1:%.*]] = load float, ptr [[G1]], align 4
-; SSE-NEXT: [[T2:%.*]] = load float, ptr [[G2]], align 4
-; SSE-NEXT: [[T3:%.*]] = load float, ptr [[G3]], align 4
-; SSE-NEXT: [[M1:%.*]] = tail call float @llvm.maximum.f32(float [[T1]], float [[T0]])
-; SSE-NEXT: [[M2:%.*]] = tail call float @llvm.maximum.f32(float [[T2]], float [[M1]])
-; SSE-NEXT: [[M3:%.*]] = tail call float @llvm.maximum.f32(float [[T3]], float [[M2]])
-; SSE-NEXT: ret float [[M3]]
+; SSE-NEXT: [[TMP1:%.*]] = load <4 x float>, ptr [[P]], align 4
+; SSE-NEXT: [[TMP2:%.*]] = call float @llvm.vector.reduce.fmaximum.v4f32(<4 x float> [[TMP1]])
+; SSE-NEXT: ret float [[TMP2]]
;
; AVX-LABEL: define float @reduction_v4f32
; AVX-SAME: (ptr [[P:%.*]]) #[[ATTR1]] {
-; AVX-NEXT: [[G1:%.*]] = getelementptr inbounds float, ptr [[P]], i64 1
-; AVX-NEXT: [[G2:%.*]] = getelementptr inbounds float, ptr [[P]], i64 2
-; AVX-NEXT: [[G3:%.*]] = getelementptr inbounds float, ptr [[P]], i64 3
-; AVX-NEXT: [[T0:%.*]] = load float, ptr [[P]], align 4
-; AVX-NEXT: [[T1:%.*]] = load float, ptr [[G1]], align 4
-; AVX-NEXT: [[T2:%.*]] = load float, ptr [[G2]], align 4
-; AVX-NEXT: [[T3:%.*]] = load float, ptr [[G3]], align 4
-; AVX-NEXT: [[M1:%.*]] = tail call float @llvm.maximum.f32(float [[T1]], float [[T0]])
-; AVX-NEXT: [[M2:%.*]] = tail call float @llvm.maximum.f32(float [[T2]], float [[M1]])
-; AVX-NEXT: [[M3:%.*]] = tail call float @llvm.maximum.f32(float [[T3]], float [[M2]])
-; AVX-NEXT: ret float [[M3]]
+; AVX-NEXT: [[TMP1:%.*]] = load <4 x float>, ptr [[P]], align 4
+; AVX-NEXT: [[TMP2:%.*]] = call float @llvm.vector.reduce.fmaximum.v4f32(<4 x float> [[TMP1]])
+; AVX-NEXT: ret float [[TMP2]]
;
%g1 = getelementptr inbounds float, ptr %p, i64 1
%g2 = getelementptr inbounds float, ptr %p, i64 2
@@ -217,31 +201,15 @@ define float @reduction_v4f32(ptr %p) {
define double @reduction_v4f64_fminimum(ptr %p) {
; SSE-LABEL: define double @reduction_v4f64_fminimum
; SSE-SAME: (ptr [[P:%.*]]) {
-; SSE-NEXT: [[G1:%.*]] = getelementptr inbounds double, ptr [[P]], i64 1
-; SSE-NEXT: [[G2:%.*]] = getelementptr inbounds double, ptr [[P]], i64 2
-; SSE-NEXT: [[G3:%.*]] = getelementptr inbounds double, ptr [[P]], i64 3
-; SSE-NEXT: [[T0:%.*]] = load double, ptr [[P]], align 4
-; SSE-NEXT: [[T1:%.*]] = load double, ptr [[G1]], align 4
-; SSE-NEXT: [[T2:%.*]] = load double, ptr [[G2]], align 4
-; SSE-NEXT: [[T3:%.*]] = load double, ptr [[G3]], align 4
-; SSE-NEXT: [[M1:%.*]] = tail call double @llvm.minimum.f64(double [[T1]], double [[T0]])
-; SSE-NEXT: [[M2:%.*]] = tail call double @llvm.minimum.f64(double [[T2]], double [[M1]])
-; SSE-NEXT: [[M3:%.*]] = tail call double @llvm.minimum.f64(double [[T3]], double [[M2]])
-; SSE-NEXT: ret double [[M3]]
+; SSE-NEXT: [[TMP1:%.*]] = load <4 x double>, ptr [[P]], align 4
+; SSE-NEXT: [[TMP2:%.*]] = call double @llvm.vector.reduce.fminimum.v4f64(<4 x double> [[TMP1]])
+; SSE-NEXT: ret double [[TMP2]]
;
; AVX-LABEL: define double @reduction_v4f64_fminimum
; AVX-SAME: (ptr [[P:%.*]]) #[[ATTR1]] {
-; AVX-NEXT: [[G1:%.*]] = getelementptr inbounds double, ptr [[P]], i64 1
-; AVX-NEXT: [[G2:%.*]] = getelementptr inbounds double, ptr [[P]], i64 2
-; AVX-NEXT: [[G3:%.*]] = getelementptr inbounds double, ptr [[P]], i64 3
-; AVX-NEXT: [[T0:%.*]] = load double, ptr [[P]], align 4
-; AVX-NEXT: [[T1:%.*]] = load double, ptr [[G1]], align 4
-; AVX-NEXT: [[T2:%.*]] = load double, ptr [[G2]], align 4
-; AVX-NEXT: [[T3:%.*]] = load double, ptr [[G3]], align 4
-; AVX-NEXT: [[M1:%.*]] = tail call double @llvm.minimum.f64(double [[T1]], double [[T0]])
-; AVX-NEXT: [[M2:%.*]] = tail call double @llvm.minimum.f64(double [[T2]], double [[M1]])
-; AVX-NEXT: [[M3:%.*]] = tail call double @llvm.minimum.f64(double [[T3]], double [[M2]])
-; AVX-NEXT: ret double [[M3]]
+; AVX-NEXT: [[TMP1:%.*]] = load <4 x double>, ptr [[P]], align 4
+; AVX-NEXT: [[TMP2:%.*]] = call double @llvm.vector.reduce.fminimum.v4f64(<4 x double> [[TMP1]])
+; AVX-NEXT: ret double [[TMP2]]
;
%g1 = getelementptr inbounds double, ptr %p, i64 1
%g2 = getelementptr inbounds double, ptr %p, i64 2
@@ -259,55 +227,15 @@ define double @reduction_v4f64_fminimum(ptr %p) {
define float @reduction_v8f32_fminimum(ptr %p) {
; SSE-LABEL: define float @reduction_v8f32_fminimum
; SSE-SAME: (ptr [[P:%.*]]) {
-; SSE-NEXT: [[G1:%.*]] = getelementptr inbounds float, ptr [[P]], i64 1
-; SSE-NEXT: [[G2:%.*]] = getelementptr inbounds float, ptr [[P]], i64 2
-; SSE-NEXT: [[G3:%.*]] = getelementptr inbounds float, ptr [[P]], i64 3
-; SSE-NEXT: [[G4:%.*]] = getelementptr inbounds float, ptr [[P]], i64 4
-; SSE-NEXT: [[G5:%.*]] = getelementptr inbounds float, ptr [[P]], i64 5
-; SSE-NEXT: [[G6:%.*]] = getelementptr inbounds float, ptr [[P]], i64 6
-; SSE-NEXT: [[G7:%.*]] = getelementptr inbounds float, ptr [[P]], i64 7
-; SSE-NEXT: [[T0:%.*]] = load float, ptr [[P]], align 4
-; SSE-NEXT: [[T1:%.*]] = load float, ptr [[G1]], align 4
-; SSE-NEXT: [[T2:%.*]] = load float, ptr [[G2]], align 4
-; SSE-NEXT: [[T3:%.*]] = load float, ptr [[G3]], align 4
-; SSE-NEXT: [[T4:%.*]] = load float, ptr [[G4]], align 4
-; SSE-NEXT: [[T5:%.*]] = load float, ptr [[G5]], align 4
-; SSE-NEXT: [[T6:%.*]] = load float, ptr [[G6]], align 4
-; SSE-NEXT: [[T7:%.*]] = load float, ptr [[G7]], align 4
-; SSE-NEXT: [[M1:%.*]] = tail call float @llvm.minimum.f32(float [[T1]], float [[T0]])
-; SSE-NEXT: [[M2:%.*]] = tail call float @llvm.minimum.f32(float [[T2]], float [[M1]])
-; SSE-NEXT: [[M3:%.*]] = tail call float @llvm.minimum.f32(float [[T3]], float [[M2]])
-; SSE-NEXT: [[M4:%.*]] = tail call float @llvm.minimum.f32(float [[T4]], float [[M3]])
-; SSE-NEXT: [[M5:%.*]] = tail call float @llvm.minimum.f32(float [[M4]], float [[T6]])
-; SSE-NEXT: [[M6:%.*]] = tail call float @llvm.minimum.f32(float [[M5]], float [[T5]])
-; SSE-NEXT: [[M7:%.*]] = tail call float @llvm.minimum.f32(float [[M6]], float [[T7]])
-; SSE-NEXT: ret float [[M7]]
+; SSE-NEXT: [[TMP1:%.*]] = load <8 x float>, ptr [[P]], align 4
+; SSE-NEXT: [[TMP2:%.*]] = call float @llvm.vector.reduce.fminimum.v8f32(<8 x float> [[TMP1]])
+; SSE-NEXT: ret float [[TMP2]]
;
; AVX-LABEL: define float @reduction_v8f32_fminimum
; AVX-SAME: (ptr [[P:%.*]]) #[[ATTR1]] {
-; AVX-NEXT: [[G1:%.*]] = getelementptr inbounds float, ptr [[P]], i64 1
-; AVX-NEXT: [[G2:%.*]] = getelementptr inbounds float, ptr [[P]], i64 2
-; AVX-NEXT: [[G3:%.*]] = getelementptr inbounds float, ptr [[P]], i64 3
-; AVX-NEXT: [[G4:%.*]] = getelementptr inbounds float, ptr [[P]], i64 4
-; AVX-NEXT: [[G5:%.*]] = getelementptr inbounds float, ptr [[P]], i64 5
-; AVX-NEXT: [[G6:%.*]] = getelementptr inbounds float, ptr [[P]], i64 6
-; AVX-NEXT: [[G7:%.*]] = getelementptr inbounds float, ptr [[P]], i64 7
-; AVX-NEXT: [[T0:%.*]] = load float, ptr [[P]], align 4
-; AVX-NEXT: [[T1:%.*]] = load float, ptr [[G1]], align 4
-; AVX-NEXT: [[T2:%.*]] = load float, ptr [[G2]], align 4
-; AVX-NEXT: [[T3:%.*]] = load float, ptr [[G3]], align 4
-; AVX-NEXT: [[T4:%.*]] = load float, ptr [[G4]], align 4
-; AVX-NEXT: [[T5:%.*]] = load float, ptr [[G5]], align 4
-; AVX-NEXT: [[T6:%.*]] = load float, ptr [[G6]], align 4
-; AVX-NEXT: [[T7:%.*]] = load float, ptr [[G7]], align 4
-; AVX-NEXT: [[M1:%.*]] = tail call float @llvm.minimum.f32(float [[T1]], float [[T0]])
-; AVX-NEXT: [[M2:%.*]] = tail call float @llvm.minimum.f32(float [[T2]], float [[M1]])
-; AVX-NEXT: [[M3:%.*]] = tail call float @llvm.minimum.f32(float [[T3]], float [[M2]])
-; AVX-NEXT: [[M4:%.*]] = tail call float @llvm.minimum.f32(float [[T4]], float [[M3]])
-; AVX-NEXT: [[M5:%.*]] = tail call float @llvm.minimum.f32(float [[M4]], float [[T6]])
-; AVX-NEXT: [[M6:%.*]] = tail call float @llvm.minimum.f32(float [[M5]], float [[T5]])
-; AVX-NEXT: [[M7:%.*]] = tail call float @llvm.minimum.f32(float [[M6]], float [[T7]])
-; AVX-NEXT: ret float [[M7]]
+; AVX-NEXT: [[TMP1:%.*]] = load <8 x float>, ptr [[P]], align 4
+; AVX-NEXT: [[TMP2:%.*]] = call float @llvm.vector.reduce.fminimum.v8f32(<8 x float> [[TMP1]])
+; AVX-NEXT: ret float [[TMP2]]
;
%g1 = getelementptr inbounds float, ptr %p, i64 1
%g2 = getelementptr inbounds float, ptr %p, i64 2
More information about the llvm-commits
mailing list