[Mlir-commits] [mlir] a8fe40d - [mlir][spirv] Add OpGroupBroadcast
Thomas Raoux
llvmlistbot at llvm.org
Mon Aug 10 09:50:26 PDT 2020
Author: Artur Bialas
Date: 2020-08-10T09:50:03-07:00
New Revision: a8fe40d9732721ed9a083cb917650f8f12b787b3
URL: https://github.com/llvm/llvm-project/commit/a8fe40d9732721ed9a083cb917650f8f12b787b3
DIFF: https://github.com/llvm/llvm-project/commit/a8fe40d9732721ed9a083cb917650f8f12b787b3.diff
LOG: [mlir][spirv] Add OpGroupBroadcast
OpGroupBroadcast added to SPIRV dialect
Differential Revision: https://reviews.llvm.org/D85435
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
mlir/test/Dialect/SPIRV/group-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index cbff82efdfd3..ab0b76161342 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3231,6 +3231,7 @@ def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional",
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>;
+def SPV_OC_OpGroupBroadcast : I32EnumAttrCase<"OpGroupBroadcast", 263>;
def SPV_OC_OpNoLine : I32EnumAttrCase<"OpNoLine", 317>;
def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>;
def SPV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
@@ -3297,8 +3298,8 @@ def SPV_OpcodeAttr :
SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor,
SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
- SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpNoLine,
- SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
+ SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast,
+ SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd,
SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul,
SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
index 832e03cba75c..c9ce8be9927f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
@@ -17,6 +17,82 @@
// -----
+def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
+ [NoSideEffect, AllTypesMatch<["value", "result"]>]> {
+ let summary = [{
+ Return the Value of the invocation identified by the local id LocalId to
+ all invocations in the group.
+ }];
+
+ let description = [{
+ All invocations of this module within Execution must reach this point of
+ execution.
+
+ Behavior is undefined if this instruction is used in control flow that
+ is non-uniform within Execution.
+
+ Result Type must be a scalar or vector of floating-point type, integer
+ type, or Boolean type.
+
+ Execution must be Workgroup or Subgroup Scope.
+
+ The type of Value must be the same as Result Type.
+
+ LocalId must be an integer datatype. It can be a scalar, or a vector
+ with 2 components or a vector with 3 components. LocalId must be the
+ same for all invocations in the group.
+
+ <!-- End of AutoGen section -->
+
+ ```
+ scope ::= `"Workgroup"` | `"Subgroup"`
+ integer-float-scalar-vector-type ::= integer-type | float-type |
+ `vector<` integer-literal `x` integer-type `>` |
+ `vector<` integer-literal `x` float-type `>`
+ localid-type ::= integer-type |
+ `vector<` integer-literal `x` integer-type `>`
+ group-broadcast-op ::= ssa-id `=` `spv.GroupBroadcast` scope ssa_use,
+ ssa_use `:` integer-float-scalar-vector-type `,` localid-type
+ ```mlir
+
+ #### Example:
+
+ ```
+ %scalar_value = ... : f32
+ %vector_value = ... : vector<4xf32>
+ %scalar_localid = ... : i32
+ %vector_localid = ... : vector<3xi32>
+ %0 = spv.GroupBroadcast "Subgroup" %scalar_value, %scalar_localid : f32, i32
+ %1 = spv.GroupBroadcast "Workgroup" %vector_value, %vector_localid :
+ vector<4xf32>, vector<3xi32>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPV_V_1_0>,
+ MaxVersion<SPV_V_1_5>,
+ Extension<[]>,
+ Capability<[SPV_C_Groups]>
+ ];
+
+ let arguments = (ins
+ SPV_ScopeAttr:$execution_scope,
+ SPV_Type:$value,
+ SPV_ScalarOrVectorOf<SPV_Integer>:$localid
+ );
+
+ let results = (outs
+ SPV_Type:$result
+ );
+
+ let assemblyFormat = [{
+ $execution_scope operands attr-dict `:` type($value) `,` type($localid)
+ }];
+
+}
+
+// -----
+
def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
let summary = "See extension SPV_KHR_shader_ballot";
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 06b06ddcc913..88ca71ac18ac 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1993,6 +1993,25 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
return success();
}
+//===----------------------------------------------------------------------===//
+// spv.GroupBroadcast
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(spirv::GroupBroadcastOp broadcastOp) {
+ spirv::Scope scope = broadcastOp.execution_scope();
+ if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+ return broadcastOp.emitOpError(
+ "execution scope must be 'Workgroup' or 'Subgroup'");
+
+ if (auto localIdTy = broadcastOp.localid().getType().dyn_cast<VectorType>())
+ if (!(localIdTy.getNumElements() == 2 || localIdTy.getNumElements() == 3))
+ return broadcastOp.emitOpError("localid is a vector and can be with only "
+ " 2 or 3 components, actual number is ")
+ << localIdTy.getNumElements();
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spv.GroupNonUniformBallotOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
index 474e40b97acc..9e1e85191874 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
@@ -7,4 +7,16 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
%0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
spv.ReturnValue %0: vector<4xi32>
}
+ // CHECK-LABEL: @group_broadcast_1
+ spv.func @group_broadcast_1(%value: f32, %localid: i32 ) -> f32 "None" {
+ // CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, i32
+ %0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, i32
+ spv.ReturnValue %0: f32
+ }
+ // CHECK-LABEL: @group_broadcast_2
+ spv.func @group_broadcast_2(%value: f32, %localid: vector<3xi32> ) -> f32 "None" {
+ // CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, vector<3xi32>
+ %0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32>
+ spv.ReturnValue %0: f32
+ }
}
diff --git a/mlir/test/Dialect/SPIRV/group-ops.mlir b/mlir/test/Dialect/SPIRV/group-ops.mlir
index ba5e79209e31..93e9054050ec 100644
--- a/mlir/test/Dialect/SPIRV/group-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/group-ops.mlir
@@ -9,3 +9,55 @@ func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
%0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
return %0: vector<4xi32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.GroupBroadcast
+//===----------------------------------------------------------------------===//
+
+func @group_broadcast_scalar(%value: f32, %localid: i32 ) -> f32 {
+ // CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, i32
+ %0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, i32
+ return %0: f32
+}
+
+// -----
+
+func @group_broadcast_scalar_vector(%value: f32, %localid: vector<3xi32> ) -> f32 {
+ // CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, vector<3xi32>
+ %0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32>
+ return %0: f32
+}
+
+// -----
+
+func @group_broadcast_vector(%value: vector<4xf32>, %localid: vector<3xi32> ) -> vector<4xf32> {
+ // CHECK: spv.GroupBroadcast "Subgroup" %{{.*}}, %{{.*}} : vector<4xf32>, vector<3xi32>
+ %0 = spv.GroupBroadcast "Subgroup" %value, %localid : vector<4xf32>, vector<3xi32>
+ return %0: vector<4xf32>
+}
+
+// -----
+
+func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> ) -> f32 {
+ // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ %0 = spv.GroupBroadcast "Device" %value, %localid : f32, vector<3xi32>
+ return %0: f32
+}
+
+// -----
+
+func @group_broadcast_negative_locid_dtype(%value: f32, %localid: vector<3xf32> ) -> f32 {
+ // expected-error @+1 {{operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}}
+ %0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<3xf32>
+ return %0: f32
+}
+
+// -----
+
+func @group_broadcast_negative_locid_vec4(%value: f32, %localid: vector<4xi32> ) -> f32 {
+ // expected-error @+1 {{localid is a vector and can be with only 2 or 3 components, actual number is 4}}
+ %0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<4xi32>
+ return %0: f32
+}
More information about the Mlir-commits
mailing list