[Mlir-commits] [mlir] cfb955a - [mlir][spirv] Relax restriction on pointer type for CooperativeMatrix load/store
Thomas Raoux
llvmlistbot at llvm.org
Fri Jul 31 08:16:30 PDT 2020
Author: Thomas Raoux
Date: 2020-07-31T08:02:21-07:00
New Revision: cfb955ac370cb724c51423a05694aaf5b70903a4
URL: https://github.com/llvm/llvm-project/commit/cfb955ac370cb724c51423a05694aaf5b70903a4
DIFF: https://github.com/llvm/llvm-project/commit/cfb955ac370cb724c51423a05694aaf5b70903a4.diff
LOG: [mlir][spirv] Relax restriction on pointer type for CooperativeMatrix load/store
This change allow CooperativeMatrix Load/Store operations to use pointer type
that may not match the matrix element type. This allow us to declare buffer
with a larger type size than the matrix element type. This follows SPIR-V spec
and this is needed to be able to use cooperative matrix in combination with
shared local memory efficiently.
Differential Revision: https://reviews.llvm.org/D84993
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
index 9c3462a2e5bf..720cfd697c24 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
@@ -101,16 +101,17 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
``` {.ebnf}
cooperative-matrixload-op ::= ssa-id `=` `spv.CooperativeMatrixLoadNV`
- storage-class ssa-use `,` ssa-use `,` ssa-use
+ ssa-use `,` ssa-use `,` ssa-use
(`[` memory-access `]`)? ` : `
+ pointer-type `as`
cooperative-matrix-type
```
For example:
```
- %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %colMajor
- : !spv.coopmatrix<i32, Workgroup, 16, 8>
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %colMajor
+ : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<i32, Workgroup, 16, 8>
```
}];
@@ -243,16 +244,17 @@ def SPV_CooperativeMatrixStoreNVOp : SPV_Op<"CooperativeMatrixStoreNV", []> {
``` {.ebnf}
coop-matrix-store-op ::= `spv.CooperativeMatrixStoreNV `
- storage-class ssa-use `, ` ssa-use `, `
ssa-use `, ` ssa-use `, `
- (`[` memory-access `]`)? `:` spirv-element-type
+ ssa-use `, ` ssa-use `, `
+ (`[` memory-access `]`)? `:`
+ pointer-type `,` spirv-element-type
```
For example:
```
- spv.CooperativeMatrixStoreNV "StorageBuffer" %arg0, %arg2, %arg1, %arg3 :
- !spv.coopmatrix<Workgroup, i32, 16, 8>
+ spv.CooperativeMatrixStoreNV %arg0, %arg2, %arg1, %arg3 :
+ !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<Workgroup, i32, 16, 8>
```
}];
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index b0235d419ebe..bac65a02f63d 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -2793,21 +2793,16 @@ static LogicalResult verify(spirv::VariableOp varOp) {
static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
OperationState &state) {
- spirv::StorageClass storageClass;
SmallVector<OpAsmParser::OperandType, 3> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32);
Type columnMajorType = parser.getBuilder().getIntegerType(1);
+ Type ptrType;
Type elementType;
- if (parseEnumStrAttr(storageClass, parser) ||
- parser.parseOperandList(operandInfo, 3) ||
+ if (parser.parseOperandList(operandInfo, 3) ||
parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
- parser.parseType(elementType)) {
+ parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
return failure();
}
-
- auto ptrType = spirv::PointerType::get(
- elementType.cast<spirv::CooperativeMatrixNVType>().getElementType(),
- storageClass);
SmallVector<Type, 3> OperandType = {ptrType, strideType, columnMajorType};
if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(),
state.operands)) {
@@ -2819,25 +2814,30 @@ static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
}
static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) {
- StringRef sc = stringifyStorageClass(
- M.pointer().getType().cast<spirv::PointerType>().getStorageClass());
- printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " \"" << sc
- << "\" " << M.pointer() << ", " << M.stride() << ", "
- << M.columnmajor();
+ printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " "
+ << M.pointer() << ", " << M.stride() << ", " << M.columnmajor();
// Print optional memory access attribute.
if (auto memAccess = M.memory_access())
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
- printer << " : " << M.getType();
+ printer << " : " << M.pointer().getType() << " as " << M.getType();
}
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
Type coopMatrix) {
- if (pointer.cast<spirv::PointerType>().getPointeeType() !=
- coopMatrix.cast<spirv::CooperativeMatrixNVType>().getElementType())
+ Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
+ if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
return op->emitError(
- "expected the same type for pointer and the cooperative matrix"
- "element, bu provided ")
- << pointer << " and " << coopMatrix;
+ "Pointer must point to a scalar or vector type but provided ")
+ << pointeeType;
+ spirv::StorageClass storage =
+ pointer.cast<spirv::PointerType>().getStorageClass();
+ if (storage != spirv::StorageClass::Workgroup &&
+ storage != spirv::StorageClass::StorageBuffer &&
+ storage != spirv::StorageClass::PhysicalStorageBuffer)
+ return op->emitError(
+ "Pointer storage class must be Workgroup, StorageBuffer or "
+ "PhysicalStorageBufferEXT but provided ")
+ << stringifyStorageClass(storage);
return success();
}
@@ -2847,21 +2847,17 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser,
OperationState &state) {
- spirv::StorageClass storageClass;
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32);
Type columnMajorType = parser.getBuilder().getIntegerType(1);
+ Type ptrType;
Type elementType;
- if (parseEnumStrAttr(storageClass, parser) ||
- parser.parseOperandList(operandInfo, 4) ||
+ if (parser.parseOperandList(operandInfo, 4) ||
parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
+ parser.parseType(ptrType) || parser.parseComma() ||
parser.parseType(elementType)) {
return failure();
}
-
- auto ptrType = spirv::PointerType::get(
- elementType.cast<spirv::CooperativeMatrixNVType>().getElementType(),
- storageClass);
SmallVector<Type, 4> OperandType = {ptrType, elementType, strideType,
columnMajorType};
if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(),
@@ -2874,17 +2870,14 @@ static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser,
static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
OpAsmPrinter &printer) {
- StringRef sc = stringifyStorageClass(coopMatrix.pointer()
- .getType()
- .cast<spirv::PointerType>()
- .getStorageClass());
- printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " \""
- << sc << "\" " << coopMatrix.pointer() << ", " << coopMatrix.object()
- << ", " << coopMatrix.stride() << ", " << coopMatrix.columnmajor();
+ printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " "
+ << coopMatrix.pointer() << ", " << coopMatrix.object() << ", "
+ << coopMatrix.stride() << ", " << coopMatrix.columnmajor();
// Print optional memory access attribute.
if (auto memAccess = coopMatrix.memory_access())
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
- printer << " : " << coopMatrix.getOperand(1).getType();
+ printer << " : " << coopMatrix.pointer().getType() << ", "
+ << coopMatrix.getOperand(1).getType();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
index ad913dfb1624..3f8c4bff4738 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
@@ -3,29 +3,29 @@
spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_NV_cooperative_matrix]> {
// CHECK-LABEL: @cooperative_matrix_load
spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
- %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+ // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
spv.Return
}
// CHECK-LABEL: @cooperative_matrix_load_memaccess
spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
- %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+ // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}
// CHECK-LABEL: @cooperative_matrix_store
spv.func @cooperative_matrix_store(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spv.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" {
- // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
- spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+ // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<16x8xi32, Workgroup>
+ spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<16x8xi32, Workgroup>
spv.Return
}
// CHECK-LABEL: @cooperative_matrix_store_memaccess
spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
- // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
- spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+ // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
+ spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
index 523bd6bfb030..f0bb50d10f58 100644
--- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
@@ -2,30 +2,37 @@
// CHECK-LABEL: @cooperative_matrix_load
spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
- %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+ // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
spv.Return
}
// -----
// CHECK-LABEL: @cooperative_matrix_load_memaccess
spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
- %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+ // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+ spv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_
diff _ptr_type
+spv.func @cooperative_matrix_load_
diff _ptr_type(%ptr : !spv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32, %b : i1) "None" {
+ // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<vector<4xi32>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr<vector<4xi32>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}
// CHECK-LABEL: @cooperative_matrix_store
spv.func @cooperative_matrix_store(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spv.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
- // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Workgroup>
- spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b : !spv.coopmatrix<8x16xi32, Workgroup>
+ // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Workgroup>
+ spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Workgroup>
spv.Return
}
// CHECK-LABEL: @cooperative_matrix_store_memaccess
spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
- // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
- spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+ // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
+ spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}
@@ -134,3 +141,18 @@ spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b
spv.Return
}
+// -----
+
+spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>, %stride : i32, %b : i1) "None" {
+ // expected-error @+1 {{Pointer must point to a scalar or vector type}}
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+ spv.Return
+}
+
+// -----
+
+spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, Function>, %stride : i32, %b : i1) "None" {
+ // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}}
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<i32, Function> as !spv.coopmatrix<8x16xi32, Subgroup>
+ spv.Return
+}
More information about the Mlir-commits
mailing list