[llvm] [Matrix] Propagate shape information through cast insts (PR #141869)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Thu May 29 07:45:20 PDT 2025


================
@@ -2198,6 +2223,57 @@ class LowerMatrixIntrinsics {
     return true;
   }
 
+  /// Lower cast instructions, if shape information is available.
+  bool VisitCastInstruction(CastInst *Inst) {
+    switch (Inst->getOpcode()) {
+    case llvm::Instruction::Trunc:
+    case llvm::Instruction::ZExt:
+    case llvm::Instruction::SExt:
+    case llvm::Instruction::FPToUI:
+    case llvm::Instruction::FPToSI:
+    case llvm::Instruction::UIToFP:
+    case llvm::Instruction::SIToFP:
+    case llvm::Instruction::FPTrunc:
+    case llvm::Instruction::FPExt:
+      break;
+    case llvm::Instruction::AddrSpaceCast:
+    case CastInst::PtrToInt:
+    case CastInst::IntToPtr:
+    case CastInst::BitCast:
+      return false;
+    case llvm::Instruction::CastOpsEnd:
+      llvm_unreachable("not an actual cast op");
+    }
+
+    auto I = ShapeMap.find(Inst);
+    if (I == ShapeMap.end())
+      return false;
+
+    Value *Op = Inst->getOperand(0);
+
+    IRBuilder<> Builder(Inst);
+    ShapeInfo &Shape = I->second;
+
+    MatrixTy Result;
+    MatrixTy M = getMatrix(Op, Shape, Builder);
+
+    Builder.setFastMathFlags(getFastMathFlags(Inst));
+
+    for (unsigned I = 0; I < Shape.getNumVectors(); ++I) {
+      auto *OrigTy = cast<VectorType>(Inst->getType());
+      auto *NewTy = VectorType::get(OrigTy->getElementType(),
+                                    ElementCount::getFixed(M.getStride()));
----------------
fhahn wrote:

Those look loop invariant, can we move them out?

https://github.com/llvm/llvm-project/pull/141869


More information about the llvm-commits mailing list