[llvm] [Matrix] Use FixedVectorType everywhere in LowerMatrixIntrinsics. NFC (PR #142316)
Jon Roelofs via llvm-commits
llvm-commits at lists.llvm.org
Sun Jun 1 10:54:44 PDT 2025
https://github.com/jroelofs created https://github.com/llvm/llvm-project/pull/142316
These matrix ops do not support scalable vectors, so we should be really explicit about that and avoid casting mistakes.
>From fe2bf1cff2611cfb200605c303f34a3b6ef10720 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Sun, 1 Jun 2025 10:51:20 -0700
Subject: [PATCH] [Matrix] Use FixedVectorType everywhere in the
LowerMatrixIntrinsics pass. NFC
These matrix ops do not support scalable vectors, so we should be really
explicit about that and avoid casting mistakes.
---
.../Scalar/LowerMatrixIntrinsics.cpp | 46 +++++++++----------
1 file changed, 22 insertions(+), 24 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..787e107464c0a 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -383,25 +383,25 @@ class LowerMatrixIntrinsics {
return Vectors.size();
else {
assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
- return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
+ return getVectorTy()->getNumElements();
}
}
unsigned getNumRows() const {
if (isColumnMajor()) {
assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
- return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
+ return getVectorTy()->getNumElements();
} else
return Vectors.size();
}
void addVector(Value *V) { Vectors.push_back(V); }
- VectorType *getColumnTy() {
+ FixedVectorType *getColumnTy() {
assert(isColumnMajor() && "only supported for column-major matrixes");
return getVectorTy();
}
- VectorType *getVectorTy() const {
- return cast<VectorType>(Vectors[0]->getType());
+ FixedVectorType *getVectorTy() const {
+ return cast<FixedVectorType>(Vectors[0]->getType());
}
iterator_range<SmallVector<Value *, 8>::iterator> columns() {
@@ -514,7 +514,7 @@ class LowerMatrixIntrinsics {
: Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {}
unsigned getNumOps(Type *VT) {
- assert(isa<VectorType>(VT) && "Expected vector type");
+ assert(isa<FixedVectorType>(VT) && "Expected vector type");
return getNumOps(VT->getScalarType(),
cast<FixedVectorType>(VT)->getNumElements());
}
@@ -540,10 +540,8 @@ class LowerMatrixIntrinsics {
/// into vectors.
MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
IRBuilder<> &Builder) {
- VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
- assert(VType && "MatrixVal must be a vector type");
- assert(cast<FixedVectorType>(VType)->getNumElements() ==
- SI.NumRows * SI.NumColumns &&
+ FixedVectorType *VType = cast<FixedVectorType>(MatrixVal->getType());
+ assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
"The vector size must match the number of matrix elements");
// Check if we lowered MatrixVal using shape information. In that case,
@@ -563,8 +561,7 @@ class LowerMatrixIntrinsics {
// Otherwise split MatrixVal.
SmallVector<Value *, 16> SplitVecs;
- for (unsigned MaskStart = 0;
- MaskStart < cast<FixedVectorType>(VType)->getNumElements();
+ for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
MaskStart += SI.getStride()) {
Value *V = Builder.CreateShuffleVector(
MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
@@ -1157,7 +1154,7 @@ class LowerMatrixIntrinsics {
/// vectors.
MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
- auto *VType = cast<VectorType>(Ty);
+ auto *VType = cast<FixedVectorType>(Ty);
Type *EltTy = VType->getElementType();
Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
Value *EltPtr = Ptr;
@@ -1239,7 +1236,7 @@ class LowerMatrixIntrinsics {
MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
MaybeAlign MAlign, Value *Stride, bool IsVolatile,
IRBuilder<> &Builder) {
- auto VType = cast<VectorType>(Ty);
+ auto *VType = cast<FixedVectorType>(Ty);
Value *EltPtr = Ptr;
for (auto Vec : enumerate(StoreVal.vectors())) {
Value *GEP = computeVectorAddr(
@@ -1377,7 +1374,7 @@ class LowerMatrixIntrinsics {
Value *LHS = MatMul->getArgOperand(0);
Value *RHS = MatMul->getArgOperand(1);
- Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
+ Type *ElementType = cast<FixedVectorType>(LHS->getType())->getElementType();
bool IsIntVec = ElementType->isIntegerTy();
// Floating point reductions require reassocation.
@@ -1475,7 +1472,7 @@ class LowerMatrixIntrinsics {
int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
InstructionCost ReductionCost =
TTI.getArithmeticReductionCost(
- AddOpCode, cast<VectorType>(LHS->getType()),
+ AddOpCode, cast<FixedVectorType>(LHS->getType()),
IsIntVec ? std::nullopt : std::optional(FMF)) +
TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
InstructionCost SequentialAddCost =
@@ -1535,8 +1532,8 @@ class LowerMatrixIntrinsics {
Result = Builder.CreateAddReduce(Mul);
else {
Result = Builder.CreateFAddReduce(
- ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
- 0.0),
+ ConstantFP::get(
+ cast<FixedVectorType>(LHS->getType())->getElementType(), 0.0),
Mul);
cast<Instruction>(Result)->setFastMathFlags(FMF);
}
@@ -1735,7 +1732,7 @@ class LowerMatrixIntrinsics {
const unsigned R = LShape.NumRows;
const unsigned C = RShape.NumColumns;
const unsigned M = LShape.NumColumns;
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+ auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
const unsigned VF = std::max<unsigned>(
TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
@@ -1771,7 +1768,7 @@ class LowerMatrixIntrinsics {
void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+ auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
// Create the main tiling loop nest.
TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
@@ -1842,7 +1839,7 @@ class LowerMatrixIntrinsics {
const unsigned R = LShape.NumRows;
const unsigned C = RShape.NumColumns;
const unsigned M = LShape.NumColumns;
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+ auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
@@ -1914,7 +1911,8 @@ class LowerMatrixIntrinsics {
? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
: match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
IRBuilder<> Builder(MatMul);
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+ auto *EltType =
+ cast<FixedVectorType>(MatMul->getType())->getElementType();
ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
const unsigned R = LShape.NumRows;
@@ -2045,7 +2043,7 @@ class LowerMatrixIntrinsics {
/// Lowers llvm.matrix.multiply.
void LowerMultiply(CallInst *MatMul) {
IRBuilder<> Builder(MatMul);
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+ auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
@@ -2073,7 +2071,7 @@ class LowerMatrixIntrinsics {
MatrixTy Result;
IRBuilder<> Builder(Inst);
Value *InputVal = Inst->getArgOperand(0);
- VectorType *VectorTy = cast<VectorType>(InputVal->getType());
+ FixedVectorType *VectorTy = cast<FixedVectorType>(InputVal->getType());
ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
More information about the llvm-commits
mailing list