[PATCH] D131125: [Matrix] Add special case dot product lowering
Florian Hahn via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 4 05:59:17 PDT 2022
fhahn added inline comments.
================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:885
+ }
+ Changed = !FusedInsts.empty();
+
----------------
I don't think this is necessary and it looks a bit odd to overwrite the earlier `Changed` status. It should be sufficient to remove setting `Changed` on line 877.
================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1214
+ void lowerDotProduct(CallInst *MatMul,
+ SmallPtrSet<Instruction *, 16> &FusedInsts,
----------------
Please add a comment describing what this function does.
================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1217
+ FastMathFlags FMF) {
+ ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
+ ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
----------------
I think we should be able to just look up the shape of operands 0 and 1 in ShapeMap instead?
================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1220
+
+ if (LShape.NumRows != 1 || RShape.NumColumns != 1) { // not a dot product
+ return;
----------------
nit: coding style recommends omitting {} for single-line blocks.
================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1229
+ Type *ElementType = cast<VectorType>(LHS->getType())->getElementType();
+ bool integerOperands = ElementType->isIntegerTy();
+ Function *Reduce, *Add;
----------------
nit: Coding style uses UpperCase for variables.
================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1249
+
+ // Check that dot product lowering is profitable
+ FastMathFlags FMFReassoc;
----------------
might be good to explain what costs we are comparing.
================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1252
+ FMFReassoc.setAllowReassoc();
+ auto ReductionCost = TTI.getArithmeticReductionCost(
+ AddOpCode, cast<VectorType>(LHS->getType()), FMFReassoc);
----------------
nit: the style guide recommends avoiding `auto` in most cases where the type is not entirely obvious from the context (e.g. auto for dyn_cast is fine).
================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1262
+ // lambda which functions as dyn_cast<BuiltinLoad>
+ auto getBuiltinLoad = [](Value *Val) -> CallInst * {
+ CallInst *CI = dyn_cast<CallInst>(Val);
----------------
nit: use UpperCase for variables.
================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1264
+ CallInst *CI = dyn_cast<CallInst>(Val);
+ if (CI && CI->getCalledFunction()->getName().startswith(
+ "llvm.matrix.column.major.load")) {
----------------
To check for a specific intrinsic, you could either use something like `match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))` or cast to `IntrinsicInst` and checking `getIntrinsicID()`.
================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1292
+ if (LHSBuiltinLoad) {
+ LHS =
+ Builder.CreateLoad(LHS->getType(), LHSBuiltinLoad->getArgOperand(0));
----------------
I am not sure if that's correct. Here we need to load a vector with 1 row and N columns. `llvm.matrix.column.load` may have a stride > 1 between columns, which is ignored. We should probably limit this to .`llvm.matrix.column.load` with a stride of 1 here for now. Needs a test case.
================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1317
+ // pack scalar back into a matrix and then replace matmul inst
+ Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()),
+ Result, uint64_t(0));
----------------
nit: according to the style guide, comments are meant to be full sentences, with the first letter capitalized and a period at the end. Applies to most comments added in the patch.
================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1323
+ // Remove BuiltinLoad if we already generated a LoadInst for it
+ if (LHSBuiltinLoad) {
+ LHSBuiltinLoad->eraseFromParent();
----------------
nit: no need for {}
================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1326
+ }
+ if (RHSBuiltinLoad) {
+ RHSBuiltinLoad->eraseFromParent();
----------------
nit: no need for {}
================
Comment at: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp:1330
+
+ return;
+ }
----------------
nit: no need to return at the end of a `void` function.
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D131125/new/
https://reviews.llvm.org/D131125
More information about the llvm-commits
mailing list