[llvm] 0fcc99a - [Matrix] Add tests for addition transpose optimizations

Francis Visoiu Mistrih via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 26 13:27:40 PDT 2022


Author: Francis Visoiu Mistrih
Date: 2022-09-26T13:27:03-07:00
New Revision: 0fcc99ade4d6f5661c0d3ea06a77c0a6d74b152c

URL: https://github.com/llvm/llvm-project/commit/0fcc99ade4d6f5661c0d3ea06a77c0a6d74b152c
DIFF: https://github.com/llvm/llvm-project/commit/0fcc99ade4d6f5661c0d3ea06a77c0a6d74b152c.diff

LOG: [Matrix] Add tests for addition transpose optimizations

Tests before transpose optimizations around additions.

Differential Revision: https://reviews.llvm.org/D133656

Added: 
    

Modified: 
    llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll

Removed: 
    


################################################################################
diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
index 5e741c1c99fd7..4f0b119b3145d 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll
@@ -94,6 +94,172 @@ entry:
   ret void
 }
 
+; A^T + B^T -> (A + B)^T
+define void @at_plus_bt(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %C) {
+; CHECK-LABEL: @at_plus_bt(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
+; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
+; CHECK-NEXT:    [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
+; CHECK-NEXT:    [[BT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[B]], i32 3, i32 3)
+; CHECK-NEXT:    [[FADD:%.*]] = fadd <9 x double> [[AT]], [[BT]]
+; CHECK-NEXT:    store <9 x double> [[FADD]], <9 x double>* [[C:%.*]], align 128
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a = load <9 x double>, <9 x double>* %Aptr
+  %b = load <9 x double>, <9 x double>* %Bptr
+  %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
+  %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
+  %fadd = fadd <9 x double> %at, %bt
+  store <9 x double> %fadd, <9 x double>* %C
+  ret void
+}
+
+; (A + B)^T -> A^T + B^T -> (A + B)^T
+define void @a_plus_b_t(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %C) {
+; CHECK-LABEL: @a_plus_b_t(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
+; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
+; CHECK-NEXT:    [[FADD:%.*]] = fadd <9 x double> [[A]], [[B]]
+; CHECK-NEXT:    [[T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[FADD]], i32 3, i32 3)
+; CHECK-NEXT:    store <9 x double> [[T]], <9 x double>* [[C:%.*]], align 128
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a = load <9 x double>, <9 x double>* %Aptr
+  %b = load <9 x double>, <9 x double>* %Bptr
+  %fadd = fadd <9 x double> %a, %b
+  %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %fadd, i32 3, i32 3)
+  store <9 x double> %t, <9 x double>* %C
+  ret void
+}
+
+; A^T * B^T + C^T * D^T -> (B * A + D * C)^T
+define void @atbt_plus_ctdt(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %Cptr, <9 x double>* %Dptr, <9 x double>* %E) {
+; CHECK-LABEL: @atbt_plus_ctdt(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
+; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
+; CHECK-NEXT:    [[C:%.*]] = load <9 x double>, <9 x double>* [[CPTR:%.*]], align 128
+; CHECK-NEXT:    [[D:%.*]] = load <9 x double>, <9 x double>* [[DPTR:%.*]], align 128
+; CHECK-NEXT:    [[TMP0:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
+; CHECK-NEXT:    [[TMP1:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[TMP0]], i32 3, i32 3)
+; CHECK-NEXT:    [[TMP2:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[C]], i32 3, i32 3, i32 3)
+; CHECK-NEXT:    [[TMP3:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[TMP2]], i32 3, i32 3)
+; CHECK-NEXT:    [[FADD:%.*]] = fadd <9 x double> [[TMP1]], [[TMP3]]
+; CHECK-NEXT:    store <9 x double> [[FADD]], <9 x double>* [[E:%.*]], align 128
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a = load <9 x double>, <9 x double>* %Aptr
+  %b = load <9 x double>, <9 x double>* %Bptr
+  %c = load <9 x double>, <9 x double>* %Cptr
+  %d = load <9 x double>, <9 x double>* %Dptr
+  %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
+  %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
+  %ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3)
+  %dt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %d, i32 3, i32 3)
+  %atbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %bt, i32 3, i32 3, i32 3)
+  %ctdt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %ct, <9 x double> %dt, i32 3, i32 3, i32 3)
+  %fadd = fadd <9 x double> %atbt, %ctdt
+  store <9 x double> %fadd, <9 x double>* %E
+  ret void
+}
+
+; -(A^T) + B^T
+define void @negat_plus_bt(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %C) {
+; CHECK-LABEL: @negat_plus_bt(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
+; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
+; CHECK-NEXT:    [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
+; CHECK-NEXT:    [[NEGAT:%.*]] = fneg <9 x double> [[AT]]
+; CHECK-NEXT:    [[BT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[B]], i32 3, i32 3)
+; CHECK-NEXT:    [[FADD:%.*]] = fadd <9 x double> [[NEGAT]], [[BT]]
+; CHECK-NEXT:    store <9 x double> [[FADD]], <9 x double>* [[C:%.*]], align 128
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a = load <9 x double>, <9 x double>* %Aptr
+  %b = load <9 x double>, <9 x double>* %Bptr
+  %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
+  %negat = fneg <9 x double> %at
+  %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
+  %fadd = fadd <9 x double> %negat, %bt
+  store <9 x double> %fadd, <9 x double>* %C
+  ret void
+}
+
+; (A^T * B^T + k * C^T * D^T)^T -> (B * A) + (D * C * k)
+define void @atbt_plus_kctdt_t(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %Cptr, <9 x double>* %Dptr, double %k, <9 x double>* %E) {
+; CHECK-LABEL: @atbt_plus_kctdt_t(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
+; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
+; CHECK-NEXT:    [[C:%.*]] = load <9 x double>, <9 x double>* [[CPTR:%.*]], align 128
+; CHECK-NEXT:    [[D:%.*]] = load <9 x double>, <9 x double>* [[DPTR:%.*]], align 128
+; CHECK-NEXT:    [[CT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[C]], i32 3, i32 3)
+; CHECK-NEXT:    [[DT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[D]], i32 3, i32 3)
+; CHECK-NEXT:    [[TMP0:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
+; CHECK-NEXT:    [[TMP1:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[TMP0]], i32 3, i32 3)
+; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
+; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
+; CHECK-NEXT:    [[KCT:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[SPLAT]], <9 x double> [[CT]], i32 3, i32 3, i32 3)
+; CHECK-NEXT:    [[KCTDT:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[KCT]], <9 x double> [[DT]], i32 3, i32 3, i32 3)
+; CHECK-NEXT:    [[FADD:%.*]] = fadd <9 x double> [[TMP1]], [[KCTDT]]
+; CHECK-NEXT:    [[T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[FADD]], i32 3, i32 3)
+; CHECK-NEXT:    store <9 x double> [[T]], <9 x double>* [[E:%.*]], align 128
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a = load <9 x double>, <9 x double>* %Aptr
+  %b = load <9 x double>, <9 x double>* %Bptr
+  %c = load <9 x double>, <9 x double>* %Cptr
+  %d = load <9 x double>, <9 x double>* %Dptr
+  %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
+  %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
+  %ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3)
+  %dt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %d, i32 3, i32 3)
+  %atbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %bt, i32 3, i32 3, i32 3)
+  %veck = insertelement <9 x double> poison, double %k, i64 0
+  %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
+  %kct = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %ct, i32 3, i32 3, i32 3)
+  %kctdt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %kct, <9 x double> %dt, i32 3, i32 3, i32 3)
+  %fadd = fadd <9 x double> %atbt, %kctdt
+  %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %fadd, i32 3, i32 3)
+  store <9 x double> %t, <9 x double>* %E
+  ret void
+}
+
+; (A^T * (k * B^T))^T => (B * k) * A
+define void @atkbt_t(<9 x double>* %Aptr, <9 x double>* %Bptr, double %k, <9 x double>* %C) {
+; CHECK-LABEL: @atkbt_t(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
+; CHECK-NEXT:    [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
+; CHECK-NEXT:    [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
+; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
+; CHECK-NEXT:    [[MMUL1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3)
+; CHECK-NEXT:    [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[MMUL1]], <9 x double> [[A]], i32 3, i32 3, i32 3)
+; CHECK-NEXT:    store <9 x double> [[MMUL]], <9 x double>* [[C:%.*]], align 128
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a = load <9 x double>, <9 x double>* %Aptr
+  %b = load <9 x double>, <9 x double>* %Bptr
+  %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
+  %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
+  %veck = insertelement <9 x double> poison, double %k, i64 0
+  %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
+  %kbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %bt, i32 3, i32 3, i32 3)
+  %atkbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %kbt, i32 3, i32 3, i32 3)
+  %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %atkbt, i32 3, i32 3)
+  store <9 x double> %t, <9 x double>* %C
+  ret void
+}
+
 declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
 declare <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double>, i32, i32)
 declare <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32>, i32, i32)


        


More information about the llvm-commits mailing list