[llvm] ccf2422 - [Matrix] Update shape propagation to iterate until done.
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 9 02:53:55 PST 2020
Author: Florian Hahn
Date: 2020-01-09T10:52:52Z
New Revision: ccf24225e3f2356ebf0e73bb114a831bf1721222
URL: https://github.com/llvm/llvm-project/commit/ccf24225e3f2356ebf0e73bb114a831bf1721222
DIFF: https://github.com/llvm/llvm-project/commit/ccf24225e3f2356ebf0e73bb114a831bf1721222.diff
LOG: [Matrix] Update shape propagation to iterate until done.
This patch updates the shape propagation to iterate until no new shape
information is discovered.
As initial seed for the forward propagation, we use the matrix intrinsic
instructions. Both propagateShapeForward and propagateShapeBackward
return new work lists, with the instructions to be used for the next
iteration. When propagating forward, we record all instructions we added
new shape information for. When propagating backward, we record all
users of instructions we added new shape information for.
Reviewers: anemet, Gerolf, reames, hfinkel, andrew.w.kaylor
Reviewed By: anemet
Differential Revision: https://reviews.llvm.org/D70901
Added:
llvm/test/Transforms/LowerMatrixIntrinsics/propagate-multiple-iterations.ll
Modified:
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index afe1b4e7cc78..0ff6ee8bcfcc 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -10,9 +10,6 @@
//
// TODO:
// * Implement multiply & add fusion
-// * Implement shape propagation
-// * Implement optimizations to reduce or eliminateshufflevector uses by using
-// shape information.
// * Add remark, summarizing the available matrix optimization opportunities.
//
//===----------------------------------------------------------------------===//
@@ -321,32 +318,12 @@ class LowerMatrixIntrinsics {
}
/// Propagate the shape information of instructions to their users.
- void propagateShapeForward() {
- // The work list contains instructions for which we can compute the shape,
- // either based on the information provided by matrix intrinsics or known
- // shapes of operands.
- SmallVector<Instruction *, 8> WorkList;
-
- // Initialize the work list with ops carrying shape information. Initially
- // only the shape of matrix intrinsics is known.
- for (BasicBlock &BB : Func)
- for (Instruction &Inst : BB) {
- IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
- if (!II)
- continue;
-
- switch (II->getIntrinsicID()) {
- case Intrinsic::matrix_multiply:
- case Intrinsic::matrix_transpose:
- case Intrinsic::matrix_columnwise_load:
- case Intrinsic::matrix_columnwise_store:
- WorkList.push_back(&Inst);
- break;
- default:
- break;
- }
- }
-
+ /// The work list contains instructions for which we can compute the shape,
+ /// either based on the information provided by matrix intrinsics or known
+ /// shapes of operands.
+ SmallVector<Instruction *, 32>
+ propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
+ SmallVector<Instruction *, 32> NewWorkList;
// Pop an element for which we guaranteed to have at least one of the
// operand shapes. Add the shape for this and then add users to the work
// list.
@@ -395,20 +372,29 @@ class LowerMatrixIntrinsics {
}
}
- if (Propagate)
+ if (Propagate) {
+ NewWorkList.push_back(Inst);
for (auto *User : Inst->users())
if (ShapeMap.count(User) == 0)
WorkList.push_back(cast<Instruction>(User));
+ }
}
+
+ return NewWorkList;
}
/// Propagate the shape to operands of instructions with shape information.
- void propagateShapeBackward() {
- SmallVector<Value *, 8> WorkList;
- // Worklist contains instruction for which we already know the shape.
- for (auto &V : ShapeMap)
- WorkList.push_back(V.first);
-
+ /// \p Worklist contains the instruction for which we already know the shape.
+ SmallVector<Instruction *, 32>
+ propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
+ SmallVector<Instruction *, 32> NewWorkList;
+
+ auto pushInstruction = [](Value *V,
+ SmallVectorImpl<Instruction *> &WorkList) {
+ Instruction *I = dyn_cast<Instruction>(V);
+ if (I)
+ WorkList.push_back(I);
+ };
// Pop an element with known shape. Traverse the operands, if their shape
// derives from the result shape and is unknown, add it and add them to the
// worklist.
@@ -417,6 +403,7 @@ class LowerMatrixIntrinsics {
Value *V = WorkList.back();
WorkList.pop_back();
+ size_t BeforeProcessingV = WorkList.size();
if (!isa<Instruction>(V))
continue;
@@ -429,21 +416,21 @@ class LowerMatrixIntrinsics {
m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
m_Value(N), m_Value(K)))) {
if (setShapeInfo(MatrixA, {M, N}))
- WorkList.push_back(MatrixA);
+ pushInstruction(MatrixA, WorkList);
if (setShapeInfo(MatrixB, {N, K}))
- WorkList.push_back(MatrixB);
+ pushInstruction(MatrixB, WorkList);
} else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
m_Value(MatrixA), m_Value(M), m_Value(N)))) {
// Flip dimensions.
if (setShapeInfo(MatrixA, {M, N}))
- WorkList.push_back(MatrixA);
+ pushInstruction(MatrixA, WorkList);
} else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
m_Value(MatrixA), m_Value(), m_Value(),
m_Value(M), m_Value(N)))) {
if (setShapeInfo(MatrixA, {M, N})) {
- WorkList.push_back(MatrixA);
+ pushInstruction(MatrixA, WorkList);
}
} else if (isa<LoadInst>(V) ||
match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) {
@@ -456,16 +443,48 @@ class LowerMatrixIntrinsics {
ShapeInfo Shape = ShapeMap[V];
for (Use &U : cast<Instruction>(V)->operands()) {
if (setShapeInfo(U.get(), Shape))
- WorkList.push_back(U.get());
+ pushInstruction(U.get(), WorkList);
}
}
+ // After we discovered new shape info for new instructions in the
+ // worklist, we use their users as seeds for the next round of forward
+ // propagation.
+ for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
+ for (User *U : WorkList[I]->users())
+ if (isa<Instruction>(U) && V != U)
+ NewWorkList.push_back(cast<Instruction>(U));
}
+ return NewWorkList;
}
bool Visit() {
if (EnableShapePropagation) {
- propagateShapeForward();
- propagateShapeBackward();
+ SmallVector<Instruction *, 32> WorkList;
+
+ // Initially only the shape of matrix intrinsics is known.
+ // Initialize the work list with ops carrying shape information.
+ for (BasicBlock &BB : Func)
+ for (Instruction &Inst : BB) {
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
+ if (!II)
+ continue;
+
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::matrix_multiply:
+ case Intrinsic::matrix_transpose:
+ case Intrinsic::matrix_columnwise_load:
+ case Intrinsic::matrix_columnwise_store:
+ WorkList.push_back(&Inst);
+ break;
+ default:
+ break;
+ }
+ }
+ // Propagate shapes until nothing changes any longer.
+ while (!WorkList.empty()) {
+ WorkList = propagateShapeForward(WorkList);
+ WorkList = propagateShapeBackward(WorkList);
+ }
}
ReversePostOrderTraversal<Function *> RPOT(&Func);
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-multiple-iterations.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-multiple-iterations.ll
new file mode 100644
index 000000000000..38200b3883dc
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-multiple-iterations.ll
@@ -0,0 +1,84 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+
+; Make sure we propagate in multiple iterations. First, we back-propagate the
+; shape information from the transpose to %A, in the next iteration we
+; forward-propagate it to %Mul, and then back to %B.
+define <16 x double> @backpropagation_iterations(<16 x double>* %A.Ptr, <16 x double>* %B.Ptr) {
+; CHECK-LABEL: @backpropagation_iterations(
+; CHECK-NEXT: [[TMP1:%.*]] = bitcast <16 x double>* [[A_PTR:%.*]] to double*
+; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[TMP1]] to <4 x double>*
+; CHECK-NEXT: [[TMP3:%.*]] = load <4 x double>, <4 x double>* [[TMP2]], align 8
+; CHECK-NEXT: [[TMP5:%.*]] = getelementptr double, double* [[TMP1]], i32 4
+; CHECK-NEXT: [[TMP6:%.*]] = bitcast double* [[TMP5]] to <4 x double>*
+; CHECK-NEXT: [[TMP7:%.*]] = load <4 x double>, <4 x double>* [[TMP6]], align 8
+; CHECK-NEXT: [[TMP9:%.*]] = getelementptr double, double* [[TMP1]], i32 8
+; CHECK-NEXT: [[TMP10:%.*]] = bitcast double* [[TMP9]] to <4 x double>*
+; CHECK-NEXT: [[TMP11:%.*]] = load <4 x double>, <4 x double>* [[TMP10]], align 8
+; CHECK-NEXT: [[TMP13:%.*]] = getelementptr double, double* [[TMP1]], i32 12
+; CHECK-NEXT: [[TMP14:%.*]] = bitcast double* [[TMP13]] to <4 x double>*
+; CHECK-NEXT: [[TMP15:%.*]] = load <4 x double>, <4 x double>* [[TMP14]], align 8
+; CHECK-NEXT: [[TMP16:%.*]] = extractelement <4 x double> [[TMP3]], i64 0
+; CHECK-NEXT: [[TMP17:%.*]] = insertelement <4 x double> undef, double [[TMP16]], i64 0
+; CHECK-NEXT: [[TMP18:%.*]] = extractelement <4 x double> [[TMP7]], i64 0
+; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x double> [[TMP17]], double [[TMP18]], i64 1
+; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x double> [[TMP11]], i64 0
+; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x double> [[TMP19]], double [[TMP20]], i64 2
+; CHECK-NEXT: [[TMP22:%.*]] = extractelement <4 x double> [[TMP15]], i64 0
+; CHECK-NEXT: [[TMP23:%.*]] = insertelement <4 x double> [[TMP21]], double [[TMP22]], i64 3
+; CHECK-NEXT: [[TMP24:%.*]] = extractelement <4 x double> [[TMP3]], i64 1
+; CHECK-NEXT: [[TMP25:%.*]] = insertelement <4 x double> undef, double [[TMP24]], i64 0
+; CHECK-NEXT: [[TMP26:%.*]] = extractelement <4 x double> [[TMP7]], i64 1
+; CHECK-NEXT: [[TMP27:%.*]] = insertelement <4 x double> [[TMP25]], double [[TMP26]], i64 1
+; CHECK-NEXT: [[TMP28:%.*]] = extractelement <4 x double> [[TMP11]], i64 1
+; CHECK-NEXT: [[TMP29:%.*]] = insertelement <4 x double> [[TMP27]], double [[TMP28]], i64 2
+; CHECK-NEXT: [[TMP30:%.*]] = extractelement <4 x double> [[TMP15]], i64 1
+; CHECK-NEXT: [[TMP31:%.*]] = insertelement <4 x double> [[TMP29]], double [[TMP30]], i64 3
+; CHECK-NEXT: [[TMP32:%.*]] = extractelement <4 x double> [[TMP3]], i64 2
+; CHECK-NEXT: [[TMP33:%.*]] = insertelement <4 x double> undef, double [[TMP32]], i64 0
+; CHECK-NEXT: [[TMP34:%.*]] = extractelement <4 x double> [[TMP7]], i64 2
+; CHECK-NEXT: [[TMP35:%.*]] = insertelement <4 x double> [[TMP33]], double [[TMP34]], i64 1
+; CHECK-NEXT: [[TMP36:%.*]] = extractelement <4 x double> [[TMP11]], i64 2
+; CHECK-NEXT: [[TMP37:%.*]] = insertelement <4 x double> [[TMP35]], double [[TMP36]], i64 2
+; CHECK-NEXT: [[TMP38:%.*]] = extractelement <4 x double> [[TMP15]], i64 2
+; CHECK-NEXT: [[TMP39:%.*]] = insertelement <4 x double> [[TMP37]], double [[TMP38]], i64 3
+; CHECK-NEXT: [[TMP40:%.*]] = extractelement <4 x double> [[TMP3]], i64 3
+; CHECK-NEXT: [[TMP41:%.*]] = insertelement <4 x double> undef, double [[TMP40]], i64 0
+; CHECK-NEXT: [[TMP42:%.*]] = extractelement <4 x double> [[TMP7]], i64 3
+; CHECK-NEXT: [[TMP43:%.*]] = insertelement <4 x double> [[TMP41]], double [[TMP42]], i64 1
+; CHECK-NEXT: [[TMP44:%.*]] = extractelement <4 x double> [[TMP11]], i64 3
+; CHECK-NEXT: [[TMP45:%.*]] = insertelement <4 x double> [[TMP43]], double [[TMP44]], i64 2
+; CHECK-NEXT: [[TMP46:%.*]] = extractelement <4 x double> [[TMP15]], i64 3
+; CHECK-NEXT: [[TMP47:%.*]] = insertelement <4 x double> [[TMP45]], double [[TMP46]], i64 3
+; CHECK-NEXT: [[TMP48:%.*]] = bitcast <16 x double>* [[B_PTR:%.*]] to double*
+; CHECK-NEXT: [[TMP49:%.*]] = bitcast double* [[TMP48]] to <4 x double>*
+; CHECK-NEXT: [[TMP50:%.*]] = load <4 x double>, <4 x double>* [[TMP49]], align 8
+; CHECK-NEXT: [[TMP52:%.*]] = getelementptr double, double* [[TMP48]], i32 4
+; CHECK-NEXT: [[TMP53:%.*]] = bitcast double* [[TMP52]] to <4 x double>*
+; CHECK-NEXT: [[TMP54:%.*]] = load <4 x double>, <4 x double>* [[TMP53]], align 8
+; CHECK-NEXT: [[TMP56:%.*]] = getelementptr double, double* [[TMP48]], i32 8
+; CHECK-NEXT: [[TMP57:%.*]] = bitcast double* [[TMP56]] to <4 x double>*
+; CHECK-NEXT: [[TMP58:%.*]] = load <4 x double>, <4 x double>* [[TMP57]], align 8
+; CHECK-NEXT: [[TMP60:%.*]] = getelementptr double, double* [[TMP48]], i32 12
+; CHECK-NEXT: [[TMP61:%.*]] = bitcast double* [[TMP60]] to <4 x double>*
+; CHECK-NEXT: [[TMP62:%.*]] = load <4 x double>, <4 x double>* [[TMP61]], align 8
+; CHECK-NEXT: [[TMP63:%.*]] = fmul <4 x double> [[TMP3]], [[TMP50]]
+; CHECK-NEXT: [[TMP64:%.*]] = fmul <4 x double> [[TMP7]], [[TMP54]]
+; CHECK-NEXT: [[TMP65:%.*]] = fmul <4 x double> [[TMP11]], [[TMP58]]
+; CHECK-NEXT: [[TMP66:%.*]] = fmul <4 x double> [[TMP15]], [[TMP62]]
+; CHECK-NEXT: [[TMP67:%.*]] = shufflevector <4 x double> [[TMP63]], <4 x double> [[TMP64]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[TMP68:%.*]] = shufflevector <4 x double> [[TMP65]], <4 x double> [[TMP66]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[TMP69:%.*]] = shufflevector <8 x double> [[TMP67]], <8 x double> [[TMP68]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+; CHECK-NEXT: ret <16 x double> [[TMP69]]
+;
+ %A = load <16 x double>, <16 x double>* %A.Ptr
+ %A.trans = tail call <16 x double> @llvm.matrix.transpose.v16f64(<16 x double> %A, i32 4, i32 4)
+ %B = load <16 x double>, <16 x double>* %B.Ptr
+ %Mul = fmul <16 x double> %A, %B
+ ret <16 x double> %Mul
+}
+
+declare <16 x double> @llvm.matrix.multiply.v16f64.v16f64.v16f64(<16 x double>, <16 x double>, i32 immarg, i32 immarg, i32 immarg)
+declare <16 x double> @llvm.matrix.transpose.v16f64(<16 x double>, i32 immarg, i32 immarg)
More information about the llvm-commits
mailing list