[Mlir-commits] [mlir] 4ce84b0 - [mlir][spirv] Add GroupNonUniformBroadcastOp

Thomas Raoux llvmlistbot at llvm.org
Wed Sep 16 23:14:37 PDT 2020


Author: Artur Bialas
Date: 2020-09-16T23:13:06-07:00
New Revision: 4ce84b0e704ee7b8b13e236e65b3bf49da27a91c

URL: https://github.com/llvm/llvm-project/commit/4ce84b0e704ee7b8b13e236e65b3bf49da27a91c
DIFF: https://github.com/llvm/llvm-project/commit/4ce84b0e704ee7b8b13e236e65b3bf49da27a91c.diff

LOG: [mlir][spirv] Add GroupNonUniformBroadcastOp

Added GroupNonUniformBroadcastOp to spirv dialect.

Differential Revision: https://reviews.llvm.org/D87688

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir
    mlir/test/Dialect/SPIRV/non-uniform-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 1fa72bf4dcab..83150dad514d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3256,6 +3256,7 @@ def SPV_OC_OpGroupBroadcast            : I32EnumAttrCase<"OpGroupBroadcast", 263
 def SPV_OC_OpNoLine                    : I32EnumAttrCase<"OpNoLine", 317>;
 def SPV_OC_OpModuleProcessed           : I32EnumAttrCase<"OpModuleProcessed", 330>;
 def SPV_OC_OpGroupNonUniformElect      : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
+def SPV_OC_OpGroupNonUniformBroadcast  : I32EnumAttrCase<"OpGroupNonUniformBroadcast", 337>;
 def SPV_OC_OpGroupNonUniformBallot     : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
 def SPV_OC_OpGroupNonUniformIAdd       : I32EnumAttrCase<"OpGroupNonUniformIAdd", 349>;
 def SPV_OC_OpGroupNonUniformFAdd       : I32EnumAttrCase<"OpGroupNonUniformFAdd", 350>;
@@ -3323,16 +3324,16 @@ def SPV_OpcodeAttr :
       SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
       SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast,
       SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect,
-      SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd,
-      SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul,
-      SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin,
-      SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin,
-      SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax,
-      SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
-      SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV,
-      SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV,
-      SPV_OC_OpCooperativeMatrixLengthNV, SPV_OC_OpSubgroupBlockReadINTEL,
-      SPV_OC_OpSubgroupBlockWriteINTEL
+      SPV_OC_OpGroupNonUniformBroadcast, SPV_OC_OpGroupNonUniformBallot,
+      SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd,
+      SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul,
+      SPV_OC_OpGroupNonUniformSMin, SPV_OC_OpGroupNonUniformUMin,
+      SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax,
+      SPV_OC_OpGroupNonUniformUMax, SPV_OC_OpGroupNonUniformFMax,
+      SPV_OC_OpSubgroupBallotKHR, SPV_OC_OpTypeCooperativeMatrixNV,
+      SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV,
+      SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV,
+      SPV_OC_OpSubgroupBlockReadINTEL, SPV_OC_OpSubgroupBlockWriteINTEL
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
index 34be336bb2a5..da3da3050efc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
@@ -105,6 +105,77 @@ def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
 
 // -----
 
+def SPV_GroupNonUniformBroadcastOp : SPV_Op<"GroupNonUniformBroadcast",
+  [NoSideEffect, AllTypesMatch<["value", "result"]>]> {
+  let summary = [{
+    Return the Value of the invocation identified by the id Id to all active
+    invocations in the group.
+  }];
+
+  let description = [{
+    Result Type  must be a scalar or vector of floating-point type, integer
+    type, or Boolean type.
+
+    Execution must be Workgroup or Subgroup Scope.
+
+     The type of Value must be the same as Result Type.
+
+    Id  must be a scalar of integer type, whose Signedness operand is 0.
+
+    Before version 1.5, Id must come from a constant instruction. Starting
+    with version 1.5, Id must be dynamically uniform.
+
+    The resulting value is undefined if Id is an inactive invocation, or is
+    greater than or equal to the size of the group.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    scope ::= `"Workgroup"` | `"Subgroup"`
+    integer-float-scalar-vector-type ::= integer-type | float-type |
+                               `vector<` integer-literal `x` integer-type `>` |
+                               `vector<` integer-literal `x` float-type `>`
+    group-non-uniform-broadcast-op ::= ssa-id `=` 
+	            `spv.GroupNonUniformBroadcast` scope ssa_use,
+                ssa_use `:` integer-float-scalar-vector-type `,` integer-type
+    ```mlir
+
+    #### Example:
+
+    ```
+    %scalar_value = ... : f32
+    %vector_value = ... : vector<4xf32>
+    %id = ... : i32
+    %0 = spv.GroupNonUniformBroadcast "Subgroup" %scalar_value, %id : f32, i32
+    %1 = spv.GroupNonUniformBroadcast "Workgroup" %vector_value, %id :
+      vector<4xf32>, i32
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPV_V_1_3>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[]>,
+    Capability<[SPV_C_GroupNonUniformBallot]>
+  ];
+
+  let arguments = (ins
+    SPV_ScopeAttr:$execution_scope,
+    SPV_Type:$value,
+    SPV_Integer:$id
+  );
+
+  let results = (outs
+    SPV_Type:$result
+  );
+
+  let assemblyFormat = [{
+    $execution_scope operands attr-dict `:` type($value) `,` type($id)
+  }];
+}
+
+// -----
+
 def SPV_GroupNonUniformElectOp : SPV_Op<"GroupNonUniformElect", []> {
   let summary = [{
     Result is true only in the active invocation with the lowest id in the
@@ -368,8 +439,8 @@ def SPV_GroupNonUniformFMulOp :
 def SPV_GroupNonUniformIAddOp :
     SPV_GroupNonUniformArithmeticOp<"GroupNonUniformIAdd", SPV_Integer, []> {
   let summary = [{
-    An integer add group operation of all Value operands contributed active
-    by invocations in the group.
+    An integer add group operation of all Value operands contributed by
+    active invocations in the group.
   }];
 
   let description = [{

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index a16dc1c8bc35..a01177132b27 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/FunctionImplementation.h"
@@ -2043,6 +2044,32 @@ static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spv.GroupNonUniformBroadcast
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(spirv::GroupNonUniformBroadcastOp broadcastOp) {
+  spirv::Scope scope = broadcastOp.execution_scope();
+  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+    return broadcastOp.emitOpError(
+        "execution scope must be 'Workgroup' or 'Subgroup'");
+
+  // SPIR-V spec: "Before version 1.5, Id must come from a
+  // constant instruction.
+  auto targetEnv = spirv::getDefaultTargetEnv(broadcastOp.getContext());
+  if (auto spirvModule = broadcastOp.getParentOfType<spirv::ModuleOp>())
+    targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
+
+  if (targetEnv.getVersion() < spirv::Version::V_1_5) {
+    auto *idOp = broadcastOp.id().getDefiningOp();
+    if (!idOp || !isa<spirv::ConstantOp,           // for normal constant
+                      spirv::ReferenceOfOp>(idOp)) // for spec constant
+      return broadcastOp.emitOpError("id must be the result of a constant op");
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spv.SubgroupBlockReadINTEL
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir
index ab714dfbaa00..f7b8f6cfc185 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir
@@ -8,6 +8,14 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     spv.ReturnValue %0: vector<4xi32>
   }
 
+  // CHECK-LABEL: @group_non_uniform_broadcast
+  spv.func @group_non_uniform_broadcast(%value: f32) -> f32 "None" {
+    %one = spv.constant 1 : i32
+    // CHECK: spv.GroupNonUniformBroadcast "Subgroup" %{{.*}}, %{{.*}} : f32, i32
+    %0 = spv.GroupNonUniformBroadcast "Subgroup" %value, %one : f32, i32
+    spv.ReturnValue %0: f32
+  }
+
   // CHECK-LABEL: @group_non_uniform_elect
   spv.func @group_non_uniform_elect() -> i1 "None" {
     // CHECK: %{{.+}} = spv.GroupNonUniformElect "Workgroup" : i1

diff  --git a/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir
index 86c3c2886a4f..5839ee7c5627 100644
--- a/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir
@@ -28,6 +28,45 @@ func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spv.NonUniformGroupBroadcast
+//===----------------------------------------------------------------------===//
+
+func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 {
+  %one = spv.constant 1 : i32
+  // CHECK: spv.GroupNonUniformBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, i32
+  %0 = spv.GroupNonUniformBroadcast "Workgroup" %value, %one : f32, i32
+  return %0: f32
+}
+
+// -----
+
+func @group_non_uniform_broadcast_vector(%value: vector<4xf32>) -> vector<4xf32> {
+  %one = spv.constant 1 : i32
+  // CHECK: spv.GroupNonUniformBroadcast "Subgroup" %{{.*}}, %{{.*}} : vector<4xf32>, i32
+  %0 = spv.GroupNonUniformBroadcast "Subgroup" %value, %one : vector<4xf32>, i32
+  return %0: vector<4xf32>
+}
+
+// -----
+
+func @group_non_uniform_broadcast_negative_scope(%value: f32, %localid: i32 ) -> f32 {
+  %one = spv.constant 1 : i32
+  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}} 
+  %0 = spv.GroupNonUniformBroadcast "Device" %value, %one : f32, i32
+  return %0: f32
+}
+
+// -----
+
+func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid: i32) -> f32 {
+  // expected-error @+1 {{id must be the result of a constant op}}
+  %0 = spv.GroupNonUniformBroadcast "Subgroup" %value, %localid : f32, i32
+  return %0: f32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spv.GroupNonUniformElect
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list