[llvm] 8345d62 - [Matrix] Hoist finalizeLowering into caller. NFC (#143038)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 10 12:36:41 PDT 2025
Author: Jon Roelofs
Date: 2025-06-10T12:36:37-07:00
New Revision: 8345d62478054d4ab97c6f28cfea6d1ecca837da
URL: https://github.com/llvm/llvm-project/commit/8345d62478054d4ab97c6f28cfea6d1ecca837da
DIFF: https://github.com/llvm/llvm-project/commit/8345d62478054d4ab97c6f28cfea6d1ecca837da.diff
LOG: [Matrix] Hoist finalizeLowering into caller. NFC (#143038)
Added:
Modified:
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 21683089a4693..eb81d2ea49673 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1134,26 +1134,28 @@ class LowerMatrixIntrinsics {
if (FusedInsts.count(Inst))
continue;
- IRBuilder<> Builder(Inst);
-
const ShapeInfo &SI = ShapeMap.at(Inst);
Value *Op1;
Value *Op2;
+ MatrixTy Result;
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
- VisitBinaryOperator(BinOp, SI);
+ Result = VisitBinaryOperator(BinOp, SI);
else if (auto *Cast = dyn_cast<CastInst>(Inst))
- VisitCastInstruction(Cast, SI);
+ Result = VisitCastInstruction(Cast, SI);
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
- VisitUnaryOperator(UnOp, SI);
- else if (IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Inst))
- VisitIntrinsicInst(Intr, SI);
+ Result = VisitUnaryOperator(UnOp, SI);
+ else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst))
+ Result = VisitIntrinsicInst(Intr, SI);
else if (match(Inst, m_Load(m_Value(Op1))))
- VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
+ Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
- VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
+ Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2);
else
continue;
+
+ IRBuilder<> Builder(Inst);
+ finalizeLowering(Inst, Result, Builder);
Changed = true;
}
@@ -1193,25 +1195,24 @@ class LowerMatrixIntrinsics {
}
/// Replace intrinsic calls.
- void VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &Shape) {
- switch (Inst->getIntrinsicID()) {
+ MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI) {
+ assert(Inst->getCalledFunction() &&
+ Inst->getCalledFunction()->isIntrinsic());
+
+ switch (Inst->getCalledFunction()->getIntrinsicID()) {
case Intrinsic::matrix_multiply:
- LowerMultiply(Inst);
- return;
+ return LowerMultiply(Inst);
case Intrinsic::matrix_transpose:
- LowerTranspose(Inst);
- return;
+ return LowerTranspose(Inst);
case Intrinsic::matrix_column_major_load:
- LowerColumnMajorLoad(Inst);
- return;
+ return LowerColumnMajorLoad(Inst);
case Intrinsic::matrix_column_major_store:
- LowerColumnMajorStore(Inst);
- return;
+ return LowerColumnMajorStore(Inst);
case Intrinsic::abs:
case Intrinsic::fabs: {
IRBuilder<> Builder(Inst);
MatrixTy Result;
- MatrixTy M = getMatrix(Inst->getOperand(0), Shape, Builder);
+ MatrixTy M = getMatrix(Inst->getOperand(0), SI, Builder);
Builder.setFastMathFlags(getFastMathFlags(Inst));
for (auto &Vector : M.vectors()) {
@@ -1229,16 +1230,14 @@ class LowerMatrixIntrinsics {
}
}
- finalizeLowering(Inst,
- Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
- Result.getNumVectors()),
- Builder);
- return;
+ return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+ Result.getNumVectors());
}
default:
- llvm_unreachable(
- "only intrinsics supporting shape info should be seen here");
+ break;
}
+ llvm_unreachable(
+ "only intrinsics supporting shape info should be seen here");
}
/// Compute the alignment for a column/row \p Idx with \p Stride between them.
@@ -1304,26 +1303,24 @@ class LowerMatrixIntrinsics {
}
/// Lower a load instruction with shape information.
- void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
- bool IsVolatile, ShapeInfo Shape) {
+ MatrixTy LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align,
+ Value *Stride, bool IsVolatile, ShapeInfo Shape) {
IRBuilder<> Builder(Inst);
- finalizeLowering(Inst,
- loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
- Shape, Builder),
- Builder);
+ return loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, Shape,
+ Builder);
}
/// Lowers llvm.matrix.column.major.load.
///
/// The intrinsic loads a matrix from memory using a stride between columns.
- void LowerColumnMajorLoad(CallInst *Inst) {
+ MatrixTy LowerColumnMajorLoad(CallInst *Inst) {
assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
"Intrinsic only supports column-major layout!");
Value *Ptr = Inst->getArgOperand(0);
Value *Stride = Inst->getArgOperand(1);
- LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
- cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
- {Inst->getArgOperand(3), Inst->getArgOperand(4)});
+ return LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
+ cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
+ {Inst->getArgOperand(3), Inst->getArgOperand(4)});
}
/// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
@@ -1366,28 +1363,27 @@ class LowerMatrixIntrinsics {
}
/// Lower a store instruction with shape information.
- void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
- Value *Stride, bool IsVolatile, ShapeInfo Shape) {
+ MatrixTy LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr,
+ MaybeAlign A, Value *Stride, bool IsVolatile,
+ ShapeInfo Shape) {
IRBuilder<> Builder(Inst);
auto StoreVal = getMatrix(Matrix, Shape, Builder);
- finalizeLowering(Inst,
- storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
- IsVolatile, Builder),
- Builder);
+ return storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, IsVolatile,
+ Builder);
}
/// Lowers llvm.matrix.column.major.store.
///
/// The intrinsic store a matrix back memory using a stride between columns.
- void LowerColumnMajorStore(CallInst *Inst) {
+ MatrixTy LowerColumnMajorStore(CallInst *Inst) {
assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
"Intrinsic only supports column-major layout!");
Value *Matrix = Inst->getArgOperand(0);
Value *Ptr = Inst->getArgOperand(1);
Value *Stride = Inst->getArgOperand(2);
- LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
- cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
- {Inst->getArgOperand(4), Inst->getArgOperand(5)});
+ return LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
+ cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
+ {Inst->getArgOperand(4), Inst->getArgOperand(5)});
}
// Set elements I..I+NumElts-1 to Block
@@ -2162,7 +2158,7 @@ class LowerMatrixIntrinsics {
}
/// Lowers llvm.matrix.multiply.
- void LowerMultiply(CallInst *MatMul) {
+ MatrixTy LowerMultiply(CallInst *MatMul) {
IRBuilder<> Builder(MatMul);
auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
@@ -2184,11 +2180,11 @@ class LowerMatrixIntrinsics {
emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
getFastMathFlags(MatMul));
- finalizeLowering(MatMul, Result, Builder);
+ return Result;
}
/// Lowers llvm.matrix.transpose.
- void LowerTranspose(CallInst *Inst) {
+ MatrixTy LowerTranspose(CallInst *Inst) {
MatrixTy Result;
IRBuilder<> Builder(Inst);
Value *InputVal = Inst->getArgOperand(0);
@@ -2218,28 +2214,26 @@ class LowerMatrixIntrinsics {
// TODO: Improve estimate of operations needed for transposes. Currently we
// just count the insertelement/extractelement instructions, but do not
// account for later simplifications/combines.
- finalizeLowering(
- Inst,
- Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
- .addNumExposedTransposes(1),
- Builder);
+ return Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
+ .addNumExposedTransposes(1);
}
/// Lower load instructions.
- void VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
- IRBuilder<> &Builder) {
- LowerLoad(Inst, Ptr, Inst->getAlign(), Builder.getInt64(SI.getStride()),
- Inst->isVolatile(), SI);
+ MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr) {
+ IRBuilder<> Builder(Inst);
+ return LowerLoad(Inst, Ptr, Inst->getAlign(),
+ Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
}
- void VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
- Value *Ptr, IRBuilder<> &Builder) {
- LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
- Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
+ MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
+ Value *Ptr) {
+ IRBuilder<> Builder(Inst);
+ return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
+ Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
}
/// Lower binary operators.
- void VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
+ MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
Value *Lhs = Inst->getOperand(0);
Value *Rhs = Inst->getOperand(1);
@@ -2258,14 +2252,12 @@ class LowerMatrixIntrinsics {
Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I),
B.getVector(I)));
- finalizeLowering(Inst,
- Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
- Result.getNumVectors()),
- Builder);
+ return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+ Result.getNumVectors());
}
/// Lower unary operators.
- void VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
+ MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
Value *Op = Inst->getOperand(0);
IRBuilder<> Builder(Inst);
@@ -2288,14 +2280,12 @@ class LowerMatrixIntrinsics {
for (unsigned I = 0; I < SI.getNumVectors(); ++I)
Result.addVector(BuildVectorOp(M.getVector(I)));
- finalizeLowering(Inst,
- Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
- Result.getNumVectors()),
- Builder);
+ return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+ Result.getNumVectors());
}
/// Lower cast instructions.
- void VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape) {
+ MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape) {
Value *Op = Inst->getOperand(0);
IRBuilder<> Builder(Inst);
@@ -2312,10 +2302,8 @@ class LowerMatrixIntrinsics {
for (auto &Vector : M.vectors())
Result.addVector(Builder.CreateCast(Inst->getOpcode(), Vector, NewVTy));
- finalizeLowering(Inst,
- Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
- Result.getNumVectors()),
- Builder);
+ return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+ Result.getNumVectors());
}
/// Helper to linearize a matrix expression tree into a string. Currently
More information about the llvm-commits
mailing list