[PATCH] D131125: [Matrix] Add special case dot product lowering

Francis Visoiu Mistrih via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 4 09:41:33 PDT 2022


thegameg added inline comments.


================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1240
+      Reduce = Intrinsic::getDeclaration(
+          Func.getParent(), Intrinsic::vector_reduce_fadd, LHS->getType());
+      Add = Intrinsic::getDeclaration(Func.getParent(), Instruction::FAdd,
----------------
You can probably remove this and use `IRBuilder::Create(F)AddReduce` directly below?


================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1244
+      AddOpCode = Instruction::FAdd;
+      if (!FMF.allowReassoc()) {
+        return;
----------------
Maybe we can do all the FMF and cost checks before everything?


================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1272
+    // Since row vectors loads have to be lowered differently, matmul must be
+    // the only user
+    CallInst *LHSBuiltinLoad = getBuiltinLoad(LHS);
----------------
You might want to specify what you mean by "differently" (not using `LowerLoad` because we would end up with scalar loads for the row vector).


================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1274
+    CallInst *LHSBuiltinLoad = getBuiltinLoad(LHS);
+    if (LHSBuiltinLoad || isa<LoadInst>(LHS)) {
+      if (LHS->hasOneUse())
----------------
To add on top of what @fhahn said above, you can combine matchers like this:

```
match(&I, m_CombineOr(m_Load(), m_Intrinsic<Intrinsic::matrix_column_major_load>())
```


================
Comment at: llvm/test/Transforms/LowerMatrixIntrinsics/dot-product.ll:165
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <8 x i32>, ptr [[LHS_ADDRESS:%.*]], align 4
-; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <1 x i32>, ptr [[RHS_ADDRESS:%.*]], align 4
-; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 8
-; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <1 x i32>, ptr [[VEC_GEP]], align 4
-; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 16
-; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <1 x i32>, ptr [[VEC_GEP3]], align 4
-; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 24
-; CHECK-NEXT:    [[COL_LOAD6:%.*]] = load <1 x i32>, ptr [[VEC_GEP5]], align 4
-; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 32
-; CHECK-NEXT:    [[COL_LOAD8:%.*]] = load <1 x i32>, ptr [[VEC_GEP7]], align 4
-; CHECK-NEXT:    [[VEC_GEP9:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 40
-; CHECK-NEXT:    [[COL_LOAD10:%.*]] = load <1 x i32>, ptr [[VEC_GEP9]], align 4
-; CHECK-NEXT:    [[VEC_GEP11:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 48
-; CHECK-NEXT:    [[COL_LOAD12:%.*]] = load <1 x i32>, ptr [[VEC_GEP11]], align 4
-; CHECK-NEXT:    [[VEC_GEP13:%.*]] = getelementptr i32, ptr [[RHS_ADDRESS]], i64 56
-; CHECK-NEXT:    [[COL_LOAD14:%.*]] = load <1 x i32>, ptr [[VEC_GEP13]], align 4
-; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x i32> [[COL_LOAD]], <8 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[SPLIT15:%.*]] = shufflevector <8 x i32> [[COL_LOAD]], <8 x i32> poison, <1 x i32> <i32 1>
-; CHECK-NEXT:    [[SPLIT16:%.*]] = shufflevector <8 x i32> [[COL_LOAD]], <8 x i32> poison, <1 x i32> <i32 2>
-; CHECK-NEXT:    [[SPLIT17:%.*]] = shufflevector <8 x i32> [[COL_LOAD]], <8 x i32> poison, <1 x i32> <i32 3>
-; CHECK-NEXT:    [[SPLIT18:%.*]] = shufflevector <8 x i32> [[COL_LOAD]], <8 x i32> poison, <1 x i32> <i32 4>
-; CHECK-NEXT:    [[SPLIT19:%.*]] = shufflevector <8 x i32> [[COL_LOAD]], <8 x i32> poison, <1 x i32> <i32 5>
-; CHECK-NEXT:    [[SPLIT20:%.*]] = shufflevector <8 x i32> [[COL_LOAD]], <8 x i32> poison, <1 x i32> <i32 6>
-; CHECK-NEXT:    [[SPLIT21:%.*]] = shufflevector <8 x i32> [[COL_LOAD]], <8 x i32> poison, <1 x i32> <i32 7>
-; CHECK-NEXT:    [[TMP0:%.*]] = shufflevector <1 x i32> [[COL_LOAD1]], <1 x i32> [[COL_LOAD2]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <1 x i32> [[COL_LOAD4]], <1 x i32> [[COL_LOAD6]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <1 x i32> [[COL_LOAD8]], <1 x i32> [[COL_LOAD10]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <1 x i32> [[COL_LOAD12]], <1 x i32> [[COL_LOAD14]], <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <2 x i32> [[TMP0]], <2 x i32> [[TMP1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP5:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> [[TMP3]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <4 x i32> [[TMP4]], <4 x i32> [[TMP5]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[SPLIT22:%.*]] = shufflevector <8 x i32> [[TMP6]], <8 x i32> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT:    [[BLOCK:%.*]] = shufflevector <1 x i32> [[SPLIT]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <8 x i32> [[SPLIT22]], i64 0
-; CHECK-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x i32> poison, i32 [[TMP7]], i32 0
-; CHECK-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP8:%.*]] = mul <1 x i32> [[BLOCK]], [[SPLAT_SPLAT]]
-; CHECK-NEXT:    [[BLOCK23:%.*]] = shufflevector <1 x i32> [[SPLIT15]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <8 x i32> [[SPLIT22]], i64 1
-; CHECK-NEXT:    [[SPLAT_SPLATINSERT24:%.*]] = insertelement <1 x i32> poison, i32 [[TMP9]], i32 0
-; CHECK-NEXT:    [[SPLAT_SPLAT25:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT24]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP10:%.*]] = mul <1 x i32> [[BLOCK23]], [[SPLAT_SPLAT25]]
-; CHECK-NEXT:    [[TMP11:%.*]] = add <1 x i32> [[TMP8]], [[TMP10]]
-; CHECK-NEXT:    [[BLOCK26:%.*]] = shufflevector <1 x i32> [[SPLIT16]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <8 x i32> [[SPLIT22]], i64 2
-; CHECK-NEXT:    [[SPLAT_SPLATINSERT27:%.*]] = insertelement <1 x i32> poison, i32 [[TMP12]], i32 0
-; CHECK-NEXT:    [[SPLAT_SPLAT28:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT27]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP13:%.*]] = mul <1 x i32> [[BLOCK26]], [[SPLAT_SPLAT28]]
-; CHECK-NEXT:    [[TMP14:%.*]] = add <1 x i32> [[TMP11]], [[TMP13]]
-; CHECK-NEXT:    [[BLOCK29:%.*]] = shufflevector <1 x i32> [[SPLIT17]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <8 x i32> [[SPLIT22]], i64 3
-; CHECK-NEXT:    [[SPLAT_SPLATINSERT30:%.*]] = insertelement <1 x i32> poison, i32 [[TMP15]], i32 0
-; CHECK-NEXT:    [[SPLAT_SPLAT31:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT30]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP16:%.*]] = mul <1 x i32> [[BLOCK29]], [[SPLAT_SPLAT31]]
-; CHECK-NEXT:    [[TMP17:%.*]] = add <1 x i32> [[TMP14]], [[TMP16]]
-; CHECK-NEXT:    [[BLOCK32:%.*]] = shufflevector <1 x i32> [[SPLIT18]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP18:%.*]] = extractelement <8 x i32> [[SPLIT22]], i64 4
-; CHECK-NEXT:    [[SPLAT_SPLATINSERT33:%.*]] = insertelement <1 x i32> poison, i32 [[TMP18]], i32 0
-; CHECK-NEXT:    [[SPLAT_SPLAT34:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT33]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP19:%.*]] = mul <1 x i32> [[BLOCK32]], [[SPLAT_SPLAT34]]
-; CHECK-NEXT:    [[TMP20:%.*]] = add <1 x i32> [[TMP17]], [[TMP19]]
-; CHECK-NEXT:    [[BLOCK35:%.*]] = shufflevector <1 x i32> [[SPLIT19]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP21:%.*]] = extractelement <8 x i32> [[SPLIT22]], i64 5
-; CHECK-NEXT:    [[SPLAT_SPLATINSERT36:%.*]] = insertelement <1 x i32> poison, i32 [[TMP21]], i32 0
-; CHECK-NEXT:    [[SPLAT_SPLAT37:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT36]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP22:%.*]] = mul <1 x i32> [[BLOCK35]], [[SPLAT_SPLAT37]]
-; CHECK-NEXT:    [[TMP23:%.*]] = add <1 x i32> [[TMP20]], [[TMP22]]
-; CHECK-NEXT:    [[BLOCK38:%.*]] = shufflevector <1 x i32> [[SPLIT20]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP24:%.*]] = extractelement <8 x i32> [[SPLIT22]], i64 6
-; CHECK-NEXT:    [[SPLAT_SPLATINSERT39:%.*]] = insertelement <1 x i32> poison, i32 [[TMP24]], i32 0
-; CHECK-NEXT:    [[SPLAT_SPLAT40:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT39]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP25:%.*]] = mul <1 x i32> [[BLOCK38]], [[SPLAT_SPLAT40]]
-; CHECK-NEXT:    [[TMP26:%.*]] = add <1 x i32> [[TMP23]], [[TMP25]]
-; CHECK-NEXT:    [[BLOCK41:%.*]] = shufflevector <1 x i32> [[SPLIT21]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP27:%.*]] = extractelement <8 x i32> [[SPLIT22]], i64 7
-; CHECK-NEXT:    [[SPLAT_SPLATINSERT42:%.*]] = insertelement <1 x i32> poison, i32 [[TMP27]], i32 0
-; CHECK-NEXT:    [[SPLAT_SPLAT43:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT42]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP28:%.*]] = mul <1 x i32> [[BLOCK41]], [[SPLAT_SPLAT43]]
-; CHECK-NEXT:    [[TMP29:%.*]] = add <1 x i32> [[TMP26]], [[TMP28]]
-; CHECK-NEXT:    [[TMP30:%.*]] = shufflevector <1 x i32> [[TMP29]], <1 x i32> poison, <1 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP31:%.*]] = shufflevector <1 x i32> undef, <1 x i32> [[TMP30]], <1 x i32> <i32 1>
-; CHECK-NEXT:    ret <1 x i32> [[TMP31]]
+; CHECK-NEXT:    [[TMP0:%.*]] = load <8 x i32>, ptr [[LHS_ADDRESS:%.*]], align 32
+; CHECK-NEXT:    [[TMP1:%.*]] = load <8 x i32>, ptr [[RHS_ADDRESS:%.*]], align 32
----------------
Very nice!


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D131125/new/

https://reviews.llvm.org/D131125



More information about the llvm-commits mailing list