[llvm] [Matrix] Lower vector reductions using shape info (PR #142055)

Jon Roelofs via llvm-commits llvm-commits at lists.llvm.org
Thu May 29 15:50:45 PDT 2025


https://github.com/jroelofs created https://github.com/llvm/llvm-project/pull/142055

When possible, this avoids a bunch of shuffles in & out of the flattened
layout.

>From 323c9a8d2de459a0f81f32f7537e8b2e087ffc00 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Wed, 28 May 2025 09:30:50 -0700
Subject: [PATCH 1/4] [Matrix] Optimize static extracts with ShapeInfo

For ExtractElementInsts with static indices that extract from a Matrix, use the
known layout of the Rows/Columns, avoiding some of the shuffles that
embedInVector creates.
---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 43 ++++++++++++++---
 .../LowerMatrixIntrinsics/extract.ll          | 47 +++++++++++++++++++
 2 files changed, 84 insertions(+), 6 deletions(-)
 create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..8b322afd9b6e4 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -34,6 +34,7 @@
 #include "llvm/IR/DebugInfoMetadata.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/MatrixBuilder.h"
@@ -623,7 +624,8 @@ class LowerMatrixIntrinsics {
       default:
         return false;
       }
-    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
+    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
+           isa<ExtractElementInst>(V);
   }
 
   /// Propagate the shape information of instructions to their users.
@@ -1337,6 +1339,28 @@ class LowerMatrixIntrinsics {
     return Builder.CreateAdd(Sum, Mul);
   }
 
+  bool VisitExtractElt(ExtractElementInst *Inst, uint64_t Index) {
+    Value *Op0 = Inst->getOperand(0);
+    auto *VTy = cast<VectorType>(Op0->getType());
+
+    if (VTy->getElementCount().getKnownMinValue() < Index) {
+      Inst->replaceAllUsesWith(PoisonValue::get(VTy->getElementType()));
+      Inst->eraseFromParent();
+      return true;
+    }
+
+    auto *I = Inst2ColumnMatrix.find(Op0);
+    if (I == Inst2ColumnMatrix.end())
+      return false;
+
+    const MatrixTy &M = I->second;
+
+    IRBuilder<> Builder(Inst);
+    Inst->setOperand(0, M.getVector(Index / M.getStride()));
+    Inst->setOperand(1, Builder.getInt32(Index % M.getStride()));
+    return true;
+  }
+
   /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
   /// users with shape information, there's nothing to do: they will use the
   /// cached value when they are lowered. For other users, \p Matrix is
@@ -1351,11 +1375,18 @@ class LowerMatrixIntrinsics {
     ToRemove.push_back(Inst);
     Value *Flattened = nullptr;
     for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
-      if (!ShapeMap.contains(U.getUser())) {
-        if (!Flattened)
-          Flattened = Matrix.embedInVector(Builder);
-        U.set(Flattened);
-      }
+      if (ShapeMap.contains(U.getUser()))
+        continue;
+
+      Value *Op1;
+      uint64_t Index;
+      if (match(U.getUser(), m_ExtractElt(m_Value(Op1), m_ConstantInt(Index))))
+        if (VisitExtractElt(cast<ExtractElementInst>(U.getUser()), Index))
+          continue;
+
+      if (!Flattened)
+        Flattened = Matrix.embedInVector(Builder);
+      U.set(Flattened);
     }
   }
 
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
new file mode 100644
index 0000000000000..0bac9492d654a
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
@@ -0,0 +1,47 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define float @extract_static(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_static(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[EXTRACT:%.*]] = extractelement <2 x float> [[COL_LOAD1]], i32 1
+; CHECK-NEXT:    ret float [[EXTRACT]]
+;
+  %inv = load <4 x float>, ptr %in
+  %invt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
+  %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
+  %extract = extractelement <4 x float> %invtt, i32 3
+  ret float %extract
+}
+
+define float @extract_static_outofbounds(ptr %in, ptr %out) {
+; CHECK-LABEL: @extract_static_outofbounds(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    ret float poison
+;
+  %inv = load <4 x float>, ptr %in
+  %invt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
+  %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
+  %extract = extractelement <4 x float> %invtt, i32 5
+  ret float %extract
+}
+
+define float @extract_dynamic(ptr %in, i32 %idx, ptr %out) {
+; CHECK-LABEL: @extract_dynamic(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 [[IDX:%.*]]
+; CHECK-NEXT:    ret float [[EXTRACT]]
+;
+  %inv = load <4 x float>, ptr %in
+  %invt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
+  %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
+  %extract = extractelement <4 x float> %invtt, i32 %idx
+  ret float %extract
+}

>From c962ce9a79028d9514951d8ec07a818114aeea92 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 29 May 2025 11:07:13 -0700
Subject: [PATCH 2/4] use colum major load intrinsic for shape info

---
 .../LowerMatrixIntrinsics/extract.ll          | 30 ++++++++-----------
 1 file changed, 12 insertions(+), 18 deletions(-)

diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
index 0bac9492d654a..db5444ca036ae 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll
@@ -3,45 +3,39 @@
 
 define float @extract_static(ptr %in, ptr %out) {
 ; CHECK-LABEL: @extract_static(
-; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
-; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4
 ; CHECK-NEXT:    [[EXTRACT:%.*]] = extractelement <2 x float> [[COL_LOAD1]], i32 1
 ; CHECK-NEXT:    ret float [[EXTRACT]]
 ;
-  %inv = load <4 x float>, ptr %in
-  %invt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
-  %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
-  %extract = extractelement <4 x float> %invtt, i32 3
+  %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2)
+  %extract = extractelement <4 x float> %inv, i32 3
   ret float %extract
 }
 
 define float @extract_static_outofbounds(ptr %in, ptr %out) {
 ; CHECK-LABEL: @extract_static_outofbounds(
-; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
-; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4
 ; CHECK-NEXT:    ret float poison
 ;
-  %inv = load <4 x float>, ptr %in
-  %invt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
-  %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
-  %extract = extractelement <4 x float> %invtt, i32 5
+  %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2)
+  %extract = extractelement <4 x float> %inv, i32 5
   ret float %extract
 }
 
 define float @extract_dynamic(ptr %in, i32 %idx, ptr %out) {
 ; CHECK-LABEL: @extract_dynamic(
-; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2
-; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4
 ; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
 ; CHECK-NEXT:    [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 [[IDX:%.*]]
 ; CHECK-NEXT:    ret float [[EXTRACT]]
 ;
-  %inv = load <4 x float>, ptr %in
-  %invt  = call <4 x float> @llvm.matrix.transpose(<4 x float> %inv, i32 2, i32 2)
-  %invtt = call <4 x float> @llvm.matrix.transpose(<4 x float> %invt, i32 2, i32 2)
-  %extract = extractelement <4 x float> %invtt, i32 %idx
+  %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2)
+  %extract = extractelement <4 x float> %inv, i32 %idx
   ret float %extract
 }

>From 169ed5650b3583d0d4046aafb35958593f267803 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 29 May 2025 15:48:58 -0700
Subject: [PATCH 3/4] extractelemtn shouldn't supportShapeInfo

---
 llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 8b322afd9b6e4..341b4cc5d75c4 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -624,8 +624,7 @@ class LowerMatrixIntrinsics {
       default:
         return false;
       }
-    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
-           isa<ExtractElementInst>(V);
+    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
   }
 
   /// Propagate the shape information of instructions to their users.

>From e1368d5e67f82ac0e98032d3b2568d8917686f07 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 29 May 2025 15:46:45 -0700
Subject: [PATCH 4/4] [Matrix] Lower vector reductions using shape info

When possible, this avoids a bunch of shuffles in & out of the flattened
layout.
---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 151 +++++++++++++
 .../LowerMatrixIntrinsics/reduce.ll           | 200 ++++++++++++++++++
 2 files changed, 351 insertions(+)
 create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 341b4cc5d75c4..4f997f2133527 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -30,6 +30,7 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/CFG.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugInfoMetadata.h"
 #include "llvm/IR/Function.h"
@@ -41,6 +42,7 @@
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/Support/Alignment.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/LoopUtils.h"
@@ -1102,6 +1104,7 @@ class LowerMatrixIntrinsics {
     if (!PoisonedInsts.empty()) {
       // If we didn't remove all poisoned instructions, it's a hard error.
       dbgs() << "Poisoned but present instructions:\n";
+      Func.dump();
       for (auto *I : PoisonedInsts)
         dbgs() << *I << "\n";
       llvm_unreachable("Poisoned but instruction not removed");
@@ -1360,6 +1363,133 @@ class LowerMatrixIntrinsics {
     return true;
   }
 
+  bool VisitReduce(IntrinsicInst *Inst) {
+    FastMathFlags FMF = getFastMathFlags(Inst);
+
+    if (Inst->getType()->isFloatingPointTy() && !FMF.allowReassoc())
+      return false;
+
+    Value *Start = nullptr;
+    Value *Op = nullptr;
+    switch (Inst->getIntrinsicID()) {
+    case Intrinsic::vector_reduce_fadd:
+    case Intrinsic::vector_reduce_fmul:
+      Start = Inst->getOperand(0);
+      Op = Inst->getOperand(1);
+      break;
+    case Intrinsic::vector_reduce_fmax:
+    case Intrinsic::vector_reduce_fmaximum:
+    case Intrinsic::vector_reduce_fmin:
+    case Intrinsic::vector_reduce_fminimum:
+    case Intrinsic::vector_reduce_add:
+    case Intrinsic::vector_reduce_and:
+    case Intrinsic::vector_reduce_mul:
+    case Intrinsic::vector_reduce_or:
+    case Intrinsic::vector_reduce_xor:
+      Op = Inst->getOperand(0);
+      break;
+    default:
+      llvm_unreachable("unexpected intrinsic");
+    }
+
+    switch (Inst->getIntrinsicID()) {
+    case Intrinsic::vector_reduce_fadd: {
+      if (!match(Start, m_AnyZeroFP()))
+        return false;
+    } break;
+    case Intrinsic::vector_reduce_fmul: {
+      if (!match(Start, m_FPOne()))
+        return false;
+    } break;
+    default:
+      break;
+    }
+
+    auto *I = Inst2ColumnMatrix.find(Op);
+    if (I == Inst2ColumnMatrix.end())
+      return false;
+
+    IRBuilder<> Builder(Inst);
+
+    const MatrixTy &M = I->second;
+
+    auto CreateVReduce = [&](Value *LHS, Value *RHS) {
+      switch (Inst->getIntrinsicID()) {
+      case Intrinsic::vector_reduce_add:
+        return Builder.CreateAdd(LHS, RHS);
+      case Intrinsic::vector_reduce_and:
+        return Builder.CreateAnd(LHS, RHS);
+      case Intrinsic::vector_reduce_fadd:
+        return Builder.CreateFAdd(LHS, RHS);
+      case Intrinsic::vector_reduce_fmax:
+        return Builder.CreateMaximum(LHS, RHS); // FIXME: is this correct re: nans?
+      case Intrinsic::vector_reduce_fmaximum:
+        return Builder.CreateMaximumNum(LHS, RHS); // FIXME: is this correct re: nans?
+      case Intrinsic::vector_reduce_fmin:
+        return Builder.CreateMinimum(LHS, RHS); // FIXME: is this correct re: nans?
+      case Intrinsic::vector_reduce_fminimum:
+        return Builder.CreateMinimumNum(LHS, RHS); // FIXME: is this correct re: nans?
+      case Intrinsic::vector_reduce_fmul:
+        return Builder.CreateFMul(LHS, RHS);
+      case Intrinsic::vector_reduce_mul:
+        return Builder.CreateMul(LHS, RHS);
+      case Intrinsic::vector_reduce_or:
+        return Builder.CreateOr(LHS, RHS);
+      case Intrinsic::vector_reduce_xor:
+        return Builder.CreateXor(LHS, RHS);
+      default:
+        llvm_unreachable("unexpected intrinsic");
+      }
+    };
+
+    Value *ResultV;
+    if (Inst->getIntrinsicID() == Intrinsic::vector_reduce_fadd ||
+        Inst->getIntrinsicID() == Intrinsic::vector_reduce_fmul) {
+      ResultV = Builder.CreateVectorSplat(ElementCount::getFixed(M.getStride()), Start);
+      for (unsigned VI = 0, VE = M.getNumVectors(); VI != VE; VI++)
+        ResultV = CreateVReduce(ResultV, M.getVector(VI));
+    } else {
+      ResultV = M.getVector(0);
+      for (unsigned VI = 1, VE = M.getNumVectors(); VI != VE; VI++)
+        ResultV = CreateVReduce(ResultV, M.getVector(VI));
+    }
+
+    auto CreateHReduce = [&](Value *V) {
+      switch (Inst->getIntrinsicID()) {
+      case Intrinsic::vector_reduce_add:
+        return Builder.CreateAddReduce(V);
+      case Intrinsic::vector_reduce_and:
+        return Builder.CreateAndReduce(V);
+      case Intrinsic::vector_reduce_fadd:
+        return Builder.CreateFAddReduce(Start, V);
+      case Intrinsic::vector_reduce_fmax:
+        return Builder.CreateFPMaxReduce(V);
+      case Intrinsic::vector_reduce_fmaximum:
+        return Builder.CreateFPMaximumReduce(V);
+      case Intrinsic::vector_reduce_fmin:
+        return Builder.CreateFPMinReduce(V);
+      case Intrinsic::vector_reduce_fminimum:
+        return Builder.CreateFPMinimumReduce(V);
+      case Intrinsic::vector_reduce_fmul:
+        return Builder.CreateFMulReduce(Start, V);
+      case Intrinsic::vector_reduce_mul:
+        return Builder.CreateMulReduce(V);
+      case Intrinsic::vector_reduce_or:
+        return Builder.CreateOrReduce(V);
+      case Intrinsic::vector_reduce_xor:
+        return Builder.CreateXorReduce(V);
+      default:
+        llvm_unreachable("unexpected intrinsic");
+      }
+    };
+
+    Value *Result = CreateHReduce(ResultV);
+    Inst->replaceAllUsesWith(Result);
+    Result->takeName(Inst);
+    Inst->eraseFromParent();
+    return true;
+  }
+
   /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
   /// users with shape information, there's nothing to do: they will use the
   /// cached value when they are lowered. For other users, \p Matrix is
@@ -1383,6 +1513,27 @@ class LowerMatrixIntrinsics {
         if (VisitExtractElt(cast<ExtractElementInst>(U.getUser()), Index))
           continue;
 
+      if (auto *Intr = dyn_cast<IntrinsicInst>(U.getUser())) {
+        switch (Intr->getIntrinsicID()) {
+        case Intrinsic::vector_reduce_add:
+        case Intrinsic::vector_reduce_and:
+        case Intrinsic::vector_reduce_fadd:
+        case Intrinsic::vector_reduce_fmax:
+        case Intrinsic::vector_reduce_fmaximum:
+        case Intrinsic::vector_reduce_fmin:
+        case Intrinsic::vector_reduce_fminimum:
+        case Intrinsic::vector_reduce_fmul:
+        case Intrinsic::vector_reduce_mul:
+        case Intrinsic::vector_reduce_or:
+        case Intrinsic::vector_reduce_xor:
+          if (VisitReduce(Intr))
+            continue;
+          break;
+        default:
+          break;
+        }
+      }
+
       if (!Flattened)
         Flattened = Matrix.embedInVector(Builder);
       U.set(Flattened);
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
new file mode 100644
index 0000000000000..41f65e01fec79
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/reduce.ll
@@ -0,0 +1,200 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define i32 @reduce_add(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_add(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = add <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call i32 @llvm.vector.reduce.add(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define i32 @reduce_and(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_and(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = and <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call i32 @llvm.vector.reduce.and(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define i32 @reduce_or(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_or(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = or <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call i32 @llvm.vector.reduce.or(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define i32 @reduce_mul(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_mul(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = mul <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.mul.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call i32 @llvm.vector.reduce.mul(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define i32 @reduce_xor(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_xor(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x i32>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = xor <4 x i32> [[COL_LOAD]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call i32 @llvm.vector.reduce.xor.v4i32(<4 x i32> [[TMP1]])
+; CHECK-NEXT:    ret i32 [[REDUCE]]
+;
+  %inv = call <8 x i32> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call i32 @llvm.vector.reduce.xor(<8 x i32> %inv)
+  ret i32 %reduce
+}
+
+define float @reduce_fadd(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call float @llvm.vector.reduce.fadd(float 0., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fadd_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = fadd <4 x float> zeroinitializer, [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fadd <4 x float> [[TMP1]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[TMP2]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fadd(float 0., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fadd_weirdstart(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fadd_weirdstart(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[REDUCE:%.*]] = call reassoc float @llvm.vector.reduce.fadd.v8f32(float 1.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fadd(float 1., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fmul_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmul_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul <4 x float> splat (float 1.000000e+00), [[COL_LOAD]]
+; CHECK-NEXT:    [[TMP2:%.*]] = fmul <4 x float> [[TMP1]], [[COL_LOAD1]]
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmul.v4f32(float 1.000000e+00, <4 x float> [[TMP2]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fmul(float 1., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fmul_weirdstart(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmul_weirdstart(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[REDUCE:%.*]] = call reassoc float @llvm.vector.reduce.fmul.v8f32(float 0.000000e+00, <8 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fmul(float 0., <8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fmax_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmax_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @llvm.maximum.v4f32(<4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]])
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmax.v4f32(<4 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fmax(<8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fmaximum_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmaximum_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @llvm.maximumnum.v4f32(<4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]])
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmaximum.v4f32(<4 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fmaximum(<8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fmin_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fmin_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @llvm.minimum.v4f32(<4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]])
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fmin.v4f32(<4 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fmin(<8 x float> %inv)
+  ret float %reduce
+}
+
+define float @reduce_fminimum_reassoc(ptr %in, ptr %out) {
+; CHECK-LABEL: @reduce_fminimum_reassoc(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load volatile <4 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load volatile <4 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @llvm.minimumnum.v4f32(<4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD1]])
+; CHECK-NEXT:    [[REDUCE:%.*]] = call float @llvm.vector.reduce.fminimum.v4f32(<4 x float> [[TMP1]])
+; CHECK-NEXT:    ret float [[REDUCE]]
+;
+  %inv = call <8 x float> @llvm.matrix.column.major.load(ptr %in, i64 4, i1 1, i32 4, i32 2)
+  %reduce = call reassoc float @llvm.vector.reduce.fminimum(<8 x float> %inv)
+  ret float %reduce
+}



More information about the llvm-commits mailing list