[llvm] [Matrix] Propagate shape information through cast insts (PR #141869)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Thu May 29 07:45:20 PDT 2025
================
@@ -2198,6 +2223,57 @@ class LowerMatrixIntrinsics {
return true;
}
+ /// Lower cast instructions, if shape information is available.
+ bool VisitCastInstruction(CastInst *Inst) {
+ switch (Inst->getOpcode()) {
+ case llvm::Instruction::Trunc:
+ case llvm::Instruction::ZExt:
+ case llvm::Instruction::SExt:
+ case llvm::Instruction::FPToUI:
+ case llvm::Instruction::FPToSI:
+ case llvm::Instruction::UIToFP:
+ case llvm::Instruction::SIToFP:
+ case llvm::Instruction::FPTrunc:
+ case llvm::Instruction::FPExt:
+ break;
+ case llvm::Instruction::AddrSpaceCast:
+ case CastInst::PtrToInt:
+ case CastInst::IntToPtr:
+ case CastInst::BitCast:
+ return false;
+ case llvm::Instruction::CastOpsEnd:
+ llvm_unreachable("not an actual cast op");
+ }
+
+ auto I = ShapeMap.find(Inst);
+ if (I == ShapeMap.end())
+ return false;
+
+ Value *Op = Inst->getOperand(0);
+
+ IRBuilder<> Builder(Inst);
+ ShapeInfo &Shape = I->second;
+
+ MatrixTy Result;
+ MatrixTy M = getMatrix(Op, Shape, Builder);
+
+ Builder.setFastMathFlags(getFastMathFlags(Inst));
+
+ for (unsigned I = 0; I < Shape.getNumVectors(); ++I) {
+ auto *OrigTy = cast<VectorType>(Inst->getType());
+ auto *NewTy = VectorType::get(OrigTy->getElementType(),
+ ElementCount::getFixed(M.getStride()));
----------------
fhahn wrote:
Those look loop invariant, can we move them out?
https://github.com/llvm/llvm-project/pull/141869
More information about the llvm-commits
mailing list