[llvm] 70caa31 - [Matrix] Refactor shape info computation (NFCI).
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 12 08:03:20 PST 2024
Author: Florian Hahn
Date: 2024-02-12T16:02:59Z
New Revision: 70caa316e955a35904e34961c79d75935b7d906f
URL: https://github.com/llvm/llvm-project/commit/70caa316e955a35904e34961c79d75935b7d906f
DIFF: https://github.com/llvm/llvm-project/commit/70caa316e955a35904e34961c79d75935b7d906f.diff
LOG: [Matrix] Refactor shape info computation (NFCI).
Factor our forward shape computation for a given instruction. This
allows re-use in a follow-up fix.
Added:
Modified:
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index b528762b545659..03e289f7a087ac 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -192,6 +192,109 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
return VecStart;
}
+namespace {
+struct ShapeInfo {
+ unsigned NumRows;
+ unsigned NumColumns;
+
+ bool IsColumnMajor;
+
+ ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
+ : NumRows(NumRows), NumColumns(NumColumns),
+ IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
+
+ ShapeInfo(Value *NumRows, Value *NumColumns)
+ : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
+ cast<ConstantInt>(NumColumns)->getZExtValue()) {}
+
+ bool operator==(const ShapeInfo &other) {
+ return NumRows == other.NumRows && NumColumns == other.NumColumns;
+ }
+ bool operator!=(const ShapeInfo &other) { return !(*this == other); }
+
+ /// Returns true if shape-information is defined, meaning both dimensions
+ /// are != 0.
+ operator bool() const {
+ assert(NumRows == 0 || NumColumns != 0);
+ return NumRows != 0;
+ }
+
+ unsigned getStride() const {
+ if (IsColumnMajor)
+ return NumRows;
+ return NumColumns;
+ }
+
+ unsigned getNumVectors() const {
+ if (IsColumnMajor)
+ return NumColumns;
+ return NumRows;
+ }
+
+ /// Returns the transposed shape.
+ ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
+};
+} // namespace
+
+static bool isUniformShape(Value *V) {
+ Instruction *I = dyn_cast<Instruction>(V);
+ if (!I)
+ return true;
+
+ switch (I->getOpcode()) {
+ case Instruction::FAdd:
+ case Instruction::FSub:
+ case Instruction::FMul: // Scalar multiply.
+ case Instruction::FNeg:
+ case Instruction::Add:
+ case Instruction::Mul:
+ case Instruction::Sub:
+ return true;
+ default:
+ return false;
+ }
+}
+
+/// Return the ShapeInfo for the result of \p I, it it can be determined.
+static std::optional<ShapeInfo>
+computeShapeInfoForInst(Instruction *I,
+ const ValueMap<Value *, ShapeInfo> &ShapeMap) {
+ Value *M;
+ Value *N;
+ Value *K;
+ if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>(
+ m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K))))
+ return ShapeInfo(M, K);
+ if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M),
+ m_Value(N)))) {
+ // Flip dimensions.
+ return ShapeInfo(N, M);
+ }
+ if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
+ m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M),
+ m_Value(N))))
+ return ShapeInfo(N, M);
+ if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
+ m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N))))
+ return ShapeInfo(M, N);
+ Value *MatrixA;
+ if (match(I, m_Store(m_Value(MatrixA), m_Value()))) {
+ auto OpShape = ShapeMap.find(MatrixA);
+ if (OpShape != ShapeMap.end())
+ return OpShape->second;
+ }
+
+ if (isUniformShape(I)) {
+ // Find the first operand that has a known shape and use that.
+ for (auto &Op : I->operands()) {
+ auto OpShape = ShapeMap.find(Op.get());
+ if (OpShape != ShapeMap.end())
+ return OpShape->second;
+ }
+ }
+ return std::nullopt;
+}
+
/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
///
/// Currently, the lowering for each matrix intrinsic is done as follows:
@@ -383,48 +486,6 @@ class LowerMatrixIntrinsics {
}
};
- struct ShapeInfo {
- unsigned NumRows;
- unsigned NumColumns;
-
- bool IsColumnMajor;
-
- ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
- : NumRows(NumRows), NumColumns(NumColumns),
- IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
-
- ShapeInfo(Value *NumRows, Value *NumColumns)
- : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
- cast<ConstantInt>(NumColumns)->getZExtValue()) {}
-
- bool operator==(const ShapeInfo &other) {
- return NumRows == other.NumRows && NumColumns == other.NumColumns;
- }
- bool operator!=(const ShapeInfo &other) { return !(*this == other); }
-
- /// Returns true if shape-information is defined, meaning both dimensions
- /// are != 0.
- operator bool() const {
- assert(NumRows == 0 || NumColumns != 0);
- return NumRows != 0;
- }
-
- unsigned getStride() const {
- if (IsColumnMajor)
- return NumRows;
- return NumColumns;
- }
-
- unsigned getNumVectors() const {
- if (IsColumnMajor)
- return NumColumns;
- return NumRows;
- }
-
- /// Returns the transposed shape.
- ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
- };
-
/// Maps instructions to their shape information. The shape information
/// describes the shape to be used while lowering. This matches the shape of
/// the result value of the instruction, with the only exceptions being store
@@ -554,25 +615,6 @@ class LowerMatrixIntrinsics {
return true;
}
- bool isUniformShape(Value *V) {
- Instruction *I = dyn_cast<Instruction>(V);
- if (!I)
- return true;
-
- switch (I->getOpcode()) {
- case Instruction::FAdd:
- case Instruction::FSub:
- case Instruction::FMul: // Scalar multiply.
- case Instruction::FNeg:
- case Instruction::Add:
- case Instruction::Mul:
- case Instruction::Sub:
- return true;
- default:
- return false;
- }
- }
-
/// Returns true if shape information can be used for \p V. The supported
/// instructions must match the instructions that can be lowered by this pass.
bool supportsShapeInfo(Value *V) {
@@ -610,43 +652,8 @@ class LowerMatrixIntrinsics {
// New entry, set the value and insert operands
bool Propagate = false;
-
- Value *MatrixA;
- Value *MatrixB;
- Value *M;
- Value *N;
- Value *K;
- if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
- m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
- m_Value(N), m_Value(K)))) {
- Propagate = setShapeInfo(Inst, {M, K});
- } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
- m_Value(MatrixA), m_Value(M), m_Value(N)))) {
- // Flip dimensions.
- Propagate = setShapeInfo(Inst, {N, M});
- } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
- m_Value(MatrixA), m_Value(), m_Value(),
- m_Value(), m_Value(M), m_Value(N)))) {
- Propagate = setShapeInfo(Inst, {N, M});
- } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
- m_Value(), m_Value(), m_Value(), m_Value(M),
- m_Value(N)))) {
- Propagate = setShapeInfo(Inst, {M, N});
- } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
- auto OpShape = ShapeMap.find(MatrixA);
- if (OpShape != ShapeMap.end())
- setShapeInfo(Inst, OpShape->second);
- continue;
- } else if (isUniformShape(Inst)) {
- // Find the first operand that has a known shape and use that.
- for (auto &Op : Inst->operands()) {
- auto OpShape = ShapeMap.find(Op.get());
- if (OpShape != ShapeMap.end()) {
- Propagate |= setShapeInfo(Inst, OpShape->second);
- break;
- }
- }
- }
+ if (auto SI = computeShapeInfoForInst(Inst, ShapeMap))
+ Propagate = setShapeInfo(Inst, *SI);
if (Propagate) {
NewWorkList.push_back(Inst);
More information about the llvm-commits
mailing list