[llvm-branch-commits] [llvm] 0cc38ac - [Matrix] Propagate shape information through fneg
Francis Visoiu Mistrih via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jan 22 14:39:46 PST 2021
Author: Francis Visoiu Mistrih
Date: 2021-01-22T14:34:28-08:00
New Revision: 0cc38acfc4e1dcdc2a9b6287bc93eef57acfd105
URL: https://github.com/llvm/llvm-project/commit/0cc38acfc4e1dcdc2a9b6287bc93eef57acfd105
DIFF: https://github.com/llvm/llvm-project/commit/0cc38acfc4e1dcdc2a9b6287bc93eef57acfd105.diff
LOG: [Matrix] Propagate shape information through fneg
Similar to binary operators like fadd/fmul/fsub, propagate shape info
through unary operators (fneg is the only one?).
Differential Revision: https://reviews.llvm.org/D95252
Added:
Modified:
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index aa9894ca32b3..812922c49cfa 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -488,6 +488,7 @@ class LowerMatrixIntrinsics {
case Instruction::FAdd:
case Instruction::FSub:
case Instruction::FMul: // Scalar multiply.
+ case Instruction::FNeg:
case Instruction::Add:
case Instruction::Mul:
case Instruction::Sub:
@@ -724,6 +725,8 @@ class LowerMatrixIntrinsics {
Value *Op2;
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
Changed |= VisitBinaryOperator(BinOp);
+ if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
+ Changed |= VisitUnaryOperator(UnOp);
if (match(Inst, m_Load(m_Value(Op1))))
Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
@@ -1499,6 +1502,40 @@ class LowerMatrixIntrinsics {
return true;
}
+ /// Lower unary operators, if shape information is available.
+ bool VisitUnaryOperator(UnaryOperator *Inst) {
+ auto I = ShapeMap.find(Inst);
+ if (I == ShapeMap.end())
+ return false;
+
+ Value *Op = Inst->getOperand(0);
+
+ IRBuilder<> Builder(Inst);
+ ShapeInfo &Shape = I->second;
+
+ MatrixTy Result;
+ MatrixTy M = getMatrix(Op, Shape, Builder);
+
+ // Helper to perform unary op on vectors.
+ auto BuildVectorOp = [&Builder, Inst](Value *Op) {
+ switch (Inst->getOpcode()) {
+ case Instruction::FNeg:
+ return Builder.CreateFNeg(Op);
+ default:
+ llvm_unreachable("Unsupported unary operator for matrix");
+ }
+ };
+
+ for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
+ Result.addVector(BuildVectorOp(M.getVector(I)));
+
+ finalizeLowering(Inst,
+ Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+ Result.getNumVectors()),
+ Builder);
+ return true;
+ }
+
/// Helper to linearize a matrix expression tree into a string. Currently
/// matrix expressions are linarized by starting at an expression leaf and
/// linearizing bottom up.
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
index 584f47d8530b..0860d7adb608 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-backward.ll
@@ -93,4 +93,48 @@ entry:
ret <8 x double> %c
}
+define <8 x double> @load_fneg_transpose(<8 x double>* %A.Ptr) {
+; CHECK-LABEL: @load_fneg_transpose(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x double>* [[A_PTR:%.*]] to double*
+; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast double* [[TMP0]] to <2 x double>*
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST]], align 8
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, double* [[TMP0]], i64 2
+; CHECK-NEXT: [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>*
+; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8
+; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr double, double* [[TMP0]], i64 4
+; CHECK-NEXT: [[VEC_CAST4:%.*]] = bitcast double* [[VEC_GEP3]] to <2 x double>*
+; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST4]], align 8
+; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr double, double* [[TMP0]], i64 6
+; CHECK-NEXT: [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <2 x double>*
+; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST7]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = fneg <2 x double> [[COL_LOAD]]
+; CHECK-NEXT: [[TMP2:%.*]] = fneg <2 x double> [[COL_LOAD2]]
+; CHECK-NEXT: [[TMP3:%.*]] = fneg <2 x double> [[COL_LOAD5]]
+; CHECK-NEXT: [[TMP4:%.*]] = fneg <2 x double> [[COL_LOAD8]]
+; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x double> [[TMP1]], i64 0
+; CHECK-NEXT: [[TMP6:%.*]] = insertelement <4 x double> undef, double [[TMP5]], i64 0
+; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x double> [[TMP2]], i64 0
+; CHECK-NEXT: [[TMP8:%.*]] = insertelement <4 x double> [[TMP6]], double [[TMP7]], i64 1
+; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x double> [[TMP3]], i64 0
+; CHECK-NEXT: [[TMP10:%.*]] = insertelement <4 x double> [[TMP8]], double [[TMP9]], i64 2
+; CHECK-NEXT: [[TMP11:%.*]] = extractelement <2 x double> [[TMP4]], i64 0
+; CHECK-NEXT: [[TMP12:%.*]] = insertelement <4 x double> [[TMP10]], double [[TMP11]], i64 3
+; CHECK-NEXT: [[TMP13:%.*]] = extractelement <2 x double> [[TMP1]], i64 1
+; CHECK-NEXT: [[TMP14:%.*]] = insertelement <4 x double> undef, double [[TMP13]], i64 0
+; CHECK-NEXT: [[TMP15:%.*]] = extractelement <2 x double> [[TMP2]], i64 1
+; CHECK-NEXT: [[TMP16:%.*]] = insertelement <4 x double> [[TMP14]], double [[TMP15]], i64 1
+; CHECK-NEXT: [[TMP17:%.*]] = extractelement <2 x double> [[TMP3]], i64 1
+; CHECK-NEXT: [[TMP18:%.*]] = insertelement <4 x double> [[TMP16]], double [[TMP17]], i64 2
+; CHECK-NEXT: [[TMP19:%.*]] = extractelement <2 x double> [[TMP4]], i64 1
+; CHECK-NEXT: [[TMP20:%.*]] = insertelement <4 x double> [[TMP18]], double [[TMP19]], i64 3
+; CHECK-NEXT: [[TMP21:%.*]] = shufflevector <4 x double> [[TMP12]], <4 x double> [[TMP20]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: ret <8 x double> [[TMP21]]
+;
+entry:
+ %a = load <8 x double>, <8 x double>* %A.Ptr, align 8
+ %neg = fneg <8 x double> %a
+ %c = call <8 x double> @llvm.matrix.transpose(<8 x double> %neg, i32 2, i32 4)
+ ret <8 x double> %c
+}
declare <8 x double> @llvm.matrix.transpose(<8 x double>, i32, i32)
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll
index 51796f3d1b60..6095957a6559 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll
@@ -114,3 +114,37 @@ entry:
%res = fmul <8 x double> %c, %a
ret <8 x double> %res
}
+
+define <8 x double> @transpose_fneg(<8 x double> %a) {
+; CHECK-LABEL: @transpose_fneg(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 4, i32 5>
+; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> poison, <2 x i32> <i32 6, i32 7>
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x double> undef, double [[TMP0]], i64 0
+; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 0
+; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x double> [[TMP1]], double [[TMP2]], i64 1
+; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0
+; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x double> [[TMP3]], double [[TMP4]], i64 2
+; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0
+; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 3
+; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT]], i64 1
+; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x double> undef, double [[TMP8]], i64 0
+; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 1
+; CHECK-NEXT: [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 1
+; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1
+; CHECK-NEXT: [[TMP13:%.*]] = insertelement <4 x double> [[TMP11]], double [[TMP12]], i64 2
+; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1
+; CHECK-NEXT: [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 3
+; CHECK-NEXT: [[TMP16:%.*]] = fneg <4 x double> [[TMP7]]
+; CHECK-NEXT: [[TMP17:%.*]] = fneg <4 x double> [[TMP15]]
+; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <4 x double> [[TMP16]], <4 x double> [[TMP17]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: ret <8 x double> [[TMP18]]
+;
+entry:
+ %c = call <8 x double> @llvm.matrix.transpose(<8 x double> %a, i32 2, i32 4)
+ %res = fneg <8 x double> %c
+ ret <8 x double> %res
+}
More information about the llvm-branch-commits
mailing list