[Mlir-commits] [mlir] [mlir][gpu] Add 'cluster_stride' attribute to gpu.subgroup_reduce (PR #107142)
Jakub Kuderski
llvmlistbot at llvm.org
Tue Sep 3 14:16:07 PDT 2024
================
@@ -140,44 +137,75 @@ struct ScalarizeSingleElementReduce final
Location loc = op.getLoc();
Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
- loc, extracted, op.getOp(), op.getUniform(), clusterSize);
+ loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
+ op.getClusterStride());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
return success();
}
};
-/// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
-/// and `unpackFn` to convert to the native shuffle type and to the reduction
-/// type, respectively. For example, with `input` of type `f16`, `packFn` could
-/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
-/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
-/// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
-/// `clusterSize` lanes, reducing all lanes in each cluster in parallel.
-static Value createSubgroupShuffleReduction(
- OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
- unsigned clusterSize, unsigned subgroupSize,
- function_ref<Value(Value)> packFn, function_ref<Value(Value)> unpackFn) {
- assert(llvm::isPowerOf2_32(clusterSize));
- assert(llvm::isPowerOf2_32(subgroupSize));
- assert(clusterSize <= subgroupSize);
- // Lane value always stays in the original type. We use it to perform arith
- // reductions.
- Value laneVal = input;
- // Parallel reduction using butterfly shuffles.
- for (unsigned i = 1; i < clusterSize; i <<= 1) {
- Value shuffled = builder
- .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
- /*width=*/subgroupSize,
- /*mode=*/gpu::ShuffleMode::XOR)
- .getShuffleResult();
- laneVal = vector::makeArithReduction(builder, loc,
- gpu::convertReductionKind(mode),
- laneVal, unpackFn(shuffled));
- assert(laneVal.getType() == input.getType());
+struct ClusterInfo {
+ unsigned clusterStride;
+ unsigned clusterSize;
+ unsigned subgroupSize;
+ LogicalResult getAndValidate(gpu::SubgroupReduceOp op,
+ unsigned subgroupSize) {
+ this->subgroupSize = subgroupSize;
+
+ std::optional<uint32_t> clusterSize = op.getClusterSize();
+ if (clusterSize && *clusterSize > subgroupSize)
+ return op.emitOpError()
+ << "cluster size " << *clusterSize
+ << " is greater than subgroup size " << subgroupSize;
+ this->clusterSize = clusterSize.value_or(subgroupSize); // effective size
+
+ clusterStride = op.getClusterStride();
+ if (clusterStride >= subgroupSize)
+ return op.emitOpError()
+ << "cluster stride " << clusterStride
+ << " is not less than subgroup size " << subgroupSize;
+
+ return success();
}
+ /// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
+ /// and `unpackFn` to convert to the native shuffle type and to the reduction
+ /// type, respectively. For example, with `input` of type `f16`, `packFn`
+ /// could build ops to cast the value to `i32` to perform shuffles, while
+ /// `unpackFn` would cast it back to `f16` to perform arithmetic reduction on.
+ /// Assumes that the subgroup is `subgroupSize` lanes wide and divides it into
+ /// clusters of `clusterSize` lanes starting at lane 0 with a stride of
+ /// `clusterStride` for lanes within a cluster, reducing all lanes in each
+ /// cluster in parallel.
+ Value
+ createSubgroupShuffleReduction(OpBuilder &builder, Location loc, Value input,
+ gpu::AllReduceOperation mode,
+ function_ref<Value(Value)> packFn,
+ function_ref<Value(Value)> unpackFn) const {
+ assert(llvm::isPowerOf2_32(clusterStride));
+ assert(llvm::isPowerOf2_32(clusterSize));
+ assert(llvm::isPowerOf2_32(subgroupSize));
+ assert(clusterStride < subgroupSize);
+ assert(clusterSize <= subgroupSize);
----------------
kuhar wrote:
I think it would be better to check these in the constructor
https://github.com/llvm/llvm-project/pull/107142
More information about the Mlir-commits
mailing list