[llvm] [Matrix] Lower vector reductions using shape info (PR #142055)

Jon Roelofs via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 16 08:07:27 PDT 2025


https://github.com/jroelofs updated https://github.com/llvm/llvm-project/pull/142055

>From 51e04878cce3971f222714aa380828c2ff5211d8 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 29 May 2025 15:46:45 -0700
Subject: [PATCH 01/10] [Matrix] Lower vector reductions using shape info

When possible, this avoids a bunch of shuffles in & out of the flattened
layout.
---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 161 +++++++++++++-
 .../LowerMatrixIntrinsics/reduce.ll           | 200 ++++++++++++++++++
 2 files changed, 357 insertions(+), 4 deletions(-)
 create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..8287bc094fed2 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -30,6 +30,7 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/CFG.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugInfoMetadata.h"
 #include "llvm/IR/Function.h"
@@ -40,6 +41,7 @@
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/Support/Alignment.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/LoopUtils.h"
@@ -1101,6 +1103,7 @@ class LowerMatrixIntrinsics {
     if (!PoisonedInsts.empty()) {
       // If we didn't remove all poisoned instructions, it's a hard error.
       dbgs() << "Poisoned but present instructions:\n";
+      Func.dump();
       for (auto *I : PoisonedInsts)
         dbgs() << *I << "\n";
       llvm_unreachable("Poisoned but instruction not removed");
@@ -1337,6 +1340,134 @@ class LowerMatrixIntrinsics {
     return Builder.CreateAdd(Sum, Mul);
   }
 
+  bool VisitReduce(IntrinsicInst *Inst) {
+    FastMathFlags FMF = getFastMathFlags(Inst);
+
+    if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
+      return false;
+
+    Value *Start = nullptr;
+    Value *Op = nullptr;
+    switch (Inst->getIntrinsicID()) {
+    case Intrinsic::vector_reduce_fadd:
+    case Intrinsic::vector_reduce_fmul:
+      Start = Inst->getOperand(0);
+      Op = Inst->getOperand(1);
+      break;
+    case Intrinsic::vector_reduce_fmax:
+    case Intrinsic::vector_reduce_fmaximum:
+    case Intrinsic::vector_reduce_fmin:
+    case Intrinsic::vector_reduce_fminimum:
+    case Intrinsic::vector_reduce_add:
+    case Intrinsic::vector_reduce_and:
+    case Intrinsic::vector_reduce_mul:
+    case Intrinsic::vector_reduce_or:
+    case Intrinsic::vector_reduce_xor:
+      Op = Inst->getOperand(0);
+      break;
+    default:
+      llvm_unreachable("unexpected intrinsic");
+    }
+
+    switch (Inst->getIntrinsicID()) {
+    case Intrinsic::vector_reduce_fadd: {
+      if (!match(Start, m_AnyZeroFP()))
+        return false;
+    } break;
+    case Intrinsic::vector_reduce_fmul: {
+      if (!match(Start, m_FPOne()))
+        return false;
+    } break;
+    default:
+      break;
+    }
+
+    auto *I = Inst2ColumnMatrix.find(Op);
+    if (I == Inst2ColumnMatrix.end())
+      return false;
+
+    IRBuilder<> Builder(Inst);
+
+    const MatrixTy &M = I->second;
+
+    auto CreateVReduce = [&](Value *LHS, Value *RHS) {
+      switch (Inst->getIntrinsicID()) {
+      case Intrinsic::vector_reduce_add:
+        return Builder.CreateAdd(LHS, RHS);
+      case Intrinsic::vector_reduce_and:
+        return Builder.CreateAnd(LHS, RHS);
+      case Intrinsic::vector_reduce_fadd:
+        return Builder.CreateFAdd(LHS, RHS);
+      case Intrinsic::vector_reduce_fmax:
+        return Builder.CreateMaximum(LHS, RHS); // FIXME: is this correct re: nans?
+      case Intrinsic::vector_reduce_fmaximum:
+        return Builder.CreateMaximumNum(LHS, RHS); // FIXME: is this correct re: nans?
+      case Intrinsic::vector_reduce_fmin:
+        return Builder.CreateMinimum(LHS, RHS); // FIXME: is this correct re: nans?
+      case Intrinsic::vector_reduce_fminimum:
+        return Builder.CreateMinimumNum(LHS, RHS); // FIXME: is this correct re: nans?
+      case Intrinsic::vector_reduce_fmul:
+        return Builder.CreateFMul(LHS, RHS);
+      case Intrinsic::vector_reduce_mul:
+        return Builder.CreateMul(LHS, RHS);
+      case Intrinsic::vector_reduce_or:
+        return Builder.CreateOr(LHS, RHS);
+      case Intrinsic::vector_reduce_xor:
+        return Builder.CreateXor(LHS, RHS);
+      default:
+        llvm_unreachable("unexpected intrinsic");
+      }
+    };
+
+    Value *ResultV;
+    if (Inst->getIntrinsicID() == Intrinsic::vector_reduce_fadd ||
+        Inst->getIntrinsicID() == Intrinsic::vector_reduce_fmul) {
+      ResultV = Builder.CreateVectorSplat(ElementCount::getFixed(M.getStride()),
+                                          Start);
+      for (unsigned VI = 0, VE = M.getNumVectors(); VI != VE; VI++)
+        ResultV = CreateVReduce(ResultV, M.getVector(VI));
+    } else {
+      ResultV = M.getVector(0);
+      for (unsigned VI = 1, VE = M.getNumVectors(); VI != VE; VI++)
+        ResultV = CreateVReduce(ResultV, M.getVector(VI));
+    }
+
+    auto CreateHReduce = [&](Value *V) {
+      switch (Inst->getIntrinsicID()) {
+      case Intrinsic::vector_reduce_add:
+        return Builder.CreateAddReduce(V);
+      case Intrinsic::vector_reduce_and:
+        return Builder.CreateAndReduce(V);
+      case Intrinsic::vector_reduce_fadd:
+        return Builder.CreateFAddReduce(Start, V);
+      case Intrinsic::vector_reduce_fmax:
+        return Builder.CreateFPMaxReduce(V);
+      case Intrinsic::vector_reduce_fmaximum:
+        return Builder.CreateFPMaximumReduce(V);
+      case Intrinsic::vector_reduce_fmin:
+        return Builder.CreateFPMinReduce(V);
+      case Intrinsic::vector_reduce_fminimum:
+        return Builder.CreateFPMinimumReduce(V);
+      case Intrinsic::vector_reduce_fmul:
+        return Builder.CreateFMulReduce(Start, V);
+      case Intrinsic::vector_reduce_mul:
+        return Builder.CreateMulReduce(V);
+      case Intrinsic::vector_reduce_or:
+        return Builder.CreateOrReduce(V);
+      case Intrinsic::vector_reduce_xor:
+        return Builder.CreateXorReduce(V);
+      default:
+        llvm_unreachable("unexpected intrinsic");
+      }
+    };
+
+    Value *Result = CreateHReduce(ResultV);
+    Inst->replaceAllUsesWith(Result);
+    Result->takeName(Inst);
+    Inst->eraseFromParent();
+    return true;
+  }
+
   /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
   /// users with shape information, there's nothing to do: they will use the
   /// cached value when they are lowered. For other users, \p Matrix is
@@ -1351,11 +1482,33 @@ class LowerMatrixIntrinsics {
     ToRemove.push_back(Inst);
     Value *Flattened = nullptr;
     for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
-      if (!ShapeMap.contains(U.getUser())) {
-        if (!Flattened)
-          Flattened = Matrix.embedInVector(Builder);
-        U.set(Flattened);
+      if (ShapeMap.contains(U.getUser()))
+        continue;
+
+      if (auto *Intr = dyn_cast<IntrinsicInst>(U.getUser())) {
+        switch (Intr->getIntrinsicID()) {
+        case Intrinsic::vector_reduce_add:
+        case Intrinsic::vector_reduce_and:
+        case Intrinsic::vector_reduce_fadd:
+        case Intrinsic::vector_reduce_fmax:
+        case Intrinsic::vector_reduce_fmaximum:
+        case Intrinsic::vector_reduce_fmin:
+        case Intrinsic::vector_reduce_fminimum:
+        case Intrinsic::vector_reduce_fmul:
+        case Intrinsic::vector_reduce_mul:
+        case Intrinsic::vector_reduce_or:
+        case Intrinsic::vector_reduce_xor:
+          if (VisitReduce(Intr))
+            continue;
+          break;
+        default:
+          break;
+        }
       }
+
+      if (!Flattened)
+        Flattened = Matrix.embedInVector(Builder);
+      U.set(Flattened);
     }
   }
 
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
new file mode 100644
index 0000000000000..41f65e01fec79
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
@@ -0,0 +1,200 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define i32 @reduce_add(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_add(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = add <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call i32 @llvm.vector.reduce.add(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define i32 @reduce_and(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_and(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = and <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call i32 @llvm.vector.reduce.and(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define i32 @reduce_or(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_or(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = or <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call i32 @llvm.vector.reduce.or(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define i32 @reduce_mul(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_mul(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = mul <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.mul.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call i32 @llvm.vector.reduce.mul(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define i32 @reduce_xor(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_xor(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = xor <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.xor.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call i32 @llvm.vector.reduce.xor(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define float @reduce_fadd(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call float @llvm.vector.reduce.fadd(float 0., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fadd_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = fadd <4 x float> zeroinitializer, [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fadd <4 x float> [[TMP1]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[TMP2]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fadd(float 0., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fadd_weirdstart(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd_weirdstart(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[REDUCE:%.*]] = call reassoc float @llvm.vector.reduce.fadd.v8f32(float 1.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fadd(float 1., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fmul_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmul_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul <4 x float> splat (float 1.000000e+00), [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul <4 x float> [[TMP1]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmul.v4f32(float 1.000000e+00, <4 x float> [[TMP2]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fmul(float 1., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fmul_weirdstart(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmul_weirdstart(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[REDUCE:%.*]] = call reassoc float @llvm.vector.reduce.fmul.v8f32(float 0.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fmul(float 0., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fmax_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmax_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @llvm.maximum.v4f32(<4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]])
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmax.v4f32(<4 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fmax(<8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fmaximum_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmaximum_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @llvm.maximumnum.v4f32(<4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]])
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmaximum.v4f32(<4 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fmaximum(<8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fmin_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmin_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @llvm.minimum.v4f32(<4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]])
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmin.v4f32(<4 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fmin(<8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fminimum_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fminimum_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @llvm.minimumnum.v4f32(<4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]])
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fminimum.v4f32(<4 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fminimum(<8 x float> %inv)
+  ret float %reduce
+}

>From c343ce326d574a24d14a3a5e329e2b2fd6386d94 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Fri, 30 May 2025 13:53:14 -0700
Subject: [PATCH 02/10] refactor as tryLower

---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 36 +++++--------------
 1 file changed, 9 insertions(+), 27 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 8287bc094fed2..2df748d6c7fca 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1340,12 +1340,7 @@ class LowerMatrixIntrinsics {
     return Builder.CreateAdd(Sum, Mul);
   }
 
-  bool VisitReduce(IntrinsicInst *Inst) {
-    FastMathFlags FMF = getFastMathFlags(Inst);
-
-    if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
-      return false;
-
+  bool tryLowerIntrinsic(IntrinsicInst *Inst) {
     Value *Start = nullptr;
     Value *Op = nullptr;
     switch (Inst->getIntrinsicID()) {
@@ -1366,7 +1361,7 @@ class LowerMatrixIntrinsics {
       Op = Inst->getOperand(0);
       break;
     default:
-      llvm_unreachable("unexpected intrinsic");
+      return false;
     }
 
     switch (Inst->getIntrinsicID()) {
@@ -1382,6 +1377,10 @@ class LowerMatrixIntrinsics {
       break;
     }
 
+    FastMathFlags FMF = getFastMathFlags(Inst);
+    if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
+      return false;
+
     auto *I = Inst2ColumnMatrix.find(Op);
     if (I == Inst2ColumnMatrix.end())
       return false;
@@ -1485,26 +1484,9 @@ class LowerMatrixIntrinsics {
       if (ShapeMap.contains(U.getUser()))
         continue;
 
-      if (auto *Intr = dyn_cast<IntrinsicInst>(U.getUser())) {
-        switch (Intr->getIntrinsicID()) {
-        case Intrinsic::vector_reduce_add:
-        case Intrinsic::vector_reduce_and:
-        case Intrinsic::vector_reduce_fadd:
-        case Intrinsic::vector_reduce_fmax:
-        case Intrinsic::vector_reduce_fmaximum:
-        case Intrinsic::vector_reduce_fmin:
-        case Intrinsic::vector_reduce_fminimum:
-        case Intrinsic::vector_reduce_fmul:
-        case Intrinsic::vector_reduce_mul:
-        case Intrinsic::vector_reduce_or:
-        case Intrinsic::vector_reduce_xor:
-          if (VisitReduce(Intr))
-            continue;
-          break;
-        default:
-          break;
-        }
-      }
+      if (auto *Intr = dyn_cast<IntrinsicInst>(U.getUser()))
+        if (tryLowerIntrinsic(Intr))
+          continue;
 
       if (!Flattened)
         Flattened = Matrix.embedInVector(Builder);

>From 1f10a2ccd5482d4c1c693b8cf2559b1a311d69de Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Mon, 9 Jun 2025 16:50:59 -0700
Subject: [PATCH 03/10] reassoc only matters on fadd/fmul

---
 llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 83a83d88e842b..11bda92f47da5 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1380,10 +1380,13 @@ class LowerMatrixIntrinsics {
     Value *Op = nullptr;
     switch (Inst->getIntrinsicID()) {
     case Intrinsic::vector_reduce_fadd:
-    case Intrinsic::vector_reduce_fmul:
+    case Intrinsic::vector_reduce_fmul: {
+      FastMathFlags FMF = getFastMathFlags(Inst);
+      if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
+        return false;
       Start = Inst->getOperand(0);
       Op = Inst->getOperand(1);
-      break;
+    } break;
     case Intrinsic::vector_reduce_fmax:
     case Intrinsic::vector_reduce_fmaximum:
     case Intrinsic::vector_reduce_fmin:
@@ -1412,10 +1415,6 @@ class LowerMatrixIntrinsics {
       break;
     }
 
-    FastMathFlags FMF = getFastMathFlags(Inst);
-    if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
-      return false;
-
     auto *I = Inst2ColumnMatrix.find(Op);
     if (I == Inst2ColumnMatrix.end())
       return false;

>From a16e807bec9b9cb09fc636c56d4a4510fd611347 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Mon, 9 Jun 2025 16:51:28 -0700
Subject: [PATCH 04/10] looks correct w.r.t. nans, AFAICT

---
 llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 11bda92f47da5..342aedbd98d3b 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1432,13 +1432,13 @@ class LowerMatrixIntrinsics {
       case Intrinsic::vector_reduce_fadd:
         return Builder.CreateFAdd(LHS, RHS);
       case Intrinsic::vector_reduce_fmax:
-        return Builder.CreateMaximum(LHS, RHS); // FIXME: is this correct re: nans?
+        return Builder.CreateMaximum(LHS, RHS);
       case Intrinsic::vector_reduce_fmaximum:
-        return Builder.CreateMaximumNum(LHS, RHS); // FIXME: is this correct re: nans?
+        return Builder.CreateMaximumNum(LHS, RHS);
       case Intrinsic::vector_reduce_fmin:
-        return Builder.CreateMinimum(LHS, RHS); // FIXME: is this correct re: nans?
+        return Builder.CreateMinimum(LHS, RHS);
       case Intrinsic::vector_reduce_fminimum:
-        return Builder.CreateMinimumNum(LHS, RHS); // FIXME: is this correct re: nans?
+        return Builder.CreateMinimumNum(LHS, RHS);
       case Intrinsic::vector_reduce_fmul:
         return Builder.CreateFMul(LHS, RHS);
       case Intrinsic::vector_reduce_mul:

>From e85328ff8140ddecfbc1a46d376e7edb7727dad3 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Mon, 9 Jun 2025 16:58:25 -0700
Subject: [PATCH 05/10] simplify how we exclude weirdstart cases

---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 22 +++++++----------
 .../LowerMatrixIntrinsics/reduce.ll           | 24 +++++++++----------
 2 files changed, 21 insertions(+), 25 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 342aedbd98d3b..41ecee70f7027 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1384,6 +1384,15 @@ class LowerMatrixIntrinsics {
       FastMathFlags FMF = getFastMathFlags(Inst);
       if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
         return false;
+
+      if (match(Inst, m_Intrinsic<Intrinsic::vector_reduce_fadd>(
+                          m_Unless(m_AnyZeroFP()), m_Value())))
+        return false;
+
+      if (match(Inst, m_Intrinsic<Intrinsic::vector_reduce_fmul>(
+                          m_Unless(m_FPOne()), m_Value())))
+        return false;
+
       Start = Inst->getOperand(0);
       Op = Inst->getOperand(1);
     } break;
@@ -1402,19 +1411,6 @@ class LowerMatrixIntrinsics {
       return false;
     }
 
-    switch (Inst->getIntrinsicID()) {
-    case Intrinsic::vector_reduce_fadd: {
-      if (!match(Start, m_AnyZeroFP()))
-        return false;
-    } break;
-    case Intrinsic::vector_reduce_fmul: {
-      if (!match(Start, m_FPOne()))
-        return false;
-    } break;
-    default:
-      break;
-    }
-
     auto *I = Inst2ColumnMatrix.find(Op);
     if (I == Inst2ColumnMatrix.end())
       return false;
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
index 41f65e01fec79..1d8129b715fd0 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
@@ -143,8 +143,8 @@ define float @reduce_fmul_weirdstart(ptr %in, ptr %out) {
   ret float %reduce
 }
 
-define float @reduce_fmax_reassoc(ptr %in, ptr %out) {
-; CHECK-LABEL: @reduce_fmax_reassoc(
+define float @reduce_fmax(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmax(
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
@@ -153,12 +153,12 @@ define float @reduce_fmax_reassoc(ptr %in, ptr %out) {
 ; CHECK-NEXT:    ret float [[REDUCE]]
 ;
   %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
-  %reduce = call reassoc float @llvm.vector.reduce.fmax(<8 x float> %inv)
+  %reduce = call float @llvm.vector.reduce.fmax(<8 x float> %inv)
   ret float %reduce
 }
 
-define float @reduce_fmaximum_reassoc(ptr %in, ptr %out) {
-; CHECK-LABEL: @reduce_fmaximum_reassoc(
+define float @reduce_fmaximum(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmaximum(
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
@@ -167,12 +167,12 @@ define float @reduce_fmaximum_reassoc(ptr %in, ptr %out) {
 ; CHECK-NEXT:    ret float [[REDUCE]]
 ;
   %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
-  %reduce = call reassoc float @llvm.vector.reduce.fmaximum(<8 x float> %inv)
+  %reduce = call float @llvm.vector.reduce.fmaximum(<8 x float> %inv)
   ret float %reduce
 }
 
-define float @reduce_fmin_reassoc(ptr %in, ptr %out) {
-; CHECK-LABEL: @reduce_fmin_reassoc(
+define float @reduce_fmin(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmin(
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
@@ -181,12 +181,12 @@ define float @reduce_fmin_reassoc(ptr %in, ptr %out) {
 ; CHECK-NEXT:    ret float [[REDUCE]]
 ;
   %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
-  %reduce = call reassoc float @llvm.vector.reduce.fmin(<8 x float> %inv)
+  %reduce = call float @llvm.vector.reduce.fmin(<8 x float> %inv)
   ret float %reduce
 }
 
-define float @reduce_fminimum_reassoc(ptr %in, ptr %out) {
-; CHECK-LABEL: @reduce_fminimum_reassoc(
+define float @reduce_fminimum(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fminimum(
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
@@ -195,6 +195,6 @@ define float @reduce_fminimum_reassoc(ptr %in, ptr %out) {
 ; CHECK-NEXT:    ret float [[REDUCE]]
 ;
   %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
-  %reduce = call reassoc float @llvm.vector.reduce.fminimum(<8 x float> %inv)
+  %reduce = call float @llvm.vector.reduce.fminimum(<8 x float> %inv)
   ret float %reduce
 }

>From e9cfab5fdec66943c726116d0ec4cd55b2ec73d2 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Tue, 10 Jun 2025 10:57:29 -0700
Subject: [PATCH 06/10] remove debug print

---
 llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 41ecee70f7027..2d0bc881885ce 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1113,7 +1113,6 @@ class LowerMatrixIntrinsics {
     if (!PoisonedInsts.empty()) {
       // If we didn't remove all poisoned instructions, it's a hard error.
       dbgs() << "Poisoned but present instructions:\n";
-      Func.dump();
       for (auto *I : PoisonedInsts)
         dbgs() << *I << "\n";
       llvm_unreachable("Poisoned but instruction not removed");

>From 77b59b9976a22aee4f0469f93b0f057962459a06 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Tue, 10 Jun 2025 11:07:40 -0700
Subject: [PATCH 07/10] use range-for

---
 .../lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 2d0bc881885ce..591875c759d93 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -22,6 +22,7 @@
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/DomTreeUpdater.h"
 #include "llvm/Analysis/LoopInfo.h"
@@ -425,6 +426,10 @@ class LowerMatrixIntrinsics {
       return make_range(Vectors.begin(), Vectors.end());
     }
 
+    iterator_range<SmallVector<Value *, 8>::const_iterator> vectors() const {
+      return make_range(Vectors.begin(), Vectors.end());
+    }
+
     /// Embed the vectors of the matrix into a flat vector by concatenating
     /// them.
     Value *embedInVector(IRBuilder<> &Builder) const {
@@ -1452,12 +1457,12 @@ class LowerMatrixIntrinsics {
         Inst->getIntrinsicID() == Intrinsic::vector_reduce_fmul) {
       ResultV = Builder.CreateVectorSplat(ElementCount::getFixed(M.getStride()),
                                           Start);
-      for (unsigned VI = 0, VE = M.getNumVectors(); VI != VE; VI++)
-        ResultV = CreateVReduce(ResultV, M.getVector(VI));
+      for (auto &Vector : M.vectors())
+        ResultV = CreateVReduce(ResultV, Vector);
     } else {
       ResultV = M.getVector(0);
-      for (unsigned VI = 1, VE = M.getNumVectors(); VI != VE; VI++)
-        ResultV = CreateVReduce(ResultV, M.getVector(VI));
+      for (auto &Vector : drop_begin(M.vectors()))
+        ResultV = CreateVReduce(ResultV, Vector);
     }
 
     auto CreateHReduce = [&](Value *V) {

>From e2815dffc4e92193ecce8a3a71ddbcfbad3e3374 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Tue, 10 Jun 2025 12:16:48 -0700
Subject: [PATCH 08/10] clang-format

---
 llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 4261011828da2..29733e5020691 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -19,11 +19,11 @@
 
 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
 #include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Statistic.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/DomTreeUpdater.h"
 #include "llvm/Analysis/LoopInfo.h"

>From f43d02eb95fc476abc2982edd19439ecb18bf75a Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Mon, 16 Jun 2025 07:13:07 -0700
Subject: [PATCH 09/10] review feedback

---
 .../Scalar/LowerMatrixIntrinsics.cpp          |  22 ++--
 .../LowerMatrixIntrinsics/reduce.ll           | 116 ++++++++++++++++--
 2 files changed, 123 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 29733e5020691..a213d03380f06 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -40,6 +40,7 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/MatrixBuilder.h"
+#include "llvm/IR/Operator.h"
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/Support/Alignment.h"
 #include "llvm/Support/CommandLine.h"
@@ -438,10 +439,6 @@ class LowerMatrixIntrinsics {
       return make_range(Vectors.begin(), Vectors.end());
     }
 
-    iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
-      return make_range(Vectors.begin(), Vectors.end());
-    }
-
     iterator_range<SmallVector<Value *, 8>::const_iterator> vectors() const {
       return make_range(Vectors.begin(), Vectors.end());
     }
@@ -1420,12 +1417,12 @@ class LowerMatrixIntrinsics {
   }
 
   bool tryLowerIntrinsic(IntrinsicInst *Inst) {
+    FastMathFlags FMF = getFastMathFlags(Inst);
     Value *Start = nullptr;
     Value *Op = nullptr;
     switch (Inst->getIntrinsicID()) {
     case Intrinsic::vector_reduce_fadd:
     case Intrinsic::vector_reduce_fmul: {
-      FastMathFlags FMF = getFastMathFlags(Inst);
       if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
         return false;
 
@@ -1497,12 +1494,20 @@ class LowerMatrixIntrinsics {
         Inst->getIntrinsicID() == Intrinsic::vector_reduce_fmul) {
       ResultV = Builder.CreateVectorSplat(ElementCount::getFixed(M.getStride()),
                                           Start);
-      for (auto &Vector : M.vectors())
+      for (auto &Vector : M.vectors()) {
         ResultV = CreateVReduce(ResultV, Vector);
+        if (isa<FPMathOperator>(ResultV))
+          if (auto *ResultVI = dyn_cast<Instruction>(ResultV))
+            ResultVI->setFastMathFlags(FMF);
+      }
     } else {
       ResultV = M.getVector(0);
-      for (auto &Vector : drop_begin(M.vectors()))
+      for (auto &Vector : drop_begin(M.vectors())) {
         ResultV = CreateVReduce(ResultV, Vector);
+        if (isa<FPMathOperator>(ResultV))
+          if (auto *ResultVI = dyn_cast<Instruction>(ResultV))
+            ResultVI->setFastMathFlags(FMF);
+      }
     }
 
     auto CreateHReduce = [&](Value *V) {
@@ -1535,6 +1540,9 @@ class LowerMatrixIntrinsics {
     };
 
     Value *Result = CreateHReduce(ResultV);
+    if (isa<FPMathOperator>(Result))
+      if (auto *ResultI = dyn_cast<Instruction>(Result))
+        ResultI->setFastMathFlags(FMF);
     Inst->replaceAllUsesWith(Result);
     Result->takeName(Inst);
     Inst->eraseFromParent();
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
index 1d8129b715fd0..503378bebb85b 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
@@ -1,8 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
 
-define i32 @reduce_add(ptr %in, ptr %out) {
-; CHECK-LABEL: @reduce_add(
+define i32 @reduce_add_4x2(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_add_4x2(
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
@@ -15,6 +15,77 @@ define i32 @reduce_add(ptr %in, ptr %out) {
   ret i32 %reduce
 }
 
+define i32 @reduce_add_8x1(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_add_8x1(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <8 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[COL_LOAD]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 8, i1 1, i32 8, i32 1)
+  %reduce = call i32 @llvm.vector.reduce.add(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define i32 @reduce_add_1x8(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_add_1x8(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <1 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 1
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <1 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load volatile <1 x i32>, ptr [[VEC_GEP2]], align 4
+; CHECK-NEXT:    [[VEC_GEP4:%.*]] = getelementptr i32, ptr [[IN]], i64 3
+; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load volatile <1 x i32>, ptr [[VEC_GEP4]], align 4
+; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD7:%.*]] = load volatile <1 x i32>, ptr [[VEC_GEP6]], align 4
+; CHECK-NEXT:    [[VEC_GEP8:%.*]] = getelementptr i32, ptr [[IN]], i64 5
+; CHECK-NEXT:    [[COL_LOAD9:%.*]] = load volatile <1 x i32>, ptr [[VEC_GEP8]], align 4
+; CHECK-NEXT:    [[VEC_GEP10:%.*]] = getelementptr i32, ptr [[IN]], i64 6
+; CHECK-NEXT:    [[COL_LOAD11:%.*]] = load volatile <1 x i32>, ptr [[VEC_GEP10]], align 4
+; CHECK-NEXT:    [[VEC_GEP12:%.*]] = getelementptr i32, ptr [[IN]], i64 7
+; CHECK-NEXT:    [[COL_LOAD13:%.*]] = load volatile <1 x i32>, ptr [[VEC_GEP12]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = add <1 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = add <1 x i32> [[TMP1]], [[COL_LOAD3]]
+; CHECK-NEXT:    [[TMP3:%.*]] = add <1 x i32> [[TMP2]], [[COL_LOAD5]]
+; CHECK-NEXT:    [[TMP4:%.*]] = add <1 x i32> [[TMP3]], [[COL_LOAD7]]
+; CHECK-NEXT:    [[TMP5:%.*]] = add <1 x i32> [[TMP4]], [[COL_LOAD9]]
+; CHECK-NEXT:    [[TMP6:%.*]] = add <1 x i32> [[TMP5]], [[COL_LOAD11]]
+; CHECK-NEXT:    [[TMP7:%.*]] = add <1 x i32> [[TMP6]], [[COL_LOAD13]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.add.v1i32(<1 x i32> [[TMP7]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 1, i1 1, i32 1, i32 8)
+  %reduce = call i32 @llvm.vector.reduce.add(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define i32 @reduce_add_1x3(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_add_1x3(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <1 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 1
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <1 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load volatile <1 x i32>, ptr [[VEC_GEP2]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = add <1 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = add <1 x i32> [[TMP1]], [[COL_LOAD3]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.add.v1i32(<1 x i32> [[TMP2]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <3 x i32> @llvm.matrix.column.major.load(ptr %in, i64 1, i1 1, i32 1, i32 3)
+  %reduce = call i32 @llvm.vector.reduce.add(<3 x i32> %inv)
+  ret i32 %reduce
+}
+
+define i32 @reduce_add_3x1(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_add_3x1(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <3 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.add.v3i32(<3 x i32> [[COL_LOAD]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <3 x i32> @llvm.matrix.column.major.load(ptr %in, i64 3, i1 1, i32 3, i32 1)
+  %reduce = call i32 @llvm.vector.reduce.add(<3 x i32> %inv)
+  ret i32 %reduce
+}
+
 define i32 @reduce_and(ptr %in, ptr %out) {
 ; CHECK-LABEL: @reduce_and(
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
@@ -90,9 +161,9 @@ define float @reduce_fadd_reassoc(ptr %in, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
-; CHECK-NEXT:    [[TMP1:%.*]] = fadd <4 x float> zeroinitializer, [[COL_LOAD]]
-; CHECK-NEXT:    [[TMP2:%.*]] = fadd <4 x float> [[TMP1]], [[COL_LOAD1]]
-; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[TMP2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = fadd reassoc <4 x float> zeroinitializer, [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fadd reassoc <4 x float> [[TMP1]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call reassoc float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[TMP2]])
 ; CHECK-NEXT:    ret float [[REDUCE]]
 ;
   %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
@@ -100,6 +171,35 @@ define float @reduce_fadd_reassoc(ptr %in, ptr %out) {
   ret float %reduce
 }
 
+define float @reduce_fadd_contract(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd_contract(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[REDUCE:%.*]] = call contract float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call contract float @llvm.vector.reduce.fadd(float 0., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fadd_reassoccontract(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd_reassoccontract(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = fadd reassoc contract <4 x float> zeroinitializer, [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fadd reassoc contract <4 x float> [[TMP1]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call reassoc contract float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[TMP2]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc contract float @llvm.vector.reduce.fadd(float 0., <8 x float> %inv)
+  ret float %reduce
+}
+
 define float @reduce_fadd_weirdstart(ptr %in, ptr %out) {
 ; CHECK-LABEL: @reduce_fadd_weirdstart(
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
@@ -119,9 +219,9 @@ define float @reduce_fmul_reassoc(ptr %in, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
-; CHECK-NEXT:    [[TMP1:%.*]] = fmul <4 x float> splat (float 1.000000e+00), [[COL_LOAD]]
-; CHECK-NEXT:    [[TMP2:%.*]] = fmul <4 x float> [[TMP1]], [[COL_LOAD1]]
-; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmul.v4f32(float 1.000000e+00, <4 x float> [[TMP2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc <4 x float> splat (float 1.000000e+00), [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul reassoc <4 x float> [[TMP1]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call reassoc float @llvm.vector.reduce.fmul.v4f32(float 1.000000e+00, <4 x float> [[TMP2]])
 ; CHECK-NEXT:    ret float [[REDUCE]]
 ;
   %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)

>From 90c2640552fd04bfcc2d3785dac6b0ca28087f30 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Mon, 16 Jun 2025 07:52:31 -0700
Subject: [PATCH 10/10] review feedback: make reductions have 1x1 shape info

---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 308 ++++++++++--------
 1 file changed, 171 insertions(+), 137 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index a213d03380f06..3be7cfec5d489 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -298,6 +298,25 @@ computeShapeInfoForInst(Instruction *I,
       return OpShape->second;
   }
 
+  if (auto *II = dyn_cast<IntrinsicInst>(I)) {
+    switch (II->getIntrinsicID()) {
+    case Intrinsic::vector_reduce_fadd:
+    case Intrinsic::vector_reduce_fmul:
+    case Intrinsic::vector_reduce_fmax:
+    case Intrinsic::vector_reduce_fmaximum:
+    case Intrinsic::vector_reduce_fmin:
+    case Intrinsic::vector_reduce_fminimum:
+    case Intrinsic::vector_reduce_add:
+    case Intrinsic::vector_reduce_and:
+    case Intrinsic::vector_reduce_mul:
+    case Intrinsic::vector_reduce_or:
+    case Intrinsic::vector_reduce_xor:
+      return ShapeInfo(1, 1);
+    default:
+      break;
+    }
+  }
+
   if (isUniformShape(I)) {
     // Find the first operand that has a known shape and use that.
     for (auto &Op : I->operands()) {
@@ -669,7 +688,31 @@ class LowerMatrixIntrinsics {
       case Intrinsic::matrix_transpose:
       case Intrinsic::matrix_column_major_load:
       case Intrinsic::matrix_column_major_store:
+      case Intrinsic::vector_reduce_fmax:
+      case Intrinsic::vector_reduce_fmaximum:
+      case Intrinsic::vector_reduce_fmin:
+      case Intrinsic::vector_reduce_fminimum:
+      case Intrinsic::vector_reduce_add:
+      case Intrinsic::vector_reduce_and:
+      case Intrinsic::vector_reduce_mul:
+      case Intrinsic::vector_reduce_or:
+      case Intrinsic::vector_reduce_xor:
         return true;
+      case Intrinsic::vector_reduce_fadd:
+      case Intrinsic::vector_reduce_fmul: {
+        FastMathFlags FMF = getFastMathFlags(Inst);
+        if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
+          return false;
+
+        if (match(Inst, m_Intrinsic<Intrinsic::vector_reduce_fadd>(
+                            m_Unless(m_AnyZeroFP()), m_Value())))
+          return false;
+
+        if (match(Inst, m_Intrinsic<Intrinsic::vector_reduce_fmul>(
+                            m_Unless(m_FPOne()), m_Value())))
+          return false;
+        return true;
+      }
       default:
         return isUniformShape(II);
       }
@@ -1206,6 +1249,134 @@ class LowerMatrixIntrinsics {
                        Builder);
       return;
     }
+    case Intrinsic::vector_reduce_fadd:
+    case Intrinsic::vector_reduce_fmul:
+    case Intrinsic::vector_reduce_fmax:
+    case Intrinsic::vector_reduce_fmaximum:
+    case Intrinsic::vector_reduce_fmin:
+    case Intrinsic::vector_reduce_fminimum:
+    case Intrinsic::vector_reduce_add:
+    case Intrinsic::vector_reduce_and:
+    case Intrinsic::vector_reduce_mul:
+    case Intrinsic::vector_reduce_or:
+    case Intrinsic::vector_reduce_xor: {
+      IRBuilder<> Builder(Inst);
+
+      FastMathFlags FMF = getFastMathFlags(Inst);
+      Value *Start = nullptr;
+      Value *Op = nullptr;
+      switch (Inst->getIntrinsicID()) {
+      case Intrinsic::vector_reduce_fadd:
+      case Intrinsic::vector_reduce_fmul: {
+        Start = Inst->getOperand(0);
+        Op = Inst->getOperand(1);
+      } break;
+      case Intrinsic::vector_reduce_fmax:
+      case Intrinsic::vector_reduce_fmaximum:
+      case Intrinsic::vector_reduce_fmin:
+      case Intrinsic::vector_reduce_fminimum:
+      case Intrinsic::vector_reduce_add:
+      case Intrinsic::vector_reduce_and:
+      case Intrinsic::vector_reduce_mul:
+      case Intrinsic::vector_reduce_or:
+      case Intrinsic::vector_reduce_xor:
+        Op = Inst->getOperand(0);
+        break;
+      default:
+        llvm_unreachable("unexpected reduction intrinsic");
+      }
+
+      auto *I = Inst2ColumnMatrix.find(Op);
+      assert(I != Inst2ColumnMatrix.end());
+      const MatrixTy &M = I->second;
+
+      auto CreateVReduce = [&](Value *LHS, Value *RHS) {
+        switch (Inst->getIntrinsicID()) {
+        case Intrinsic::vector_reduce_add:
+          return Builder.CreateAdd(LHS, RHS);
+        case Intrinsic::vector_reduce_and:
+          return Builder.CreateAnd(LHS, RHS);
+        case Intrinsic::vector_reduce_fadd:
+          return Builder.CreateFAdd(LHS, RHS);
+        case Intrinsic::vector_reduce_fmax:
+          return Builder.CreateMaximum(LHS, RHS);
+        case Intrinsic::vector_reduce_fmaximum:
+          return Builder.CreateMaximumNum(LHS, RHS);
+        case Intrinsic::vector_reduce_fmin:
+          return Builder.CreateMinimum(LHS, RHS);
+        case Intrinsic::vector_reduce_fminimum:
+          return Builder.CreateMinimumNum(LHS, RHS);
+        case Intrinsic::vector_reduce_fmul:
+          return Builder.CreateFMul(LHS, RHS);
+        case Intrinsic::vector_reduce_mul:
+          return Builder.CreateMul(LHS, RHS);
+        case Intrinsic::vector_reduce_or:
+          return Builder.CreateOr(LHS, RHS);
+        case Intrinsic::vector_reduce_xor:
+          return Builder.CreateXor(LHS, RHS);
+        default:
+          llvm_unreachable("unexpected reduction intrinsic");
+        }
+      };
+
+      Value *ResultV;
+      if (Inst->getIntrinsicID() == Intrinsic::vector_reduce_fadd ||
+          Inst->getIntrinsicID() == Intrinsic::vector_reduce_fmul) {
+        ResultV = Builder.CreateVectorSplat(
+            ElementCount::getFixed(M.getStride()), Start);
+        for (auto &Vector : M.vectors()) {
+          ResultV = CreateVReduce(ResultV, Vector);
+          if (isa<FPMathOperator>(ResultV))
+            if (auto *ResultVI = dyn_cast<Instruction>(ResultV))
+              ResultVI->setFastMathFlags(FMF);
+        }
+      } else {
+        ResultV = M.getVector(0);
+        for (auto &Vector : drop_begin(M.vectors())) {
+          ResultV = CreateVReduce(ResultV, Vector);
+          if (isa<FPMathOperator>(ResultV))
+            if (auto *ResultVI = dyn_cast<Instruction>(ResultV))
+              ResultVI->setFastMathFlags(FMF);
+        }
+      }
+
+      auto CreateHReduce = [&](Value *V) {
+        switch (Inst->getIntrinsicID()) {
+        case Intrinsic::vector_reduce_add:
+          return Builder.CreateAddReduce(V);
+        case Intrinsic::vector_reduce_and:
+          return Builder.CreateAndReduce(V);
+        case Intrinsic::vector_reduce_fadd:
+          return Builder.CreateFAddReduce(Start, V);
+        case Intrinsic::vector_reduce_fmax:
+          return Builder.CreateFPMaxReduce(V);
+        case Intrinsic::vector_reduce_fmaximum:
+          return Builder.CreateFPMaximumReduce(V);
+        case Intrinsic::vector_reduce_fmin:
+          return Builder.CreateFPMinReduce(V);
+        case Intrinsic::vector_reduce_fminimum:
+          return Builder.CreateFPMinimumReduce(V);
+        case Intrinsic::vector_reduce_fmul:
+          return Builder.CreateFMulReduce(Start, V);
+        case Intrinsic::vector_reduce_mul:
+          return Builder.CreateMulReduce(V);
+        case Intrinsic::vector_reduce_or:
+          return Builder.CreateOrReduce(V);
+        case Intrinsic::vector_reduce_xor:
+          return Builder.CreateXorReduce(V);
+        default:
+          llvm_unreachable("unexpected reduction intrinsic");
+        }
+      };
+
+      Value *Result = CreateHReduce(ResultV);
+      if (isa<FPMathOperator>(Result))
+        if (auto *ResultI = dyn_cast<Instruction>(Result))
+          ResultI->setFastMathFlags(FMF);
+      Inst->replaceAllUsesWith(Result);
+      Result->takeName(Inst);
+      finalizeLowering(Inst, {Result}, Builder);
+    } break;
     default:
       llvm_unreachable(
           "only intrinsics supporting shape info should be seen here");
@@ -1416,139 +1587,6 @@ class LowerMatrixIntrinsics {
     return Builder.CreateAdd(Sum, Mul);
   }
 
-  bool tryLowerIntrinsic(IntrinsicInst *Inst) {
-    FastMathFlags FMF = getFastMathFlags(Inst);
-    Value *Start = nullptr;
-    Value *Op = nullptr;
-    switch (Inst->getIntrinsicID()) {
-    case Intrinsic::vector_reduce_fadd:
-    case Intrinsic::vector_reduce_fmul: {
-      if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
-        return false;
-
-      if (match(Inst, m_Intrinsic<Intrinsic::vector_reduce_fadd>(
-                          m_Unless(m_AnyZeroFP()), m_Value())))
-        return false;
-
-      if (match(Inst, m_Intrinsic<Intrinsic::vector_reduce_fmul>(
-                          m_Unless(m_FPOne()), m_Value())))
-        return false;
-
-      Start = Inst->getOperand(0);
-      Op = Inst->getOperand(1);
-    } break;
-    case Intrinsic::vector_reduce_fmax:
-    case Intrinsic::vector_reduce_fmaximum:
-    case Intrinsic::vector_reduce_fmin:
-    case Intrinsic::vector_reduce_fminimum:
-    case Intrinsic::vector_reduce_add:
-    case Intrinsic::vector_reduce_and:
-    case Intrinsic::vector_reduce_mul:
-    case Intrinsic::vector_reduce_or:
-    case Intrinsic::vector_reduce_xor:
-      Op = Inst->getOperand(0);
-      break;
-    default:
-      return false;
-    }
-
-    auto *I = Inst2ColumnMatrix.find(Op);
-    if (I == Inst2ColumnMatrix.end())
-      return false;
-
-    IRBuilder<> Builder(Inst);
-
-    const MatrixTy &M = I->second;
-
-    auto CreateVReduce = [&](Value *LHS, Value *RHS) {
-      switch (Inst->getIntrinsicID()) {
-      case Intrinsic::vector_reduce_add:
-        return Builder.CreateAdd(LHS, RHS);
-      case Intrinsic::vector_reduce_and:
-        return Builder.CreateAnd(LHS, RHS);
-      case Intrinsic::vector_reduce_fadd:
-        return Builder.CreateFAdd(LHS, RHS);
-      case Intrinsic::vector_reduce_fmax:
-        return Builder.CreateMaximum(LHS, RHS);
-      case Intrinsic::vector_reduce_fmaximum:
-        return Builder.CreateMaximumNum(LHS, RHS);
-      case Intrinsic::vector_reduce_fmin:
-        return Builder.CreateMinimum(LHS, RHS);
-      case Intrinsic::vector_reduce_fminimum:
-        return Builder.CreateMinimumNum(LHS, RHS);
-      case Intrinsic::vector_reduce_fmul:
-        return Builder.CreateFMul(LHS, RHS);
-      case Intrinsic::vector_reduce_mul:
-        return Builder.CreateMul(LHS, RHS);
-      case Intrinsic::vector_reduce_or:
-        return Builder.CreateOr(LHS, RHS);
-      case Intrinsic::vector_reduce_xor:
-        return Builder.CreateXor(LHS, RHS);
-      default:
-        llvm_unreachable("unexpected intrinsic");
-      }
-    };
-
-    Value *ResultV;
-    if (Inst->getIntrinsicID() == Intrinsic::vector_reduce_fadd ||
-        Inst->getIntrinsicID() == Intrinsic::vector_reduce_fmul) {
-      ResultV = Builder.CreateVectorSplat(ElementCount::getFixed(M.getStride()),
-                                          Start);
-      for (auto &Vector : M.vectors()) {
-        ResultV = CreateVReduce(ResultV, Vector);
-        if (isa<FPMathOperator>(ResultV))
-          if (auto *ResultVI = dyn_cast<Instruction>(ResultV))
-            ResultVI->setFastMathFlags(FMF);
-      }
-    } else {
-      ResultV = M.getVector(0);
-      for (auto &Vector : drop_begin(M.vectors())) {
-        ResultV = CreateVReduce(ResultV, Vector);
-        if (isa<FPMathOperator>(ResultV))
-          if (auto *ResultVI = dyn_cast<Instruction>(ResultV))
-            ResultVI->setFastMathFlags(FMF);
-      }
-    }
-
-    auto CreateHReduce = [&](Value *V) {
-      switch (Inst->getIntrinsicID()) {
-      case Intrinsic::vector_reduce_add:
-        return Builder.CreateAddReduce(V);
-      case Intrinsic::vector_reduce_and:
-        return Builder.CreateAndReduce(V);
-      case Intrinsic::vector_reduce_fadd:
-        return Builder.CreateFAddReduce(Start, V);
-      case Intrinsic::vector_reduce_fmax:
-        return Builder.CreateFPMaxReduce(V);
-      case Intrinsic::vector_reduce_fmaximum:
-        return Builder.CreateFPMaximumReduce(V);
-      case Intrinsic::vector_reduce_fmin:
-        return Builder.CreateFPMinReduce(V);
-      case Intrinsic::vector_reduce_fminimum:
-        return Builder.CreateFPMinimumReduce(V);
-      case Intrinsic::vector_reduce_fmul:
-        return Builder.CreateFMulReduce(Start, V);
-      case Intrinsic::vector_reduce_mul:
-        return Builder.CreateMulReduce(V);
-      case Intrinsic::vector_reduce_or:
-        return Builder.CreateOrReduce(V);
-      case Intrinsic::vector_reduce_xor:
-        return Builder.CreateXorReduce(V);
-      default:
-        llvm_unreachable("unexpected intrinsic");
-      }
-    };
-
-    Value *Result = CreateHReduce(ResultV);
-    if (isa<FPMathOperator>(Result))
-      if (auto *ResultI = dyn_cast<Instruction>(Result))
-        ResultI->setFastMathFlags(FMF);
-    Inst->replaceAllUsesWith(Result);
-    Result->takeName(Inst);
-    Inst->eraseFromParent();
-    return true;
-  }
-
   /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
   /// users with shape information, there's nothing to do: they will use the
   /// cached value when they are lowered. For other users, \p Matrix is
@@ -1566,10 +1604,6 @@ class LowerMatrixIntrinsics {
       if (ShapeMap.contains(U.getUser()))
         continue;
 
-      if (auto *Intr = dyn_cast<IntrinsicInst>(U.getUser()))
-        if (tryLowerIntrinsic(Intr))
-          continue;
-
       if (!Flattened) {
         Flattened = Matrix.embedInVector(Builder);
         LLVM_DEBUG(



More information about the llvm-commits mailing list