[PATCH] D80181: [mlir][spirv] Add remaining cooperative matrix instructions.
Lei Zhang via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Thu May 21 05:54:21 PDT 2020
antiagainst accepted this revision.
antiagainst marked 2 inline comments as done.
antiagainst added a comment.
This revision is now accepted and ready to land.
Cool! Sorry just a few more nits.
================
Comment at: mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td:56
+ );
+ let verifier = [{ return success(); }];
+}
----------------
ThomasRaoux wrote:
> antiagainst wrote:
> > I think we can add
> >
> > ```
> > let assemblyFormat = "attr-dict `:` type($result)";
> > ```
> >
> > So that we don't need to manually write the parser and printer?
> Here type is not the type of $result so this syntax wouldn't work. Note sure if there is a way to get it to pick up type for the argument?
Ah you are right. Actually the op takes in the id for a type directly. Sorry missed that.
================
Comment at: mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td:176
+ sssa-use `,` ssa-use `,` ssa-use ` : `
+ cooperative-matrix-type
+ ```
----------------
ThomasRaoux wrote:
> antiagainst wrote:
> > One general rule in MLIR regarding assembly is that it should be parsable on its own. I think we need to give at least two types (for `$a` and `$b`) for this? Otherwise by only looking at this op, I'm not able to tell what the type is for `$a` and `$b`.
> My bad I had wrongly assumed the type were matching. I added 2 types and a result type even though the result type could be deducted from the operand types.
> I wasn't able to get the assemblyFormat tow work as I could find a way to make the type of c the same as type of result. (type($c, $result) doesn't work) I haven't look very deeply though I can investigate more if you want. Fow now I left the custom parsing/printing functions.
>
> I also added a verify function for muladd and some tests along with it. I was planning to add those later to keep the patch small but it is better to add it now so that I don't forget.
SGTM. The assembly format is a bit opaque; River added support for it so he knows all the nitty gritty. Feel free to ping him later if you have questions there. (But he is OOO ATM I think.)
================
Comment at: mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td:179
+ ssa-use `,` ssa-use `,` ssa-use ` : `
+ cooperative-matrix-type
+ ```
----------------
This need to be updated right?
================
Comment at: mlir/lib/Dialect/SPIRV/SPIRVOps.cpp:2672
+static LogicalResult verifyPointerandCoopMatrixType(Operation *op, Type pointer,
+ Type coopMatrix) {
----------------
Nit: ...And... ?
================
Comment at: mlir/lib/Dialect/SPIRV/SPIRVOps.cpp:2677
+ return op->emitError(
+ "expected the same type for pointer snd the coop matrix"
+ "element, bu provided ")
----------------
s/snd/and/
s/coop/cooperative/
For error messages it's better to be clear; so spelling out things completely is preferrable.
================
Comment at: mlir/lib/Dialect/SPIRV/SPIRVOps.cpp:2766
+
+static void print(spirv::CooperativeMatrixMulAddNVOp M, OpAsmPrinter &printer) {
+ printer << M.getOperationName() << ' ' << M.getOperand(0) << ", "
----------------
Nit: In general MLIR does not use capitalized variable names. So I'd suggest to s/M/coopMatrix/ or something.
================
Comment at: mlir/lib/Dialect/SPIRV/SPIRVOps.cpp:2777
+ return op.emitOpError(
+ "destination and 3rd operand must have the same type.");
+ auto typeA = op.a().getType().dyn_cast<spirv::CooperativeMatrixNVType>();
----------------
result and the third ... ?
No need to have the trailing dot here.
================
Comment at: mlir/lib/Dialect/SPIRV/SPIRVOps.cpp:2782
+ auto typeR = op.result().getType().dyn_cast<spirv::CooperativeMatrixNVType>();
+ if (typeA == nullptr || typeB == nullptr || typeC == nullptr ||
+ typeR == nullptr)
----------------
We can do unconditional `cast` in the above and omit the check here. These are guaranteed by ODS. Normally the constraints in ODS does not need to be repeated here. We just need to check additional constraints using C++ here.
================
Comment at: mlir/lib/Dialect/SPIRV/SPIRVOps.cpp:2789
+ typeB.getColumns() != typeR.getColumns())
+ return op.emitOpError("Matrix size mismatch");
+ if (typeR.getScope() != typeA.getScope() ||
----------------
In general error messages should stay all lower-case in MLIR. I think that's also the case for Clang/etc. So here: "matrix ..."
Besides, it's better to be consistent so either "matrix .. mismatch" or "matrix ... must match" for this and the following two cases.
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D80181/new/
https://reviews.llvm.org/D80181
More information about the llvm-commits
mailing list