[Mlir-commits] [mlir] [mlir][gpu] Allow subgroup reductions over 1-d vector types (PR #76015)

Jakub Kuderski llvmlistbot at llvm.org
Wed Dec 20 12:39:56 PST 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/76015

>From 06485807404c6c398fb4e1856cd1284c5a2c54ce Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 20 Dec 2023 01:08:56 -0500
Subject: [PATCH 1/3] [mlir][gpu] Allow subgroup reductions over 1-d vector
 types

Each vector element is reduced independently, which is a form of
multi-reduction.

The plan is to allow for gradual lowering of multi-redictions that
results in fewer `gpu.shuffle` ops at the end:
1d `vector.multi_reduction` --> 1d `gpu.subgroup_reduce` --> smaller
1d `gpu.subgroup_reduce` --> packed `gpu.shuffle` over i32 or i64

For example we can perform 4 independent f16 reductions with a series of
`gpu.shuffles` over i64, reducing the final number of `gpu.shuffles` by
4x.
---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    | 16 ++++++--
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 11 ++++--
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 11 +++++-
 .../Conversion/GPUToSPIRV/reductions.mlir     | 38 +++++++++++++++++++
 mlir/test/Dialect/GPU/invalid.mlir            | 14 +++++--
 mlir/test/Dialect/GPU/ops.mlir                |  5 +++
 6 files changed, 83 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 2e21cd77d2d83b..7777dd58eba1f7 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -19,10 +19,11 @@ include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
 include "mlir/Dialect/GPU/IR/CompilationAttrs.td"
 include "mlir/Dialect/GPU/IR/ParallelLoopMapperAttr.td"
 include "mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td"
+include "mlir/IR/CommonTypeConstraints.td"
 include "mlir/IR/EnumAttr.td"
-include "mlir/Interfaces/FunctionInterfaces.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/DataLayoutInterfaces.td"
+include "mlir/Interfaces/FunctionInterfaces.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -1022,16 +1023,23 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
   let hasRegionVerifier = 1;
 }
 
+def AnyIntegerOrFloatOr1DVector :
+  AnyTypeOf<[AnyIntegerOrFloat, VectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>;
+
 def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]> {
   let summary = "Reduce values among subgroup.";
   let description = [{
     The `subgroup_reduce` op reduces the value of every work item across a
     subgroup. The result is equal for all work items of a subgroup.
 
+    When the reduced value is of a vector type, each vector element is reduced
+    independently. Only 1-d vector types are allowed.
+
     Example:
 
     ```mlir
-    %1 = gpu.subgroup_reduce add %0 : (f32) -> (f32)
+    %1 = gpu.subgroup_reduce add %a : (f32) -> (f32)
+    %2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> (vector<4xf16>)
     ```
 
     If `uniform` flag is set either none or all work items of a subgroup
@@ -1044,11 +1052,11 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
   }];
 
   let arguments = (ins
-    AnyIntegerOrFloat:$value,
+    AnyIntegerOrFloatOr1DVector:$value,
     GPU_AllReduceOperationAttr:$op,
     UnitAttr:$uniform
   );
-  let results = (outs AnyIntegerOrFloat:$result);
+  let results = (outs AnyIntegerOrFloatOr1DVector:$result);
 
   let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
                           (`uniform` $uniform^)? attr-dict
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index d383c16949f0ef..f1378fc998dadf 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include <optional>
 
@@ -591,10 +592,12 @@ class GPUSubgroupReduceConversion final
   LogicalResult
   matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto opType = op.getOp();
-    auto result =
-        createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), opType,
-                            /*isGroup*/ false, op.getUniform());
+    if (isa<VectorType>(adaptor.getValue().getType()))
+      return failure();
+
+    auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(),
+                                      adaptor.getOp(),
+                                      /*isGroup*/ false, op.getUniform());
     if (!result)
       return failure();
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7c3330f4c238f8..dd482f305fcbc8 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
@@ -588,8 +589,16 @@ static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
 //===----------------------------------------------------------------------===//
 
 LogicalResult gpu::SubgroupReduceOp::verify() {
+  Type elemType = getType();
+  if (auto vecTy = dyn_cast<VectorType>(elemType)) {
+    if (vecTy.isScalable())
+      return emitOpError() << "is not compatible with scalable vector types";
+
+    elemType = vecTy.getElementType();
+  }
+
   gpu::AllReduceOperation opName = getOp();
-  if (failed(verifyReduceOpAndType(opName, getType()))) {
+  if (failed(verifyReduceOpAndType(opName, elemType))) {
     return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
                        << "` reduction operation is not compatible with type "
                        << getType();
diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
index af58f4173136f8..44f85f68587f1a 100644
--- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir
@@ -655,6 +655,26 @@ gpu.module @kernels {
 
 // -----
 
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @test
+  //  CHECK-SAME: (%[[ARG:.*]]: i32)
+  gpu.func @test(%arg : vector<1xi32>) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // CHECK: %{{.*}} = spirv.GroupNonUniformSMax "Subgroup" "Reduce" %[[ARG]] : i32
+    %r0 = gpu.subgroup_reduce maxsi %arg : (vector<1xi32>) -> (vector<1xi32>)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
 // TODO: Handle boolean reductions.
 
 module attributes {
@@ -751,3 +771,21 @@ gpu.module @kernels {
   }
 }
 }
+
+// -----
+
+// Vector reductions need to be lowered to scalar reductions first.
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+gpu.module @kernels {
+  gpu.func @maxui(%arg : vector<2xi32>) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
+    %r0 = gpu.subgroup_reduce maxui %arg : (vector<2xi32>) -> (vector<2xi32>)
+    gpu.return
+  }
+}
+}
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index d8a40f89f80ac2..8a34d64326072b 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -333,9 +333,17 @@ func.func @reduce_invalid_op_type_maximumf(%arg0 : i32) {
 
 // -----
 
-func.func @subgroup_reduce_bad_type(%arg0 : vector<2xf32>) {
-  // expected-error at +1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float}}
-  %res = gpu.subgroup_reduce add %arg0 : (vector<2xf32>) -> vector<2xf32>
+func.func @subgroup_reduce_bad_type(%arg0 : vector<2x2xf32>) {
+  // expected-error at +1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or vector of}}
+  %res = gpu.subgroup_reduce add %arg0 : (vector<2x2xf32>) -> vector<2x2xf32>
+  return
+}
+
+// -----
+
+func.func @subgroup_reduce_bad_type_scalable(%arg0 : vector<[2]xf32>) {
+  // expected-error at +1 {{is not compatible with scalable vector types}}
+  %res = gpu.subgroup_reduce add %arg0 : (vector<[2]xf32>) -> vector<[2]xf32>
   return
 }
 
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 48193436415637..c3b548062638f1 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -84,6 +84,8 @@ module attributes {gpu.container_module} {
 
       %one = arith.constant 1.0 : f32
 
+      %vec = vector.broadcast %arg0 : f32 to vector<4xf32>
+
       // CHECK: %{{.*}} = gpu.all_reduce add %{{.*}} {
       // CHECK-NEXT: } : (f32) -> f32
       %sum = gpu.all_reduce add %one {} : (f32) -> (f32)
@@ -98,6 +100,9 @@ module attributes {gpu.container_module} {
       // CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} uniform : (f32) -> f32
       %sum_subgroup1 = gpu.subgroup_reduce add %one uniform : (f32) -> f32
 
+      // CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} : (vector<4xf32>) -> vector<4xf32
+      %sum_subgroup2 = gpu.subgroup_reduce add %vec : (vector<4xf32>) -> vector<4xf32>
+
       %width = arith.constant 7 : i32
       %offset = arith.constant 3 : i32
       // CHECK: gpu.shuffle xor %{{.*}}, %{{.*}}, %{{.*}} : f32

>From 05cfdb45efa59292d10d0c031d7d349d34fbaf02 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 20 Dec 2023 11:00:59 -0500
Subject: [PATCH 2/3] Improve spirv pattern

---
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index f1378fc998dadf..a1d534042debdd 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -592,12 +593,12 @@ class GPUSubgroupReduceConversion final
   LogicalResult
   matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (isa<VectorType>(adaptor.getValue().getType()))
-      return failure();
+    if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
+      return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
 
     auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(),
                                       adaptor.getOp(),
-                                      /*isGroup*/ false, op.getUniform());
+                                      /*isGroup*/ false, adaptor.getUniform());
     if (!result)
       return failure();
 

>From b49bd81974523172e85f1954f9344cda850fdace Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 20 Dec 2023 15:39:44 -0500
Subject: [PATCH 3/3] Fix typo

---
 mlir/test/Dialect/GPU/ops.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index c3b548062638f1..60512424383052 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -100,7 +100,7 @@ module attributes {gpu.container_module} {
       // CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} uniform : (f32) -> f32
       %sum_subgroup1 = gpu.subgroup_reduce add %one uniform : (f32) -> f32
 
-      // CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} : (vector<4xf32>) -> vector<4xf32
+      // CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} : (vector<4xf32>) -> vector<4xf32>
       %sum_subgroup2 = gpu.subgroup_reduce add %vec : (vector<4xf32>) -> vector<4xf32>
 
       %width = arith.constant 7 : i32



More information about the Mlir-commits mailing list