[llvm] [Matrix] Hoist finalizeLowering into caller. NFC (PR #143038)

Jon Roelofs via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 10 10:04:56 PDT 2025


https://github.com/jroelofs updated https://github.com/llvm/llvm-project/pull/143038

>From e96975eacb353007502ac87b5d14079df2467d36 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 5 Jun 2025 14:43:35 -0700
Subject: [PATCH] hoist finalizeLowering into caller

---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 120 ++++++++----------
 1 file changed, 55 insertions(+), 65 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 20279bf69dd59..c2a89f3c97aa1 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1054,24 +1054,26 @@ class LowerMatrixIntrinsics {
       if (FusedInsts.count(Inst))
         continue;
 
-      IRBuilder<> Builder(Inst);
-
       const ShapeInfo &SI = ShapeMap.at(Inst);
 
       Value *Op1;
       Value *Op2;
+      MatrixTy Result;
       if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
-        VisitBinaryOperator(BinOp, SI);
+        Result = VisitBinaryOperator(BinOp, SI);
       else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
-        VisitUnaryOperator(UnOp, SI);
-      else if (CallInst *CInst = dyn_cast<CallInst>(Inst))
-        VisitCallInst(CInst);
+        Result = VisitUnaryOperator(UnOp, SI);
+      else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst))
+        Result = VisitIntrinsicInst(Intr, SI);
       else if (match(Inst, m_Load(m_Value(Op1))))
-        VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
+        Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1);
       else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
-        VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
+        Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2);
       else
         continue;
+
+      IRBuilder<> Builder(Inst);
+      finalizeLowering(Inst, Result, Builder);
       Changed = true;
     }
 
@@ -1111,27 +1113,24 @@ class LowerMatrixIntrinsics {
   }
 
   /// Replace intrinsic calls.
-  void VisitCallInst(CallInst *Inst) {
+  MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI) {
     assert(Inst->getCalledFunction() &&
            Inst->getCalledFunction()->isIntrinsic());
 
     switch (Inst->getCalledFunction()->getIntrinsicID()) {
     case Intrinsic::matrix_multiply:
-      LowerMultiply(Inst);
-      break;
+      return LowerMultiply(Inst);
     case Intrinsic::matrix_transpose:
-      LowerTranspose(Inst);
-      break;
+      return LowerTranspose(Inst);
     case Intrinsic::matrix_column_major_load:
-      LowerColumnMajorLoad(Inst);
-      break;
+      return LowerColumnMajorLoad(Inst);
     case Intrinsic::matrix_column_major_store:
-      LowerColumnMajorStore(Inst);
-      break;
+      return LowerColumnMajorStore(Inst);
     default:
-      llvm_unreachable(
-          "only intrinsics supporting shape info should be seen here");
+      break;
     }
+    llvm_unreachable(
+        "only intrinsics supporting shape info should be seen here");
   }
 
   /// Compute the alignment for a column/row \p Idx with \p Stride between them.
@@ -1197,26 +1196,24 @@ class LowerMatrixIntrinsics {
   }
 
   /// Lower a load instruction with shape information.
-  void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
-                 bool IsVolatile, ShapeInfo Shape) {
+  MatrixTy LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align,
+                     Value *Stride, bool IsVolatile, ShapeInfo Shape) {
     IRBuilder<> Builder(Inst);
-    finalizeLowering(Inst,
-                     loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
-                                Shape, Builder),
-                     Builder);
+    return loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, Shape,
+                      Builder);
   }
 
   /// Lowers llvm.matrix.column.major.load.
   ///
   /// The intrinsic loads a matrix from memory using a stride between columns.
-  void LowerColumnMajorLoad(CallInst *Inst) {
+  MatrixTy LowerColumnMajorLoad(CallInst *Inst) {
     assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
            "Intrinsic only supports column-major layout!");
     Value *Ptr = Inst->getArgOperand(0);
     Value *Stride = Inst->getArgOperand(1);
-    LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
-              cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
-              {Inst->getArgOperand(3), Inst->getArgOperand(4)});
+    return LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
+                     cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
+                     {Inst->getArgOperand(3), Inst->getArgOperand(4)});
   }
 
   /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
@@ -1259,28 +1256,27 @@ class LowerMatrixIntrinsics {
   }
 
   /// Lower a store instruction with shape information.
-  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
-                  Value *Stride, bool IsVolatile, ShapeInfo Shape) {
+  MatrixTy LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr,
+                      MaybeAlign A, Value *Stride, bool IsVolatile,
+                      ShapeInfo Shape) {
     IRBuilder<> Builder(Inst);
     auto StoreVal = getMatrix(Matrix, Shape, Builder);
-    finalizeLowering(Inst,
-                     storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
-                                 IsVolatile, Builder),
-                     Builder);
+    return storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, IsVolatile,
+                       Builder);
   }
 
   /// Lowers llvm.matrix.column.major.store.
   ///
   /// The intrinsic store a matrix back memory using a stride between columns.
-  void LowerColumnMajorStore(CallInst *Inst) {
+  MatrixTy LowerColumnMajorStore(CallInst *Inst) {
     assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
            "Intrinsic only supports column-major layout!");
     Value *Matrix = Inst->getArgOperand(0);
     Value *Ptr = Inst->getArgOperand(1);
     Value *Stride = Inst->getArgOperand(2);
-    LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
-               cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
-               {Inst->getArgOperand(4), Inst->getArgOperand(5)});
+    return LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
+                      cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
+                      {Inst->getArgOperand(4), Inst->getArgOperand(5)});
   }
 
   // Set elements I..I+NumElts-1 to Block
@@ -2045,7 +2041,7 @@ class LowerMatrixIntrinsics {
   }
 
   /// Lowers llvm.matrix.multiply.
-  void LowerMultiply(CallInst *MatMul) {
+  MatrixTy LowerMultiply(CallInst *MatMul) {
     IRBuilder<> Builder(MatMul);
     auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
     ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
@@ -2067,11 +2063,11 @@ class LowerMatrixIntrinsics {
 
     emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
                        getFastMathFlags(MatMul));
-    finalizeLowering(MatMul, Result, Builder);
+    return Result;
   }
 
   /// Lowers llvm.matrix.transpose.
-  void LowerTranspose(CallInst *Inst) {
+  MatrixTy LowerTranspose(CallInst *Inst) {
     MatrixTy Result;
     IRBuilder<> Builder(Inst);
     Value *InputVal = Inst->getArgOperand(0);
@@ -2101,28 +2097,26 @@ class LowerMatrixIntrinsics {
     // TODO: Improve estimate of operations needed for transposes. Currently we
     // just count the insertelement/extractelement instructions, but do not
     // account for later simplifications/combines.
-    finalizeLowering(
-        Inst,
-        Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
-            .addNumExposedTransposes(1),
-        Builder);
+    return Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
+        .addNumExposedTransposes(1);
   }
 
   /// Lower load instructions.
-  void VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
-                 IRBuilder<> &Builder) {
-    LowerLoad(Inst, Ptr, Inst->getAlign(), Builder.getInt64(SI.getStride()),
-              Inst->isVolatile(), SI);
+  MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr) {
+    IRBuilder<> Builder(Inst);
+    return LowerLoad(Inst, Ptr, Inst->getAlign(),
+                     Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
   }
 
-  void VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
-                  Value *Ptr, IRBuilder<> &Builder) {
-    LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
-               Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
+  MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
+                      Value *Ptr) {
+    IRBuilder<> Builder(Inst);
+    return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
+                      Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
   }
 
   /// Lower binary operators.
-  void VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
+  MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
     Value *Lhs = Inst->getOperand(0);
     Value *Rhs = Inst->getOperand(1);
 
@@ -2141,14 +2135,12 @@ class LowerMatrixIntrinsics {
       Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I),
                                            B.getVector(I)));
 
-    finalizeLowering(Inst,
-                     Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
-                                             Result.getNumVectors()),
-                     Builder);
+    return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+                                   Result.getNumVectors());
   }
 
   /// Lower unary operators.
-  void VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
+  MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
     Value *Op = Inst->getOperand(0);
 
     IRBuilder<> Builder(Inst);
@@ -2171,10 +2163,8 @@ class LowerMatrixIntrinsics {
     for (unsigned I = 0; I < SI.getNumVectors(); ++I)
       Result.addVector(BuildVectorOp(M.getVector(I)));
 
-    finalizeLowering(Inst,
-                     Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
-                                             Result.getNumVectors()),
-                     Builder);
+    return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+                                   Result.getNumVectors());
   }
 
   /// Helper to linearize a matrix expression tree into a string. Currently



More information about the llvm-commits mailing list