[Mlir-commits] [mlir] c652c30 - [mlir][spirv] Clean up coop matrix assembly declaration.
Thomas Raoux
llvmlistbot at llvm.org
Fri May 29 16:38:01 PDT 2020
Author: Thomas Raoux
Date: 2020-05-29T16:37:35-07:00
New Revision: c652c306a6aa3b356cebae78caf4b33b63afb866
URL: https://github.com/llvm/llvm-project/commit/c652c306a6aa3b356cebae78caf4b33b63afb866
DIFF: https://github.com/llvm/llvm-project/commit/c652c306a6aa3b356cebae78caf4b33b63afb866.diff
LOG: [mlir][spirv] Clean up coop matrix assembly declaration.
Address code review feedback and use declarative assembly format.
Differential Revision: https://reviews.llvm.org/D80687
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
index 4645765b66ba..9c3462a2e5bf 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
@@ -39,6 +39,8 @@ def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
```
}];
+ let assemblyFormat = "attr-dict `:` $type";
+
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
@@ -139,7 +141,7 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
// -----
def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
- [NoSideEffect]> {
+ [NoSideEffect, AllTypesMatch<["c", "result"]>]> {
let summary = "See extension SPV_NV_cooperative_matrix";
let description = [{
@@ -188,6 +190,10 @@ def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
```
}];
+ let assemblyFormat = [{
+ operands attr-dict`:` type($a) `,` type($b) `->` type($c)
+ }];
+
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 4f48ef9d7d7c..ac8fee8619b6 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1134,12 +1134,11 @@ static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
return compositeConstructOp.emitError(
"has incorrect number of operands: expected ")
<< "1, but provided " << constituents.size();
- } else {
- if (constituents.size() != cType.getNumElements())
- return compositeConstructOp.emitError(
- "has incorrect number of operands: expected ")
- << cType.getNumElements() << ", but provided "
- << constituents.size();
+ } else if (constituents.size() != cType.getNumElements()) {
+ return compositeConstructOp.emitError(
+ "has incorrect number of operands: expected ")
+ << cType.getNumElements() << ", but provided "
+ << constituents.size();
}
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
@@ -2735,57 +2734,10 @@ static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
printer << " : " << coopMatrix.getOperand(1).getType();
}
-//===----------------------------------------------------------------------===//
-// spv.CooperativeMatrixLengthNV
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCooperativeMatrixLengthNVOp(OpAsmParser &parser,
- OperationState &state) {
- OpAsmParser::OperandType operandInfo;
- Type dstType = parser.getBuilder().getIntegerType(32);
- Type type;
- if (parser.parseColonType(type)) {
- return failure();
- }
- state.addAttribute(kTypeAttrName, TypeAttr::get(type));
- state.addTypes(dstType);
- return success();
-}
-
-static void print(spirv::CooperativeMatrixLengthNVOp coopMatrix,
- OpAsmPrinter &printer) {
- printer << coopMatrix.getOperationName() << " : " << coopMatrix.type();
-}
-
//===----------------------------------------------------------------------===//
// spv.CooperativeMatrixMulAddNV
//===----------------------------------------------------------------------===//
-static ParseResult parseCooperativeMatrixMulAddNVOp(OpAsmParser &parser,
- OperationState &state) {
- SmallVector<OpAsmParser::OperandType, 3> ops;
- SmallVector<Type, 3> types(3);
- if (parser.parseOperandList(ops, 3) || parser.parseColon() ||
- parser.parseType(types[0]) || parser.parseComma() ||
- parser.parseType(types[1]) || parser.parseArrow() ||
- parser.parseType(types[2]) ||
- parser.resolveOperands(ops, types, parser.getNameLoc(), state.operands)) {
- return failure();
- }
- state.addTypes(types[2]);
- return success();
-}
-
-static void print(spirv::CooperativeMatrixMulAddNVOp coopMatrix,
- OpAsmPrinter &printer) {
- printer << coopMatrix.getOperationName() << ' ' << coopMatrix.getOperand(0)
- << ", " << coopMatrix.getOperand(1) << ", "
- << coopMatrix.getOperand(2) << ", "
- << " : " << coopMatrix.getOperand(0).getType() << ", "
- << coopMatrix.getOperand(1).getType() << " -> "
- << coopMatrix.getOperand(2).getType();
-}
-
static LogicalResult
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
if (op.c().getType() != op.result().getType())
diff --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
index 12f710ea1b46..0d58fea18a11 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
@@ -38,7 +38,7 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_N
// CHECK-LABEL: @cooperative_matrix_muladd
spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}, : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+ // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
spv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
index e30352625da6..51c709067f6f 100644
--- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
@@ -38,7 +38,7 @@ spv.func @cooperative_matrix_length() -> i32 "None" {
// CHECK-LABEL: @cooperative_matrix_muladd
spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
- // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}, : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+ // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
spv.Return
}
More information about the Mlir-commits
mailing list