[Mlir-commits] [mlir] [mlir][spirv] Support `spirv.coopmatrix` type (de-)serialization (PR #65831)
Jakub Kuderski
llvmlistbot at llvm.org
Fri Sep 8 21:37:14 PDT 2023
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/65831:
>From 8b5439685f8ebb05bc433632e648a382037127bd Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 8 Sep 2023 17:07:06 -0400
Subject: [PATCH 1/2] [mlir][spirv] Support `spirv.coopmatrix` type
(de-)serialization
Extend SPIR-V target serialization and deserialization to handle coop
matrix types. Add a roundtrip test. In addition to `FileCheck` checks,
the resulting spirv binary also passes `spir-val` (external tool).
Also fix a type attribute bug surfaced by the `CooperativeMatrixLength`
op.
Multiple matrix operand attributes will be handled in a future patch to
reduce the scope.
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 4 +
.../SPIRV/Deserialization/DeserializeOps.cpp | 1 +
.../SPIRV/Deserialization/Deserializer.cpp | 66 +++++++++++--
.../SPIRV/Deserialization/Deserializer.h | 4 +-
.../Target/SPIRV/Serialization/Serializer.cpp | 22 +++++
.../SPIRV/IR/cooperative-matrix-ops.mlir | 2 +-
.../SPIRV/khr-cooperative-matrix-ops.mlir | 93 +++++++++++++++++++
...ps.mlir => nv-cooperative-matrix-ops.mlir} | 0
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 38 ++++----
9 files changed, 202 insertions(+), 28 deletions(-)
create mode 100644 mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
rename mlir/test/Target/SPIRV/{cooperative-matrix-ops.mlir => nv-cooperative-matrix-ops.mlir} (100%)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index cc4417077d459c4..2ce3ad875fa45d1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4053,6 +4053,8 @@ def SPIRV_KHR_CMU_MatrixA : I32EnumAttrCase<"MatrixA", 0>;
def SPIRV_KHR_CMU_MatrixB : I32EnumAttrCase<"MatrixB", 1>;
def SPIRV_KHR_CMU_MatrixAcc : I32EnumAttrCase<"MatrixAcc", 2>;
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
def SPIRV_KHR_CooperativeMatrixUseAttr :
SPIRV_I32EnumAttr<"CooperativeMatrixUseKHR",
"valid SPIR-V Cooperative Matrix Use (KHR)",
@@ -4064,6 +4066,8 @@ def SPIRV_KHR_CooperativeMatrixUseAttr :
def SPIRV_KHR_CML_RowMajor : I32EnumAttrCase<"RowMajor", 0>;
def SPIRV_KHR_CML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 1>;
+// NOTE: This is an attribute in the SPIR-V *dialect* but a constant (<id>) in
+// SPIR-V proper.
def SPIRV_KHR_CooperativeMatrixLayoutAttr :
SPIRV_I32EnumAttr<"CooperativeMatrixLayoutKHR",
"valid SPIR-V Cooperative Matrix Layout (KHR)",
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 78afcc7003effa2..7510e1e2eb9b6f2 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -164,6 +164,7 @@ LogicalResult spirv::Deserializer::processInstruction(
case spirv::Opcode::OpTypeRuntimeArray:
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
+ case spirv::Opcode::OpTypeCooperativeMatrixKHR:
case spirv::Opcode::OpTypeCooperativeMatrixNV:
return processType(opcode, operands);
case spirv::Opcode::OpTypeForwardPointer:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b84d1d9c2187932..ce8b3ab3894606c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -765,8 +765,10 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
} break;
case spirv::Opcode::OpTypeArray:
return processArrayType(operands);
+ case spirv::Opcode::OpTypeCooperativeMatrixKHR:
+ return processCooperativeMatrixTypeKHR(operands);
case spirv::Opcode::OpTypeCooperativeMatrixNV:
- return processCooperativeMatrixType(operands);
+ return processCooperativeMatrixTypeNV(operands);
case spirv::Opcode::OpTypeFunction:
return processFunctionType(operands);
case spirv::Opcode::OpTypeJointMatrixINTEL:
@@ -900,32 +902,76 @@ spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
return success();
}
-LogicalResult
-spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
+LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
+ ArrayRef<uint32_t> operands) {
+ if (operands.size() != 6) {
+ return emitError(unknownLoc,
+ "OpTypeCooperativeMatrixKHR must have element type, "
+ "scope, row and column parameters, and use");
+ }
+
+ Type elementTy = getType(operands[1]);
+ if (!elementTy) {
+ return emitError(unknownLoc,
+ "OpTypeCooperativeMatrixKHR references undefined <id> ")
+ << operands[1];
+ }
+
+ std::optional<spirv::Scope> scope =
+ spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
+ if (!scope) {
+ return emitError(
+ unknownLoc,
+ "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
+ << operands[2];
+ }
+
+ unsigned rows = getConstantInt(operands[3]).getInt();
+ unsigned columns = getConstantInt(operands[4]).getInt();
+
+ std::optional<spirv::CooperativeMatrixUseKHR> use =
+ spirv::symbolizeCooperativeMatrixUseKHR(
+ getConstantInt(operands[5]).getInt());
+ if (!use) {
+ return emitError(
+ unknownLoc,
+ "OpTypeCooperativeMatrixKHR references undefined use <id> ")
+ << operands[5];
+ }
+
+ typeMap[operands[0]] =
+ spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
+ return success();
+}
+
+LogicalResult spirv::Deserializer::processCooperativeMatrixTypeNV(
+ ArrayRef<uint32_t> operands) {
if (operands.size() != 5) {
- return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
+ return emitError(unknownLoc, "OpTypeCooperativeMatrixNV must have element "
"type and row x column parameters");
}
Type elementTy = getType(operands[1]);
if (!elementTy) {
return emitError(unknownLoc,
- "OpTypeCooperativeMatrix references undefined <id> ")
+ "OpTypeCooperativeMatrixNV references undefined <id> ")
<< operands[1];
}
- auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
+ std::optional<spirv::Scope> scope =
+ spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
if (!scope) {
- return emitError(unknownLoc,
- "OpTypeCooperativeMatrix references undefined scope <id> ")
+ return emitError(
+ unknownLoc,
+ "OpTypeCooperativeMatrixNV references undefined scope <id> ")
<< operands[2];
}
unsigned rows = getConstantInt(operands[3]).getInt();
unsigned columns = getConstantInt(operands[4]).getInt();
- typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
- elementTy, scope.value(), rows, columns);
+ typeMap[operands[0]] =
+ spirv::CooperativeMatrixNVType::get(elementTy, *scope, rows, columns);
return success();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 613e4f6738df6b2..69be47851ef3c50 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -254,7 +254,9 @@ class Deserializer {
LogicalResult processArrayType(ArrayRef<uint32_t> operands);
- LogicalResult processCooperativeMatrixType(ArrayRef<uint32_t> operands);
+ LogicalResult processCooperativeMatrixTypeKHR(ArrayRef<uint32_t> operands);
+
+ LogicalResult processCooperativeMatrixTypeNV(ArrayRef<uint32_t> operands);
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 1ef8ff043e690d4..dad085e21b42727 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -593,6 +593,28 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
+ if (auto cooperativeMatrixType =
+ dyn_cast<spirv::CooperativeMatrixType>(type)) {
+ uint32_t elementTypeID = 0;
+ if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
+ elementTypeID, serializationCtx))) {
+ return failure();
+ }
+ typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
+ auto getConstantOp = [&](uint32_t id) {
+ auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
+ return prepareConstantInt(loc, attr);
+ };
+ operands.push_back(elementTypeID);
+ operands.push_back(
+ getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
+ operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
+ operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
+ operands.push_back(
+ getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
+ return success();
+ }
+
if (auto cooperativeMatrixType =
dyn_cast<spirv::CooperativeMatrixNVType>(type)) {
uint32_t elementTypeID = 0;
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 03acb0c08b275a3..40736367520e843 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -14,7 +14,7 @@ spirv.func @cooperative_matrix_length() -> i32 "None" {
// -----
spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
- // expected-error @+1 {{'spirv.KHR.CooperativeMatrixLength' op type attribute must be a '!spirv.coopmatrix'}}
+ // expected-error @+1 {{'cooperative_matrix_type' failed to satisfy constraint: type attribute of any SPIR-V cooperative matrix type}}
%0 = spirv.KHR.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
spirv.ReturnValue %0 : i32
}
diff --git a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
new file mode 100644
index 000000000000000..8546172f4f797b5
--- /dev/null
+++ b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
@@ -0,0 +1,93 @@
+// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip \
+// RUN: --split-input-file %s | FileCheck %s
+
+spirv.module Logical GLSL450 requires
+ #spirv.vce<v1.5, [Shader, Int8, Int16, Int64, Linkage, CooperativeMatrixKHR],
+ [SPV_KHR_storage_buffer_storage_class, SPV_KHR_cooperative_matrix]> {
+
+ // CHECK-LABEL: @cooperative_matrix_length
+ spirv.func @cooperative_matrix_length() "None" {
+ // CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
+ %0 = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @cooperative_matrix_load_1
+ spirv.func @cooperative_matrix_load_1(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+ // CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>
+ // CHECK-SAME: : !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
+ !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @cooperative_matrix_load_2
+ spirv.func @cooperative_matrix_load_2(%ptr : !spirv.ptr<f32, StorageBuffer>, %stride : i64) "None" {
+ // CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile>
+ // CHECK-SAME: : !spirv.ptr<f32, StorageBuffer>, i64 -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile> :
+ !spirv.ptr<f32, StorageBuffer>, i64 -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixAcc>
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @cooperative_matrix_store_1
+ spirv.func @cooperative_matrix_store_1(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
+ %m : !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>) "None" {
+ // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>
+ // CHECK-SAME: : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+ spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
+ !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @cooperative_matrix_store_2
+ spirv.func @cooperative_matrix_store_2(%ptr : !spirv.ptr<f32, Workgroup>, %stride : i64,
+ %m : !spirv.coopmatrix<4x8xf32, Subgroup, MatrixB>) "None" {
+ // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <ColumnMajor>, <Nontemporal>
+ // CHECK-SAME: : !spirv.ptr<f32, Workgroup>, !spirv.coopmatrix<4x8xf32, Subgroup, MatrixB>, i64
+ spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor>, <Nontemporal> :
+ !spirv.ptr<f32, Workgroup>, !spirv.coopmatrix<4x8xf32, Subgroup, MatrixB>, i64
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @cooperative_matrix_muladd
+ spirv.func @cooperative_matrix_muladd_1(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+ %b : !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>) "None" {
+ // CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} :
+ // CHECK-SAME: !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+ // CHECK-SAME: !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+ // CHECK-SAME: -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+ %p = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+ -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
+ // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <BSigned> :
+ // CHECK-SAME: !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+ // CHECK-SAME: !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+ // CHECK-SAME: -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+ %q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
+ <BSigned> : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+ !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+ -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
+ // TODO: Handle multiple matrix operands and add relevant testcases here.
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @cooperative_matrix_muladd
+ spirv.func @cooperative_matrix_muladd_2(%a : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixA>,
+ %b : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>,
+ %c : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixAcc>) "None" {
+ // CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} :
+ // CHECK-SAME: !spirv.coopmatrix<8x8xf32, Workgroup, MatrixA>,
+ // CHECK-SAME: !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>
+ // CHECK-SAME: -> !spirv.coopmatrix<8x8xf32, Workgroup, MatrixAcc>
+ %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x8xf32, Workgroup, MatrixA>,
+ !spirv.coopmatrix<8x8xf32, Workgroup, MatrixB>
+ -> !spirv.coopmatrix<8x8xf32, Workgroup, MatrixAcc>
+
+ spirv.Return
+ }
+
+}
diff --git a/mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir
similarity index 100%
rename from mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir
rename to mlir/test/Target/SPIRV/nv-cooperative-matrix-ops.mlir
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 8468f92600a44e9..ac00ddc6422c658 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -16,6 +16,7 @@
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
@@ -512,6 +513,14 @@ static mlir::GenRegistration
// Serialization AutoGen
//===----------------------------------------------------------------------===//
+// These enums are encoded as <id> to constant values in SPIR-V blob, but we
+// directly use the constant value as attribute in SPIR-V dialect. So need
+// to handle them separately from normal enum attributes.
+constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
+ "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
+ "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
+ "SPIRV_MatrixLayoutAttr"};
+
/// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
/// generates code extracts the attribute with name `attrName` from
/// `operandList` of `op`.
@@ -521,12 +530,7 @@ static void emitAttributeSerialization(const Attribute &attr,
StringRef attrName, raw_ostream &os) {
os << tabs
<< formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
- if (attr.getAttrDefName() == "SPIRV_ScopeAttr" ||
- attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
- attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
- // These two enums are encoded as <id> to constant values in SPIR-V blob,
- // but we directly use the constant value as attribute in SPIR-V dialect. So
- // need to handle them separately from normal enum attributes.
+ if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
os << tabs
<< formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), "
@@ -557,11 +561,18 @@ static void emitAttributeSerialization(const Attribute &attr,
" {0}.push_back(static_cast<uint32_t>("
"llvm::cast<IntegerAttr>(attr).getValue().getZExtValue()));\n",
operandList);
- } else if (attr.isEnumAttr() || attr.getAttrDefName() == "TypeAttr") {
+ } else if (attr.isEnumAttr() || attr.isTypeAttr()) {
+ // It may be the first time this type appears in the IR, so we need to
+ // process it.
+ StringRef attrTypeID = "attrTypeID";
+ os << tabs << formatv(" uint32_t {0} = 0;\n", attrTypeID);
os << tabs
- << formatv(" {0}.push_back(static_cast<uint32_t>("
- "getTypeID(llvm::cast<TypeAttr>(attr).getValue())));\n",
- operandList);
+ << formatv(" if (failed(processType({0}.getLoc(), "
+ "llvm::cast<TypeAttr>(attr).getValue(), {1}))) {{\n",
+ opVar, attrTypeID);
+ os << tabs << " return failure();\n";
+ os << tabs << " }\n";
+ os << tabs << formatv(" {0}.push_back(attrTypeID);\n", operandList);
} else {
PrintFatalError(
loc,
@@ -816,12 +827,7 @@ static void emitAttributeDeserialization(const Attribute &attr,
StringRef attrList, StringRef attrName,
StringRef words, StringRef wordIndex,
raw_ostream &os) {
- if (attr.getAttrDefName() == "SPIRV_ScopeAttr" ||
- attr.getAttrDefName() == "SPIRV_MemorySemanticsAttr" ||
- attr.getAttrDefName() == "SPIRV_MatrixLayoutAttr") {
- // These two enums are encoded as <id> to constant values in SPIR-V blob,
- // but we directly use the constant value as attribute in SPIR-V dialect. So
- // need to handle them separately from normal enum attributes.
+ if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
os << tabs
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
>From aa96ddbb909175890945ea05f4a568638e898bc8 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 9 Sep 2023 00:37:01 -0400
Subject: [PATCH 2/2] Use ODS to check type constraints of the
CooperativeMatrixLength op
---
.../Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td | 4 +++-
mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp | 14 --------------
2 files changed, 3 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 3ce43c7e2b1fcee..b5ea0774f589d16 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -55,12 +55,14 @@ def SPIRV_KHRCooperativeMatrixLengthOp :
];
let arguments = (ins
- TypeAttr:$cooperative_matrix_type
+ TypeAttrOf<SPIRV_AnyCooperativeMatrix>:$cooperative_matrix_type
);
let results = (outs
SPIRV_Int32:$result
);
+
+ let hasVerifier = false;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index 600813f361a4712..77dbf130c777857 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -19,20 +19,6 @@
using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
-//===----------------------------------------------------------------------===//
-// spirv.KHR.CooperativeMatrixLength
-//===----------------------------------------------------------------------===//
-
-LogicalResult KHRCooperativeMatrixLengthOp::verify() {
- if (!isa<CooperativeMatrixType>(getCooperativeMatrixType())) {
- return emitOpError(
- "type attribute must be a '!spirv.coopmatrix' type, found ")
- << getCooperativeMatrixType() << " instead";
- }
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// spirv.KHR.CooperativeMatrixLoad
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list