[llvm] 81bdb40 - [Matrix] Simplify matmuls with scalars
Francis Visoiu Mistrih via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 2 15:50:36 PDT 2022
Author: Francis Visoiu Mistrih
Date: 2022-09-02T15:50:25-07:00
New Revision: 81bdb4068d92de50e25e4c214960aa5b9598dbbe
URL: https://github.com/llvm/llvm-project/commit/81bdb4068d92de50e25e4c214960aa5b9598dbbe
DIFF: https://github.com/llvm/llvm-project/commit/81bdb4068d92de50e25e4c214960aa5b9598dbbe.diff
LOG: [Matrix] Simplify matmuls with scalars
If one of the operands is a transposed splat, the transpose can be
removed.
This is useful to simplify when transposes are distributed to operands
of a matmul:
* k^T -> k
* (A * k)^t -> A^t * k
Differential Revision: https://reviews.llvm.org/D130177
Added:
llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
Modified:
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 4c434155d162a..0a46a551beae4 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -80,6 +80,9 @@ static cl::opt<MatrixLayoutTy> MatrixLayout(
clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
"Use row-major layout")));
+static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
+ cl::init(false));
+
/// Helper function to either return Scope, if it is a subprogram or the
/// attached subprogram for a local scope.
static DISubprogram *getSubprogram(DIScope *Scope) {
@@ -88,6 +91,20 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
return cast<DILocalScope>(Scope)->getSubprogram();
}
+/// Return true if V is a splat of a value (which is used when multiplying a
+/// matrix with a scalar).
+static bool isSplat(Value *V) {
+ if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
+ return SV->isZeroEltSplat();
+ return false;
+}
+
+/// Match any mul operation (fp or integer).
+template <typename LTy, typename RTy>
+auto m_AnyMul(const LTy &L, const RTy &R) {
+ return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
+}
+
namespace {
// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
@@ -747,8 +764,8 @@ class LowerMatrixIntrinsics {
Value *TA, *TAMA, *TAMB;
ConstantInt *R, *K, *C;
- if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) {
-
+ if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
+ m_Value(TA), m_ConstantInt(R), m_ConstantInt(C)))) {
// Transpose of a transpose is a nop
Value *TATA;
if (match(TA,
@@ -757,7 +774,11 @@ class LowerMatrixIntrinsics {
EraseFromParent(&I);
EraseFromParent(TA);
}
-
+ // k^T -> k
+ else if (isSplat(TA)) {
+ ReplaceAllUsesWith(I, TA);
+ EraseFromParent(&I);
+ }
// (A * B)^t -> B^t * A^t
// RxK KxC CxK KxR
else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
@@ -773,6 +794,28 @@ class LowerMatrixIntrinsics {
ReplaceAllUsesWith(I, NewInst);
EraseFromParent(&I);
EraseFromParent(TA);
+ // Same as above, but with a mul, which occurs when multiplied
+ // with a scalar.
+ // (A * k)^t -> A^t * k
+ // R x C RxC
+ } else if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
+ (isSplat(TAMA) || isSplat(TAMB))) {
+ IRBuilder<> LocalBuilder(&I);
+ // We know that the transposed operand is of shape RxC.
+ // An when multiplied with a scalar, the shape is preserved.
+ NewInst = distributeTransposes(
+ TAMA, {R, C}, TAMB, {R, C}, Builder,
+ [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
+ bool IsFP = I.getType()->isFPOrFPVectorTy();
+ auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
+ : LocalBuilder.CreateMul(T0, T1, "mmul");
+ auto *Result = cast<Instruction>(Mul);
+ setShapeInfo(Result, Shape0);
+ return Result;
+ });
+ ReplaceAllUsesWith(I, NewInst);
+ EraseFromParent(&I);
+ EraseFromParent(TA);
}
}
@@ -848,10 +891,10 @@ class LowerMatrixIntrinsics {
if (!isMinimal()) {
optimizeTransposes();
- LLVM_DEBUG({
+ if (PrintAfterTransposeOpt) {
dbgs() << "Dump after matrix transpose optimization:\n";
Func.dump();
- });
+ }
}
bool Changed = false;
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
new file mode 100644
index 0000000000000..5e741c1c99fd7
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
@@ -0,0 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; REQUIRES: aarch64-registered-target
+
+; RUN: opt -passes='lower-matrix-intrinsics' -matrix-print-after-transpose-opt -disable-output %s 2>&1 | FileCheck %s
+
+target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "aarch64-apple-ios"
+
+; k * A^T
+define void @kat(<9 x double>* %Aptr, double %k, <9 x double>* %C) {
+; CHECK-LABEL: @kat(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
+; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
+; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
+; CHECK-NEXT: [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
+; CHECK-NEXT: [[MUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[SPLAT]], <9 x double> [[AT]], i32 3, i32 3, i32 3)
+; CHECK-NEXT: store <9 x double> [[MUL]], <9 x double>* [[C:%.*]], align 128
+; CHECK-NEXT: ret void
+;
+entry:
+ %a = load <9 x double>, <9 x double>* %Aptr
+ %veck = insertelement <9 x double> poison, double %k, i64 0
+ %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
+ %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
+ %mul = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %at, i32 3, i32 3, i32 3)
+ store <9 x double> %mul, <9 x double>* %C
+ ret void
+}
+
+; (k * A)^T -> A^T * k
+define void @ka_t(<9 x double>* %Aptr, double %k, <9 x double>* %C) {
+; CHECK-LABEL: @ka_t(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
+; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
+; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
+; CHECK-NEXT: [[A_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
+; CHECK-NEXT: [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[A_T]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3)
+; CHECK-NEXT: store <9 x double> [[MMUL]], <9 x double>* [[C:%.*]], align 128
+; CHECK-NEXT: ret void
+;
+entry:
+ %a = load <9 x double>, <9 x double>* %Aptr
+ %veck = insertelement <9 x double> poison, double %k, i64 0
+ %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
+ %mul = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %a, i32 3, i32 3, i32 3)
+ %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %mul, i32 3, i32 3)
+ store <9 x double> %t, <9 x double>* %C
+ ret void
+}
+
+; (k * A)^T -> A^T * k with fmul
+define void @ka_t_fmul(<9 x double>* %Aptr, double %k, <9 x double>* %C) {
+; CHECK-LABEL: @ka_t_fmul(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
+; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
+; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
+; CHECK-NEXT: [[A_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
+; CHECK-NEXT: [[MMUL:%.*]] = fmul <9 x double> [[SPLAT]], [[A_T]]
+; CHECK-NEXT: store <9 x double> [[MMUL]], <9 x double>* [[C:%.*]], align 128
+; CHECK-NEXT: ret void
+;
+entry:
+ %a = load <9 x double>, <9 x double>* %Aptr
+ %veck = insertelement <9 x double> poison, double %k, i64 0
+ %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
+ %mul = fmul <9 x double> %splat, %a
+ %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %mul, i32 3, i32 3)
+ store <9 x double> %t, <9 x double>* %C
+ ret void
+}
+
+; (k * A)^T -> A^T * k with mul (non-fp types)
+define void @ka_t_mul(<9 x i32>* %Aptr, i32 %k, <9 x i32>* %C) {
+; CHECK-LABEL: @ka_t_mul(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[A:%.*]] = load <9 x i32>, <9 x i32>* [[APTR:%.*]], align 64
+; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x i32> poison, i32 [[K:%.*]], i64 0
+; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x i32> [[VECK]], <9 x i32> poison, <9 x i32> zeroinitializer
+; CHECK-NEXT: [[A_T:%.*]] = call <9 x i32> @llvm.matrix.transpose.v9i32(<9 x i32> [[A]], i32 3, i32 3)
+; CHECK-NEXT: [[MMUL:%.*]] = mul <9 x i32> [[SPLAT]], [[A_T]]
+; CHECK-NEXT: store <9 x i32> [[MMUL]], <9 x i32>* [[C:%.*]], align 64
+; CHECK-NEXT: ret void
+;
+entry:
+ %a = load <9 x i32>, <9 x i32>* %Aptr
+ %veck = insertelement <9 x i32> poison, i32 %k, i64 0
+ %splat = shufflevector <9 x i32> %veck, <9 x i32> poison, <9 x i32> zeroinitializer
+ %mul = mul <9 x i32> %splat, %a
+ %t = call <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32> %mul, i32 3, i32 3)
+ store <9 x i32> %t, <9 x i32>* %C
+ ret void
+}
+
+declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
+declare <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double>, i32, i32)
+declare <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32>, i32, i32)
More information about the llvm-commits
mailing list