[llvm] 81bdb40 - [Matrix] Simplify matmuls with scalars

Francis Visoiu Mistrih via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 2 15:50:36 PDT 2022


Author: Francis Visoiu Mistrih
Date: 2022-09-02T15:50:25-07:00
New Revision: 81bdb4068d92de50e25e4c214960aa5b9598dbbe

URL: https://github.com/llvm/llvm-project/commit/81bdb4068d92de50e25e4c214960aa5b9598dbbe
DIFF: https://github.com/llvm/llvm-project/commit/81bdb4068d92de50e25e4c214960aa5b9598dbbe.diff

LOG: [Matrix] Simplify matmuls with scalars

If one of the operands is a transposed splat, the transpose can be
removed.

This is useful to simplify when transposes are distributed to operands
of a matmul:

* k^T -> k
* (A * k)^t -> A^t * k

Differential Revision: https://reviews.llvm.org/D130177

Added: 
    llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 4c434155d162a..0a46a551beae4 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -80,6 +80,9 @@ static cl::opt<MatrixLayoutTy> MatrixLayout(
                clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
                           "Use row-major layout")));
 
+static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
+                                            cl::init(false));
+
 /// Helper function to either return Scope, if it is a subprogram or the
 /// attached subprogram for a local scope.
 static DISubprogram *getSubprogram(DIScope *Scope) {
@@ -88,6 +91,20 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
   return cast<DILocalScope>(Scope)->getSubprogram();
 }
 
+/// Return true if V is a splat of a value (which is used when multiplying a
+/// matrix with a scalar).
+static bool isSplat(Value *V) {
+  if (auto *SV = dyn_cast<ShuffleVectorInst>(V))
+    return SV->isZeroEltSplat();
+  return false;
+}
+
+/// Match any mul operation (fp or integer).
+template <typename LTy, typename RTy>
+auto m_AnyMul(const LTy &L, const RTy &R) {
+  return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
+}
+
 namespace {
 
 // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
@@ -747,8 +764,8 @@ class LowerMatrixIntrinsics {
 
         Value *TA, *TAMA, *TAMB;
         ConstantInt *R, *K, *C;
-        if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) {
-
+        if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
+                          m_Value(TA), m_ConstantInt(R), m_ConstantInt(C)))) {
           // Transpose of a transpose is a nop
           Value *TATA;
           if (match(TA,
@@ -757,7 +774,11 @@ class LowerMatrixIntrinsics {
             EraseFromParent(&I);
             EraseFromParent(TA);
           }
-
+          // k^T -> k
+          else if (isSplat(TA)) {
+            ReplaceAllUsesWith(I, TA);
+            EraseFromParent(&I);
+          }
           // (A * B)^t -> B^t * A^t
           // RxK KxC      CxK   KxR
           else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
@@ -773,6 +794,28 @@ class LowerMatrixIntrinsics {
             ReplaceAllUsesWith(I, NewInst);
             EraseFromParent(&I);
             EraseFromParent(TA);
+            // Same as above, but with a mul, which occurs when multiplied
+            // with a scalar.
+            // (A * k)^t -> A^t * k
+            //  R  x  C     RxC
+          } else if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) &&
+                     (isSplat(TAMA) || isSplat(TAMB))) {
+            IRBuilder<> LocalBuilder(&I);
+            // We know that the transposed operand is of shape RxC.
+            // An when multiplied with a scalar, the shape is preserved.
+            NewInst = distributeTransposes(
+                TAMA, {R, C}, TAMB, {R, C}, Builder,
+                [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
+                  bool IsFP = I.getType()->isFPOrFPVectorTy();
+                  auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul")
+                                   : LocalBuilder.CreateMul(T0, T1, "mmul");
+                  auto *Result = cast<Instruction>(Mul);
+                  setShapeInfo(Result, Shape0);
+                  return Result;
+                });
+            ReplaceAllUsesWith(I, NewInst);
+            EraseFromParent(&I);
+            EraseFromParent(TA);
           }
         }
 
@@ -848,10 +891,10 @@ class LowerMatrixIntrinsics {
 
     if (!isMinimal()) {
       optimizeTransposes();
-      LLVM_DEBUG({
+      if (PrintAfterTransposeOpt) {
         dbgs() << "Dump after matrix transpose optimization:\n";
         Func.dump();
-      });
+      }
     }
 
     bool Changed = false;

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
new file mode 100644
index 0000000000000..5e741c1c99fd7
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
@@ -0,0 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; REQUIRES: aarch64-registered-target
+
+; RUN: opt -passes='lower-matrix-intrinsics' -matrix-print-after-transpose-opt -disable-output %s 2>&1 | FileCheck %s
+
+target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "aarch64-apple-ios"
+
+; k * A^T
+define void @kat(<9 x double>* %Aptr, double %k, <9 x double>* %C) {
+; CHECK-LABEL: @kat(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
+; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
+; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
+; CHECK-NEXT:    [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
+; CHECK-NEXT:    [[MUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[SPLAT]], <9 x double> [[AT]], i32 3, i32 3, i32 3)
+; CHECK-NEXT:    store <9 x double> [[MUL]], <9 x double>* [[C:%.*]], align 128
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a = load <9 x double>, <9 x double>* %Aptr
+  %veck = insertelement <9 x double> poison, double %k, i64 0
+  %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
+  %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
+  %mul = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %at, i32 3, i32 3, i32 3)
+  store <9 x double> %mul, <9 x double>* %C
+  ret void
+}
+
+; (k * A)^T -> A^T * k
+define void @ka_t(<9 x double>* %Aptr, double %k, <9 x double>* %C) {
+; CHECK-LABEL: @ka_t(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
+; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
+; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
+; CHECK-NEXT:    [[A_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
+; CHECK-NEXT:    [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[A_T]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3)
+; CHECK-NEXT:    store <9 x double> [[MMUL]], <9 x double>* [[C:%.*]], align 128
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a = load <9 x double>, <9 x double>* %Aptr
+  %veck = insertelement <9 x double> poison, double %k, i64 0
+  %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
+  %mul = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %a, i32 3, i32 3, i32 3)
+  %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %mul, i32 3, i32 3)
+  store <9 x double> %t, <9 x double>* %C
+  ret void
+}
+
+; (k * A)^T -> A^T * k with fmul
+define void @ka_t_fmul(<9 x double>* %Aptr, double %k, <9 x double>* %C) {
+; CHECK-LABEL: @ka_t_fmul(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
+; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
+; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
+; CHECK-NEXT:    [[A_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
+; CHECK-NEXT:    [[MMUL:%.*]] = fmul <9 x double> [[SPLAT]], [[A_T]]
+; CHECK-NEXT:    store <9 x double> [[MMUL]], <9 x double>* [[C:%.*]], align 128
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a = load <9 x double>, <9 x double>* %Aptr
+  %veck = insertelement <9 x double> poison, double %k, i64 0
+  %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
+  %mul = fmul <9 x double> %splat, %a
+  %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %mul, i32 3, i32 3)
+  store <9 x double> %t, <9 x double>* %C
+  ret void
+}
+
+; (k * A)^T -> A^T * k with mul (non-fp types)
+define void @ka_t_mul(<9 x i32>* %Aptr, i32 %k, <9 x i32>* %C) {
+; CHECK-LABEL: @ka_t_mul(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A:%.*]] = load <9 x i32>, <9 x i32>* [[APTR:%.*]], align 64
+; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x i32> poison, i32 [[K:%.*]], i64 0
+; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x i32> [[VECK]], <9 x i32> poison, <9 x i32> zeroinitializer
+; CHECK-NEXT:    [[A_T:%.*]] = call <9 x i32> @llvm.matrix.transpose.v9i32(<9 x i32> [[A]], i32 3, i32 3)
+; CHECK-NEXT:    [[MMUL:%.*]] = mul <9 x i32> [[SPLAT]], [[A_T]]
+; CHECK-NEXT:    store <9 x i32> [[MMUL]], <9 x i32>* [[C:%.*]], align 64
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a = load <9 x i32>, <9 x i32>* %Aptr
+  %veck = insertelement <9 x i32> poison, i32 %k, i64 0
+  %splat = shufflevector <9 x i32> %veck, <9 x i32> poison, <9 x i32> zeroinitializer
+  %mul = mul <9 x i32> %splat, %a
+  %t = call <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32> %mul, i32 3, i32 3)
+  store <9 x i32> %t, <9 x i32>* %C
+  ret void
+}
+
+declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
+declare <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double>, i32, i32)
+declare <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32>, i32, i32)


        


More information about the llvm-commits mailing list