[llvm] 7b2ac8f - [Matrix] Pass ShapeInfo to Visit* methods (NFC). (#142487)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 5 11:22:20 PDT 2025
Author: Jon Roelofs
Date: 2025-06-05T11:22:17-07:00
New Revision: 7b2ac8ff54fbc194fd639be3f4073733c1e3d05c
URL: https://github.com/llvm/llvm-project/commit/7b2ac8ff54fbc194fd639be3f4073733c1e3d05c
DIFF: https://github.com/llvm/llvm-project/commit/7b2ac8ff54fbc194fd639be3f4073733c1e3d05c.diff
LOG: [Matrix] Pass ShapeInfo to Visit* methods (NFC). (#142487)
They all require it now.
Added:
Modified:
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 38f92561a917d..20279bf69dd59 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1056,19 +1056,20 @@ class LowerMatrixIntrinsics {
IRBuilder<> Builder(Inst);
- if (CallInst *CInst = dyn_cast<CallInst>(Inst))
- Changed |= VisitCallInst(CInst);
+ const ShapeInfo &SI = ShapeMap.at(Inst);
Value *Op1;
Value *Op2;
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
- VisitBinaryOperator(BinOp);
+ VisitBinaryOperator(BinOp, SI);
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
- VisitUnaryOperator(UnOp);
+ VisitUnaryOperator(UnOp, SI);
+ else if (CallInst *CInst = dyn_cast<CallInst>(Inst))
+ VisitCallInst(CInst);
else if (match(Inst, m_Load(m_Value(Op1))))
- VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
+ VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
- VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
+ VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
else
continue;
Changed = true;
@@ -1109,10 +1110,10 @@ class LowerMatrixIntrinsics {
return Changed;
}
- /// Replace intrinsic calls
- bool VisitCallInst(CallInst *Inst) {
- if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
- return false;
+ /// Replace intrinsic calls.
+ void VisitCallInst(CallInst *Inst) {
+ assert(Inst->getCalledFunction() &&
+ Inst->getCalledFunction()->isIntrinsic());
switch (Inst->getCalledFunction()->getIntrinsicID()) {
case Intrinsic::matrix_multiply:
@@ -1128,9 +1129,9 @@ class LowerMatrixIntrinsics {
LowerColumnMajorStore(Inst);
break;
default:
- return false;
+ llvm_unreachable(
+ "only intrinsics supporting shape info should be seen here");
}
- return true;
}
/// Compute the alignment for a column/row \p Idx with \p Stride between them.
@@ -2107,48 +2108,36 @@ class LowerMatrixIntrinsics {
Builder);
}
- /// Lower load instructions, if shape information is available.
- void VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
- auto I = ShapeMap.find(Inst);
- assert(I != ShapeMap.end() &&
- "must only visit instructions with shape info");
- LowerLoad(Inst, Ptr, Inst->getAlign(),
- Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
- I->second);
+ /// Lower load instructions.
+ void VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
+ IRBuilder<> &Builder) {
+ LowerLoad(Inst, Ptr, Inst->getAlign(), Builder.getInt64(SI.getStride()),
+ Inst->isVolatile(), SI);
}
- void VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
- IRBuilder<> &Builder) {
- auto I = ShapeMap.find(Inst);
- assert(I != ShapeMap.end() &&
- "must only visit instructions with shape info");
+ void VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
+ Value *Ptr, IRBuilder<> &Builder) {
LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
- Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
- I->second);
+ Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
}
- /// Lower binary operators, if shape information is available.
- void VisitBinaryOperator(BinaryOperator *Inst) {
- auto I = ShapeMap.find(Inst);
- assert(I != ShapeMap.end() &&
- "must only visit instructions with shape info");
-
+ /// Lower binary operators.
+ void VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
Value *Lhs = Inst->getOperand(0);
Value *Rhs = Inst->getOperand(1);
IRBuilder<> Builder(Inst);
- ShapeInfo &Shape = I->second;
MatrixTy Result;
- MatrixTy A = getMatrix(Lhs, Shape, Builder);
- MatrixTy B = getMatrix(Rhs, Shape, Builder);
+ MatrixTy A = getMatrix(Lhs, SI, Builder);
+ MatrixTy B = getMatrix(Rhs, SI, Builder);
assert(A.isColumnMajor() == B.isColumnMajor() &&
Result.isColumnMajor() == A.isColumnMajor() &&
"operands must agree on matrix layout");
Builder.setFastMathFlags(getFastMathFlags(Inst));
- for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
+ for (unsigned I = 0; I < SI.getNumVectors(); ++I)
Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I),
B.getVector(I)));
@@ -2158,19 +2147,14 @@ class LowerMatrixIntrinsics {
Builder);
}
- /// Lower unary operators, if shape information is available.
- void VisitUnaryOperator(UnaryOperator *Inst) {
- auto I = ShapeMap.find(Inst);
- assert(I != ShapeMap.end() &&
- "must only visit instructions with shape info");
-
+ /// Lower unary operators.
+ void VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
Value *Op = Inst->getOperand(0);
IRBuilder<> Builder(Inst);
- ShapeInfo &Shape = I->second;
MatrixTy Result;
- MatrixTy M = getMatrix(Op, Shape, Builder);
+ MatrixTy M = getMatrix(Op, SI, Builder);
Builder.setFastMathFlags(getFastMathFlags(Inst));
@@ -2184,7 +2168,7 @@ class LowerMatrixIntrinsics {
}
};
- for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
+ for (unsigned I = 0; I < SI.getNumVectors(); ++I)
Result.addVector(BuildVectorOp(M.getVector(I)));
finalizeLowering(Inst,
More information about the llvm-commits
mailing list