[Mlir-commits] [mlir] b845add - [mlir][gpu] Add `subgroup_reduce` operation

Ivan Butygin llvmlistbot at llvm.org
Tue Oct 11 02:49:12 PDT 2022


Author: Ivan Butygin
Date: 2022-10-11T11:47:15+02:00
New Revision: b845addae89b6940c1af3c453aab914c6d170d20

URL: https://github.com/llvm/llvm-project/commit/b845addae89b6940c1af3c453aab914c6d170d20
DIFF: https://github.com/llvm/llvm-project/commit/b845addae89b6940c1af3c453aab914c6d170d20.diff

LOG: [mlir][gpu] Add `subgroup_reduce` operation

Introduce `subgroup_reduce` operation, similar to `all_reduce`, but operating on subgroup scope instead of workgroup.
It is intended as low-level building block for more high level abstractions (e.g for workgroup-wide `all_reduce` ops).
Only introduce version taking reduce operation enum for simplicity sake.

Differential Revision: https://reviews.llvm.org/D135323

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/test/Dialect/GPU/invalid.mlir
    mlir/test/Dialect/GPU/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index f3d10464cbc45..f1d894a59455c 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -717,6 +717,30 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
   let hasRegionVerifier = 1;
 }
 
+def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
+    [SameOperandsAndResultType]>,
+    Arguments<(ins AnyType:$value,
+               GPU_AllReduceOperationAttr:$op)>,
+    Results<(outs AnyType)> {
+  let summary = "Reduce values among subgroup.";
+  let description = [{
+    The `subgroup_reduce` op reduces the value of every work item across a
+    subgroup. The result is equal for all work items of a subgroup.
+
+    Example:
+
+    ```mlir
+    %1 = gpu.subgroup_reduce add %0 : (f32) -> (f32)
+    ```
+
+    Either none or all work items of a subgroup need to execute this op
+    in convergence.
+  }];
+  let assemblyFormat = [{ custom<AllReduceOperation>($op) $value attr-dict
+                          `:` functional-type(operands, results) }];
+  let hasVerifier = 1;
+}
+
 def GPU_ShuffleOpXor  : I32EnumAttrCase<"XOR",  0, "xor">;
 def GPU_ShuffleOpDown : I32EnumAttrCase<"DOWN", 1, "down">;
 def GPU_ShuffleOpUp   : I32EnumAttrCase<"UP",   2, "up">;

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 83ee9bbfa384b..bfdcedfa8d771 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -309,6 +309,17 @@ static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
 // AllReduceOp
 //===----------------------------------------------------------------------===//
 
+static bool verifyReduceOpAndType(gpu::AllReduceOperation opName,
+                                  Type resType) {
+  if ((opName == gpu::AllReduceOperation::AND ||
+       opName == gpu::AllReduceOperation::OR ||
+       opName == gpu::AllReduceOperation::XOR) &&
+      !resType.isa<IntegerType>())
+    return false;
+
+  return true;
+}
+
 LogicalResult gpu::AllReduceOp::verifyRegions() {
   if (getBody().empty() != getOp().has_value())
     return emitError("expected either an op attribute or a non-empty body");
@@ -333,10 +344,7 @@ LogicalResult gpu::AllReduceOp::verifyRegions() {
       return emitError("expected gpu.yield op in region");
   } else {
     gpu::AllReduceOperation opName = *getOp();
-    if ((opName == gpu::AllReduceOperation::AND ||
-         opName == gpu::AllReduceOperation::OR ||
-         opName == gpu::AllReduceOperation::XOR) &&
-        !getType().isa<IntegerType>()) {
+    if (!verifyReduceOpAndType(opName, getType())) {
       return emitError()
              << '`' << gpu::stringifyAllReduceOperation(opName)
              << "` accumulator is only compatible with Integer type";
@@ -364,6 +372,19 @@ static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
     attr.print(printer);
 }
 
+//===----------------------------------------------------------------------===//
+// SubgroupReduceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult gpu::SubgroupReduceOp::verify() {
+  gpu::AllReduceOperation opName = getOp();
+  if (!verifyReduceOpAndType(opName, getType())) {
+    return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
+                       << "` accumulator is only compatible with Integer type";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // AsyncOpInterface
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index f3c81233f7b54..b029d2fa7c9a4 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -245,6 +245,14 @@ func.func @reduce_invalid_op_type(%arg0 : f32) {
 
 // -----
 
+func.func @subgroup_reduce_invalid_op_type(%arg0 : f32) {
+  // expected-error at +1 {{`and` accumulator is only compatible with Integer type}}
+  %res = gpu.subgroup_reduce and %arg0 : (f32) -> (f32)
+  return
+}
+
+// -----
+
 func.func @reduce_incorrect_region_arguments(%arg0 : f32) {
   // expected-error at +1 {{expected two region arguments}}
   %res = gpu.all_reduce %arg0 {

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 52320744d0784..9b31a326aa919 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -85,6 +85,9 @@ module attributes {gpu.container_module} {
       %one = arith.constant 1.0 : f32
       %sum = gpu.all_reduce add %one {} : (f32) -> (f32)
 
+      // CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} : (f32) -> f32
+      %sum_subgroup = gpu.subgroup_reduce add %one : (f32) -> f32
+
       %width = arith.constant 7 : i32
       %offset = arith.constant 3 : i32
       // CHECK: gpu.shuffle xor %{{.*}}, %{{.*}}, %{{.*}} : f32


        


More information about the Mlir-commits mailing list