[Mlir-commits] [mlir] [mlir][gpu] Add 'cluster_size' attribute to gpu.subgroup_reduce (PR #104851)

Andrea Faulds llvmlistbot at llvm.org
Tue Aug 20 06:17:07 PDT 2024


https://github.com/andfau-amd updated https://github.com/llvm/llvm-project/pull/104851

>From 24fd0e5f65159fe2341b51dd555d4afa0126abdc Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Mon, 19 Aug 2024 22:32:40 +0200
Subject: [PATCH] [mlir][gpu] Add 'cluster_size' attribute to
 gpu.subgroup_reduce

This enables performing several reductions in parallel, each smaller
than the size of the subgroup.

One potential application is flash attention with subgroup-wide matrix
multiplication and reduction combined in one kernel. The multiplication
operation requires a 2D matrix to be distributed over the lanes of the
subgroup, which then constrains the shape the following reduction can
have if we want to keep data in registers.
---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    | 44 ++++++++--
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        |  4 +
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp |  4 +
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 11 +++
 .../GPU/Transforms/SubgroupReduceLowering.cpp | 49 ++++++++---
 mlir/test/Dialect/GPU/canonicalize.mlir       | 18 ++++
 mlir/test/Dialect/GPU/invalid.mlir            | 16 ++++
 .../Dialect/GPU/subgroup-reduce-lowering.mlir | 87 +++++++++++++++++++
 8 files changed, 213 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index a024c3018eb8d3..37dfb5343c951e 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1198,7 +1198,7 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
   let summary = "Reduce values among subgroup.";
   let description = [{
     The `subgroup_reduce` op reduces the value of every lane (work item) across
-    a subgroup. The result is equal for all lanes.
+    a subgroup.
 
     When the reduced value is of a vector type, each vector element is reduced
     independently. Only 1-d vector types are allowed.
@@ -1206,13 +1206,23 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
     Example:
 
     ```mlir
-    %1 = gpu.subgroup_reduce add %a : (f32) -> (f32)
-    %2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> (vector<4xf16>)
+    %1 = gpu.subgroup_reduce add %a : (f32) -> f32
+    %2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> vector<4xf16>
+    %3 = gpu.subgroup_reduce add %c cluster_size(4) : (f32) -> f32
     ```
 
     If `uniform` flag is set either none or all lanes of a subgroup need to execute
-    this op in convergence. The reduction operation must be one
-    of:
+    this op in convergence.
+
+    If a `cluster_size` is not provided, the reduction covers all lanes of the
+    subgroup and the result is equal for all lanes.
+
+    If a `cluster_size` is provided, the subgroup is divided into clusters of
+    `cluster_size` contiguous lanes each, a reduction is done for all lanes of
+    each cluster (in parallel), and the result is equal for all lanes in a
+    cluster.
+
+    The reduction operation must be one of:
     *  Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
        `or`, `xor`
     *  Floating point types: `add`, `mul`, `minnumf`, `maxnumf`, `minimumf`,
@@ -1222,12 +1232,32 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
   let arguments = (ins
     AnyIntegerOrFloatOr1DVector:$value,
     GPU_AllReduceOperationAttr:$op,
-    UnitAttr:$uniform
+    UnitAttr:$uniform,
+    OptionalAttr<I32Attr>:$cluster_size
   );
   let results = (outs AnyIntegerOrFloatOr1DVector:$result);
 
+  let builders = [
+    OpBuilder<(ins "Value":$value,
+               "::mlir::gpu::AllReduceOperation":$op,
+               "bool":$uniform), [{
+      build($_builder, $_state, value, op, uniform, /*cluster_size=*/ nullptr);
+    }]>,
+    OpBuilder<(ins "Value":$value,
+               "::mlir::gpu::AllReduceOperation":$op,
+               "bool":$uniform,
+               "std::optional<uint32_t>":$cluster_size), [{
+      if (cluster_size)
+        build($_builder, $_state, value, op, uniform, $_builder.getI32IntegerAttr(*cluster_size));
+      else
+        build($_builder, $_state, value, op, uniform, nullptr);
+    }]>
+  ];
+
   let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
-                          (`uniform` $uniform^)? attr-dict
+                          (`uniform` $uniform^)?
+                          (`cluster_size` `(` $cluster_size^ `)`)?
+                          attr-dict
                           `:` functional-type(operands, results) }];
 
   let hasFolder = 1;
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 9957a5804c0b65..9b1be198f77a82 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -102,6 +102,10 @@ struct GPUSubgroupReduceOpLowering
 
   matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    if (op.getClusterSize())
+      return rewriter.notifyMatchFailure(
+          op, "lowering for clustered reduce not implemented");
+
     if (!op.getUniform())
       return rewriter.notifyMatchFailure(
           op, "cannot be lowered to redux as the op must be run "
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index b18b6344732eeb..a8ff9247e796ab 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -579,6 +579,10 @@ class GPUSubgroupReduceConversion final
   LogicalResult
   matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    if (op.getClusterSize())
+      return rewriter.notifyMatchFailure(
+          op, "lowering for clustered reduce not implemented");
+
     if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
       return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index eeffe829446cf9..174d82615abe1c 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -620,10 +620,21 @@ LogicalResult gpu::SubgroupReduceOp::verify() {
                        << "` reduction operation is not compatible with type "
                        << getType();
   }
+
+  if (auto clusterSize = getClusterSize()) {
+    uint32_t size = *clusterSize;
+    if (!llvm::isPowerOf2_32(size)) {
+      return emitError() << "cluster size " << size << " is not a power of two";
+    }
+  }
+
   return success();
 }
 
 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
+  if (getClusterSize() == 1)
+    return getValue();
+
   if (!getUniform() && canMakeGroupOpUniform(*this)) {
     setUniform(true);
     return getResult();
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 561a7e569eb2fa..45895d77561362 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -50,6 +50,8 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
+    std::optional<uint32_t> clusterSize = op.getClusterSize();
+
     auto vecTy = dyn_cast<VectorType>(op.getType());
     if (!vecTy || vecTy.getNumElements() < 2)
       return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
@@ -95,7 +97,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
       }
 
       Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
-          loc, extracted, op.getOp(), op.getUniform());
+          loc, extracted, op.getOp(), op.getUniform(), clusterSize);
       if (numElems == 1) {
         res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
         continue;
@@ -127,6 +129,8 @@ struct ScalarizeSingleElementReduce final
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
+    std::optional<uint32_t> clusterSize = op.getClusterSize();
+
     auto vecTy = dyn_cast<VectorType>(op.getType());
     if (!vecTy || vecTy.getNumElements() != 1)
       return rewriter.notifyMatchFailure(op, "not a single-element reduction");
@@ -136,7 +140,7 @@ struct ScalarizeSingleElementReduce final
     Location loc = op.getLoc();
     Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
     Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
-        loc, extracted, op.getOp(), op.getUniform());
+        loc, extracted, op.getOp(), op.getUniform(), clusterSize);
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
     return success();
   }
@@ -147,17 +151,20 @@ struct ScalarizeSingleElementReduce final
 /// type, respectively. For example, with `input` of type `f16`, `packFn` could
 /// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
 /// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
-/// the subgroup is `subgroupSize` lanes wide and reduces across all of them.
+/// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
+/// `clusterSize` lanes, reducing all lanes in each cluster in parallel.
 static Value createSubgroupShuffleReduction(
     OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
-    unsigned subgroupSize, function_ref<Value(Value)> packFn,
-    function_ref<Value(Value)> unpackFn) {
+    unsigned clusterSize, unsigned subgroupSize,
+    function_ref<Value(Value)> packFn, function_ref<Value(Value)> unpackFn) {
+  assert(llvm::isPowerOf2_32(clusterSize));
   assert(llvm::isPowerOf2_32(subgroupSize));
+  assert(clusterSize <= subgroupSize);
   // Lane value always stays in the original type. We use it to perform arith
   // reductions.
   Value laneVal = input;
   // Parallel reduction using butterfly shuffles.
-  for (unsigned i = 1; i < subgroupSize; i <<= 1) {
+  for (unsigned i = 1; i < clusterSize; i <<= 1) {
     Value shuffled = builder
                          .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
                                                  /*width=*/subgroupSize,
@@ -183,6 +190,13 @@ struct ScalarSubgroupReduceToShuffles final
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
+    std::optional<uint32_t> clusterSize = op.getClusterSize();
+    if (clusterSize && *clusterSize > subgroupSize)
+      return op.emitError()
+             << "cluster size " << *clusterSize
+             << " is greater than subgroup size " << subgroupSize;
+    unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
+
     Type valueTy = op.getType();
     unsigned elemBitwidth =
         getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth();
@@ -196,7 +210,8 @@ struct ScalarSubgroupReduceToShuffles final
       auto identityFn = [](Value v) { return v; };
       rewriter.replaceOp(op, createSubgroupShuffleReduction(
                                  rewriter, loc, op.getValue(), op.getOp(),
-                                 subgroupSize, identityFn, identityFn));
+                                 effectiveClusterSize, subgroupSize, identityFn,
+                                 identityFn));
       return success();
     }
 
@@ -215,9 +230,10 @@ struct ScalarSubgroupReduceToShuffles final
       return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
     };
 
-    rewriter.replaceOp(op, createSubgroupShuffleReduction(
-                               rewriter, loc, op.getValue(), op.getOp(),
-                               subgroupSize, packFn, unpackFn));
+    rewriter.replaceOp(
+        op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
+                                           op.getOp(), effectiveClusterSize,
+                                           subgroupSize, packFn, unpackFn));
     return success();
   }
 
@@ -237,6 +253,13 @@ struct VectorSubgroupReduceToShuffles final
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
+    std::optional<uint32_t> clusterSize = op.getClusterSize();
+    if (clusterSize && *clusterSize > subgroupSize)
+      return op.emitError()
+             << "cluster size " << *clusterSize
+             << " is greater than subgroup size " << subgroupSize;
+    unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
+
     auto vecTy = dyn_cast<VectorType>(op.getType());
     if (!vecTy)
       return rewriter.notifyMatchFailure(op, "value type is not a vector");
@@ -285,9 +308,9 @@ struct VectorSubgroupReduceToShuffles final
       return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
     };
 
-    Value res =
-        createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
-                                       subgroupSize, packFn, unpackFn);
+    Value res = createSubgroupShuffleReduction(rewriter, loc, extendedInput,
+                                               op.getOp(), effectiveClusterSize,
+                                               subgroupSize, packFn, unpackFn);
 
     if (vecBitwidth < shuffleBitwidth) {
       res = rewriter.create<vector::ExtractStridedSliceOp>(
diff --git a/mlir/test/Dialect/GPU/canonicalize.mlir b/mlir/test/Dialect/GPU/canonicalize.mlir
index 372dd78790276c..469c03c9460df1 100644
--- a/mlir/test/Dialect/GPU/canonicalize.mlir
+++ b/mlir/test/Dialect/GPU/canonicalize.mlir
@@ -246,6 +246,24 @@ func.func @make_subgroup_reduce_uniform() {
 
 // -----
 
+// CHECK-LABEL: func @subgroup_reduce_cluster_size_1
+//       CHECK: gpu.launch blocks
+//       CHECK: %[[V1:.*]] = "test.test2"() : () -> i32
+//       CHECK: "test.test3"(%[[V1]]) : (i32) -> ()
+func.func @subgroup_reduce_cluster_size_1() {
+  %0:6 = "test.test1"() : () -> (index, index, index, index, index, index)
+  gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %0#0, %arg7 = %0#1, %arg8 = %0#2)
+    threads(%arg3, %arg4, %arg5) in (%arg9 = %0#3, %arg10 = %0#4, %arg11 = %0#5) {
+    %1 = "test.test2"() : () -> i32
+    %2 = gpu.subgroup_reduce add %1 cluster_size(1) : (i32) -> (i32)
+    "test.test3"(%2) : (i32) -> ()
+    gpu.terminator
+  }
+  return
+}
+
+// -----
+
 // The GPU kernel does not have any side effecting ops, so the entire
 // gpu.launch op can fold away.
 
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index e9d8f329be8ede..ce09190e1b7280 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -333,6 +333,22 @@ func.func @reduce_invalid_op_type_maximumf(%arg0 : i32) {
 
 // -----
 
+func.func @subgroup_reduce_zero_cluster_size(%arg0 : vector<4xf32>) {
+  // expected-error at +1 {{cluster size 0 is not a power of two}}
+  %res = gpu.subgroup_reduce add %arg0 cluster_size(0) : (vector<4xf32>) -> vector<4xf32>
+  return
+}
+
+// -----
+
+func.func @subgroup_reduce_npot_cluster_size(%arg0 : vector<4xf32>) {
+  // expected-error at +1 {{cluster size 3 is not a power of two}}
+  %res = gpu.subgroup_reduce add %arg0 cluster_size(3) : (vector<4xf32>) -> vector<4xf32>
+  return
+}
+
+// -----
+
 func.func @subgroup_reduce_bad_type(%arg0 : vector<2x2xf32>) {
   // expected-error at +1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or vector of}}
   %res = gpu.subgroup_reduce add %arg0 : (vector<2x2xf32>) -> vector<2x2xf32>
diff --git a/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir
index f04a01ffe75d3c..37608ce4cfed76 100644
--- a/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir
+++ b/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir
@@ -34,6 +34,16 @@ gpu.module @kernels {
     %sum1 = gpu.subgroup_reduce mul %arg0 uniform : (vector<5xf16>) -> (vector<5xf16>)
     "test.consume"(%sum1) : (vector<5xf16>) -> ()
 
+    // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} cluster_size(4)
+    // CHECK-SUB: "test.consume"
+    %sum2 = gpu.subgroup_reduce mul %arg0 cluster_size(4) : (vector<5xf16>) -> (vector<5xf16>)
+    "test.consume"(%sum2) : (vector<5xf16>) -> ()
+
+    // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform cluster_size(4)
+    // CHECK-SUB: "test.consume"
+    %sum3 = gpu.subgroup_reduce mul %arg0 uniform cluster_size(4) : (vector<5xf16>) -> (vector<5xf16>)
+    "test.consume"(%sum3) : (vector<5xf16>) -> ()
+
     // CHECK-SUB: gpu.return
     gpu.return
   }
@@ -55,6 +65,16 @@ gpu.module @kernels {
     %sum1 = gpu.subgroup_reduce add %arg0 uniform : (vector<1xf32>) -> (vector<1xf32>)
     "test.consume"(%sum1) : (vector<1xf32>) -> ()
 
+    // CHECK-SUB: gpu.subgroup_reduce add {{.+}} cluster_size(8) : (f32) -> f32
+    // CHECK-SUB: "test.consume"
+    %sum2 = gpu.subgroup_reduce add %arg0 cluster_size(8) : (vector<1xf32>) -> (vector<1xf32>)
+    "test.consume"(%sum2) : (vector<1xf32>) -> ()
+
+    // CHECK-SUB: gpu.subgroup_reduce add {{.+}} uniform cluster_size(8) : (f32) -> f32
+    // CHECK-SUB: "test.consume"
+    %sum3 = gpu.subgroup_reduce add %arg0 uniform cluster_size(8) : (vector<1xf32>) -> (vector<1xf32>)
+    "test.consume"(%sum3) : (vector<1xf32>) -> ()
+
     // CHECK-SUB: gpu.return
     gpu.return
   }
@@ -108,6 +128,28 @@ gpu.module @kernels {
     gpu.return
   }
 
+  // CHECK-SHFL-LABEL: gpu.func @kernel3_clustered(
+  // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i32)
+  gpu.func @kernel3_clustered(%arg0: i32) kernel {
+    // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
+    // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
+    // CHECK-SHFL-DAG: %[[C4:.+]] = arith.constant 4 : i32
+    // CHECK-SHFL-DAG: %[[C32:.+]] = arith.constant 32 : i32
+
+    // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[ARG0]], %[[C1]], %[[C32]] : i32
+    // CHECK-SHFL: %[[A0:.+]] = arith.addi %[[ARG0]], %[[S0]] : i32
+    // CHECK-SHFL: %[[S1:.+]], %{{.+}} = gpu.shuffle xor %[[A0]], %[[C2]], %[[C32]] : i32
+    // CHECK-SHFL: %[[A1:.+]] = arith.addi %[[A0]], %[[S1]] : i32
+    // CHECK-SHFL: %[[S2:.+]], %{{.+}} = gpu.shuffle xor %[[A1]], %[[C4]], %[[C32]] : i32
+    // CHECK-SHFL: %[[A2:.+]] = arith.addi %[[A1]], %[[S2]] : i32
+    // CHECK-SHFL: "test.consume"(%[[A2]]) : (i32) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 cluster_size(8) : (i32) -> i32
+    "test.consume"(%sum0) : (i32) -> ()
+
+    // CHECK-SHFL: gpu.return
+    gpu.return
+  }
+
   // CHECK-SHFL-LABEL: gpu.func @kernel4(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: vector<2xf16>)
   gpu.func @kernel4(%arg0: vector<2xf16>) kernel {
@@ -144,6 +186,21 @@ gpu.module @kernels {
     gpu.return
   }
 
+  // CHECK-SHFL-LABEL: gpu.func @kernel4_clustered(
+  // CHECK-SHFL-SAME:    %[[ARG0:.+]]: vector<2xf16>)
+  gpu.func @kernel4_clustered(%arg0: vector<2xf16>) kernel {
+    // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
+    // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
+    // CHECK-SHFL-DAG: %[[C32:.+]] = arith.constant 32 : i32
+
+    // CHECK-SHFL-COUNT-2: gpu.shuffle xor
+    %sum0 = gpu.subgroup_reduce add %arg0 cluster_size(4) : (vector<2xf16>) -> (vector<2xf16>)
+    "test.consume"(%sum0) : (vector<2xf16>) -> ()
+
+    // CHECK-SHFL: gpu.return
+    gpu.return
+  }
+
   // CHECK-SHFL-LABEL: gpu.func @kernel5(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i16)
   gpu.func @kernel5(%arg0: i16) kernel {
@@ -164,6 +221,26 @@ gpu.module @kernels {
     gpu.return
   }
 
+  // CHECK-SHFL-LABEL: gpu.func @kernel5_clustered(
+  // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i16)
+  gpu.func @kernel5_clustered(%arg0: i16) kernel {
+    // CHECK-SHFL: %[[E0:.+]] = arith.extui %[[ARG0]] : i16 to i32
+    // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[E0]], {{.+}} : i32
+    // CHECK-SHFL: %[[T0:.+]] = arith.trunci %[[S0]] : i32 to i16
+    // CHECK-SHFL: %[[A0:.+]] = arith.addi %[[ARG0]], %[[T0]] : i16
+    // CHECK-SHFL: %[[E1:.+]] = arith.extui %[[A0]] : i16 to i32
+    // CHECK-SHFL: %{{.+}}, %{{.+}} = gpu.shuffle xor %[[E1]], {{.+}} : i32
+    // CHECK-SHFL-COUNT-2: gpu.shuffle xor
+    // CHECK-SHFL: arith.trunci {{.+}} : i32 to i16
+    // CHECK-SHFL: %[[AL:.+]] = arith.addi {{.+}} : i16
+    // CHECK-SHFL: "test.consume"(%[[AL]]) : (i16) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 cluster_size(16) : (i16) -> i16
+    "test.consume"(%sum0) : (i16) -> ()
+
+    // CHECK-SHFL: gpu.return
+    gpu.return
+  }
+
   // CHECK-SHFL-LABEL: gpu.func @kernel6(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: vector<3xi8>)
   gpu.func @kernel6(%arg0: vector<3xi8>) kernel {
@@ -187,5 +264,15 @@ gpu.module @kernels {
     gpu.return
   }
 
+  // CHECK-SHFL-LABEL: gpu.func @kernel_cluster_size_is_subgroup_size(
+  // CHECK-SHFL-SAME:    %[[ARG0:.+]]: vector<3xi8>)
+  gpu.func @kernel_cluster_size_is_subgroup_size(%arg0: vector<3xi8>) kernel {
+    // CHECK-SHFL-COUNT-5: gpu.shuffle xor
+    %sum0 = gpu.subgroup_reduce add %arg0 cluster_size(32) : (vector<3xi8>) -> (vector<3xi8>)
+    "test.consume"(%sum0) : (vector<3xi8>) -> ()
+
+    // CHECK-SHFL: gpu.return
+    gpu.return
+  }
 }
 



More information about the Mlir-commits mailing list