[Mlir-commits] [mlir] d9b4245 - [mlir][spirv] Add block read and write from SPV_INTEL_subgroups
Thomas Raoux
llvmlistbot at llvm.org
Wed Sep 2 20:07:28 PDT 2020
Author: Artur Bialas
Date: 2020-09-02T20:06:59-07:00
New Revision: d9b4245f56a98d8ea72d6f75d5bdd5c7c8e5c88c
URL: https://github.com/llvm/llvm-project/commit/d9b4245f56a98d8ea72d6f75d5bdd5c7c8e5c88c
DIFF: https://github.com/llvm/llvm-project/commit/d9b4245f56a98d8ea72d6f75d5bdd5c7c8e5c88c.diff
LOG: [mlir][spirv] Add block read and write from SPV_INTEL_subgroups
Added support to OpSubgroupBlockReadINTEL and OpSubgroupBlockWriteINTEL
Differential Revision: https://reviews.llvm.org/D86876
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
mlir/test/Dialect/SPIRV/group-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index ab0b76161342..6458183bdeb2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3252,6 +3252,8 @@ def SPV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoa
def SPV_OC_OpCooperativeMatrixStoreNV : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>;
def SPV_OC_OpCooperativeMatrixMulAddNV : I32EnumAttrCase<"OpCooperativeMatrixMulAddNV", 5361>;
def SPV_OC_OpCooperativeMatrixLengthNV : I32EnumAttrCase<"OpCooperativeMatrixLengthNV", 5362>;
+def SPV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
+def SPV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
def SPV_OpcodeAttr :
SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
@@ -3308,7 +3310,8 @@ def SPV_OpcodeAttr :
SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV,
SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV,
- SPV_OC_OpCooperativeMatrixLengthNV
+ SPV_OC_OpCooperativeMatrixLengthNV, SPV_OC_OpSubgroupBlockReadINTEL,
+ SPV_OC_OpSubgroupBlockWriteINTEL
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
index c9ce8be9927f..7eab3b44601e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
@@ -88,7 +88,6 @@ def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
let assemblyFormat = [{
$execution_scope operands attr-dict `:` type($value) `,` type($localid)
}];
-
}
// -----
@@ -147,4 +146,104 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
// -----
+def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
+ let summary = "See extension SPV_INTEL_subgroups";
+
+ let description = [{
+ Reads one or more components of Result data for each invocation in the
+ subgroup from the specified Ptr as a block operation.
+
+ The data is read strided, so the first value read is:
+ Ptr[ SubgroupLocalInvocationId ]
+
+ and the second value read is:
+ Ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ]
+ etc.
+
+ Result Type may be a scalar or vector type, and its component type must be
+ equal to the type pointed to by Ptr.
+
+ The type of Ptr must be a pointer type, and must point to a scalar type.
+
+ <!-- End of AutoGen section -->
+
+ ```
+ subgroup-block-read-INTEL-op ::= ssa-id `=` `spv.SubgroupBlockReadINTEL`
+ storage-class ssa_use `:` spirv-element-type
+ ```mlir
+
+ #### Example:
+
+ ```
+ %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPV_V_1_0>,
+ MaxVersion<SPV_V_1_5>,
+ Extension<[SPV_INTEL_subgroups]>,
+ Capability<[SPV_C_SubgroupBufferBlockIOINTEL]>
+ ];
+
+ let arguments = (ins
+ SPV_AnyPtr:$ptr
+ );
+
+ let results = (outs
+ SPV_Type:$value
+ );
+}
+
+// -----
+
+def SPV_SubgroupBlockWriteINTELOp : SPV_Op<"SubgroupBlockWriteINTEL", []> {
+ let summary = "See extension SPV_INTEL_subgroups";
+
+ let description = [{
+ Writes one or more components of Data for each invocation in the subgroup
+ from the specified Ptr as a block operation.
+
+ The data is written strided, so the first value is written to:
+ Ptr[ SubgroupLocalInvocationId ]
+
+ and the second value written is:
+ Ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ]
+ etc.
+
+ The type of Ptr must be a pointer type, and must point to a scalar type.
+
+ The component type of Data must be equal to the type pointed to by Ptr.
+
+ <!-- End of AutoGen section -->
+
+ ```
+ subgroup-block-write-INTEL-op ::= ssa-id `=` `spv.SubgroupBlockWriteINTEL`
+ storage-class ssa_use `,` ssa-use `:` spirv-element-type
+ ```mlir
+
+ #### Example:
+
+ ```
+ spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPV_V_1_0>,
+ MaxVersion<SPV_V_1_5>,
+ Extension<[SPV_INTEL_subgroups]>,
+ Capability<[SPV_C_SubgroupBufferBlockIOINTEL]>
+ ];
+
+ let arguments = (ins
+ SPV_AnyPtr:$ptr,
+ SPV_Type:$value
+ );
+
+ let results = (outs);
+}
+
+// -----
+
#endif // SPIRV_GROUP_OPS
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index f729752e02a0..339f588541f6 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -468,6 +468,19 @@ static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
return success();
}
+template <typename BlockReadWriteOpTy>
+static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
+ Value ptr, Value val) {
+ auto valType = val.getType();
+ if (auto valVecTy = valType.dyn_cast<VectorType>())
+ valType = valVecTy.getElementType();
+
+ if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
+ return op.emitOpError("mismatch in result type and pointer type");
+ }
+ return success();
+}
+
static ParseResult parseVariableDecorations(OpAsmParser &parser,
OperationState &state) {
auto builtInName = llvm::convertToSnakeFromCamelCase(
@@ -2025,6 +2038,93 @@ static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
return success();
}
+//===----------------------------------------------------------------------===//
+// spv.SubgroupBlockReadINTEL
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseSubgroupBlockReadINTELOp(OpAsmParser &parser,
+ OperationState &state) {
+ // Parse the storage class specification
+ spirv::StorageClass storageClass;
+ OpAsmParser::OperandType ptrInfo;
+ Type elementType;
+ if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
+ parser.parseColon() || parser.parseType(elementType)) {
+ return failure();
+ }
+
+ auto ptrType = spirv::PointerType::get(elementType, storageClass);
+ if (auto valVecTy = elementType.dyn_cast<VectorType>())
+ ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
+
+ if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
+ return failure();
+ }
+
+ state.addTypes(elementType);
+ return success();
+}
+
+static void print(spirv::SubgroupBlockReadINTELOp blockReadOp,
+ OpAsmPrinter &printer) {
+ SmallVector<StringRef, 4> elidedAttrs;
+ printer << spirv::SubgroupBlockReadINTELOp::getOperationName() << " "
+ << blockReadOp.ptr();
+ printer << " : " << blockReadOp.getType();
+}
+
+static LogicalResult verify(spirv::SubgroupBlockReadINTELOp blockReadOp) {
+ if (failed(verifyBlockReadWritePtrAndValTypes(blockReadOp, blockReadOp.ptr(),
+ blockReadOp.value())))
+ return failure();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.SubgroupBlockWriteINTEL
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseSubgroupBlockWriteINTELOp(OpAsmParser &parser,
+ OperationState &state) {
+ // Parse the storage class specification
+ spirv::StorageClass storageClass;
+ SmallVector<OpAsmParser::OperandType, 2> operandInfo;
+ auto loc = parser.getCurrentLocation();
+ Type elementType;
+ if (parseEnumStrAttr(storageClass, parser) ||
+ parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
+ parser.parseType(elementType)) {
+ return failure();
+ }
+
+ auto ptrType = spirv::PointerType::get(elementType, storageClass);
+ if (auto valVecTy = elementType.dyn_cast<VectorType>())
+ ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
+
+ if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
+ state.operands)) {
+ return failure();
+ }
+ return success();
+}
+
+static void print(spirv::SubgroupBlockWriteINTELOp blockWriteOp,
+ OpAsmPrinter &printer) {
+ SmallVector<StringRef, 4> elidedAttrs;
+ printer << spirv::SubgroupBlockWriteINTELOp::getOperationName() << " "
+ << blockWriteOp.ptr() << ", " << blockWriteOp.value();
+ printer << " : " << blockWriteOp.value().getType();
+}
+
+static LogicalResult verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp) {
+ if (failed(verifyBlockReadWritePtrAndValTypes(
+ blockWriteOp, blockWriteOp.ptr(), blockWriteOp.value())))
+ return failure();
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spv.GroupNonUniformElectOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
index 9e1e85191874..b3aaf63856a5 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
@@ -19,4 +19,28 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
%0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32>
spv.ReturnValue %0: f32
}
+ // CHECK-LABEL: @subgroup_block_read_intel
+ spv.func @subgroup_block_read_intel(%ptr : !spv.ptr<i32, StorageBuffer>) -> i32 "None" {
+ // CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : i32
+ %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32
+ spv.ReturnValue %0: i32
+ }
+ // CHECK-LABEL: @subgroup_block_read_intel_vector
+ spv.func @subgroup_block_read_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>) -> vector<3xi32> "None" {
+ // CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : vector<3xi32>
+ %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : vector<3xi32>
+ spv.ReturnValue %0: vector<3xi32>
+ }
+ // CHECK-LABEL: @subgroup_block_write_intel
+ spv.func @subgroup_block_write_intel(%ptr : !spv.ptr<i32, StorageBuffer>, %value: i32) -> () "None" {
+ // CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : i32
+ spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32
+ spv.Return
+ }
+ // CHECK-LABEL: @subgroup_block_write_intel_vector
+ spv.func @subgroup_block_write_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>, %value: vector<3xi32>) -> () "None" {
+ // CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : vector<3xi32>
+ spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : vector<3xi32>
+ spv.Return
+ }
}
diff --git a/mlir/test/Dialect/SPIRV/group-ops.mlir b/mlir/test/Dialect/SPIRV/group-ops.mlir
index 93e9054050ec..55a07270a348 100644
--- a/mlir/test/Dialect/SPIRV/group-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/group-ops.mlir
@@ -61,3 +61,43 @@ func @group_broadcast_negative_locid_vec4(%value: f32, %localid: vector<4xi32> )
%0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<4xi32>
return %0: f32
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.SubgroupBlockReadINTEL
+//===----------------------------------------------------------------------===//
+
+func @subgroup_block_read_intel(%ptr : !spv.ptr<i32, StorageBuffer>) -> i32 {
+ // CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : i32
+ %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32
+ return %0: i32
+}
+
+// -----
+
+func @subgroup_block_read_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>) -> vector<3xi32> {
+ // CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : vector<3xi32>
+ %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : vector<3xi32>
+ return %0: vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.SubgroupBlockWriteINTEL
+//===----------------------------------------------------------------------===//
+
+func @subgroup_block_write_intel(%ptr : !spv.ptr<i32, StorageBuffer>, %value: i32) -> () {
+ // CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : i32
+ spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32
+ return
+}
+
+// -----
+
+func @subgroup_block_write_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>, %value: vector<3xi32>) -> () {
+ // CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : vector<3xi32>
+ spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : vector<3xi32>
+ return
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list