[llvm] [Matrix] Use FixedVectorType everywhere in LowerMatrixIntrinsics. NFC (PR #142316)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jun 1 10:55:17 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Jon Roelofs (jroelofs)

<details>
<summary>Changes</summary>

These matrix ops do not support scalable vectors, so we should be really explicit about that and avoid casting mistakes.

---
Full diff: https://github.com/llvm/llvm-project/pull/142316.diff


1 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (+22-24) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..787e107464c0a 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -383,25 +383,25 @@ class LowerMatrixIntrinsics {
         return Vectors.size();
       else {
         assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
-        return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
+        return getVectorTy()->getNumElements();
       }
     }
     unsigned getNumRows() const {
       if (isColumnMajor()) {
         assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
-        return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
+        return getVectorTy()->getNumElements();
       } else
         return Vectors.size();
     }
 
     void addVector(Value *V) { Vectors.push_back(V); }
-    VectorType *getColumnTy() {
+    FixedVectorType *getColumnTy() {
       assert(isColumnMajor() && "only supported for column-major matrixes");
       return getVectorTy();
     }
 
-    VectorType *getVectorTy() const {
-      return cast<VectorType>(Vectors[0]->getType());
+    FixedVectorType *getVectorTy() const {
+      return cast<FixedVectorType>(Vectors[0]->getType());
     }
 
     iterator_range<SmallVector<Value *, 8>::iterator> columns() {
@@ -514,7 +514,7 @@ class LowerMatrixIntrinsics {
       : Func(F), DL(F.getDataLayout()), TTI(TTI), AM(AM) {}
 
   unsigned getNumOps(Type *VT) {
-    assert(isa<VectorType>(VT) && "Expected vector type");
+    assert(isa<FixedVectorType>(VT) && "Expected vector type");
     return getNumOps(VT->getScalarType(),
                      cast<FixedVectorType>(VT)->getNumElements());
   }
@@ -540,10 +540,8 @@ class LowerMatrixIntrinsics {
   /// into vectors.
   MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
                      IRBuilder<> &Builder) {
-    VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
-    assert(VType && "MatrixVal must be a vector type");
-    assert(cast<FixedVectorType>(VType)->getNumElements() ==
-               SI.NumRows * SI.NumColumns &&
+    FixedVectorType *VType = cast<FixedVectorType>(MatrixVal->getType());
+    assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
            "The vector size must match the number of matrix elements");
 
     // Check if we lowered MatrixVal using shape information. In that case,
@@ -563,8 +561,7 @@ class LowerMatrixIntrinsics {
 
     // Otherwise split MatrixVal.
     SmallVector<Value *, 16> SplitVecs;
-    for (unsigned MaskStart = 0;
-         MaskStart < cast<FixedVectorType>(VType)->getNumElements();
+    for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
          MaskStart += SI.getStride()) {
       Value *V = Builder.CreateShuffleVector(
           MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
@@ -1157,7 +1154,7 @@ class LowerMatrixIntrinsics {
   /// vectors.
   MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
                       bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
-    auto *VType = cast<VectorType>(Ty);
+    auto *VType = cast<FixedVectorType>(Ty);
     Type *EltTy = VType->getElementType();
     Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
     Value *EltPtr = Ptr;
@@ -1239,7 +1236,7 @@ class LowerMatrixIntrinsics {
   MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
                        MaybeAlign MAlign, Value *Stride, bool IsVolatile,
                        IRBuilder<> &Builder) {
-    auto VType = cast<VectorType>(Ty);
+    auto *VType = cast<FixedVectorType>(Ty);
     Value *EltPtr = Ptr;
     for (auto Vec : enumerate(StoreVal.vectors())) {
       Value *GEP = computeVectorAddr(
@@ -1377,7 +1374,7 @@ class LowerMatrixIntrinsics {
     Value *LHS = MatMul->getArgOperand(0);
     Value *RHS = MatMul->getArgOperand(1);
 
-    Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
+    Type *ElementType = cast<FixedVectorType>(LHS->getType())->getElementType();
     bool IsIntVec = ElementType->isIntegerTy();
 
     // Floating point reductions require reassocation.
@@ -1475,7 +1472,7 @@ class LowerMatrixIntrinsics {
     int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
     InstructionCost ReductionCost =
         TTI.getArithmeticReductionCost(
-            AddOpCode, cast<VectorType>(LHS->getType()),
+            AddOpCode, cast<FixedVectorType>(LHS->getType()),
             IsIntVec ? std::nullopt : std::optional(FMF)) +
         TTI.getArithmeticInstrCost(MulOpCode, LHS->getType());
     InstructionCost SequentialAddCost =
@@ -1535,8 +1532,8 @@ class LowerMatrixIntrinsics {
       Result = Builder.CreateAddReduce(Mul);
     else {
       Result = Builder.CreateFAddReduce(
-          ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(),
-                          0.0),
+          ConstantFP::get(
+              cast<FixedVectorType>(LHS->getType())->getElementType(), 0.0),
           Mul);
       cast<Instruction>(Result)->setFastMathFlags(FMF);
     }
@@ -1735,7 +1732,7 @@ class LowerMatrixIntrinsics {
     const unsigned R = LShape.NumRows;
     const unsigned C = RShape.NumColumns;
     const unsigned M = LShape.NumColumns;
-    auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+    auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
 
     const unsigned VF = std::max<unsigned>(
         TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
@@ -1771,7 +1768,7 @@ class LowerMatrixIntrinsics {
 
   void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
                         Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
-    auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+    auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
 
     // Create the main tiling loop nest.
     TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
@@ -1842,7 +1839,7 @@ class LowerMatrixIntrinsics {
     const unsigned R = LShape.NumRows;
     const unsigned C = RShape.NumColumns;
     const unsigned M = LShape.NumColumns;
-    auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+    auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
 
     Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
     Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
@@ -1914,7 +1911,8 @@ class LowerMatrixIntrinsics {
             ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
             : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
       IRBuilder<> Builder(MatMul);
-      auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+      auto *EltType =
+          cast<FixedVectorType>(MatMul->getType())->getElementType();
       ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
       ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
       const unsigned R = LShape.NumRows;
@@ -2045,7 +2043,7 @@ class LowerMatrixIntrinsics {
   /// Lowers llvm.matrix.multiply.
   void LowerMultiply(CallInst *MatMul) {
     IRBuilder<> Builder(MatMul);
-    auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+    auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
     ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
     ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
 
@@ -2073,7 +2071,7 @@ class LowerMatrixIntrinsics {
     MatrixTy Result;
     IRBuilder<> Builder(Inst);
     Value *InputVal = Inst->getArgOperand(0);
-    VectorType *VectorTy = cast<VectorType>(InputVal->getType());
+    FixedVectorType *VectorTy = cast<FixedVectorType>(InputVal->getType());
     ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
     MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/142316


More information about the llvm-commits mailing list