[Mlir-commits] [mlir] [mlir][gpu] Disjoint patterns for lowering clustered subgroup reduce (PR #109158)

Andrea Faulds llvmlistbot at llvm.org
Wed Sep 18 08:17:11 PDT 2024


https://github.com/andfau-amd created https://github.com/llvm/llvm-project/pull/109158

Making the existing populateGpuLowerSubgroupReduceToShufflePatterns() function also cover the new "clustered" subgroup reductions is proving to be inconvenient, because certain backends may have more specific lowerings that only cover the non-clustered type, and this creates pass ordering constraints. This commit removes coverage of clustered reductions from this function in favour of a new separate function, which makes makes controlling the lowering much more straightforward.

>From 3e1a9fe455d2c11c9fc6b9aabc66a5a70bb7f1c6 Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Wed, 18 Sep 2024 17:14:46 +0200
Subject: [PATCH] [mlir][gpu] Disjoint patterns for lowering clustered subgroup
 reduce

Making the existing populateGpuLowerSubgroupReduceToShufflePatterns()
function also cover the new "clustered" subgroup reductions is proving
to be inconvenient, because certain backends may have more specific
lowerings that only cover the non-clustered type, and this creates pass
ordering constraints. This commit removes coverage of clustered
reductions from this function in favour of a new separate function,
which makes makes controlling the lowering much more straightforward.
---
 .../mlir/Dialect/GPU/Transforms/Passes.h      | 10 +++++
 .../GPU/Transforms/SubgroupReduceLowering.cpp | 43 ++++++++++++++++---
 mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp  |  5 ++-
 3 files changed, 52 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 67baa8777a6fcc..8eb711962583da 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -73,10 +73,20 @@ void populateGpuBreakDownSubgroupReducePatterns(
 /// Collect a set of patterns to lower `gpu.subgroup_reduce` into `gpu.shuffle`
 /// ops over `shuffleBitwidth` scalar types. Assumes that the subgroup has
 /// `subgroupSize` lanes. Uses the butterfly shuffle algorithm.
+///
+/// The patterns populated by this function will ignore ops with the
+/// `cluster_size` attribute.
+/// `populateGpuLowerClusteredSubgroupReduceToShufflePatterns` is the opposite.
 void populateGpuLowerSubgroupReduceToShufflePatterns(
     RewritePatternSet &patterns, unsigned subgroupSize,
     unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
 
+/// Disjoint counterpart of `populateGpuLowerSubgroupReduceToShufflePatterns`
+/// that only matches `gpu.subgroup_reduce` ops with a `cluster_size`.
+void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+    RewritePatternSet &patterns, unsigned subgroupSize,
+    unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
+
 /// Collect all patterns to rewrite ops within the GPU dialect.
 inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
   populateGpuAllReducePatterns(patterns);
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index b166f1cd469a4d..56e53a806843ed 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -210,13 +210,24 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
 struct ScalarSubgroupReduceToShuffles final
     : OpRewritePattern<gpu::SubgroupReduceOp> {
   ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
-                                 unsigned shuffleBitwidth,
+                                 unsigned shuffleBitwidth, bool matchClustered,
                                  PatternBenefit benefit)
       : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
-        shuffleBitwidth(shuffleBitwidth) {}
+        shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
+    if (op.getClusterSize().has_value() != matchClustered) {
+      if (matchClustered)
+        return rewriter.notifyMatchFailure(
+            op, "op is non-clustered but pattern is configured to only match "
+                "clustered ops");
+      else
+        return rewriter.notifyMatchFailure(
+            op, "op is clustered but pattern is configured to only match "
+                "non-clustered ops");
+    }
+
     auto ci = getAndValidateClusterInfo(op, subgroupSize);
     if (failed(ci))
       return failure();
@@ -262,19 +273,31 @@ struct ScalarSubgroupReduceToShuffles final
 private:
   unsigned subgroupSize = 0;
   unsigned shuffleBitwidth = 0;
+  bool matchClustered;
 };
 
 /// Lowers vector gpu subgroup reductions to a series of shuffles.
 struct VectorSubgroupReduceToShuffles final
     : OpRewritePattern<gpu::SubgroupReduceOp> {
   VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
-                                 unsigned shuffleBitwidth,
+                                 unsigned shuffleBitwidth, bool matchClustered,
                                  PatternBenefit benefit)
       : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
-        shuffleBitwidth(shuffleBitwidth) {}
+        shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
+    if (op.getClusterSize().has_value() != matchClustered) {
+      if (matchClustered)
+        return rewriter.notifyMatchFailure(
+            op, "op is non-clustered but pattern is configured to only match "
+                "clustered ops");
+      else
+        return rewriter.notifyMatchFailure(
+            op, "op is clustered but pattern is configured to only match "
+                "non-clustered ops");
+    }
+
     auto ci = getAndValidateClusterInfo(op, subgroupSize);
     if (failed(ci))
       return failure();
@@ -343,6 +366,7 @@ struct VectorSubgroupReduceToShuffles final
 private:
   unsigned subgroupSize = 0;
   unsigned shuffleBitwidth = 0;
+  bool matchClustered;
 };
 } // namespace
 
@@ -358,5 +382,14 @@ void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
     RewritePatternSet &patterns, unsigned subgroupSize,
     unsigned shuffleBitwidth, PatternBenefit benefit) {
   patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
-      patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
+      patterns.getContext(), subgroupSize, shuffleBitwidth,
+      /*matchClustered=*/false, benefit);
+}
+
+void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+    RewritePatternSet &patterns, unsigned subgroupSize,
+    unsigned shuffleBitwidth, PatternBenefit benefit) {
+  patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
+      patterns.getContext(), subgroupSize, shuffleBitwidth,
+      /*matchClustered=*/true, benefit);
 }
diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
index 99a914506b011a..74d057c0b7b6cb 100644
--- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
@@ -78,9 +78,12 @@ struct TestGpuSubgroupReduceLoweringPass
     populateGpuBreakDownSubgroupReducePatterns(patterns,
                                                /*maxShuffleBitwidth=*/32,
                                                PatternBenefit(2));
-    if (expandToShuffles)
+    if (expandToShuffles) {
       populateGpuLowerSubgroupReduceToShufflePatterns(
           patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
+      populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+          patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
+    }
 
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }



More information about the Mlir-commits mailing list