[llvm] bfd3883 - [Matrix] Refactor transpose distribution. NFC
Francis Visoiu Mistrih via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 28 17:30:16 PDT 2022
Author: Francis Visoiu Mistrih
Date: 2022-07-28T17:30:00-07:00
New Revision: bfd3883e83dd61c68bd920d933f8ca679b788ad9
URL: https://github.com/llvm/llvm-project/commit/bfd3883e83dd61c68bd920d933f8ca679b788ad9
DIFF: https://github.com/llvm/llvm-project/commit/bfd3883e83dd61c68bd920d933f8ca679b788ad9.diff
LOG: [Matrix] Refactor transpose distribution. NFC
Use a function to distribute transposes. Preparation for future patches.
Added:
Modified:
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index f1e1359255bdd..c4281b6060fe1 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -384,6 +384,9 @@ class LowerMatrixIntrinsics {
return NumColumns;
return NumRows;
}
+
+ /// Returns the transposed shape.
+ ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
};
/// Maps instructions to their shape information. The shape information
@@ -684,6 +687,25 @@ class LowerMatrixIntrinsics {
return NewWorkList;
}
+ /// (Op0 op Op1)^T -> Op0^T op Op1^T
+ /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
+ /// them on both sides of \p Operation.
+ Instruction *distributeTransposes(
+ Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
+ MatrixBuilder &Builder,
+ function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
+ Operation) {
+ Value *T0 = Builder.CreateMatrixTranspose(
+ Op0, Shape0.NumRows, Shape0.NumColumns, Op0->getName() + "_t");
+ // We are being run after shape prop, add shape for newly created
+ // instructions so that we lower them later.
+ setShapeInfo(T0, Shape0.t());
+ Value *T1 = Builder.CreateMatrixTranspose(
+ Op1, Shape1.NumRows, Shape1.NumColumns, Op1->getName() + "_t");
+ setShapeInfo(T1, Shape1.t());
+ return Operation(T0, Shape0.t(), T1, Shape1.t());
+ }
+
/// Try moving transposes in order to fold them away or into multiplies.
void optimizeTransposes() {
auto ReplaceAllUsesWith = [this](Instruction &Old, Value *New) {
@@ -741,19 +763,13 @@ class LowerMatrixIntrinsics {
else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
m_ConstantInt(K), m_ConstantInt(C)))) {
- Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(),
- C->getZExtValue(),
- TAMB->getName() + "_t");
- // We are being run after shape prop, add shape for newly created
- // instructions so that we lower them later.
- setShapeInfo(T0, {C, K});
- Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(),
- K->getZExtValue(),
- TAMA->getName() + "_t");
- setShapeInfo(T1, {K, R});
- NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(),
- K->getZExtValue(),
- R->getZExtValue(), "mmul");
+ NewInst = distributeTransposes(
+ TAMB, {K, C}, TAMA, {R, K}, Builder,
+ [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
+ return Builder.CreateMatrixMultiply(
+ T0, T1, Shape0.NumRows, Shape0.NumColumns,
+ Shape1.NumColumns, "mmul");
+ });
ReplaceAllUsesWith(I, NewInst);
EraseFromParent(&I);
EraseFromParent(TA);
More information about the llvm-commits
mailing list