[Mlir-commits] [mlir] d9b4245 - [mlir][spirv] Add block read and write from SPV_INTEL_subgroups

Thomas Raoux llvmlistbot at llvm.org
Wed Sep 2 20:07:28 PDT 2020


Author: Artur Bialas
Date: 2020-09-02T20:06:59-07:00
New Revision: d9b4245f56a98d8ea72d6f75d5bdd5c7c8e5c88c

URL: https://github.com/llvm/llvm-project/commit/d9b4245f56a98d8ea72d6f75d5bdd5c7c8e5c88c
DIFF: https://github.com/llvm/llvm-project/commit/d9b4245f56a98d8ea72d6f75d5bdd5c7c8e5c88c.diff

LOG: [mlir][spirv] Add block read and write from SPV_INTEL_subgroups

Added support to OpSubgroupBlockReadINTEL and OpSubgroupBlockWriteINTEL

Differential Revision: https://reviews.llvm.org/D86876

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
    mlir/test/Dialect/SPIRV/group-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index ab0b76161342..6458183bdeb2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3252,6 +3252,8 @@ def SPV_OC_OpCooperativeMatrixLoadNV   : I32EnumAttrCase<"OpCooperativeMatrixLoa
 def SPV_OC_OpCooperativeMatrixStoreNV  : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>;
 def SPV_OC_OpCooperativeMatrixMulAddNV : I32EnumAttrCase<"OpCooperativeMatrixMulAddNV", 5361>;
 def SPV_OC_OpCooperativeMatrixLengthNV : I32EnumAttrCase<"OpCooperativeMatrixLengthNV", 5362>;
+def SPV_OC_OpSubgroupBlockReadINTEL    : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
+def SPV_OC_OpSubgroupBlockWriteINTEL   : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
 
 def SPV_OpcodeAttr :
     SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
@@ -3308,7 +3310,8 @@ def SPV_OpcodeAttr :
       SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
       SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV,
       SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV,
-      SPV_OC_OpCooperativeMatrixLengthNV
+      SPV_OC_OpCooperativeMatrixLengthNV, SPV_OC_OpSubgroupBlockReadINTEL,
+      SPV_OC_OpSubgroupBlockWriteINTEL
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
index c9ce8be9927f..7eab3b44601e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
@@ -88,7 +88,6 @@ def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
   let assemblyFormat = [{
     $execution_scope operands attr-dict `:` type($value) `,` type($localid)
   }];
-
 }
 
 // -----
@@ -147,4 +146,104 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
 
 // -----
 
+def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
+  let summary = "See extension SPV_INTEL_subgroups";
+
+  let description = [{
+    Reads one or more components of Result data for each invocation in the
+    subgroup from the specified Ptr as a block operation.
+
+    The data is read strided, so the first value read is:
+    Ptr[ SubgroupLocalInvocationId ]
+
+    and the second value read is:
+    Ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ]
+    etc.
+
+    Result Type may be a scalar or vector type, and its component type must be
+    equal to the type pointed to by Ptr.
+
+    The type of Ptr must be a pointer type, and must point to a scalar type.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    subgroup-block-read-INTEL-op ::= ssa-id `=` `spv.SubgroupBlockReadINTEL`
+                                storage-class ssa_use `:` spirv-element-type
+    ```mlir
+
+    #### Example:
+
+    ```
+    %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[SPV_INTEL_subgroups]>,
+    Capability<[SPV_C_SubgroupBufferBlockIOINTEL]>
+  ];
+
+  let arguments = (ins
+    SPV_AnyPtr:$ptr
+  );
+
+  let results = (outs
+    SPV_Type:$value
+  );
+}
+
+// -----
+
+def SPV_SubgroupBlockWriteINTELOp : SPV_Op<"SubgroupBlockWriteINTEL", []> {
+  let summary = "See extension SPV_INTEL_subgroups";
+
+  let description = [{
+    Writes one or more components of Data for each invocation in the subgroup
+    from the specified Ptr as a block operation.
+
+    The data is written strided, so the first value is written to:
+    Ptr[ SubgroupLocalInvocationId ]
+
+    and the second value written is:
+    Ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ]
+    etc.
+
+    The type of Ptr must be a pointer type, and must point to a scalar type.
+
+    The component type of Data must be equal to the type pointed to by Ptr.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    subgroup-block-write-INTEL-op ::= ssa-id `=` `spv.SubgroupBlockWriteINTEL`
+                      storage-class ssa_use `,` ssa-use `:` spirv-element-type
+    ```mlir
+
+    #### Example:
+
+    ```
+    spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[SPV_INTEL_subgroups]>,
+    Capability<[SPV_C_SubgroupBufferBlockIOINTEL]>
+  ];
+
+  let arguments = (ins
+    SPV_AnyPtr:$ptr,
+    SPV_Type:$value
+  );
+
+  let results = (outs);
+}
+
+// -----
+
 #endif // SPIRV_GROUP_OPS

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index f729752e02a0..339f588541f6 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -468,6 +468,19 @@ static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
   return success();
 }
 
+template <typename BlockReadWriteOpTy>
+static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
+                                                        Value ptr, Value val) {
+  auto valType = val.getType();
+  if (auto valVecTy = valType.dyn_cast<VectorType>())
+    valType = valVecTy.getElementType();
+
+  if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
+    return op.emitOpError("mismatch in result type and pointer type");
+  }
+  return success();
+}
+
 static ParseResult parseVariableDecorations(OpAsmParser &parser,
                                             OperationState &state) {
   auto builtInName = llvm::convertToSnakeFromCamelCase(
@@ -2025,6 +2038,93 @@ static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spv.SubgroupBlockReadINTEL
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseSubgroupBlockReadINTELOp(OpAsmParser &parser,
+                                                 OperationState &state) {
+  // Parse the storage class specification
+  spirv::StorageClass storageClass;
+  OpAsmParser::OperandType ptrInfo;
+  Type elementType;
+  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
+      parser.parseColon() || parser.parseType(elementType)) {
+    return failure();
+  }
+
+  auto ptrType = spirv::PointerType::get(elementType, storageClass);
+  if (auto valVecTy = elementType.dyn_cast<VectorType>())
+    ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
+
+  if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
+    return failure();
+  }
+
+  state.addTypes(elementType);
+  return success();
+}
+
+static void print(spirv::SubgroupBlockReadINTELOp blockReadOp,
+                  OpAsmPrinter &printer) {
+  SmallVector<StringRef, 4> elidedAttrs;
+  printer << spirv::SubgroupBlockReadINTELOp::getOperationName() << " "
+          << blockReadOp.ptr();
+  printer << " : " << blockReadOp.getType();
+}
+
+static LogicalResult verify(spirv::SubgroupBlockReadINTELOp blockReadOp) {
+  if (failed(verifyBlockReadWritePtrAndValTypes(blockReadOp, blockReadOp.ptr(),
+                                                blockReadOp.value())))
+    return failure();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.SubgroupBlockWriteINTEL
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseSubgroupBlockWriteINTELOp(OpAsmParser &parser,
+                                                  OperationState &state) {
+  // Parse the storage class specification
+  spirv::StorageClass storageClass;
+  SmallVector<OpAsmParser::OperandType, 2> operandInfo;
+  auto loc = parser.getCurrentLocation();
+  Type elementType;
+  if (parseEnumStrAttr(storageClass, parser) ||
+      parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
+      parser.parseType(elementType)) {
+    return failure();
+  }
+
+  auto ptrType = spirv::PointerType::get(elementType, storageClass);
+  if (auto valVecTy = elementType.dyn_cast<VectorType>())
+    ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
+
+  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
+                             state.operands)) {
+    return failure();
+  }
+  return success();
+}
+
+static void print(spirv::SubgroupBlockWriteINTELOp blockWriteOp,
+                  OpAsmPrinter &printer) {
+  SmallVector<StringRef, 4> elidedAttrs;
+  printer << spirv::SubgroupBlockWriteINTELOp::getOperationName() << " "
+          << blockWriteOp.ptr() << ", " << blockWriteOp.value();
+  printer << " : " << blockWriteOp.value().getType();
+}
+
+static LogicalResult verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp) {
+  if (failed(verifyBlockReadWritePtrAndValTypes(
+          blockWriteOp, blockWriteOp.ptr(), blockWriteOp.value())))
+    return failure();
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spv.GroupNonUniformElectOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
index 9e1e85191874..b3aaf63856a5 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
@@ -19,4 +19,28 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32>
     spv.ReturnValue %0: f32
   }
+  // CHECK-LABEL: @subgroup_block_read_intel
+  spv.func @subgroup_block_read_intel(%ptr : !spv.ptr<i32, StorageBuffer>) -> i32 "None" {
+    // CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : i32
+    %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32
+    spv.ReturnValue %0: i32
+  }
+  // CHECK-LABEL: @subgroup_block_read_intel_vector
+  spv.func @subgroup_block_read_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>) -> vector<3xi32> "None" {
+    // CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : vector<3xi32>
+    %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : vector<3xi32>
+    spv.ReturnValue %0: vector<3xi32>
+  }
+  // CHECK-LABEL: @subgroup_block_write_intel
+  spv.func @subgroup_block_write_intel(%ptr : !spv.ptr<i32, StorageBuffer>, %value: i32) -> () "None" {
+    // CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : i32
+    spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32
+    spv.Return
+  }
+  // CHECK-LABEL: @subgroup_block_write_intel_vector
+  spv.func @subgroup_block_write_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>, %value: vector<3xi32>) -> () "None" {
+    // CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : vector<3xi32>
+    spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : vector<3xi32>
+    spv.Return
+  }
 }

diff  --git a/mlir/test/Dialect/SPIRV/group-ops.mlir b/mlir/test/Dialect/SPIRV/group-ops.mlir
index 93e9054050ec..55a07270a348 100644
--- a/mlir/test/Dialect/SPIRV/group-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/group-ops.mlir
@@ -61,3 +61,43 @@ func @group_broadcast_negative_locid_vec4(%value: f32, %localid: vector<4xi32> )
   %0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<4xi32>
   return %0: f32
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.SubgroupBlockReadINTEL
+//===----------------------------------------------------------------------===//
+
+func @subgroup_block_read_intel(%ptr : !spv.ptr<i32, StorageBuffer>) -> i32 {
+  // CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : i32
+  %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32
+  return %0: i32
+}
+
+// -----
+
+func @subgroup_block_read_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>) -> vector<3xi32> {
+  // CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : vector<3xi32>
+  %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : vector<3xi32>
+  return %0: vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.SubgroupBlockWriteINTEL
+//===----------------------------------------------------------------------===//
+
+func @subgroup_block_write_intel(%ptr : !spv.ptr<i32, StorageBuffer>, %value: i32) -> () {
+  // CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : i32
+  spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32
+  return
+}
+
+// -----
+
+func @subgroup_block_write_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>, %value: vector<3xi32>) -> () {
+  // CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : vector<3xi32>
+  spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : vector<3xi32>
+  return
+}
\ No newline at end of file


        


More information about the Mlir-commits mailing list