[Mlir-commits] [mlir] [mlir][gpu] Disjoint patterns for lowering clustered subgroup reduce (PR #109158)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 18 08:17:46 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir
Author: Andrea Faulds (andfau-amd)
<details>
<summary>Changes</summary>
Making the existing populateGpuLowerSubgroupReduceToShufflePatterns() function also cover the new "clustered" subgroup reductions is proving to be inconvenient, because certain backends may have more specific lowerings that only cover the non-clustered type, and this creates pass ordering constraints. This commit removes coverage of clustered reductions from this function in favour of a new separate function, which makes makes controlling the lowering much more straightforward.
---
Full diff: https://github.com/llvm/llvm-project/pull/109158.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/GPU/Transforms/Passes.h (+10)
- (modified) mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp (+38-5)
- (modified) mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp (+4-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 67baa8777a6fcc..8eb711962583da 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -73,10 +73,20 @@ void populateGpuBreakDownSubgroupReducePatterns(
/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `gpu.shuffle`
/// ops over `shuffleBitwidth` scalar types. Assumes that the subgroup has
/// `subgroupSize` lanes. Uses the butterfly shuffle algorithm.
+///
+/// The patterns populated by this function will ignore ops with the
+/// `cluster_size` attribute.
+/// `populateGpuLowerClusteredSubgroupReduceToShufflePatterns` is the opposite.
void populateGpuLowerSubgroupReduceToShufflePatterns(
RewritePatternSet &patterns, unsigned subgroupSize,
unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
+/// Disjoint counterpart of `populateGpuLowerSubgroupReduceToShufflePatterns`
+/// that only matches `gpu.subgroup_reduce` ops with a `cluster_size`.
+void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+ RewritePatternSet &patterns, unsigned subgroupSize,
+ unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
+
/// Collect all patterns to rewrite ops within the GPU dialect.
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index b166f1cd469a4d..56e53a806843ed 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -210,13 +210,24 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
struct ScalarSubgroupReduceToShuffles final
: OpRewritePattern<gpu::SubgroupReduceOp> {
ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
- unsigned shuffleBitwidth,
+ unsigned shuffleBitwidth, bool matchClustered,
PatternBenefit benefit)
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
- shuffleBitwidth(shuffleBitwidth) {}
+ shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
+ if (op.getClusterSize().has_value() != matchClustered) {
+ if (matchClustered)
+ return rewriter.notifyMatchFailure(
+ op, "op is non-clustered but pattern is configured to only match "
+ "clustered ops");
+ else
+ return rewriter.notifyMatchFailure(
+ op, "op is clustered but pattern is configured to only match "
+ "non-clustered ops");
+ }
+
auto ci = getAndValidateClusterInfo(op, subgroupSize);
if (failed(ci))
return failure();
@@ -262,19 +273,31 @@ struct ScalarSubgroupReduceToShuffles final
private:
unsigned subgroupSize = 0;
unsigned shuffleBitwidth = 0;
+ bool matchClustered;
};
/// Lowers vector gpu subgroup reductions to a series of shuffles.
struct VectorSubgroupReduceToShuffles final
: OpRewritePattern<gpu::SubgroupReduceOp> {
VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
- unsigned shuffleBitwidth,
+ unsigned shuffleBitwidth, bool matchClustered,
PatternBenefit benefit)
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
- shuffleBitwidth(shuffleBitwidth) {}
+ shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
+ if (op.getClusterSize().has_value() != matchClustered) {
+ if (matchClustered)
+ return rewriter.notifyMatchFailure(
+ op, "op is non-clustered but pattern is configured to only match "
+ "clustered ops");
+ else
+ return rewriter.notifyMatchFailure(
+ op, "op is clustered but pattern is configured to only match "
+ "non-clustered ops");
+ }
+
auto ci = getAndValidateClusterInfo(op, subgroupSize);
if (failed(ci))
return failure();
@@ -343,6 +366,7 @@ struct VectorSubgroupReduceToShuffles final
private:
unsigned subgroupSize = 0;
unsigned shuffleBitwidth = 0;
+ bool matchClustered;
};
} // namespace
@@ -358,5 +382,14 @@ void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
RewritePatternSet &patterns, unsigned subgroupSize,
unsigned shuffleBitwidth, PatternBenefit benefit) {
patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
- patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
+ patterns.getContext(), subgroupSize, shuffleBitwidth,
+ /*matchClustered=*/false, benefit);
+}
+
+void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+ RewritePatternSet &patterns, unsigned subgroupSize,
+ unsigned shuffleBitwidth, PatternBenefit benefit) {
+ patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
+ patterns.getContext(), subgroupSize, shuffleBitwidth,
+ /*matchClustered=*/true, benefit);
}
diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
index 99a914506b011a..74d057c0b7b6cb 100644
--- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
@@ -78,9 +78,12 @@ struct TestGpuSubgroupReduceLoweringPass
populateGpuBreakDownSubgroupReducePatterns(patterns,
/*maxShuffleBitwidth=*/32,
PatternBenefit(2));
- if (expandToShuffles)
+ if (expandToShuffles) {
populateGpuLowerSubgroupReduceToShufflePatterns(
patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
+ populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+ patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
+ }
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/109158
More information about the Mlir-commits
mailing list