[Mlir-commits] [mlir] 3d01f0a - [mlir][gpu] Add 'cluster_stride' attribute to gpu.subgroup_reduce (#107142)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 5 06:03:26 PDT 2024


Author: Andrea Faulds
Date: 2024-09-05T09:03:22-04:00
New Revision: 3d01f0a33b9a14545217938fbd2475226ade2719

URL: https://github.com/llvm/llvm-project/commit/3d01f0a33b9a14545217938fbd2475226ade2719
DIFF: https://github.com/llvm/llvm-project/commit/3d01f0a33b9a14545217938fbd2475226ade2719.diff

LOG: [mlir][gpu] Add 'cluster_stride' attribute to gpu.subgroup_reduce (#107142)

Follow-up to 7aa22f013e24d20291aad745368ff907baa9dfa4, adding an
additional attribute needed in some applications.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
    mlir/test/Dialect/GPU/canonicalize.mlir
    mlir/test/Dialect/GPU/invalid.mlir
    mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index d2a5e5d77ad843..6098eb34d04d52 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1200,10 +1200,12 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
     The `subgroup_reduce` op reduces the values of lanes (work items) across a
     subgroup.
 
-    The subgroup is divided into clusters of `cluster_size` contiguous lanes
-    each, and a reduction is done for every lane of each cluster (in parallel).
-    The result is equal for all lanes in a cluster. When `cluster_size` is
-    omitted, there is a single cluster covering the entire subgroup.
+    The subgroup is divided into clusters starting at lane index 0. Within each
+    cluster, there are `size` lanes, and the lane index advances by `stride`.
+    A reduction is done for each cluster in parallel: every lane in the cluster
+    is reduced, and the result is equal for all lanes in the cluster. If `size`
+    is omitted, there is a single cluster covering the entire subgroup. If
+    `stride` is omitted, the stride is 1 (the cluster's lanes are contiguous).
 
     When the reduced value is of a vector type, each vector element is reduced
     independently. Only 1-d vector types are allowed.
@@ -1213,7 +1215,8 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
     ```mlir
     %1 = gpu.subgroup_reduce add %a : (f32) -> f32
     %2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> vector<4xf16>
-    %3 = gpu.subgroup_reduce add %c cluster_size(4) : (f32) -> f32
+    %3 = gpu.subgroup_reduce add %c cluster(size = 4) : (f32) -> f32
+    %3 = gpu.subgroup_reduce add %c cluster(size = 4, stride = 2) : (f32) -> f32
     ```
 
     If `uniform` flag is set either none or all lanes of a subgroup need to execute
@@ -1230,7 +1233,8 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
     AnyIntegerOrFloatOr1DVector:$value,
     GPU_AllReduceOperationAttr:$op,
     UnitAttr:$uniform,
-    OptionalAttr<I32Attr>:$cluster_size
+    OptionalAttr<I32Attr>:$cluster_size,
+    DefaultValuedAttr<I32Attr,"1">:$cluster_stride
   );
   let results = (outs AnyIntegerOrFloatOr1DVector:$result);
 
@@ -1238,19 +1242,29 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
     OpBuilder<(ins "Value":$value,
                "::mlir::gpu::AllReduceOperation":$op,
                "bool":$uniform), [{
-      build($_builder, $_state, value, op, uniform, /*cluster_size=*/ nullptr);
+      build($_builder, $_state, value, op, uniform, std::nullopt);
     }]>,
     OpBuilder<(ins "Value":$value,
                "::mlir::gpu::AllReduceOperation":$op,
                "bool":$uniform,
                "std::optional<uint32_t>":$cluster_size), [{
-      build($_builder, $_state, value, op, uniform, cluster_size ? $_builder.getI32IntegerAttr(*cluster_size) : nullptr);
+      build($_builder, $_state, value, op, uniform,
+            cluster_size ? $_builder.getI32IntegerAttr(*cluster_size) : nullptr);
+    }]>,
+    OpBuilder<(ins "Value":$value,
+               "::mlir::gpu::AllReduceOperation":$op,
+               "bool":$uniform,
+               "std::optional<uint32_t>":$cluster_size,
+               "uint32_t":$cluster_stride), [{
+      build($_builder, $_state, value, op, uniform,
+            cluster_size ? $_builder.getI32IntegerAttr(*cluster_size) : nullptr,
+            cluster_stride);
     }]>
   ];
 
   let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
                           (`uniform` $uniform^)?
-                          (`cluster_size` `(` $cluster_size^ `)`)?
+                          (`cluster` `(` `size` `=` $cluster_size^ (`,` `stride` `=` $cluster_stride^)? `)`)?
                           attr-dict
                           `:` functional-type(operands, results) }];
 

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index e45ba7838b453c..f822c11aeec008 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -621,7 +621,8 @@ LogicalResult gpu::SubgroupReduceOp::verify() {
                        << getType();
   }
 
-  if (auto clusterSize = getClusterSize()) {
+  auto clusterSize = getClusterSize();
+  if (clusterSize) {
     uint32_t size = *clusterSize;
     if (!llvm::isPowerOf2_32(size)) {
       return emitOpError() << "cluster size " << size
@@ -629,6 +630,16 @@ LogicalResult gpu::SubgroupReduceOp::verify() {
     }
   }
 
+  uint32_t stride = getClusterStride();
+  if (stride != 1 && !clusterSize) {
+    return emitOpError() << "cluster stride can only be specified if cluster "
+                            "size is specified";
+  }
+  if (!llvm::isPowerOf2_32(stride)) {
+    return emitOpError() << "cluster stride " << stride
+                         << " is not a power of two";
+  }
+
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 288f7ab9f30222..2cf5d12588a73c 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -50,8 +50,6 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
-    std::optional<uint32_t> clusterSize = op.getClusterSize();
-
     auto vecTy = dyn_cast<VectorType>(op.getType());
     if (!vecTy || vecTy.getNumElements() < 2)
       return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
@@ -97,7 +95,8 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
       }
 
       Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
-          loc, extracted, op.getOp(), op.getUniform(), clusterSize);
+          loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
+          op.getClusterStride());
       if (numElems == 1) {
         res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
         continue;
@@ -129,8 +128,6 @@ struct ScalarizeSingleElementReduce final
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
-    std::optional<uint32_t> clusterSize = op.getClusterSize();
-
     auto vecTy = dyn_cast<VectorType>(op.getType());
     if (!vecTy || vecTy.getNumElements() != 1)
       return rewriter.notifyMatchFailure(op, "not a single-element reduction");
@@ -140,34 +137,64 @@ struct ScalarizeSingleElementReduce final
     Location loc = op.getLoc();
     Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
     Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
-        loc, extracted, op.getOp(), op.getUniform(), clusterSize);
+        loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
+        op.getClusterStride());
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
     return success();
   }
 };
 
+struct ClusterInfo {
+  unsigned clusterStride;
+  unsigned clusterSize;
+  unsigned subgroupSize;
+};
+
+static FailureOr<ClusterInfo>
+getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) {
+  assert(llvm::isPowerOf2_32(subgroupSize));
+
+  std::optional<uint32_t> clusterSize = op.getClusterSize();
+  assert(!clusterSize ||
+         llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this.
+  if (clusterSize && *clusterSize > subgroupSize)
+    return op.emitOpError()
+           << "cluster size " << *clusterSize
+           << " is greater than subgroup size " << subgroupSize;
+  unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
+
+  auto clusterStride = op.getClusterStride();
+  assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this.
+  if (clusterStride >= subgroupSize)
+    return op.emitOpError()
+           << "cluster stride " << clusterStride
+           << " is not less than subgroup size " << subgroupSize;
+
+  return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
+}
+
 /// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
 /// and `unpackFn` to convert to the native shuffle type and to the reduction
 /// type, respectively. For example, with `input` of type `f16`, `packFn` could
 /// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
 /// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
 /// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
-/// `clusterSize` lanes, reducing all lanes in each cluster in parallel.
-static Value createSubgroupShuffleReduction(
-    OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
-    unsigned clusterSize, unsigned subgroupSize,
-    function_ref<Value(Value)> packFn, function_ref<Value(Value)> unpackFn) {
-  assert(llvm::isPowerOf2_32(clusterSize));
-  assert(llvm::isPowerOf2_32(subgroupSize));
-  assert(clusterSize <= subgroupSize);
+/// `clusterSize` lanes starting at lane 0 with a stride of `clusterStride` for
+/// lanes within a cluster, reducing all lanes in each cluster in parallel.
+Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
+                                     Value input, gpu::AllReduceOperation mode,
+                                     const ClusterInfo &ci,
+                                     function_ref<Value(Value)> packFn,
+                                     function_ref<Value(Value)> unpackFn) {
   // Lane value always stays in the original type. We use it to perform arith
   // reductions.
   Value laneVal = input;
   // Parallel reduction using butterfly shuffles.
-  for (unsigned i = 1; i < clusterSize; i <<= 1) {
+  for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
+       i <<= 1) {
     Value shuffled = builder
                          .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
-                                                 /*width=*/subgroupSize,
+                                                 /*width=*/ci.subgroupSize,
                                                  /*mode=*/gpu::ShuffleMode::XOR)
                          .getShuffleResult();
     laneVal = vector::makeArithReduction(builder, loc,
@@ -190,12 +217,9 @@ struct ScalarSubgroupReduceToShuffles final
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
-    std::optional<uint32_t> clusterSize = op.getClusterSize();
-    if (clusterSize && *clusterSize > subgroupSize)
-      return op.emitOpError()
-             << "cluster size " << *clusterSize
-             << " is greater than subgroup size " << subgroupSize;
-    unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
+    auto ci = getAndValidateClusterInfo(op, subgroupSize);
+    if (failed(ci))
+      return failure();
 
     Type valueTy = op.getType();
     unsigned elemBitwidth =
@@ -209,9 +233,8 @@ struct ScalarSubgroupReduceToShuffles final
     if (elemBitwidth == shuffleBitwidth) {
       auto identityFn = [](Value v) { return v; };
       rewriter.replaceOp(op, createSubgroupShuffleReduction(
-                                 rewriter, loc, op.getValue(), op.getOp(),
-                                 effectiveClusterSize, subgroupSize, identityFn,
-                                 identityFn));
+                                 rewriter, loc, op.getValue(), op.getOp(), *ci,
+                                 identityFn, identityFn));
       return success();
     }
 
@@ -232,8 +255,7 @@ struct ScalarSubgroupReduceToShuffles final
 
     rewriter.replaceOp(
         op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
-                                           op.getOp(), effectiveClusterSize,
-                                           subgroupSize, packFn, unpackFn));
+                                           op.getOp(), *ci, packFn, unpackFn));
     return success();
   }
 
@@ -253,12 +275,9 @@ struct VectorSubgroupReduceToShuffles final
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
-    std::optional<uint32_t> clusterSize = op.getClusterSize();
-    if (clusterSize && *clusterSize > subgroupSize)
-      return op.emitOpError()
-             << "cluster size " << *clusterSize
-             << " is greater than subgroup size " << subgroupSize;
-    unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
+    auto ci = getAndValidateClusterInfo(op, subgroupSize);
+    if (failed(ci))
+      return failure();
 
     auto vecTy = dyn_cast<VectorType>(op.getType());
     if (!vecTy)
@@ -308,9 +327,8 @@ struct VectorSubgroupReduceToShuffles final
       return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
     };
 
-    Value res = createSubgroupShuffleReduction(rewriter, loc, extendedInput,
-                                               op.getOp(), effectiveClusterSize,
-                                               subgroupSize, packFn, unpackFn);
+    Value res = createSubgroupShuffleReduction(
+        rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn);
 
     if (vecBitwidth < shuffleBitwidth) {
       res = rewriter.create<vector::ExtractStridedSliceOp>(

diff  --git a/mlir/test/Dialect/GPU/canonicalize.mlir b/mlir/test/Dialect/GPU/canonicalize.mlir
index 469c03c9460df1..d342ae9df10eea 100644
--- a/mlir/test/Dialect/GPU/canonicalize.mlir
+++ b/mlir/test/Dialect/GPU/canonicalize.mlir
@@ -255,7 +255,7 @@ func.func @subgroup_reduce_cluster_size_1() {
   gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %0#0, %arg7 = %0#1, %arg8 = %0#2)
     threads(%arg3, %arg4, %arg5) in (%arg9 = %0#3, %arg10 = %0#4, %arg11 = %0#5) {
     %1 = "test.test2"() : () -> i32
-    %2 = gpu.subgroup_reduce add %1 cluster_size(1) : (i32) -> (i32)
+    %2 = gpu.subgroup_reduce add %1 cluster(size=1) : (i32) -> (i32)
     "test.test3"(%2) : (i32) -> ()
     gpu.terminator
   }

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index fd7618020b5d89..2a0f7e8c6b10c2 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -335,7 +335,7 @@ func.func @reduce_invalid_op_type_maximumf(%arg0 : i32) {
 
 func.func @subgroup_reduce_zero_cluster_size(%arg0 : vector<4xf32>) {
   // expected-error at +1 {{cluster size 0 is not a power of two}}
-  %res = gpu.subgroup_reduce add %arg0 cluster_size(0) : (vector<4xf32>) -> vector<4xf32>
+  %res = gpu.subgroup_reduce add %arg0 cluster(size = 0) : (vector<4xf32>) -> vector<4xf32>
   return
 }
 
@@ -343,10 +343,27 @@ func.func @subgroup_reduce_zero_cluster_size(%arg0 : vector<4xf32>) {
 
 func.func @subgroup_reduce_npot_cluster_size(%arg0 : vector<4xf32>) {
   // expected-error at +1 {{cluster size 3 is not a power of two}}
-  %res = gpu.subgroup_reduce add %arg0 cluster_size(3) : (vector<4xf32>) -> vector<4xf32>
+  %res = gpu.subgroup_reduce add %arg0 cluster(size = 3) : (vector<4xf32>) -> vector<4xf32>
   return
 }
 
+// -----
+
+func.func @subgroup_reduce_zero_cluster_stride(%arg0 : vector<4xf32>) {
+  // expected-error at +1 {{cluster stride 0 is not a power of two}}
+  %res = gpu.subgroup_reduce add %arg0 cluster(size = 4, stride = 0) : (vector<4xf32>) -> vector<4xf32>
+  return
+}
+
+// -----
+
+func.func @subgroup_reduce_cluster_stride_without_size(%arg0 : vector<4xf32>) {
+  // expected-error at +1 {{cluster stride can only be specified if cluster size is specified}}
+  %res = gpu.subgroup_reduce add %arg0 { cluster_stride = 2 : i32 } : (vector<4xf32>) -> vector<4xf32>
+  return
+}
+
+
 // -----
 
 func.func @subgroup_reduce_bad_type(%arg0 : vector<2x2xf32>) {

diff  --git a/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir
index 37608ce4cfed76..9f2aa1be52fc37 100644
--- a/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir
+++ b/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir
@@ -34,14 +34,14 @@ gpu.module @kernels {
     %sum1 = gpu.subgroup_reduce mul %arg0 uniform : (vector<5xf16>) -> (vector<5xf16>)
     "test.consume"(%sum1) : (vector<5xf16>) -> ()
 
-    // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} cluster_size(4)
+    // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} cluster(size = 4)
     // CHECK-SUB: "test.consume"
-    %sum2 = gpu.subgroup_reduce mul %arg0 cluster_size(4) : (vector<5xf16>) -> (vector<5xf16>)
+    %sum2 = gpu.subgroup_reduce mul %arg0 cluster(size = 4) : (vector<5xf16>) -> (vector<5xf16>)
     "test.consume"(%sum2) : (vector<5xf16>) -> ()
 
-    // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform cluster_size(4)
+    // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform cluster(size = 4, stride = 2)
     // CHECK-SUB: "test.consume"
-    %sum3 = gpu.subgroup_reduce mul %arg0 uniform cluster_size(4) : (vector<5xf16>) -> (vector<5xf16>)
+    %sum3 = gpu.subgroup_reduce mul %arg0 uniform cluster(size = 4, stride = 2) : (vector<5xf16>) -> (vector<5xf16>)
     "test.consume"(%sum3) : (vector<5xf16>) -> ()
 
     // CHECK-SUB: gpu.return
@@ -65,14 +65,15 @@ gpu.module @kernels {
     %sum1 = gpu.subgroup_reduce add %arg0 uniform : (vector<1xf32>) -> (vector<1xf32>)
     "test.consume"(%sum1) : (vector<1xf32>) -> ()
 
-    // CHECK-SUB: gpu.subgroup_reduce add {{.+}} cluster_size(8) : (f32) -> f32
+    // Note stride is dropped because it is == 1.
+    // CHECK-SUB: gpu.subgroup_reduce add {{.+}} cluster(size = 8) : (f32) -> f32
     // CHECK-SUB: "test.consume"
-    %sum2 = gpu.subgroup_reduce add %arg0 cluster_size(8) : (vector<1xf32>) -> (vector<1xf32>)
+    %sum2 = gpu.subgroup_reduce add %arg0 cluster(size = 8, stride = 1) : (vector<1xf32>) -> (vector<1xf32>)
     "test.consume"(%sum2) : (vector<1xf32>) -> ()
 
-    // CHECK-SUB: gpu.subgroup_reduce add {{.+}} uniform cluster_size(8) : (f32) -> f32
+    // CHECK-SUB: gpu.subgroup_reduce add {{.+}} uniform cluster(size = 8, stride = 4) : (f32) -> f32
     // CHECK-SUB: "test.consume"
-    %sum3 = gpu.subgroup_reduce add %arg0 uniform cluster_size(8) : (vector<1xf32>) -> (vector<1xf32>)
+    %sum3 = gpu.subgroup_reduce add %arg0 uniform cluster(size = 8, stride = 4) : (vector<1xf32>) -> (vector<1xf32>)
     "test.consume"(%sum3) : (vector<1xf32>) -> ()
 
     // CHECK-SUB: gpu.return
@@ -143,7 +144,29 @@ gpu.module @kernels {
     // CHECK-SHFL: %[[S2:.+]], %{{.+}} = gpu.shuffle xor %[[A1]], %[[C4]], %[[C32]] : i32
     // CHECK-SHFL: %[[A2:.+]] = arith.addi %[[A1]], %[[S2]] : i32
     // CHECK-SHFL: "test.consume"(%[[A2]]) : (i32) -> ()
-    %sum0 = gpu.subgroup_reduce add %arg0 cluster_size(8) : (i32) -> i32
+    %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 8) : (i32) -> i32
+    "test.consume"(%sum0) : (i32) -> ()
+
+    // CHECK-SHFL: gpu.return
+    gpu.return
+  }
+
+  // CHECK-SHFL-LABEL: gpu.func @kernel3_clustered_strided(
+  // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i32)
+  gpu.func @kernel3_clustered_strided(%arg0: i32) kernel {
+    // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 4 : i32
+    // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 8 : i32
+    // CHECK-SHFL-DAG: %[[C4:.+]] = arith.constant 16 : i32
+    // CHECK-SHFL-DAG: %[[C32:.+]] = arith.constant 32 : i32
+
+    // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[ARG0]], %[[C1]], %[[C32]] : i32
+    // CHECK-SHFL: %[[A0:.+]] = arith.addi %[[ARG0]], %[[S0]] : i32
+    // CHECK-SHFL: %[[S1:.+]], %{{.+}} = gpu.shuffle xor %[[A0]], %[[C2]], %[[C32]] : i32
+    // CHECK-SHFL: %[[A1:.+]] = arith.addi %[[A0]], %[[S1]] : i32
+    // CHECK-SHFL: %[[S2:.+]], %{{.+}} = gpu.shuffle xor %[[A1]], %[[C4]], %[[C32]] : i32
+    // CHECK-SHFL: %[[A2:.+]] = arith.addi %[[A1]], %[[S2]] : i32
+    // CHECK-SHFL: "test.consume"(%[[A2]]) : (i32) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 8, stride = 4) : (i32) -> i32
     "test.consume"(%sum0) : (i32) -> ()
 
     // CHECK-SHFL: gpu.return
@@ -194,7 +217,7 @@ gpu.module @kernels {
     // CHECK-SHFL-DAG: %[[C32:.+]] = arith.constant 32 : i32
 
     // CHECK-SHFL-COUNT-2: gpu.shuffle xor
-    %sum0 = gpu.subgroup_reduce add %arg0 cluster_size(4) : (vector<2xf16>) -> (vector<2xf16>)
+    %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 4) : (vector<2xf16>) -> (vector<2xf16>)
     "test.consume"(%sum0) : (vector<2xf16>) -> ()
 
     // CHECK-SHFL: gpu.return
@@ -234,7 +257,7 @@ gpu.module @kernels {
     // CHECK-SHFL: arith.trunci {{.+}} : i32 to i16
     // CHECK-SHFL: %[[AL:.+]] = arith.addi {{.+}} : i16
     // CHECK-SHFL: "test.consume"(%[[AL]]) : (i16) -> ()
-    %sum0 = gpu.subgroup_reduce add %arg0 cluster_size(16) : (i16) -> i16
+    %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 16) : (i16) -> i16
     "test.consume"(%sum0) : (i16) -> ()
 
     // CHECK-SHFL: gpu.return
@@ -268,7 +291,7 @@ gpu.module @kernels {
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: vector<3xi8>)
   gpu.func @kernel_cluster_size_is_subgroup_size(%arg0: vector<3xi8>) kernel {
     // CHECK-SHFL-COUNT-5: gpu.shuffle xor
-    %sum0 = gpu.subgroup_reduce add %arg0 cluster_size(32) : (vector<3xi8>) -> (vector<3xi8>)
+    %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 32) : (vector<3xi8>) -> (vector<3xi8>)
     "test.consume"(%sum0) : (vector<3xi8>) -> ()
 
     // CHECK-SHFL: gpu.return


        


More information about the Mlir-commits mailing list