[Mlir-commits] [mlir] [mlir][gpu] Add 'cluster_stride' attribute to gpu.subgroup_reduce (PR #107142)
Andrea Faulds
llvmlistbot at llvm.org
Tue Sep 3 11:56:50 PDT 2024
https://github.com/andfau-amd created https://github.com/llvm/llvm-project/pull/107142
Follow-up to 7aa22f013e24d20291aad745368ff907baa9dfa4, adding an additional attribute needed in some applications.
>From 897a2b822bcdff04e9292ed96b1309b4fc6ef892 Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Mon, 26 Aug 2024 17:47:50 +0200
Subject: [PATCH] [mlir][gpu] Add 'cluster_stride' attribute to
gpu.subgroup_reduce
Follow-up to 7aa22f013e24d20291aad745368ff907baa9dfa4, adding an
additional attribute needed in some applications.
---
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 32 +++--
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 2 +-
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 2 +-
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 6 +
.../GPU/Transforms/SubgroupReduceLowering.cpp | 133 ++++++++++--------
mlir/test/Dialect/GPU/canonicalize.mlir | 2 +-
mlir/test/Dialect/GPU/invalid.mlir | 20 ++-
.../Dialect/GPU/subgroup-reduce-lowering.mlir | 47 +++++--
8 files changed, 161 insertions(+), 83 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index d2a5e5d77ad843..6098eb34d04d52 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1200,10 +1200,12 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
The `subgroup_reduce` op reduces the values of lanes (work items) across a
subgroup.
- The subgroup is divided into clusters of `cluster_size` contiguous lanes
- each, and a reduction is done for every lane of each cluster (in parallel).
- The result is equal for all lanes in a cluster. When `cluster_size` is
- omitted, there is a single cluster covering the entire subgroup.
+ The subgroup is divided into clusters starting at lane index 0. Within each
+ cluster, there are `size` lanes, and the lane index advances by `stride`.
+ A reduction is done for each cluster in parallel: every lane in the cluster
+ is reduced, and the result is equal for all lanes in the cluster. If `size`
+ is omitted, there is a single cluster covering the entire subgroup. If
+ `stride` is omitted, the stride is 1 (the cluster's lanes are contiguous).
When the reduced value is of a vector type, each vector element is reduced
independently. Only 1-d vector types are allowed.
@@ -1213,7 +1215,8 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
```mlir
%1 = gpu.subgroup_reduce add %a : (f32) -> f32
%2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> vector<4xf16>
- %3 = gpu.subgroup_reduce add %c cluster_size(4) : (f32) -> f32
+ %3 = gpu.subgroup_reduce add %c cluster(size = 4) : (f32) -> f32
+ %3 = gpu.subgroup_reduce add %c cluster(size = 4, stride = 2) : (f32) -> f32
```
If `uniform` flag is set either none or all lanes of a subgroup need to execute
@@ -1230,7 +1233,8 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
AnyIntegerOrFloatOr1DVector:$value,
GPU_AllReduceOperationAttr:$op,
UnitAttr:$uniform,
- OptionalAttr<I32Attr>:$cluster_size
+ OptionalAttr<I32Attr>:$cluster_size,
+ DefaultValuedAttr<I32Attr,"1">:$cluster_stride
);
let results = (outs AnyIntegerOrFloatOr1DVector:$result);
@@ -1238,19 +1242,29 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
OpBuilder<(ins "Value":$value,
"::mlir::gpu::AllReduceOperation":$op,
"bool":$uniform), [{
- build($_builder, $_state, value, op, uniform, /*cluster_size=*/ nullptr);
+ build($_builder, $_state, value, op, uniform, std::nullopt);
}]>,
OpBuilder<(ins "Value":$value,
"::mlir::gpu::AllReduceOperation":$op,
"bool":$uniform,
"std::optional<uint32_t>":$cluster_size), [{
- build($_builder, $_state, value, op, uniform, cluster_size ? $_builder.getI32IntegerAttr(*cluster_size) : nullptr);
+ build($_builder, $_state, value, op, uniform,
+ cluster_size ? $_builder.getI32IntegerAttr(*cluster_size) : nullptr);
+ }]>,
+ OpBuilder<(ins "Value":$value,
+ "::mlir::gpu::AllReduceOperation":$op,
+ "bool":$uniform,
+ "std::optional<uint32_t>":$cluster_size,
+ "uint32_t":$cluster_stride), [{
+ build($_builder, $_state, value, op, uniform,
+ cluster_size ? $_builder.getI32IntegerAttr(*cluster_size) : nullptr,
+ cluster_stride);
}]>
];
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
(`uniform` $uniform^)?
- (`cluster_size` `(` $cluster_size^ `)`)?
+ (`cluster` `(` `size` `=` $cluster_size^ (`,` `stride` `=` $cluster_stride^)? `)`)?
attr-dict
`:` functional-type(operands, results) }];
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 9b1be198f77a82..b013826c02b97b 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -102,7 +102,7 @@ struct GPUSubgroupReduceOpLowering
matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (op.getClusterSize())
+ if (op.getClusterSize() || op.getClusterStride() != 1)
return rewriter.notifyMatchFailure(
op, "lowering for clustered reduce not implemented");
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index a8ff9247e796ab..f775e29d2738a9 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -579,7 +579,7 @@ class GPUSubgroupReduceConversion final
LogicalResult
matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (op.getClusterSize())
+ if (op.getClusterSize() || op.getClusterStride() != 1)
return rewriter.notifyMatchFailure(
op, "lowering for clustered reduce not implemented");
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index a59952228ef6ea..165e92482b1ec6 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -629,6 +629,12 @@ LogicalResult gpu::SubgroupReduceOp::verify() {
}
}
+ uint32_t stride = getClusterStride();
+ if (!llvm::isPowerOf2_32(stride)) {
+ return emitOpError() << "cluster stride " << stride
+ << " is not a power of two";
+ }
+
return success();
}
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 288f7ab9f30222..2117c30c5ddd94 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -50,8 +50,6 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
- std::optional<uint32_t> clusterSize = op.getClusterSize();
-
auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy || vecTy.getNumElements() < 2)
return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
@@ -97,7 +95,8 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
}
Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
- loc, extracted, op.getOp(), op.getUniform(), clusterSize);
+ loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
+ op.getClusterStride());
if (numElems == 1) {
res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
continue;
@@ -129,8 +128,6 @@ struct ScalarizeSingleElementReduce final
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
- std::optional<uint32_t> clusterSize = op.getClusterSize();
-
auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy || vecTy.getNumElements() != 1)
return rewriter.notifyMatchFailure(op, "not a single-element reduction");
@@ -140,44 +137,75 @@ struct ScalarizeSingleElementReduce final
Location loc = op.getLoc();
Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
- loc, extracted, op.getOp(), op.getUniform(), clusterSize);
+ loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
+ op.getClusterStride());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
return success();
}
};
-/// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
-/// and `unpackFn` to convert to the native shuffle type and to the reduction
-/// type, respectively. For example, with `input` of type `f16`, `packFn` could
-/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
-/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
-/// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
-/// `clusterSize` lanes, reducing all lanes in each cluster in parallel.
-static Value createSubgroupShuffleReduction(
- OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
- unsigned clusterSize, unsigned subgroupSize,
- function_ref<Value(Value)> packFn, function_ref<Value(Value)> unpackFn) {
- assert(llvm::isPowerOf2_32(clusterSize));
- assert(llvm::isPowerOf2_32(subgroupSize));
- assert(clusterSize <= subgroupSize);
- // Lane value always stays in the original type. We use it to perform arith
- // reductions.
- Value laneVal = input;
- // Parallel reduction using butterfly shuffles.
- for (unsigned i = 1; i < clusterSize; i <<= 1) {
- Value shuffled = builder
- .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
- /*width=*/subgroupSize,
- /*mode=*/gpu::ShuffleMode::XOR)
- .getShuffleResult();
- laneVal = vector::makeArithReduction(builder, loc,
- gpu::convertReductionKind(mode),
- laneVal, unpackFn(shuffled));
- assert(laneVal.getType() == input.getType());
+struct ClusterInfo {
+ unsigned clusterStride;
+ unsigned clusterSize;
+ unsigned subgroupSize;
+ LogicalResult getAndValidate(gpu::SubgroupReduceOp op,
+ unsigned subgroupSize) {
+ this->subgroupSize = subgroupSize;
+
+ std::optional<uint32_t> clusterSize = op.getClusterSize();
+ if (clusterSize && *clusterSize > subgroupSize)
+ return op.emitOpError()
+ << "cluster size " << *clusterSize
+ << " is greater than subgroup size " << subgroupSize;
+ this->clusterSize = clusterSize.value_or(subgroupSize); // effective size
+
+ clusterStride = op.getClusterStride();
+ if (clusterStride >= subgroupSize)
+ return op.emitOpError()
+ << "cluster stride " << clusterStride
+ << " is not less than subgroup size " << subgroupSize;
+
+ return success();
}
+ /// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
+ /// and `unpackFn` to convert to the native shuffle type and to the reduction
+ /// type, respectively. For example, with `input` of type `f16`, `packFn`
+ /// could build ops to cast the value to `i32` to perform shuffles, while
+ /// `unpackFn` would cast it back to `f16` to perform arithmetic reduction on.
+ /// Assumes that the subgroup is `subgroupSize` lanes wide and divides it into
+ /// clusters of `clusterSize` lanes starting at lane 0 with a stride of
+ /// `clusterStride` for lanes within a cluster, reducing all lanes in each
+ /// cluster in parallel.
+ Value
+ createSubgroupShuffleReduction(OpBuilder &builder, Location loc, Value input,
+ gpu::AllReduceOperation mode,
+ function_ref<Value(Value)> packFn,
+ function_ref<Value(Value)> unpackFn) const {
+ assert(llvm::isPowerOf2_32(clusterStride));
+ assert(llvm::isPowerOf2_32(clusterSize));
+ assert(llvm::isPowerOf2_32(subgroupSize));
+ assert(clusterStride < subgroupSize);
+ assert(clusterSize <= subgroupSize);
+ // Lane value always stays in the original type. We use it to perform arith
+ // reductions.
+ Value laneVal = input;
+ // Parallel reduction using butterfly shuffles.
+ for (unsigned i = clusterStride; i < clusterStride * clusterSize; i <<= 1) {
+ Value shuffled =
+ builder
+ .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
+ /*width=*/subgroupSize,
+ /*mode=*/gpu::ShuffleMode::XOR)
+ .getShuffleResult();
+ laneVal = vector::makeArithReduction(builder, loc,
+ gpu::convertReductionKind(mode),
+ laneVal, unpackFn(shuffled));
+ assert(laneVal.getType() == input.getType());
+ }
- return laneVal;
-}
+ return laneVal;
+ }
+};
/// Lowers scalar gpu subgroup reductions to a series of shuffles.
struct ScalarSubgroupReduceToShuffles final
@@ -190,12 +218,9 @@ struct ScalarSubgroupReduceToShuffles final
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
- std::optional<uint32_t> clusterSize = op.getClusterSize();
- if (clusterSize && *clusterSize > subgroupSize)
- return op.emitOpError()
- << "cluster size " << *clusterSize
- << " is greater than subgroup size " << subgroupSize;
- unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
+ ClusterInfo clusterInfo;
+ if (clusterInfo.getAndValidate(op, subgroupSize).failed())
+ return failure();
Type valueTy = op.getType();
unsigned elemBitwidth =
@@ -208,10 +233,9 @@ struct ScalarSubgroupReduceToShuffles final
// Since this is already a native shuffle scalar, no packing is necessary.
if (elemBitwidth == shuffleBitwidth) {
auto identityFn = [](Value v) { return v; };
- rewriter.replaceOp(op, createSubgroupShuffleReduction(
+ rewriter.replaceOp(op, clusterInfo.createSubgroupShuffleReduction(
rewriter, loc, op.getValue(), op.getOp(),
- effectiveClusterSize, subgroupSize, identityFn,
- identityFn));
+ identityFn, identityFn));
return success();
}
@@ -231,9 +255,8 @@ struct ScalarSubgroupReduceToShuffles final
};
rewriter.replaceOp(
- op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
- op.getOp(), effectiveClusterSize,
- subgroupSize, packFn, unpackFn));
+ op, clusterInfo.createSubgroupShuffleReduction(
+ rewriter, loc, op.getValue(), op.getOp(), packFn, unpackFn));
return success();
}
@@ -253,12 +276,9 @@ struct VectorSubgroupReduceToShuffles final
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
- std::optional<uint32_t> clusterSize = op.getClusterSize();
- if (clusterSize && *clusterSize > subgroupSize)
- return op.emitOpError()
- << "cluster size " << *clusterSize
- << " is greater than subgroup size " << subgroupSize;
- unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
+ ClusterInfo clusterInfo;
+ if (clusterInfo.getAndValidate(op, subgroupSize).failed())
+ return failure();
auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy)
@@ -308,9 +328,8 @@ struct VectorSubgroupReduceToShuffles final
return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
};
- Value res = createSubgroupShuffleReduction(rewriter, loc, extendedInput,
- op.getOp(), effectiveClusterSize,
- subgroupSize, packFn, unpackFn);
+ Value res = clusterInfo.createSubgroupShuffleReduction(
+ rewriter, loc, extendedInput, op.getOp(), packFn, unpackFn);
if (vecBitwidth < shuffleBitwidth) {
res = rewriter.create<vector::ExtractStridedSliceOp>(
diff --git a/mlir/test/Dialect/GPU/canonicalize.mlir b/mlir/test/Dialect/GPU/canonicalize.mlir
index 469c03c9460df1..d342ae9df10eea 100644
--- a/mlir/test/Dialect/GPU/canonicalize.mlir
+++ b/mlir/test/Dialect/GPU/canonicalize.mlir
@@ -255,7 +255,7 @@ func.func @subgroup_reduce_cluster_size_1() {
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %0#0, %arg7 = %0#1, %arg8 = %0#2)
threads(%arg3, %arg4, %arg5) in (%arg9 = %0#3, %arg10 = %0#4, %arg11 = %0#5) {
%1 = "test.test2"() : () -> i32
- %2 = gpu.subgroup_reduce add %1 cluster_size(1) : (i32) -> (i32)
+ %2 = gpu.subgroup_reduce add %1 cluster(size=1) : (i32) -> (i32)
"test.test3"(%2) : (i32) -> ()
gpu.terminator
}
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 20c1c4cf8a2d0b..a6fa202c7ce3c2 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -335,7 +335,7 @@ func.func @reduce_invalid_op_type_maximumf(%arg0 : i32) {
func.func @subgroup_reduce_zero_cluster_size(%arg0 : vector<4xf32>) {
// expected-error at +1 {{cluster size 0 is not a power of two}}
- %res = gpu.subgroup_reduce add %arg0 cluster_size(0) : (vector<4xf32>) -> vector<4xf32>
+ %res = gpu.subgroup_reduce add %arg0 cluster(size = 0) : (vector<4xf32>) -> vector<4xf32>
return
}
@@ -343,7 +343,23 @@ func.func @subgroup_reduce_zero_cluster_size(%arg0 : vector<4xf32>) {
func.func @subgroup_reduce_npot_cluster_size(%arg0 : vector<4xf32>) {
// expected-error at +1 {{cluster size 3 is not a power of two}}
- %res = gpu.subgroup_reduce add %arg0 cluster_size(3) : (vector<4xf32>) -> vector<4xf32>
+ %res = gpu.subgroup_reduce add %arg0 cluster(size = 3) : (vector<4xf32>) -> vector<4xf32>
+ return
+}
+
+// -----
+
+func.func @subgroup_reduce_zero_cluster_stride(%arg0 : vector<4xf32>) {
+ // expected-error at +1 {{cluster stride 0 is not a power of two}}
+ %res = gpu.subgroup_reduce add %arg0 cluster(size = 4, stride = 0) : (vector<4xf32>) -> vector<4xf32>
+ return
+}
+
+// -----
+
+func.func @subgroup_reduce_npot_cluster_stride(%arg0 : vector<4xf32>) {
+ // expected-error at +1 {{cluster stride 3 is not a power of two}}
+ %res = gpu.subgroup_reduce add %arg0 { cluster_stride = 3 : i32 } : (vector<4xf32>) -> vector<4xf32>
return
}
diff --git a/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir
index 37608ce4cfed76..9f2aa1be52fc37 100644
--- a/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir
+++ b/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir
@@ -34,14 +34,14 @@ gpu.module @kernels {
%sum1 = gpu.subgroup_reduce mul %arg0 uniform : (vector<5xf16>) -> (vector<5xf16>)
"test.consume"(%sum1) : (vector<5xf16>) -> ()
- // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} cluster_size(4)
+ // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} cluster(size = 4)
// CHECK-SUB: "test.consume"
- %sum2 = gpu.subgroup_reduce mul %arg0 cluster_size(4) : (vector<5xf16>) -> (vector<5xf16>)
+ %sum2 = gpu.subgroup_reduce mul %arg0 cluster(size = 4) : (vector<5xf16>) -> (vector<5xf16>)
"test.consume"(%sum2) : (vector<5xf16>) -> ()
- // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform cluster_size(4)
+ // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform cluster(size = 4, stride = 2)
// CHECK-SUB: "test.consume"
- %sum3 = gpu.subgroup_reduce mul %arg0 uniform cluster_size(4) : (vector<5xf16>) -> (vector<5xf16>)
+ %sum3 = gpu.subgroup_reduce mul %arg0 uniform cluster(size = 4, stride = 2) : (vector<5xf16>) -> (vector<5xf16>)
"test.consume"(%sum3) : (vector<5xf16>) -> ()
// CHECK-SUB: gpu.return
@@ -65,14 +65,15 @@ gpu.module @kernels {
%sum1 = gpu.subgroup_reduce add %arg0 uniform : (vector<1xf32>) -> (vector<1xf32>)
"test.consume"(%sum1) : (vector<1xf32>) -> ()
- // CHECK-SUB: gpu.subgroup_reduce add {{.+}} cluster_size(8) : (f32) -> f32
+ // Note stride is dropped because it is == 1.
+ // CHECK-SUB: gpu.subgroup_reduce add {{.+}} cluster(size = 8) : (f32) -> f32
// CHECK-SUB: "test.consume"
- %sum2 = gpu.subgroup_reduce add %arg0 cluster_size(8) : (vector<1xf32>) -> (vector<1xf32>)
+ %sum2 = gpu.subgroup_reduce add %arg0 cluster(size = 8, stride = 1) : (vector<1xf32>) -> (vector<1xf32>)
"test.consume"(%sum2) : (vector<1xf32>) -> ()
- // CHECK-SUB: gpu.subgroup_reduce add {{.+}} uniform cluster_size(8) : (f32) -> f32
+ // CHECK-SUB: gpu.subgroup_reduce add {{.+}} uniform cluster(size = 8, stride = 4) : (f32) -> f32
// CHECK-SUB: "test.consume"
- %sum3 = gpu.subgroup_reduce add %arg0 uniform cluster_size(8) : (vector<1xf32>) -> (vector<1xf32>)
+ %sum3 = gpu.subgroup_reduce add %arg0 uniform cluster(size = 8, stride = 4) : (vector<1xf32>) -> (vector<1xf32>)
"test.consume"(%sum3) : (vector<1xf32>) -> ()
// CHECK-SUB: gpu.return
@@ -143,7 +144,29 @@ gpu.module @kernels {
// CHECK-SHFL: %[[S2:.+]], %{{.+}} = gpu.shuffle xor %[[A1]], %[[C4]], %[[C32]] : i32
// CHECK-SHFL: %[[A2:.+]] = arith.addi %[[A1]], %[[S2]] : i32
// CHECK-SHFL: "test.consume"(%[[A2]]) : (i32) -> ()
- %sum0 = gpu.subgroup_reduce add %arg0 cluster_size(8) : (i32) -> i32
+ %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 8) : (i32) -> i32
+ "test.consume"(%sum0) : (i32) -> ()
+
+ // CHECK-SHFL: gpu.return
+ gpu.return
+ }
+
+ // CHECK-SHFL-LABEL: gpu.func @kernel3_clustered_strided(
+ // CHECK-SHFL-SAME: %[[ARG0:.+]]: i32)
+ gpu.func @kernel3_clustered_strided(%arg0: i32) kernel {
+ // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 4 : i32
+ // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 8 : i32
+ // CHECK-SHFL-DAG: %[[C4:.+]] = arith.constant 16 : i32
+ // CHECK-SHFL-DAG: %[[C32:.+]] = arith.constant 32 : i32
+
+ // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[ARG0]], %[[C1]], %[[C32]] : i32
+ // CHECK-SHFL: %[[A0:.+]] = arith.addi %[[ARG0]], %[[S0]] : i32
+ // CHECK-SHFL: %[[S1:.+]], %{{.+}} = gpu.shuffle xor %[[A0]], %[[C2]], %[[C32]] : i32
+ // CHECK-SHFL: %[[A1:.+]] = arith.addi %[[A0]], %[[S1]] : i32
+ // CHECK-SHFL: %[[S2:.+]], %{{.+}} = gpu.shuffle xor %[[A1]], %[[C4]], %[[C32]] : i32
+ // CHECK-SHFL: %[[A2:.+]] = arith.addi %[[A1]], %[[S2]] : i32
+ // CHECK-SHFL: "test.consume"(%[[A2]]) : (i32) -> ()
+ %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 8, stride = 4) : (i32) -> i32
"test.consume"(%sum0) : (i32) -> ()
// CHECK-SHFL: gpu.return
@@ -194,7 +217,7 @@ gpu.module @kernels {
// CHECK-SHFL-DAG: %[[C32:.+]] = arith.constant 32 : i32
// CHECK-SHFL-COUNT-2: gpu.shuffle xor
- %sum0 = gpu.subgroup_reduce add %arg0 cluster_size(4) : (vector<2xf16>) -> (vector<2xf16>)
+ %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 4) : (vector<2xf16>) -> (vector<2xf16>)
"test.consume"(%sum0) : (vector<2xf16>) -> ()
// CHECK-SHFL: gpu.return
@@ -234,7 +257,7 @@ gpu.module @kernels {
// CHECK-SHFL: arith.trunci {{.+}} : i32 to i16
// CHECK-SHFL: %[[AL:.+]] = arith.addi {{.+}} : i16
// CHECK-SHFL: "test.consume"(%[[AL]]) : (i16) -> ()
- %sum0 = gpu.subgroup_reduce add %arg0 cluster_size(16) : (i16) -> i16
+ %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 16) : (i16) -> i16
"test.consume"(%sum0) : (i16) -> ()
// CHECK-SHFL: gpu.return
@@ -268,7 +291,7 @@ gpu.module @kernels {
// CHECK-SHFL-SAME: %[[ARG0:.+]]: vector<3xi8>)
gpu.func @kernel_cluster_size_is_subgroup_size(%arg0: vector<3xi8>) kernel {
// CHECK-SHFL-COUNT-5: gpu.shuffle xor
- %sum0 = gpu.subgroup_reduce add %arg0 cluster_size(32) : (vector<3xi8>) -> (vector<3xi8>)
+ %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 32) : (vector<3xi8>) -> (vector<3xi8>)
"test.consume"(%sum0) : (vector<3xi8>) -> ()
// CHECK-SHFL: gpu.return
More information about the Mlir-commits
mailing list