[Mlir-commits] [mlir] c652c30 - [mlir][spirv] Clean up coop matrix assembly declaration.

Thomas Raoux llvmlistbot at llvm.org
Fri May 29 16:38:01 PDT 2020


Author: Thomas Raoux
Date: 2020-05-29T16:37:35-07:00
New Revision: c652c306a6aa3b356cebae78caf4b33b63afb866

URL: https://github.com/llvm/llvm-project/commit/c652c306a6aa3b356cebae78caf4b33b63afb866
DIFF: https://github.com/llvm/llvm-project/commit/c652c306a6aa3b356cebae78caf4b33b63afb866.diff

LOG: [mlir][spirv] Clean up coop matrix assembly declaration.

Address code review feedback and use declarative assembly format.

Differential Revision: https://reviews.llvm.org/D80687

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
    mlir/test/Dialect/SPIRV/cooperative-matrix.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
index 4645765b66ba..9c3462a2e5bf 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
@@ -39,6 +39,8 @@ def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
     ```
   }];
 
+  let assemblyFormat = "attr-dict `:` $type";
+
   let availability = [
     MinVersion<SPV_V_1_0>,
     MaxVersion<SPV_V_1_5>,
@@ -139,7 +141,7 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
 // -----
 
 def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
-  [NoSideEffect]> {
+  [NoSideEffect, AllTypesMatch<["c", "result"]>]> {
   let summary = "See extension SPV_NV_cooperative_matrix";
 
   let description = [{
@@ -188,6 +190,10 @@ def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
     ```
   }];
 
+  let assemblyFormat = [{
+    operands attr-dict`:` type($a) `,` type($b) `->` type($c)
+  }];
+
   let availability = [
     MinVersion<SPV_V_1_0>,
     MaxVersion<SPV_V_1_5>,

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 4f48ef9d7d7c..ac8fee8619b6 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1134,12 +1134,11 @@ static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
       return compositeConstructOp.emitError(
                  "has incorrect number of operands: expected ")
              << "1, but provided " << constituents.size();
-  } else {
-    if (constituents.size() != cType.getNumElements())
-      return compositeConstructOp.emitError(
-                 "has incorrect number of operands: expected ")
-             << cType.getNumElements() << ", but provided "
-             << constituents.size();
+  } else if (constituents.size() != cType.getNumElements()) {
+    return compositeConstructOp.emitError(
+               "has incorrect number of operands: expected ")
+           << cType.getNumElements() << ", but provided "
+           << constituents.size();
   }
 
   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
@@ -2735,57 +2734,10 @@ static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
   printer << " : " << coopMatrix.getOperand(1).getType();
 }
 
-//===----------------------------------------------------------------------===//
-// spv.CooperativeMatrixLengthNV
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseCooperativeMatrixLengthNVOp(OpAsmParser &parser,
-                                                    OperationState &state) {
-  OpAsmParser::OperandType operandInfo;
-  Type dstType = parser.getBuilder().getIntegerType(32);
-  Type type;
-  if (parser.parseColonType(type)) {
-    return failure();
-  }
-  state.addAttribute(kTypeAttrName, TypeAttr::get(type));
-  state.addTypes(dstType);
-  return success();
-}
-
-static void print(spirv::CooperativeMatrixLengthNVOp coopMatrix,
-                  OpAsmPrinter &printer) {
-  printer << coopMatrix.getOperationName() << " : " << coopMatrix.type();
-}
-
 //===----------------------------------------------------------------------===//
 // spv.CooperativeMatrixMulAddNV
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseCooperativeMatrixMulAddNVOp(OpAsmParser &parser,
-                                                    OperationState &state) {
-  SmallVector<OpAsmParser::OperandType, 3> ops;
-  SmallVector<Type, 3> types(3);
-  if (parser.parseOperandList(ops, 3) || parser.parseColon() ||
-      parser.parseType(types[0]) || parser.parseComma() ||
-      parser.parseType(types[1]) || parser.parseArrow() ||
-      parser.parseType(types[2]) ||
-      parser.resolveOperands(ops, types, parser.getNameLoc(), state.operands)) {
-    return failure();
-  }
-  state.addTypes(types[2]);
-  return success();
-}
-
-static void print(spirv::CooperativeMatrixMulAddNVOp coopMatrix,
-                  OpAsmPrinter &printer) {
-  printer << coopMatrix.getOperationName() << ' ' << coopMatrix.getOperand(0)
-          << ", " << coopMatrix.getOperand(1) << ", "
-          << coopMatrix.getOperand(2) << ", "
-          << " : " << coopMatrix.getOperand(0).getType() << ", "
-          << coopMatrix.getOperand(1).getType() << " -> "
-          << coopMatrix.getOperand(2).getType();
-}
-
 static LogicalResult
 verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
   if (op.c().getType() != op.result().getType())

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
index 12f710ea1b46..0d58fea18a11 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
@@ -38,7 +38,7 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_N
 
   // CHECK-LABEL: @cooperative_matrix_muladd
   spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
-    // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}},  : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+    // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}  : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
     %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
     spv.Return
   }

diff  --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
index e30352625da6..51c709067f6f 100644
--- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
@@ -38,7 +38,7 @@ spv.func @cooperative_matrix_length() -> i32 "None" {
 
 // CHECK-LABEL: @cooperative_matrix_muladd
 spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}},  : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
+  // CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}}  : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
   %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
   spv.Return
 }


        


More information about the Mlir-commits mailing list