[Mlir-commits] [mlir] 1fa9e15 - [mlir][spirv] Add cooperative matrix load op
Jakub Kuderski
llvmlistbot at llvm.org
Wed Jul 19 08:00:36 PDT 2023
Author: Jakub Kuderski
Date: 2023-07-19T10:55:27-04:00
New Revision: 1fa9e150b43dad347bd488bac584b3c2ed77d9f8
URL: https://github.com/llvm/llvm-project/commit/1fa9e150b43dad347bd488bac584b3c2ed77d9f8
DIFF: https://github.com/llvm/llvm-project/commit/1fa9e150b43dad347bd488bac584b3c2ed77d9f8.diff
LOG: [mlir][spirv] Add cooperative matrix load op
Implement cooperative matrix load for the `SPV_KHR_cooperative_matrix`
extension: https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_cooperative_matrix.html.
Also some minor fixes in common code for custom parsing.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D155616
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 885fb7138b1410..7888d6e3aa7f0a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4061,6 +4061,17 @@ def SPIRV_KHR_CooperativeMatrixUseAttr :
SPIRV_KHR_CMU_MatrixA, SPIRV_KHR_CMU_MatrixB, SPIRV_KHR_CMU_MatrixAcc
]>;
+// Cooperative Matrix Layout for the SPV_KHR_cooperative_matrix extension.
+def SPIRV_KHR_CML_RowMajor : I32EnumAttrCase<"RowMajor", 0>;
+def SPIRV_KHR_CML_ColumnMajor : I32EnumAttrCase<"ColumnMajor", 1>;
+
+def SPIRV_KHR_CooperativeMatrixLayoutAttr :
+ SPIRV_I32EnumAttr<"CooperativeMatrixLayoutKHR",
+ "valid SPIR-V Cooperative Matrix Layout (KHR)",
+ "coop_matrix_layout_khr", [
+ SPIRV_KHR_CML_RowMajor, SPIRV_KHR_CML_ColumnMajor
+ ]>;
+
//===----------------------------------------------------------------------===//
// SPIR-V attribute definitions
//===----------------------------------------------------------------------===//
@@ -4435,6 +4446,7 @@ def SPIRV_OC_OpSDotAccSat : I32EnumAttrCase<"OpSDotAccSat", 4453
def SPIRV_OC_OpUDotAccSat : I32EnumAttrCase<"OpUDotAccSat", 4454>;
def SPIRV_OC_OpSUDotAccSat : I32EnumAttrCase<"OpSUDotAccSat", 4455>;
def SPIRV_OC_OpTypeCooperativeMatrixKHR : I32EnumAttrCase<"OpTypeCooperativeMatrixKHR", 4456>;
+def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMatrixLoadKHR", 4457>;
def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
@@ -4534,7 +4546,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformUMax, SPIRV_OC_OpGroupNonUniformFMax,
SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
- SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpCooperativeMatrixLengthKHR,
+ SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
+ SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
SPIRV_OC_OpTypeCooperativeMatrixNV,
SPIRV_OC_OpCooperativeMatrixLoadNV, SPIRV_OC_OpCooperativeMatrixStoreNV,
SPIRV_OC_OpCooperativeMatrixMulAddNV, SPIRV_OC_OpCooperativeMatrixLengthNV,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index e6e2bbb26d094e..c3c1e2cd042800 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -63,6 +63,77 @@ def SPIRV_KHRCooperativeMatrixLengthOp :
);
}
+// -----
+
+def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad", []> {
+ let summary = "Loads a cooperative matrix through a pointer";
+
+ let description = [{
+ Load a cooperative matrix through a pointer.
+
+ Result Type is the type of the loaded object. It must be a cooperative
+ matrix type.
+
+ Pointer is a pointer. Its type must be an OpTypePointer whose Type operand is
+ a scalar or vector type. If the Shader capability was declared, Pointer must
+ point into an array and any ArrayStride decoration on Pointer is ignored.
+
+ MemoryLayout specifies how matrix elements are laid out in memory. It must
+ come from a 32-bit integer constant instruction whose value corresponds to a
+ Cooperative Matrix Layout. See the Cooperative Matrix Layout table for a
+ description of the layouts and detailed layout-specific rules.
+
+ Stride further qualifies how matrix elements are laid out in memory. It must
+ be a scalar integer type and its exact semantics depend on MemoryLayout.
+
+ Memory Operand must be a Memory Operand literal. If not present, it is the
+ same as specifying None.
+
+ NOTE: In earlier versions of the SPIR-V spec, 'Memory Operand' was known
+ as 'Memory Access'.
+
+ For a given dynamic instance of this instruction, all operands of this
+ instruction must be the same for all invocations in a given scope instance
+ (where the scope is the scope the cooperative matrix type was created with).
+ All invocations in a given scope instance must be active or all must be
+ inactive.
+
+ ``` {.ebnf}
+ cooperative-matrix-load-op ::= ssa-id `=` `spirv.KHR.CooperativeMatrixLoad`
+ ssa-use `,` ssa-use `,`
+ cooperative-matrix-layout `,`
+ (`[` memory-operand `]`)? ` : `
+ pointer-type `as` cooperative-matrix-type
+ ```
+
+ #### Example:
+
+ ```
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor
+ : !spirv.ptr<i32, StorageBuffer>
+ as !spirv.KHR.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_6>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_KHR_cooperative_matrix]>,
+ Capability<[SPIRV_C_CooperativeMatrixKHR]>
+ ];
+
+ let arguments = (ins
+ SPIRV_AnyPtr:$pointer,
+ SPIRV_Integer:$stride,
+ SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
+ OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
+ );
+
+ let results = (outs
+ SPIRV_AnyCooperativeMatrix:$result
+ );
+}
+
//===----------------------------------------------------------------------===//
// SPV_NV_cooperative_matrix extension ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 81e74f38307774..0eeb8bb1ec8da5 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1445,13 +1445,13 @@ class OpAsmParser : public AsmParser {
std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
resolveOperands(Operands &&operands, Types &&types, SMLoc loc,
SmallVectorImpl<Value> &result) {
- size_t operandSize = std::distance(operands.begin(), operands.end());
- size_t typeSize = std::distance(types.begin(), types.end());
+ size_t operandSize = llvm::range_size(operands);
+ size_t typeSize = llvm::range_size(types);
if (operandSize != typeSize)
return emitError(loc)
<< operandSize << " operands present, but expected " << typeSize;
- for (auto [operand, type] : llvm::zip(operands, types))
+ for (auto [operand, type] : llvm::zip_equal(operands, types))
if (resolveOperand(operand, type, result))
return failure();
return success();
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 5c11fe8d0cffa6..61b084e6a56412 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -36,6 +36,7 @@
#include "llvm/Support/FormatVariadic.h"
#include <cassert>
#include <numeric>
+#include <type_traits>
using namespace mlir;
@@ -53,7 +54,9 @@ constexpr char kGroupOperationAttrName[] = "group_operation";
constexpr char kIndicesAttrName[] = "indices";
constexpr char kInitializerAttrName[] = "initializer";
constexpr char kInterfaceAttrName[] = "interface";
+constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
constexpr char kMemoryAccessAttrName[] = "memory_access";
+constexpr char kMemoryOperandAttrName[] = "memory_operand";
constexpr char kMemoryScopeAttrName[] = "memory_scope";
constexpr char kPackedVectorFormatAttrName[] = "format";
constexpr char kSemanticsAttrName[] = "semantics";
@@ -176,6 +179,7 @@ template <typename EnumClass>
static ParseResult
parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
StringRef attrName = spirv::attributeName<EnumClass>()) {
+ static_assert(std::is_enum_v<EnumClass>);
Attribute attrVal;
NamedAttrList attr;
auto loc = parser.getCurrentLocation();
@@ -202,6 +206,7 @@ template <typename EnumAttrClass,
static ParseResult
parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
StringRef attrName = spirv::attributeName<EnumClass>()) {
+ static_assert(std::is_enum_v<EnumClass>);
if (parseEnumStrAttr(value, parser))
return failure();
state.addAttribute(attrName,
@@ -218,6 +223,7 @@ static ParseResult
parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
OperationState &state,
StringRef attrName = spirv::attributeName<EnumClass>()) {
+ static_assert(std::is_enum_v<EnumClass>);
if (parseEnumKeywordAttr(value, parser))
return failure();
state.addAttribute(attrName,
@@ -246,14 +252,15 @@ parseControlAttribute(OpAsmParser &parser, OperationState &state,
return success();
}
-/// Parses optional memory access attributes attached to a memory access
-/// operand/pointer. Specifically, parses the following syntax:
+/// Parses optional memory access (a.k.a. memory operand) attributes attached to
+/// a memory access operand/pointer. Specifically, parses the following syntax:
/// (`[` memory-access `]`)?
/// where:
/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
/// integer-literal | `"NonTemporal"`
-static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
- OperationState &state) {
+static ParseResult
+parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state,
+ StringRef attrName = kMemoryAccessAttrName) {
// Parse an optional list of attributes staring with '['
if (parser.parseOptionalLSquare()) {
// Nothing to do
@@ -262,7 +269,7 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
spirv::MemoryAccess memoryAccessAttr;
if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
- kMemoryAccessAttrName))
+ attrName))
return failure();
if (spirv::bitEnumContainsAll(memoryAccessAttr,
@@ -4035,6 +4042,75 @@ LogicalResult spirv::KHRCooperativeMatrixLengthOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixLoad
+//===----------------------------------------------------------------------===//
+
+ParseResult spirv::KHRCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ std::array<OpAsmParser::UnresolvedOperand, 2> operandInfo = {};
+ if (parser.parseOperand(operandInfo[0]) || parser.parseComma())
+ return failure();
+ if (parser.parseOperand(operandInfo[1]) || parser.parseComma())
+ return failure();
+
+ spirv::CooperativeMatrixLayoutKHR layout;
+ if (::parseEnumKeywordAttr<spirv::CooperativeMatrixLayoutKHRAttr>(
+ layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
+ return failure();
+ }
+
+ if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
+ return failure();
+
+ Type ptrType;
+ Type elementType;
+ if (parser.parseColon() || parser.parseType(ptrType) ||
+ parser.parseKeywordType("as", elementType)) {
+ return failure();
+ }
+ result.addTypes(elementType);
+
+ Type strideType = parser.getBuilder().getIntegerType(32);
+ if (parser.resolveOperands(operandInfo, {ptrType, strideType},
+ parser.getNameLoc(), result.operands)) {
+ return failure();
+ }
+
+ return success();
+}
+
+void spirv::KHRCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
+ printer << " " << getPointer() << ", " << getStride() << ", "
+ << getMatrixLayout();
+ // Print optional memory operand attribute.
+ if (auto memOperand = getMemoryOperand())
+ printer << " [\"" << memOperand << "\"]";
+ printer << " : " << getPointer().getType() << " as " << getType();
+}
+
+static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
+ Type coopMatrix) {
+ auto pointerType = cast<spirv::PointerType>(pointer);
+ Type pointeeType = pointerType.getPointeeType();
+ if (!isa<spirv::ScalarType, VectorType>(pointeeType)) {
+ return op->emitError(
+ "Pointer must point to a scalar or vector type but provided ")
+ << pointeeType;
+ }
+
+ // TODO: Verify the memory object behind the pointer:
+ // > If the Shader capability was declared, Pointer must point into an array
+ // > and any ArrayStride decoration on Pointer is ignored.
+
+ return success();
+}
+
+LogicalResult spirv::KHRCooperativeMatrixLoadOp::verify() {
+ return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
+ getResult().getType());
+}
+
//===----------------------------------------------------------------------===//
// spirv.NV.CooperativeMatrixLength
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 65311e1db9bb49..7e38161aebae80 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
// CooperativeMatrix (KHR)
@@ -21,6 +21,80 @@ spirv.func @cooperative_matrix_length_wrong_matrix() -> i32 "None" {
// -----
+// CHECK-LABEL: @cooperative_matrix_load
+spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+ // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor :
+ // CHECK-SAME: !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor :
+ !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_memoperand
+spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+ // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] :
+ // CHECK-SAME: !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor ["Volatile"] :
+ !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type
+spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32) "None" {
+ // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor ["Volatile"] :
+ // CHECK-SAME: !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor ["Volatile"] :
+ !spirv.ptr<vector<4xi32>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
+ spirv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_function
+spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr<i32, Function>, %stride : i32) "None" {
+ // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, RowMajor :
+ // CHECK-SAME: !spirv.ptr<i32, Function> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, RowMajor :
+ !spirv.ptr<i32, Function> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
+ // expected-error @+1 {{Pointer must point to a scalar or vector type}}
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor :
+ !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+ // expected-error @+1 {{expected ','}}
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride :
+ !spirv.ptr<i32, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+ // expected-error @+1 {{expected valid keyword}}
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, :
+ !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup, MatrixA>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
+ // expected-error @+1 {{op result #0 must be any SPIR-V cooperative matrix type}}
+ %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ColumnMajor :
+ !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+ spirv.Return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// NV.CooperativeMatrix
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list