[llvm] [Matrix] Lower vector reductions using shape info (PR #142055)
Jon Roelofs via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 10 11:08:13 PDT 2025
https://github.com/jroelofs updated https://github.com/llvm/llvm-project/pull/142055
>From 51e04878cce3971f222714aa380828c2ff5211d8 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 29 May 2025 15:46:45 -0700
Subject: [PATCH 1/7] [Matrix] Lower vector reductions using shape info
When possible, this avoids a bunch of shuffles in & out of the flattened
layout.
---
.../Scalar/LowerMatrixIntrinsics.cpp | 161 +++++++++++++-
.../LowerMatrixIntrinsics/reduce.ll | 200 ++++++++++++++++++
2 files changed, 357 insertions(+), 4 deletions(-)
create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..8287bc094fed2 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -30,6 +30,7 @@
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/CFG.h"
+#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/Function.h"
@@ -40,6 +41,7 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -1101,6 +1103,7 @@ class LowerMatrixIntrinsics {
if (!PoisonedInsts.empty()) {
// If we didn't remove all poisoned instructions, it's a hard error.
dbgs() << "Poisoned but present instructions:\n";
+ Func.dump();
for (auto *I : PoisonedInsts)
dbgs() << *I << "\n";
llvm_unreachable("Poisoned but instruction not removed");
@@ -1337,6 +1340,134 @@ class LowerMatrixIntrinsics {
return Builder.CreateAdd(Sum, Mul);
}
+ bool VisitReduce(IntrinsicInst *Inst) {
+ FastMathFlags FMF = getFastMathFlags(Inst);
+
+ if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
+ return false;
+
+ Value *Start = nullptr;
+ Value *Op = nullptr;
+ switch (Inst->getIntrinsicID()) {
+ case Intrinsic::vector_reduce_fadd:
+ case Intrinsic::vector_reduce_fmul:
+ Start = Inst->getOperand(0);
+ Op = Inst->getOperand(1);
+ break;
+ case Intrinsic::vector_reduce_fmax:
+ case Intrinsic::vector_reduce_fmaximum:
+ case Intrinsic::vector_reduce_fmin:
+ case Intrinsic::vector_reduce_fminimum:
+ case Intrinsic::vector_reduce_add:
+ case Intrinsic::vector_reduce_and:
+ case Intrinsic::vector_reduce_mul:
+ case Intrinsic::vector_reduce_or:
+ case Intrinsic::vector_reduce_xor:
+ Op = Inst->getOperand(0);
+ break;
+ default:
+ llvm_unreachable("unexpected intrinsic");
+ }
+
+ switch (Inst->getIntrinsicID()) {
+ case Intrinsic::vector_reduce_fadd: {
+ if (!match(Start, m_AnyZeroFP()))
+ return false;
+ } break;
+ case Intrinsic::vector_reduce_fmul: {
+ if (!match(Start, m_FPOne()))
+ return false;
+ } break;
+ default:
+ break;
+ }
+
+ auto *I = Inst2ColumnMatrix.find(Op);
+ if (I == Inst2ColumnMatrix.end())
+ return false;
+
+ IRBuilder<> Builder(Inst);
+
+ const MatrixTy &M = I->second;
+
+ auto CreateVReduce = [&](Value *LHS, Value *RHS) {
+ switch (Inst->getIntrinsicID()) {
+ case Intrinsic::vector_reduce_add:
+ return Builder.CreateAdd(LHS, RHS);
+ case Intrinsic::vector_reduce_and:
+ return Builder.CreateAnd(LHS, RHS);
+ case Intrinsic::vector_reduce_fadd:
+ return Builder.CreateFAdd(LHS, RHS);
+ case Intrinsic::vector_reduce_fmax:
+ return Builder.CreateMaximum(LHS, RHS); // FIXME: is this correct re: nans?
+ case Intrinsic::vector_reduce_fmaximum:
+ return Builder.CreateMaximumNum(LHS, RHS); // FIXME: is this correct re: nans?
+ case Intrinsic::vector_reduce_fmin:
+ return Builder.CreateMinimum(LHS, RHS); // FIXME: is this correct re: nans?
+ case Intrinsic::vector_reduce_fminimum:
+ return Builder.CreateMinimumNum(LHS, RHS); // FIXME: is this correct re: nans?
+ case Intrinsic::vector_reduce_fmul:
+ return Builder.CreateFMul(LHS, RHS);
+ case Intrinsic::vector_reduce_mul:
+ return Builder.CreateMul(LHS, RHS);
+ case Intrinsic::vector_reduce_or:
+ return Builder.CreateOr(LHS, RHS);
+ case Intrinsic::vector_reduce_xor:
+ return Builder.CreateXor(LHS, RHS);
+ default:
+ llvm_unreachable("unexpected intrinsic");
+ }
+ };
+
+ Value *ResultV;
+ if (Inst->getIntrinsicID() == Intrinsic::vector_reduce_fadd ||
+ Inst->getIntrinsicID() == Intrinsic::vector_reduce_fmul) {
+ ResultV = Builder.CreateVectorSplat(ElementCount::getFixed(M.getStride()),
+ Start);
+ for (unsigned VI = 0, VE = M.getNumVectors(); VI != VE; VI++)
+ ResultV = CreateVReduce(ResultV, M.getVector(VI));
+ } else {
+ ResultV = M.getVector(0);
+ for (unsigned VI = 1, VE = M.getNumVectors(); VI != VE; VI++)
+ ResultV = CreateVReduce(ResultV, M.getVector(VI));
+ }
+
+ auto CreateHReduce = [&](Value *V) {
+ switch (Inst->getIntrinsicID()) {
+ case Intrinsic::vector_reduce_add:
+ return Builder.CreateAddReduce(V);
+ case Intrinsic::vector_reduce_and:
+ return Builder.CreateAndReduce(V);
+ case Intrinsic::vector_reduce_fadd:
+ return Builder.CreateFAddReduce(Start, V);
+ case Intrinsic::vector_reduce_fmax:
+ return Builder.CreateFPMaxReduce(V);
+ case Intrinsic::vector_reduce_fmaximum:
+ return Builder.CreateFPMaximumReduce(V);
+ case Intrinsic::vector_reduce_fmin:
+ return Builder.CreateFPMinReduce(V);
+ case Intrinsic::vector_reduce_fminimum:
+ return Builder.CreateFPMinimumReduce(V);
+ case Intrinsic::vector_reduce_fmul:
+ return Builder.CreateFMulReduce(Start, V);
+ case Intrinsic::vector_reduce_mul:
+ return Builder.CreateMulReduce(V);
+ case Intrinsic::vector_reduce_or:
+ return Builder.CreateOrReduce(V);
+ case Intrinsic::vector_reduce_xor:
+ return Builder.CreateXorReduce(V);
+ default:
+ llvm_unreachable("unexpected intrinsic");
+ }
+ };
+
+ Value *Result = CreateHReduce(ResultV);
+ Inst->replaceAllUsesWith(Result);
+ Result->takeName(Inst);
+ Inst->eraseFromParent();
+ return true;
+ }
+
/// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
/// users with shape information, there's nothing to do: they will use the
/// cached value when they are lowered. For other users, \p Matrix is
@@ -1351,11 +1482,33 @@ class LowerMatrixIntrinsics {
ToRemove.push_back(Inst);
Value *Flattened = nullptr;
for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
- if (!ShapeMap.contains(U.getUser())) {
- if (!Flattened)
- Flattened = Matrix.embedInVector(Builder);
- U.set(Flattened);
+ if (ShapeMap.contains(U.getUser()))
+ continue;
+
+ if (auto *Intr = dyn_cast<IntrinsicInst>(U.getUser())) {
+ switch (Intr->getIntrinsicID()) {
+ case Intrinsic::vector_reduce_add:
+ case Intrinsic::vector_reduce_and:
+ case Intrinsic::vector_reduce_fadd:
+ case Intrinsic::vector_reduce_fmax:
+ case Intrinsic::vector_reduce_fmaximum:
+ case Intrinsic::vector_reduce_fmin:
+ case Intrinsic::vector_reduce_fminimum:
+ case Intrinsic::vector_reduce_fmul:
+ case Intrinsic::vector_reduce_mul:
+ case Intrinsic::vector_reduce_or:
+ case Intrinsic::vector_reduce_xor:
+ if (VisitReduce(Intr))
+ continue;
+ break;
+ default:
+ break;
+ }
}
+
+ if (!Flattened)
+ Flattened = Matrix.embedInVector(Builder);
+ U.set(Flattened);
}
}
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
new file mode 100644
index 0000000000000..41f65e01fec79
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
@@ -0,0 +1,200 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define i32 @reduce_add(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_add(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = add <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT: [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT: ret i32 [[REDUCE]]
+;
+ %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call i32 @llvm.vector.reduce.add(<8 x i32> %inv)
+ ret i32 %reduce
+}
+
+define i32 @reduce_and(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_and(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT: [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT: ret i32 [[REDUCE]]
+;
+ %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call i32 @llvm.vector.reduce.and(<8 x i32> %inv)
+ ret i32 %reduce
+}
+
+define i32 @reduce_or(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_or(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = or <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT: [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT: ret i32 [[REDUCE]]
+;
+ %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call i32 @llvm.vector.reduce.or(<8 x i32> %inv)
+ ret i32 %reduce
+}
+
+define i32 @reduce_mul(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_mul(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = mul <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT: [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.mul.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT: ret i32 [[REDUCE]]
+;
+ %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call i32 @llvm.vector.reduce.mul(<8 x i32> %inv)
+ ret i32 %reduce
+}
+
+define i32 @reduce_xor(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_xor(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = xor <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT: [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.xor.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT: ret i32 [[REDUCE]]
+;
+ %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call i32 @llvm.vector.reduce.xor(<8 x i32> %inv)
+ ret i32 %reduce
+}
+
+define float @reduce_fadd(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[REDUCE:%.*]] = call float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT: ret float [[REDUCE]]
+;
+ %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call float @llvm.vector.reduce.fadd(float 0., <8 x float> %inv)
+ ret float %reduce
+}
+
+define float @reduce_fadd_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd_reassoc(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = fadd <4 x float> zeroinitializer, [[COL_LOAD]]
+; CHECK-NEXT: [[TMP2:%.*]] = fadd <4 x float> [[TMP1]], [[COL_LOAD1]]
+; CHECK-NEXT: [[REDUCE:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[TMP2]])
+; CHECK-NEXT: ret float [[REDUCE]]
+;
+ %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call reassoc float @llvm.vector.reduce.fadd(float 0., <8 x float> %inv)
+ ret float %reduce
+}
+
+define float @reduce_fadd_weirdstart(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd_weirdstart(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[REDUCE:%.*]] = call reassoc float @llvm.vector.reduce.fadd.v8f32(float 1.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT: ret float [[REDUCE]]
+;
+ %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call reassoc float @llvm.vector.reduce.fadd(float 1., <8 x float> %inv)
+ ret float %reduce
+}
+
+define float @reduce_fmul_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmul_reassoc(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = fmul <4 x float> splat (float 1.000000e+00), [[COL_LOAD]]
+; CHECK-NEXT: [[TMP2:%.*]] = fmul <4 x float> [[TMP1]], [[COL_LOAD1]]
+; CHECK-NEXT: [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmul.v4f32(float 1.000000e+00, <4 x float> [[TMP2]])
+; CHECK-NEXT: ret float [[REDUCE]]
+;
+ %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call reassoc float @llvm.vector.reduce.fmul(float 1., <8 x float> %inv)
+ ret float %reduce
+}
+
+define float @reduce_fmul_weirdstart(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmul_weirdstart(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[REDUCE:%.*]] = call reassoc float @llvm.vector.reduce.fmul.v8f32(float 0.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT: ret float [[REDUCE]]
+;
+ %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call reassoc float @llvm.vector.reduce.fmul(float 0., <8 x float> %inv)
+ ret float %reduce
+}
+
+define float @reduce_fmax_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmax_reassoc(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.maximum.v4f32(<4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]])
+; CHECK-NEXT: [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmax.v4f32(<4 x float> [[TMP1]])
+; CHECK-NEXT: ret float [[REDUCE]]
+;
+ %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call reassoc float @llvm.vector.reduce.fmax(<8 x float> %inv)
+ ret float %reduce
+}
+
+define float @reduce_fmaximum_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmaximum_reassoc(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.maximumnum.v4f32(<4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]])
+; CHECK-NEXT: [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmaximum.v4f32(<4 x float> [[TMP1]])
+; CHECK-NEXT: ret float [[REDUCE]]
+;
+ %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call reassoc float @llvm.vector.reduce.fmaximum(<8 x float> %inv)
+ ret float %reduce
+}
+
+define float @reduce_fmin_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmin_reassoc(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.minimum.v4f32(<4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]])
+; CHECK-NEXT: [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmin.v4f32(<4 x float> [[TMP1]])
+; CHECK-NEXT: ret float [[REDUCE]]
+;
+ %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call reassoc float @llvm.vector.reduce.fmin(<8 x float> %inv)
+ ret float %reduce
+}
+
+define float @reduce_fminimum_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fminimum_reassoc(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.minimumnum.v4f32(<4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]])
+; CHECK-NEXT: [[REDUCE:%.*]] = call float @llvm.vector.reduce.fminimum.v4f32(<4 x float> [[TMP1]])
+; CHECK-NEXT: ret float [[REDUCE]]
+;
+ %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+ %reduce = call reassoc float @llvm.vector.reduce.fminimum(<8 x float> %inv)
+ ret float %reduce
+}
>From c343ce326d574a24d14a3a5e329e2b2fd6386d94 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Fri, 30 May 2025 13:53:14 -0700
Subject: [PATCH 2/7] refactor as tryLower
---
.../Scalar/LowerMatrixIntrinsics.cpp | 36 +++++--------------
1 file changed, 9 insertions(+), 27 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 8287bc094fed2..2df748d6c7fca 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1340,12 +1340,7 @@ class LowerMatrixIntrinsics {
return Builder.CreateAdd(Sum, Mul);
}
- bool VisitReduce(IntrinsicInst *Inst) {
- FastMathFlags FMF = getFastMathFlags(Inst);
-
- if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
- return false;
-
+ bool tryLowerIntrinsic(IntrinsicInst *Inst) {
Value *Start = nullptr;
Value *Op = nullptr;
switch (Inst->getIntrinsicID()) {
@@ -1366,7 +1361,7 @@ class LowerMatrixIntrinsics {
Op = Inst->getOperand(0);
break;
default:
- llvm_unreachable("unexpected intrinsic");
+ return false;
}
switch (Inst->getIntrinsicID()) {
@@ -1382,6 +1377,10 @@ class LowerMatrixIntrinsics {
break;
}
+ FastMathFlags FMF = getFastMathFlags(Inst);
+ if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
+ return false;
+
auto *I = Inst2ColumnMatrix.find(Op);
if (I == Inst2ColumnMatrix.end())
return false;
@@ -1485,26 +1484,9 @@ class LowerMatrixIntrinsics {
if (ShapeMap.contains(U.getUser()))
continue;
- if (auto *Intr = dyn_cast<IntrinsicInst>(U.getUser())) {
- switch (Intr->getIntrinsicID()) {
- case Intrinsic::vector_reduce_add:
- case Intrinsic::vector_reduce_and:
- case Intrinsic::vector_reduce_fadd:
- case Intrinsic::vector_reduce_fmax:
- case Intrinsic::vector_reduce_fmaximum:
- case Intrinsic::vector_reduce_fmin:
- case Intrinsic::vector_reduce_fminimum:
- case Intrinsic::vector_reduce_fmul:
- case Intrinsic::vector_reduce_mul:
- case Intrinsic::vector_reduce_or:
- case Intrinsic::vector_reduce_xor:
- if (VisitReduce(Intr))
- continue;
- break;
- default:
- break;
- }
- }
+ if (auto *Intr = dyn_cast<IntrinsicInst>(U.getUser()))
+ if (tryLowerIntrinsic(Intr))
+ continue;
if (!Flattened)
Flattened = Matrix.embedInVector(Builder);
>From 1f10a2ccd5482d4c1c693b8cf2559b1a311d69de Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Mon, 9 Jun 2025 16:50:59 -0700
Subject: [PATCH 3/7] reassoc only matters on fadd/fmul
---
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 11 +++++------
1 file changed, 5 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 83a83d88e842b..11bda92f47da5 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1380,10 +1380,13 @@ class LowerMatrixIntrinsics {
Value *Op = nullptr;
switch (Inst->getIntrinsicID()) {
case Intrinsic::vector_reduce_fadd:
- case Intrinsic::vector_reduce_fmul:
+ case Intrinsic::vector_reduce_fmul: {
+ FastMathFlags FMF = getFastMathFlags(Inst);
+ if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
+ return false;
Start = Inst->getOperand(0);
Op = Inst->getOperand(1);
- break;
+ } break;
case Intrinsic::vector_reduce_fmax:
case Intrinsic::vector_reduce_fmaximum:
case Intrinsic::vector_reduce_fmin:
@@ -1412,10 +1415,6 @@ class LowerMatrixIntrinsics {
break;
}
- FastMathFlags FMF = getFastMathFlags(Inst);
- if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
- return false;
-
auto *I = Inst2ColumnMatrix.find(Op);
if (I == Inst2ColumnMatrix.end())
return false;
>From a16e807bec9b9cb09fc636c56d4a4510fd611347 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Mon, 9 Jun 2025 16:51:28 -0700
Subject: [PATCH 4/7] looks correct w.r.t. nans, AFAICT
---
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 11bda92f47da5..342aedbd98d3b 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1432,13 +1432,13 @@ class LowerMatrixIntrinsics {
case Intrinsic::vector_reduce_fadd:
return Builder.CreateFAdd(LHS, RHS);
case Intrinsic::vector_reduce_fmax:
- return Builder.CreateMaximum(LHS, RHS); // FIXME: is this correct re: nans?
+ return Builder.CreateMaximum(LHS, RHS);
case Intrinsic::vector_reduce_fmaximum:
- return Builder.CreateMaximumNum(LHS, RHS); // FIXME: is this correct re: nans?
+ return Builder.CreateMaximumNum(LHS, RHS);
case Intrinsic::vector_reduce_fmin:
- return Builder.CreateMinimum(LHS, RHS); // FIXME: is this correct re: nans?
+ return Builder.CreateMinimum(LHS, RHS);
case Intrinsic::vector_reduce_fminimum:
- return Builder.CreateMinimumNum(LHS, RHS); // FIXME: is this correct re: nans?
+ return Builder.CreateMinimumNum(LHS, RHS);
case Intrinsic::vector_reduce_fmul:
return Builder.CreateFMul(LHS, RHS);
case Intrinsic::vector_reduce_mul:
>From e85328ff8140ddecfbc1a46d376e7edb7727dad3 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Mon, 9 Jun 2025 16:58:25 -0700
Subject: [PATCH 5/7] simplify how we exclude weirdstart cases
---
.../Scalar/LowerMatrixIntrinsics.cpp | 22 +++++++----------
.../LowerMatrixIntrinsics/reduce.ll | 24 +++++++++----------
2 files changed, 21 insertions(+), 25 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 342aedbd98d3b..41ecee70f7027 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1384,6 +1384,15 @@ class LowerMatrixIntrinsics {
FastMathFlags FMF = getFastMathFlags(Inst);
if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
return false;
+
+ if (match(Inst, m_Intrinsic<Intrinsic::vector_reduce_fadd>(
+ m_Unless(m_AnyZeroFP()), m_Value())))
+ return false;
+
+ if (match(Inst, m_Intrinsic<Intrinsic::vector_reduce_fmul>(
+ m_Unless(m_FPOne()), m_Value())))
+ return false;
+
Start = Inst->getOperand(0);
Op = Inst->getOperand(1);
} break;
@@ -1402,19 +1411,6 @@ class LowerMatrixIntrinsics {
return false;
}
- switch (Inst->getIntrinsicID()) {
- case Intrinsic::vector_reduce_fadd: {
- if (!match(Start, m_AnyZeroFP()))
- return false;
- } break;
- case Intrinsic::vector_reduce_fmul: {
- if (!match(Start, m_FPOne()))
- return false;
- } break;
- default:
- break;
- }
-
auto *I = Inst2ColumnMatrix.find(Op);
if (I == Inst2ColumnMatrix.end())
return false;
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
index 41f65e01fec79..1d8129b715fd0 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
@@ -143,8 +143,8 @@ define float @reduce_fmul_weirdstart(ptr %in, ptr %out) {
ret float %reduce
}
-define float @reduce_fmax_reassoc(ptr %in, ptr %out) {
-; CHECK-LABEL: @reduce_fmax_reassoc(
+define float @reduce_fmax(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmax(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
@@ -153,12 +153,12 @@ define float @reduce_fmax_reassoc(ptr %in, ptr %out) {
; CHECK-NEXT: ret float [[REDUCE]]
;
%inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
- %reduce = call reassoc float @llvm.vector.reduce.fmax(<8 x float> %inv)
+ %reduce = call float @llvm.vector.reduce.fmax(<8 x float> %inv)
ret float %reduce
}
-define float @reduce_fmaximum_reassoc(ptr %in, ptr %out) {
-; CHECK-LABEL: @reduce_fmaximum_reassoc(
+define float @reduce_fmaximum(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmaximum(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
@@ -167,12 +167,12 @@ define float @reduce_fmaximum_reassoc(ptr %in, ptr %out) {
; CHECK-NEXT: ret float [[REDUCE]]
;
%inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
- %reduce = call reassoc float @llvm.vector.reduce.fmaximum(<8 x float> %inv)
+ %reduce = call float @llvm.vector.reduce.fmaximum(<8 x float> %inv)
ret float %reduce
}
-define float @reduce_fmin_reassoc(ptr %in, ptr %out) {
-; CHECK-LABEL: @reduce_fmin_reassoc(
+define float @reduce_fmin(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmin(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
@@ -181,12 +181,12 @@ define float @reduce_fmin_reassoc(ptr %in, ptr %out) {
; CHECK-NEXT: ret float [[REDUCE]]
;
%inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
- %reduce = call reassoc float @llvm.vector.reduce.fmin(<8 x float> %inv)
+ %reduce = call float @llvm.vector.reduce.fmin(<8 x float> %inv)
ret float %reduce
}
-define float @reduce_fminimum_reassoc(ptr %in, ptr %out) {
-; CHECK-LABEL: @reduce_fminimum_reassoc(
+define float @reduce_fminimum(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fminimum(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
@@ -195,6 +195,6 @@ define float @reduce_fminimum_reassoc(ptr %in, ptr %out) {
; CHECK-NEXT: ret float [[REDUCE]]
;
%inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
- %reduce = call reassoc float @llvm.vector.reduce.fminimum(<8 x float> %inv)
+ %reduce = call float @llvm.vector.reduce.fminimum(<8 x float> %inv)
ret float %reduce
}
>From e9cfab5fdec66943c726116d0ec4cd55b2ec73d2 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Tue, 10 Jun 2025 10:57:29 -0700
Subject: [PATCH 6/7] remove debug print
---
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 41ecee70f7027..2d0bc881885ce 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1113,7 +1113,6 @@ class LowerMatrixIntrinsics {
if (!PoisonedInsts.empty()) {
// If we didn't remove all poisoned instructions, it's a hard error.
dbgs() << "Poisoned but present instructions:\n";
- Func.dump();
for (auto *I : PoisonedInsts)
dbgs() << *I << "\n";
llvm_unreachable("Poisoned but instruction not removed");
>From 77b59b9976a22aee4f0469f93b0f057962459a06 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Tue, 10 Jun 2025 11:07:40 -0700
Subject: [PATCH 7/7] use range-for
---
.../lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 13 +++++++++----
1 file changed, 9 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 2d0bc881885ce..591875c759d93 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -22,6 +22,7 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/LoopInfo.h"
@@ -425,6 +426,10 @@ class LowerMatrixIntrinsics {
return make_range(Vectors.begin(), Vectors.end());
}
+ iterator_range<SmallVector<Value *, 8>::const_iterator> vectors() const {
+ return make_range(Vectors.begin(), Vectors.end());
+ }
+
/// Embed the vectors of the matrix into a flat vector by concatenating
/// them.
Value *embedInVector(IRBuilder<> &Builder) const {
@@ -1452,12 +1457,12 @@ class LowerMatrixIntrinsics {
Inst->getIntrinsicID() == Intrinsic::vector_reduce_fmul) {
ResultV = Builder.CreateVectorSplat(ElementCount::getFixed(M.getStride()),
Start);
- for (unsigned VI = 0, VE = M.getNumVectors(); VI != VE; VI++)
- ResultV = CreateVReduce(ResultV, M.getVector(VI));
+ for (auto &Vector : M.vectors())
+ ResultV = CreateVReduce(ResultV, Vector);
} else {
ResultV = M.getVector(0);
- for (unsigned VI = 1, VE = M.getNumVectors(); VI != VE; VI++)
- ResultV = CreateVReduce(ResultV, M.getVector(VI));
+ for (auto &Vector : drop_begin(M.vectors()))
+ ResultV = CreateVReduce(ResultV, Vector);
}
auto CreateHReduce = [&](Value *V) {
More information about the llvm-commits
mailing list