[llvm] 625877b - [Matrix] Add tests dot product with varied strides

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 11 11:09:39 PDT 2022


Author: Vir Narula
Date: 2022-08-11T19:09:21+01:00
New Revision: 625877b0ef709f1d16b3c7a018ef7b8f5149b0cc

URL: https://github.com/llvm/llvm-project/commit/625877b0ef709f1d16b3c7a018ef7b8f5149b0cc
DIFF: https://github.com/llvm/llvm-project/commit/625877b0ef709f1d16b3c7a018ef7b8f5149b0cc.diff

LOG: [Matrix] Add tests dot product with varied strides

Add more tests with varied strides. Changes to lowering upcoming in https://reviews.llvm.org/D131125

Reviewed By: fhahn

Differential Revision: https://reviews.llvm.org/D131444

Added: 
    

Modified: 
    llvm/test/Transforms/LowerMatrixIntrinsics/dot-product.ll

Removed: 
    


################################################################################
diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product.ll
index 1edce450acff0..110802769c102 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product.ll
@@ -112,15 +112,15 @@ define <1 x float> @intrinsic_column_major_load_dot_product_float_v6(ptr %lhs_ad
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <6 x float>, ptr [[LHS_ADDRESS:%.*]], align 4
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <1 x float>, ptr [[RHS_ADDRESS:%.*]], align 4
-; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[RHS_ADDRESS]], i64 6
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[RHS_ADDRESS]], i64 1
 ; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <1 x float>, ptr [[VEC_GEP]], align 4
-; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS_ADDRESS]], i64 12
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS_ADDRESS]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <1 x float>, ptr [[VEC_GEP3]], align 4
-; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr float, ptr [[RHS_ADDRESS]], i64 18
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr float, ptr [[RHS_ADDRESS]], i64 3
 ; CHECK-NEXT:    [[COL_LOAD6:%.*]] = load <1 x float>, ptr [[VEC_GEP5]], align 4
-; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr float, ptr [[RHS_ADDRESS]], i64 24
+; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr float, ptr [[RHS_ADDRESS]], i64 4
 ; CHECK-NEXT:    [[COL_LOAD8:%.*]] = load <1 x float>, ptr [[VEC_GEP7]], align 4
-; CHECK-NEXT:    [[VEC_GEP9:%.*]] = getelementptr float, ptr [[RHS_ADDRESS]], i64 30
+; CHECK-NEXT:    [[VEC_GEP9:%.*]] = getelementptr float, ptr [[RHS_ADDRESS]], i64 5
 ; CHECK-NEXT:    [[COL_LOAD10:%.*]] = load <1 x float>, ptr [[VEC_GEP9]], align 4
 ; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <6 x float> [[COL_LOAD]], <6 x float> poison, <1 x i32> zeroinitializer
 ; CHECK-NEXT:    [[SPLIT11:%.*]] = shufflevector <6 x float> [[COL_LOAD]], <6 x float> poison, <1 x i32> <i32 1>
@@ -171,7 +171,7 @@ define <1 x float> @intrinsic_column_major_load_dot_product_float_v6(ptr %lhs_ad
 ;
 entry:
   %lhs = tail call fast <6 x float> @llvm.matrix.column.major.load.v6f32.i64(ptr nonnull align 4 %lhs_address, i64 6, i1 false, i32 6, i32 1)
-  %rhs = tail call fast <6 x float> @llvm.matrix.column.major.load.v6f32.i64(ptr nonnull align 4 %rhs_address, i64 6, i1 false, i32 1, i32 6)
+  %rhs = tail call fast <6 x float> @llvm.matrix.column.major.load.v6f32.i64(ptr nonnull align 4 %rhs_address, i64 1, i1 false, i32 1, i32 6)
   %result = tail call fast <1 x float> @llvm.matrix.multiply.v1f32.v6f32.v6f32(<6 x float> %lhs, <6 x float> %rhs, i32 1, i32 6, i32 1)
   ret <1 x float> %result
 }
@@ -299,15 +299,15 @@ define <1 x double> @intrinsic_column_major_load_dot_product_double_v6(ptr %lhs_
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <6 x double>, ptr [[LHS_ADDRESS:%.*]], align 4
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <1 x double>, ptr [[RHS_ADDRESS:%.*]], align 4
-; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[RHS_ADDRESS]], i64 6
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[RHS_ADDRESS]], i64 1
 ; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <1 x double>, ptr [[VEC_GEP]], align 4
-; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr double, ptr [[RHS_ADDRESS]], i64 12
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr double, ptr [[RHS_ADDRESS]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <1 x double>, ptr [[VEC_GEP3]], align 4
-; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr double, ptr [[RHS_ADDRESS]], i64 18
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr double, ptr [[RHS_ADDRESS]], i64 3
 ; CHECK-NEXT:    [[COL_LOAD6:%.*]] = load <1 x double>, ptr [[VEC_GEP5]], align 4
-; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr double, ptr [[RHS_ADDRESS]], i64 24
+; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr double, ptr [[RHS_ADDRESS]], i64 4
 ; CHECK-NEXT:    [[COL_LOAD8:%.*]] = load <1 x double>, ptr [[VEC_GEP7]], align 4
-; CHECK-NEXT:    [[VEC_GEP9:%.*]] = getelementptr double, ptr [[RHS_ADDRESS]], i64 30
+; CHECK-NEXT:    [[VEC_GEP9:%.*]] = getelementptr double, ptr [[RHS_ADDRESS]], i64 5
 ; CHECK-NEXT:    [[COL_LOAD10:%.*]] = load <1 x double>, ptr [[VEC_GEP9]], align 4
 ; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <6 x double> [[COL_LOAD]], <6 x double> poison, <1 x i32> zeroinitializer
 ; CHECK-NEXT:    [[SPLIT11:%.*]] = shufflevector <6 x double> [[COL_LOAD]], <6 x double> poison, <1 x i32> <i32 1>
@@ -358,7 +358,7 @@ define <1 x double> @intrinsic_column_major_load_dot_product_double_v6(ptr %lhs_
 ;
 entry:
   %lhs = tail call fast <6 x double> @llvm.matrix.column.major.load.v6f64.i64(ptr nonnull align 4 %lhs_address, i64 6, i1 false, i32 6, i32 1)
-  %rhs = tail call fast <6 x double> @llvm.matrix.column.major.load.v6f64.i64(ptr nonnull align 4 %rhs_address, i64 6, i1 false, i32 1, i32 6)
+  %rhs = tail call fast <6 x double> @llvm.matrix.column.major.load.v6f64.i64(ptr nonnull align 4 %rhs_address, i64 1, i1 false, i32 1, i32 6)
   %result = tail call fast <1 x double> @llvm.matrix.multiply.v1f64.v6f64.v6f64(<6 x double> %lhs, <6 x double> %rhs, i32 1, i32 6, i32 1)
   ret <1 x double> %result
 }
@@ -505,19 +505,19 @@ define <1 x i32> @intrinsic_column_major_load_dot_product_i32_v8(ptr %lhs_addres
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <8 x i32>, ptr [[LHS_ADDRESS:%.*]], align 4
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <1 x i32>, ptr [[RHS_ADDRESS:%.*]], align 4
-; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 8
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 1
 ; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <1 x i32>, ptr [[VEC_GEP]], align 4
-; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <1 x i32>, ptr [[VEC_GEP3]], align 4
-; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 24
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 3
 ; CHECK-NEXT:    [[COL_LOAD6:%.*]] = load <1 x i32>, ptr [[VEC_GEP5]], align 4
-; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 32
+; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 4
 ; CHECK-NEXT:    [[COL_LOAD8:%.*]] = load <1 x i32>, ptr [[VEC_GEP7]], align 4
-; CHECK-NEXT:    [[VEC_GEP9:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 40
+; CHECK-NEXT:    [[VEC_GEP9:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 5
 ; CHECK-NEXT:    [[COL_LOAD10:%.*]] = load <1 x i32>, ptr [[VEC_GEP9]], align 4
-; CHECK-NEXT:    [[VEC_GEP11:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 48
+; CHECK-NEXT:    [[VEC_GEP11:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 6
 ; CHECK-NEXT:    [[COL_LOAD12:%.*]] = load <1 x i32>, ptr [[VEC_GEP11]], align 4
-; CHECK-NEXT:    [[VEC_GEP13:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 56
+; CHECK-NEXT:    [[VEC_GEP13:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 7
 ; CHECK-NEXT:    [[COL_LOAD14:%.*]] = load <1 x i32>, ptr [[VEC_GEP13]], align 4
 ; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x i32> [[COL_LOAD]], <8 x i32> poison, <1 x i32> zeroinitializer
 ; CHECK-NEXT:    [[SPLIT15:%.*]] = shufflevector <8 x i32> [[COL_LOAD]], <8 x i32> poison, <1 x i32> <i32 1>
@@ -588,7 +588,7 @@ define <1 x i32> @intrinsic_column_major_load_dot_product_i32_v8(ptr %lhs_addres
 ;
 entry:
   %lhs = tail call <8 x i32> @llvm.matrix.column.major.load.v8i32.i64(ptr nonnull align 4 %lhs_address, i64 8, i1 false, i32 8, i32 1)
-  %rhs = tail call <8 x i32> @llvm.matrix.column.major.load.v8i32.i64(ptr nonnull align 4 %rhs_address, i64 8, i1 false, i32 1, i32 8)
+  %rhs = tail call <8 x i32> @llvm.matrix.column.major.load.v8i32.i64(ptr nonnull align 4 %rhs_address, i64 1, i1 false, i32 1, i32 8)
   %result = tail call <1 x i32> @llvm.matrix.multiply.v1i32.v8i32.v8i32(<8 x i32> %lhs, <8 x i32> %rhs, i32 1, i32 8, i32 1)
   ret <1 x i32> %result
 }
@@ -596,6 +596,64 @@ entry:
 declare <8 x i32> @llvm.matrix.column.major.load.v8i32.i64(ptr nonnull align 4, i64, i1, i32, i32)
 declare <1 x i32> @llvm.matrix.multiply.v1i32.v8i32.v8i32(<8 x i32>, <8 x i32>, i32, i32, i32)
 
+; This tests for a case where stride > columns in the load. Does not load all elements in the vector, so we
+; shouldn't use special lowering.
+define <1 x i32> @intrinsic_column_major_load_dot_product_i32_v4_strided(ptr %lhs_address, ptr %rhs_address) {
+; CHECK-LABEL: @intrinsic_column_major_load_dot_product_i32_v4_strided(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <4 x i32>, ptr [[LHS_ADDRESS:%.*]], align 4
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <1 x i32>, ptr [[RHS_ADDRESS:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 4
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <1 x i32>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 8
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <1 x i32>, ptr [[VEC_GEP3]], align 4
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 12
+; CHECK-NEXT:    [[COL_LOAD6:%.*]] = load <1 x i32>, ptr [[VEC_GEP5]], align 4
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <4 x i32> [[COL_LOAD]], <4 x i32> poison, <1 x i32> zeroinitializer
+; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <4 x i32> [[COL_LOAD]], <4 x i32> poison, <1 x i32> <i32 1>
+; CHECK-NEXT:    [[SPLIT8:%.*]] = shufflevector <4 x i32> [[COL_LOAD]], <4 x i32> poison, <1 x i32> <i32 2>
+; CHECK-NEXT:    [[SPLIT9:%.*]] = shufflevector <4 x i32> [[COL_LOAD]], <4 x i32> poison, <1 x i32> <i32 3>
+; CHECK-NEXT:    [[TMP0:%.*]] = shufflevector <1 x i32> [[COL_LOAD1]], <1 x i32> [[COL_LOAD2]], <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <1 x i32> [[COL_LOAD4]], <1 x i32> [[COL_LOAD6]], <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x i32> [[TMP0]], <2 x i32> [[TMP1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[SPLIT10:%.*]] = shufflevector <4 x i32> [[TMP2]], <4 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[BLOCK:%.*]] = shufflevector <1 x i32> [[SPLIT]], <1 x i32> poison, <1 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <4 x i32> [[SPLIT10]], i64 0
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x i32> poison, i32 [[TMP3]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT]], <1 x i32> poison, <1 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP4:%.*]] = mul <1 x i32> [[BLOCK]], [[SPLAT_SPLAT]]
+; CHECK-NEXT:    [[BLOCK11:%.*]] = shufflevector <1 x i32> [[SPLIT7]], <1 x i32> poison, <1 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <4 x i32> [[SPLIT10]], i64 1
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT12:%.*]] = insertelement <1 x i32> poison, i32 [[TMP5]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT13:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT12]], <1 x i32> poison, <1 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP6:%.*]] = mul <1 x i32> [[BLOCK11]], [[SPLAT_SPLAT13]]
+; CHECK-NEXT:    [[TMP7:%.*]] = add <1 x i32> [[TMP4]], [[TMP6]]
+; CHECK-NEXT:    [[BLOCK14:%.*]] = shufflevector <1 x i32> [[SPLIT8]], <1 x i32> poison, <1 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <4 x i32> [[SPLIT10]], i64 2
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT15:%.*]] = insertelement <1 x i32> poison, i32 [[TMP8]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT16:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT15]], <1 x i32> poison, <1 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP9:%.*]] = mul <1 x i32> [[BLOCK14]], [[SPLAT_SPLAT16]]
+; CHECK-NEXT:    [[TMP10:%.*]] = add <1 x i32> [[TMP7]], [[TMP9]]
+; CHECK-NEXT:    [[BLOCK17:%.*]] = shufflevector <1 x i32> [[SPLIT9]], <1 x i32> poison, <1 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <4 x i32> [[SPLIT10]], i64 3
+; CHECK-NEXT:    [[SPLAT_SPLATINSERT18:%.*]] = insertelement <1 x i32> poison, i32 [[TMP11]], i32 0
+; CHECK-NEXT:    [[SPLAT_SPLAT19:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT18]], <1 x i32> poison, <1 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = mul <1 x i32> [[BLOCK17]], [[SPLAT_SPLAT19]]
+; CHECK-NEXT:    [[TMP13:%.*]] = add <1 x i32> [[TMP10]], [[TMP12]]
+; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <1 x i32> [[TMP13]], <1 x i32> poison, <1 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP15:%.*]] = shufflevector <1 x i32> undef, <1 x i32> [[TMP14]], <1 x i32> <i32 1>
+; CHECK-NEXT:    ret <1 x i32> [[TMP15]]
+;
+entry:
+  %lhs = tail call <4 x i32> @llvm.matrix.column.major.load.v4i32.i64(ptr nonnull align 4 %lhs_address, i64 4, i1 false, i32 4, i32 1)
+  %rhs = tail call <4 x i32> @llvm.matrix.column.major.load.v4i32.i64(ptr nonnull align 4 %rhs_address, i64 4, i1 false, i32 1, i32 4)
+  %result = tail call <1 x i32> @llvm.matrix.multiply.v1i32.v4i32.v4i32(<4 x i32> %lhs, <4 x i32> %rhs, i32 1, i32 4, i32 1)
+  ret <1 x i32> %result
+}
+
+declare <4 x i32> @llvm.matrix.column.major.load.v4i32.i64(ptr nonnull align 4, i64, i1, i32, i32)
+declare <1 x i32> @llvm.matrix.multiply.v1i32.v4i32.v4i32(<4 x i32>, <4 x i32>, i32, i32, i32)
+
 define <1 x i16> @LoadInst_dot_product_i16_v6(ptr %lhs_address, ptr %rhs_address) {
 ; CHECK-LABEL: @LoadInst_dot_product_i16_v6(
 ; CHECK-NEXT:  entry:


        


More information about the llvm-commits mailing list