[Mlir-commits] [mlir] [mlir][spirv] Support `spirv.coopmatrix` type (de-)serialization (PR #65831)

Jakub Kuderski llvmlistbot at llvm.org
Fri Sep 8 21:37:14 PDT 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/65831:

>From 8b5439685f8ebb05bc433632e648a382037127bd Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 8 Sep 2023 17:07:06 -0400
Subject: [PATCH 1/2] [mlir][spirv] Support `spirv.coopmatrix` type
 (de-)serialization

Extend SPIR-V target serialization and deserialization to handle coop
matrix types. Add a roundtrip test. In addition to `FileCheck` checks,
the resulting spirv binary also passes `spir-val` (external tool).

Also fix a type attribute bug surfaced by the `CooperativeMatrixLength`
op.

Multiple matrix operand attributes will be handled in a future patch to
reduce the scope.
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  4 +
 .../SPIRV/Deserialization/DeserializeOps.cpp  |  1 +
 .../SPIRV/Deserialization/Deserializer.cpp    | 66 +++++++++++--
 .../SPIRV/Deserialization/Deserializer.h      |  4 +-
 .../Target/SPIRV/Serialization/Serializer.cpp | 22 +++++
 .../SPIRV/IR/cooperative-matrix-ops.mlir      |  2 +-
 .../SPIRV/khr-cooperative-matrix-ops.mlir     | 93 +++++++++++++++++++
 ...ps.mlir => nv-cooperative-matrix-ops.mlir} |  0
 mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp      | 38 ++++----
 9 files changed, 202 insertions(+), 28 deletions(-)
 create mode 100644 mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
 rename mlir/test/Target/SPIRV/{cooperative-matrix-ops.mlir => nv-cooperative-matrix-ops.mlir} (100%)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index cc4417077d459c4..2ce3ad875fa45d1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4053,6 +4053,8 @@ def SPIRV_KHR_CMU_MatrixA   : I32EnumAttrCase<"MatrixA", 0>;
 def SPIRV_KHR_CMU_MatrixB   : I32EnumAttrCase<"MatrixB", 1>;
 def SPIRV_KHR_CMU_MatrixAcc : I32EnumAttrCase<"MatrixAcc", 2>;
 
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
 def SPIRV_KHR_CooperativeMatrixUseAttr :
     SPIRV_I32EnumAttr<"CooperativeMatrixUseKHR",
                       "valid SPIR-V Cooperative Matrix Use (KHR)",
@@ -4064,6 +4066,8 @@ def SPIRV_KHR_CooperativeMatrixUseAttr :
 def SPIRV_KHR_CML_RowMajor    : I32EnumAttrCase<"RowMajor", 0>;
 def SPIRV_KHR_CML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 1>;
 
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
 def SPIRV_KHR_CooperativeMatrixLayoutAttr :
     SPIRV_I32EnumAttr<"CooperativeMatrixLayoutKHR",
                       "valid SPIR-V Cooperative Matrix Layout (KHR)",
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 78afcc7003effa2..7510e1e2eb9b6f2 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -164,6 +164,7 @@ LogicalResult spirv::Deserializer::processInstruction(
   case spirv::Opcode::OpTypeRuntimeArray:
   case spirv::Opcode::OpTypeStruct:
   case spirv::Opcode::OpTypePointer:
+  case spirv::Opcode::OpTypeCooperativeMatrixKHR:
   case spirv::Opcode::OpTypeCooperativeMatrixNV:
     return processType(opcode, operands);
   case spirv::Opcode::OpTypeForwardPointer:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b84d1d9c2187932..ce8b3ab3894606c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -765,8 +765,10 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
   } break;
   case spirv::Opcode::OpTypeArray:
     return processArrayType(operands);
+  case spirv::Opcode::OpTypeCooperativeMatrixKHR:
+    return processCooperativeMatrixTypeKHR(operands);
   case spirv::Opcode::OpTypeCooperativeMatrixNV:
-    return processCooperativeMatrixType(operands);
+    return processCooperativeMatrixTypeNV(operands);
   case spirv::Opcode::OpTypeFunction:
     return processFunctionType(operands);
   case spirv::Opcode::OpTypeJointMatrixINTEL:
@@ -900,32 +902,76 @@ spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
   return success();
 }
 
-LogicalResult
-spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
+LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
+    ArrayRef<uint32_t> operands) {
+  if (operands.size() != 6) {
+    return emitError(unknownLoc,
+                     "OpTypeCooperativeMatrixKHR must have element type, "
+                     "scope, row and column parameters, and use");
+  }
+
+  Type elementTy = getType(operands[1]);
+  if (!elementTy) {
+    return emitError(unknownLoc,
+                     "OpTypeCooperativeMatrixKHR references undefined <id> ")
+           << operands[1];
+  }
+
+  std::optional<spirv::Scope> scope =
+      spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
+  if (!scope) {
+    return emitError(
+               unknownLoc,
+               "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
+           << operands[2];
+  }
+
+  unsigned rows = getConstantInt(operands[3]).getInt();
+  unsigned columns = getConstantInt(operands[4]).getInt();
+
+  std::optional<spirv::CooperativeMatrixUseKHR> use =
+      spirv::symbolizeCooperativeMatrixUseKHR(
+          getConstantInt(operands[5]).getInt());
+  if (!use) {
+    return emitError(
+               unknownLoc,
+               "OpTypeCooperativeMatrixKHR references undefined use <id> ")
+           << operands[5];
+  }
+
+  typeMap[operands[0]] =
+      spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
+  return success();
+}
+
+LogicalResult spirv::Deserializer::processCooperativeMatrixTypeNV(
+    ArrayRef<uint32_t> operands) {
   if (operands.size() != 5) {
-    return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
+    return emitError(unknownLoc, "OpTypeCooperativeMatrixNV must have element "
                                  "type and row x column parameters");
   }
 
   Type elementTy = getType(operands[1]);
   if (!elementTy) {
     return emitError(unknownLoc,
-                     "OpTypeCooperativeMatrix references undefined <id> ")
+                     "OpTypeCooperativeMatrixNV references undefined <id> ")
            << operands[1];
   }
 
-  auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
+  std::optional<spirv::Scope> scope =
+      spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
   if (!scope) {
-    return emitError(unknownLoc,
-                     "OpTypeCooperativeMatrix references undefined scope <id> ")
+    return emitError(
+               unknownLoc,
+               "OpTypeCooperativeMatrixNV references undefined scope <id> ")
            << operands[2];
   }
 
   unsigned rows = getConstantInt(operands[3]).getInt();
   unsigned columns = getConstantInt(operands[4]).getInt();
 
-  typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
-      elementTy, scope.value(), rows, columns);
+  typeMap[operands[0]] =
+      spirv::CooperativeMatrixNVType::get(elementTy, *scope, rows, columns);
   return success();
 }
 
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 613e4f6738df6b2..69be47851ef3c50 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -254,7 +254,9 @@ class Deserializer {
 
   LogicalResult processArrayType(ArrayRef<uint32_t> operands);
 
-  LogicalResult processCooperativeMatrixType(ArrayRef<uint32_t> operands);
+  LogicalResult processCooperativeMatrixTypeKHR(ArrayRef<uint32_t> operands);
+
+  LogicalResult processCooperativeMatrixTypeNV(ArrayRef<uint32_t> operands);
 
   LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
 
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 1ef8ff043e690d4..dad085e21b42727 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -593,6 +593,28 @@ LogicalResult Serializer::prepareBasicType(
     return success();
   }
 
+  if (auto cooperativeMatrixType =
+          dyn_cast<spirv::CooperativeMatrixType>(type)) {
+    uint32_t elementTypeID = 0;
+    if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
+                               elementTypeID, serializationCtx))) {
+      return failure();
+    }
+    typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
+    auto getConstantOp = [&](uint32_t id) {
+      auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
+      return prepareConstantInt(loc, attr);
+    };
+    operands.push_back(elementTypeID);
+    operands.push_back(
+        getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
+    operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
+    operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
+    operands.push_back(
+        getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
+    return success();
+  }
+
   if (auto cooperativeMatrixType =
           dyn_cast<spirv::CooperativeMatrixNVType>(type)) {
     uint32_t elementTypeID = 0;
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 03acb0c08b275a3..40736367520e843 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -14,7 +14,7 @@ spirv.func @cooperative_matrix_length() -> i32 "None" {
 // -----
 
 spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
-  // expected-error @+1 {{'spirv.KHR.CooperativeMatrixLength' op type attribute must be a '!spirv.coopmatrix'}}
+  // expected-error @+1 {{'cooperative_matrix_type' failed to satisfy constraint: type attribute of any SPIR-V cooperative matrix type}}
   %0 = spirv.KHR.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
   spirv.ReturnValue %0 : i32
 }
diff --git a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
new file mode 100644
index 000000000000000..8546172f4f797b5
--- /dev/null
+++ b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
@@ -0,0 +1,93 @@
+// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip \
+// RUN:  --split-input-file %s | FileCheck %s
+
+spirv.module Logical GLSL450 requires
+  #spirv.vce<v1.5, [Shader, Int8, Int16, Int64, Linkage, CooperativeMatrixKHR],
+                   [SPV_KHR_storage_buffer_storage_class, SPV_KHR_cooperative_matrix]> {
+
+  // CHECK-LABEL: @cooperative_matrix_length
+  spirv.func @cooperative_matrix_length() "None" {
+    // CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
+    %0 = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_load_1
+  spirv.func @cooperative_matrix_load_1(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+    // CHECK:      {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>
+    // CHECK-SAME:   : !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+    %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
+      !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_load_2
+  spirv.func @cooperative_matrix_load_2(%ptr : !spirv.ptr<f32, StorageBuffer>, %stride : i64) "None" {
+    // CHECK:      {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile>
+    // CHECK-SAME:   : !spirv.ptr<f32, StorageBuffer>, i64 -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc>
+    %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile> :
+      !spirv.ptr<f32, StorageBuffer>, i64 -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc>
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_store_1
+  spirv.func @cooperative_matrix_store_1(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+                                         %m : !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>) "None" {
+    // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>
+    // CHECK-SAME:   : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+    spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
+      !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_store_2
+  spirv.func @cooperative_matrix_store_2(%ptr : !spirv.ptr<f32, Workgroup>, %stride : i64,
+                                         %m : !spirv.coopmatrix<4x8xf32, Subgroup, MatrixB>) "None" {
+    // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <ColumnMajor>, <Nontemporal>
+    // CHECK-SAME:   : !spirv.ptr<f32, Workgroup>, !spirv.coopmatrix<4x8xf32, Subgroup, MatrixB>, i64
+    spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor>, <Nontemporal> :
+      !spirv.ptr<f32, Workgroup>, !spirv.coopmatrix<4x8xf32, Subgroup, MatrixB>, i64
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_muladd
+  spirv.func @cooperative_matrix_muladd_1(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                          %b : !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>,
+                                          %c : !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>) "None" {
+    // CHECK:      {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} :
+    // CHECK-SAME:   !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+    // CHECK-SAME:   !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+    // CHECK-SAME:   -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+    %p = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                                        !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+                                                        -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <BSigned> :
+    // CHECK-SAME:   !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+    // CHECK-SAME:   !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+    // CHECK-SAME:   -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+    %q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
+                                           <BSigned> : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                                       !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+                                                       -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
+    // TODO: Handle multiple matrix operands and add relevant testcases here.
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @cooperative_matrix_muladd
+  spirv.func @cooperative_matrix_muladd_2(%a : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixA>,
+                                          %b : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>,
+                                          %c : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixAcc>) "None" {
+    // CHECK:      {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} :
+    // CHECK-SAME:   !spirv.coopmatrix<8x8xf32, Workgroup, MatrixA>,
+    // CHECK-SAME:   !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>
+    // CHECK-SAME:   -> !spirv.coopmatrix<8x8xf32, Workgroup, MatrixAcc>
+    %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixA>,
+                                                        !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>
+                                                        -> !spirv.coopmatrix<8x8xf32, Workgroup, MatrixAcc>
+
+    spirv.Return
+  }
+
+}
diff --git a/mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir
similarity index 100%
rename from mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir
rename to mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 8468f92600a44e9..ac00ddc6422c658 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -16,6 +16,7 @@
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
@@ -512,6 +513,14 @@ static mlir::GenRegistration
 // Serialization AutoGen
 //===----------------------------------------------------------------------===//
 
+// These enums are encoded as <id> to constant values in SPIR-V blob, but we
+// directly use the constant value as attribute in SPIR-V dialect. So need
+// to handle them separately from normal enum attributes.
+constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
+    "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
+    "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
+    "SPIRV_MatrixLayoutAttr"};
+
 /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
 /// generates code extracts the attribute with name `attrName` from
 /// `operandList` of `op`.
@@ -521,12 +530,7 @@ static void emitAttributeSerialization(const Attribute &attr,
                                        StringRef attrName, raw_ostream &os) {
   os << tabs
      << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
-  if (attr.getAttrDefName() == "SPIRV_ScopeAttr" ||
-      attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
-      attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
-    // These two enums are encoded as <id> to constant values in SPIR-V blob,
-    // but we directly use the constant value as attribute in SPIR-V dialect. So
-    // need to handle them separately from normal enum attributes.
+  if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
     EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
     os << tabs
        << formatv("  {0}.push_back(prepareConstantInt({1}.getLoc(), "
@@ -557,11 +561,18 @@ static void emitAttributeSerialization(const Attribute &attr,
               "  {0}.push_back(static_cast<uint32_t>("
               "llvm::cast<IntegerAttr>(attr).getValue().getZExtValue()));\n",
               operandList);
-  } else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") {
+  } else if (attr.isEnumAttr() || attr.isTypeAttr()) {
+    // It may be the first time this type appears in the IR, so we need to
+    // process it.
+    StringRef attrTypeID = "attrTypeID";
+    os << tabs << formatv("  uint32_t {0} = 0;\n", attrTypeID);
     os << tabs
-       << formatv("  {0}.push_back(static_cast<uint32_t>("
-                  "getTypeID(llvm::cast<TypeAttr>(attr).getValue())));\n",
-                  operandList);
+       << formatv("  if (failed(processType({0}.getLoc(), "
+                  "llvm::cast<TypeAttr>(attr).getValue(), {1}))) {{\n",
+                  opVar, attrTypeID);
+    os << tabs << "    return failure();\n";
+    os << tabs << "  }\n";
+    os << tabs << formatv("  {0}.push_back(attrTypeID);\n", operandList);
   } else {
     PrintFatalError(
         loc,
@@ -816,12 +827,7 @@ static void emitAttributeDeserialization(const Attribute &attr,
                                          StringRef attrList, StringRef attrName,
                                          StringRef words, StringRef wordIndex,
                                          raw_ostream &os) {
-  if (attr.getAttrDefName() == "SPIRV_ScopeAttr" ||
-      attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
-      attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
-    // These two enums are encoded as <id> to constant values in SPIR-V blob,
-    // but we directly use the constant value as attribute in SPIR-V dialect. So
-    // need to handle them separately from normal enum attributes.
+  if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
     EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
     os << tabs
        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "

>From aa96ddbb909175890945ea05f4a568638e898bc8 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 9 Sep 2023 00:37:01 -0400
Subject: [PATCH 2/2] Use ODS to check type constraints of the
 CooperativeMatrixLength op

---
 .../Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td  |  4 +++-
 mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp | 14 --------------
 2 files changed, 3 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 3ce43c7e2b1fcee..b5ea0774f589d16 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -55,12 +55,14 @@ def SPIRV_KHRCooperativeMatrixLengthOp :
   ];
 
   let arguments = (ins
-    TypeAttr:$cooperative_matrix_type
+    TypeAttrOf<SPIRV_AnyCooperativeMatrix>:$cooperative_matrix_type
   );
 
   let results = (outs
     SPIRV_Int32:$result
   );
+
+  let hasVerifier = false;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index 600813f361a4712..77dbf130c777857 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -19,20 +19,6 @@
 using namespace mlir::spirv::AttrNames;
 
 namespace mlir::spirv {
-//===----------------------------------------------------------------------===//
-// spirv.KHR.CooperativeMatrixLength
-//===----------------------------------------------------------------------===//
-
-LogicalResult KHRCooperativeMatrixLengthOp::verify() {
-  if (!isa<CooperativeMatrixType>(getCooperativeMatrixType())) {
-    return emitOpError(
-               "type attribute must be a '!spirv.coopmatrix' type, found ")
-           << getCooperativeMatrixType() << " instead";
-  }
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.KHR.CooperativeMatrixLoad
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list