[Mlir-commits] [mlir] a8fe40d - [mlir][spirv] Add OpGroupBroadcast

Thomas Raoux llvmlistbot at llvm.org
Mon Aug 10 09:50:26 PDT 2020


Author: Artur Bialas
Date: 2020-08-10T09:50:03-07:00
New Revision: a8fe40d9732721ed9a083cb917650f8f12b787b3

URL: https://github.com/llvm/llvm-project/commit/a8fe40d9732721ed9a083cb917650f8f12b787b3
DIFF: https://github.com/llvm/llvm-project/commit/a8fe40d9732721ed9a083cb917650f8f12b787b3.diff

LOG: [mlir][spirv] Add OpGroupBroadcast

OpGroupBroadcast added to SPIRV dialect

Differential Revision: https://reviews.llvm.org/D85435

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
    mlir/test/Dialect/SPIRV/group-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index cbff82efdfd3..ab0b76161342 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3231,6 +3231,7 @@ def SPV_OC_OpBranchConditional         : I32EnumAttrCase<"OpBranchConditional",
 def SPV_OC_OpReturn                    : I32EnumAttrCase<"OpReturn", 253>;
 def SPV_OC_OpReturnValue               : I32EnumAttrCase<"OpReturnValue", 254>;
 def SPV_OC_OpUnreachable               : I32EnumAttrCase<"OpUnreachable", 255>;
+def SPV_OC_OpGroupBroadcast            : I32EnumAttrCase<"OpGroupBroadcast", 263>;
 def SPV_OC_OpNoLine                    : I32EnumAttrCase<"OpNoLine", 317>;
 def SPV_OC_OpModuleProcessed           : I32EnumAttrCase<"OpModuleProcessed", 330>;
 def SPV_OC_OpGroupNonUniformElect      : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
@@ -3297,8 +3298,8 @@ def SPV_OpcodeAttr :
       SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor,
       SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
       SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
-      SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpNoLine,
-      SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
+      SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast,
+      SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
       SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd,
       SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul,
       SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
index 832e03cba75c..c9ce8be9927f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
@@ -17,6 +17,82 @@
 
 // -----
 
+def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
+  [NoSideEffect, AllTypesMatch<["value", "result"]>]> {
+  let summary = [{
+    Return the Value of the invocation identified by the local id LocalId to
+    all invocations in the group.
+  }];
+
+  let description = [{
+    All invocations of this module within Execution must reach this point of
+    execution.
+
+    Behavior is undefined if this instruction is used in control flow that
+    is non-uniform within Execution.
+
+    Result Type  must be a scalar or vector of floating-point type, integer
+    type, or Boolean type.
+
+    Execution must be Workgroup or Subgroup Scope.
+
+     The type of Value must be the same as Result Type.
+
+    LocalId must be an integer datatype. It can be a scalar, or a vector
+    with 2 components or a vector with 3 components. LocalId must be the
+    same for all invocations in the group.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    scope ::= `"Workgroup"` | `"Subgroup"`
+    integer-float-scalar-vector-type ::= integer-type | float-type |
+                               `vector<` integer-literal `x` integer-type `>` |
+                               `vector<` integer-literal `x` float-type `>`
+    localid-type ::= integer-type |
+                   `vector<` integer-literal `x` integer-type `>`
+    group-broadcast-op ::= ssa-id `=` `spv.GroupBroadcast` scope ssa_use,
+                   ssa_use `:` integer-float-scalar-vector-type `,` localid-type
+    ```mlir
+
+    #### Example:
+
+    ```
+    %scalar_value = ... : f32
+    %vector_value = ... : vector<4xf32>
+    %scalar_localid = ... : i32
+    %vector_localid = ... : vector<3xi32>
+    %0 = spv.GroupBroadcast "Subgroup" %scalar_value, %scalar_localid : f32, i32
+    %1 = spv.GroupBroadcast "Workgroup" %vector_value, %vector_localid :
+      vector<4xf32>, vector<3xi32>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[]>,
+    Capability<[SPV_C_Groups]>
+  ];
+
+  let arguments = (ins
+    SPV_ScopeAttr:$execution_scope,
+    SPV_Type:$value,
+    SPV_ScalarOrVectorOf<SPV_Integer>:$localid
+  );
+
+  let results = (outs
+    SPV_Type:$result
+  );
+
+  let assemblyFormat = [{
+    $execution_scope operands attr-dict `:` type($value) `,` type($localid)
+  }];
+
+}
+
+// -----
+
 def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
   let summary = "See extension SPV_KHR_shader_ballot";
 

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 06b06ddcc913..88ca71ac18ac 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1993,6 +1993,25 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spv.GroupBroadcast
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(spirv::GroupBroadcastOp broadcastOp) {
+  spirv::Scope scope = broadcastOp.execution_scope();
+  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+    return broadcastOp.emitOpError(
+        "execution scope must be 'Workgroup' or 'Subgroup'");
+
+  if (auto localIdTy = broadcastOp.localid().getType().dyn_cast<VectorType>())
+    if (!(localIdTy.getNumElements() == 2 || localIdTy.getNumElements() == 3))
+      return broadcastOp.emitOpError("localid is a vector and can be with only "
+                                     " 2 or 3 components, actual number is ")
+             << localIdTy.getNumElements();
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spv.GroupNonUniformBallotOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
index 474e40b97acc..9e1e85191874 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
@@ -7,4 +7,16 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
     spv.ReturnValue %0: vector<4xi32>
   }
+  // CHECK-LABEL: @group_broadcast_1
+  spv.func @group_broadcast_1(%value: f32, %localid: i32 ) -> f32 "None" {
+    // CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, i32
+    %0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, i32
+    spv.ReturnValue %0: f32
+  }
+  // CHECK-LABEL: @group_broadcast_2
+  spv.func @group_broadcast_2(%value: f32, %localid: vector<3xi32> ) -> f32 "None" {
+    // CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, vector<3xi32>
+    %0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32>
+    spv.ReturnValue %0: f32
+  }
 }

diff  --git a/mlir/test/Dialect/SPIRV/group-ops.mlir b/mlir/test/Dialect/SPIRV/group-ops.mlir
index ba5e79209e31..93e9054050ec 100644
--- a/mlir/test/Dialect/SPIRV/group-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/group-ops.mlir
@@ -9,3 +9,55 @@ func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
   %0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
   return %0: vector<4xi32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.GroupBroadcast
+//===----------------------------------------------------------------------===//
+
+func @group_broadcast_scalar(%value: f32, %localid: i32 ) -> f32 {
+  // CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, i32
+  %0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, i32
+  return %0: f32
+}
+
+// -----
+
+func @group_broadcast_scalar_vector(%value: f32, %localid: vector<3xi32> ) -> f32 {
+  // CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, vector<3xi32>
+  %0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32>
+  return %0: f32
+}
+
+// -----
+
+func @group_broadcast_vector(%value: vector<4xf32>, %localid: vector<3xi32> ) -> vector<4xf32> {
+  // CHECK: spv.GroupBroadcast "Subgroup" %{{.*}}, %{{.*}} : vector<4xf32>, vector<3xi32>
+  %0 = spv.GroupBroadcast "Subgroup" %value, %localid : vector<4xf32>, vector<3xi32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
+func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> ) -> f32 {
+  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}} 
+  %0 = spv.GroupBroadcast "Device" %value, %localid : f32, vector<3xi32>
+  return %0: f32
+}
+
+// -----
+
+func @group_broadcast_negative_locid_dtype(%value: f32, %localid: vector<3xf32> ) -> f32 {
+  // expected-error @+1 {{operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}}
+  %0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<3xf32>
+  return %0: f32
+}
+
+// -----
+
+func @group_broadcast_negative_locid_vec4(%value: f32, %localid: vector<4xi32> ) -> f32 {
+  // expected-error @+1 {{localid is a vector and can be with only  2 or 3 components, actual number is 4}}
+  %0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<4xi32>
+  return %0: f32
+}


        


More information about the Mlir-commits mailing list