[Mlir-commits] [mlir] [mlir][gpu] Disjoint patterns for lowering clustered subgroup reduce (PR #109158)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 18 08:17:46 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Andrea Faulds (andfau-amd)

<details>
<summary>Changes</summary>

Making the existing populateGpuLowerSubgroupReduceToShufflePatterns() function also cover the new "clustered" subgroup reductions is proving to be inconvenient, because certain backends may have more specific lowerings that only cover the non-clustered type, and this creates pass ordering constraints. This commit removes coverage of clustered reductions from this function in favour of a new separate function, which makes makes controlling the lowering much more straightforward.

---
Full diff: https://github.com/llvm/llvm-project/pull/109158.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/GPU/Transforms/Passes.h (+10) 
- (modified) mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp (+38-5) 
- (modified) mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp (+4-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 67baa8777a6fcc..8eb711962583da 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -73,10 +73,20 @@ void populateGpuBreakDownSubgroupReducePatterns(
 /// Collect a set of patterns to lower `gpu.subgroup_reduce` into `gpu.shuffle`
 /// ops over `shuffleBitwidth` scalar types. Assumes that the subgroup has
 /// `subgroupSize` lanes. Uses the butterfly shuffle algorithm.
+///
+/// The patterns populated by this function will ignore ops with the
+/// `cluster_size` attribute.
+/// `populateGpuLowerClusteredSubgroupReduceToShufflePatterns` is the opposite.
 void populateGpuLowerSubgroupReduceToShufflePatterns(
     RewritePatternSet &patterns, unsigned subgroupSize,
     unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
 
+/// Disjoint counterpart of `populateGpuLowerSubgroupReduceToShufflePatterns`
+/// that only matches `gpu.subgroup_reduce` ops with a `cluster_size`.
+void populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+    RewritePatternSet &patterns, unsigned subgroupSize,
+    unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
+
 /// Collect all patterns to rewrite ops within the GPU dialect.
 inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
   populateGpuAllReducePatterns(patterns);
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index b166f1cd469a4d..56e53a806843ed 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -210,13 +210,24 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
 struct ScalarSubgroupReduceToShuffles final
     : OpRewritePattern<gpu::SubgroupReduceOp> {
   ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
-                                 unsigned shuffleBitwidth,
+                                 unsigned shuffleBitwidth, bool matchClustered,
                                  PatternBenefit benefit)
       : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
-        shuffleBitwidth(shuffleBitwidth) {}
+        shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
+    if (op.getClusterSize().has_value() != matchClustered) {
+      if (matchClustered)
+        return rewriter.notifyMatchFailure(
+            op, "op is non-clustered but pattern is configured to only match "
+                "clustered ops");
+      else
+        return rewriter.notifyMatchFailure(
+            op, "op is clustered but pattern is configured to only match "
+                "non-clustered ops");
+    }
+
     auto ci = getAndValidateClusterInfo(op, subgroupSize);
     if (failed(ci))
       return failure();
@@ -262,19 +273,31 @@ struct ScalarSubgroupReduceToShuffles final
 private:
   unsigned subgroupSize = 0;
   unsigned shuffleBitwidth = 0;
+  bool matchClustered;
 };
 
 /// Lowers vector gpu subgroup reductions to a series of shuffles.
 struct VectorSubgroupReduceToShuffles final
     : OpRewritePattern<gpu::SubgroupReduceOp> {
   VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
-                                 unsigned shuffleBitwidth,
+                                 unsigned shuffleBitwidth, bool matchClustered,
                                  PatternBenefit benefit)
       : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
-        shuffleBitwidth(shuffleBitwidth) {}
+        shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
+    if (op.getClusterSize().has_value() != matchClustered) {
+      if (matchClustered)
+        return rewriter.notifyMatchFailure(
+            op, "op is non-clustered but pattern is configured to only match "
+                "clustered ops");
+      else
+        return rewriter.notifyMatchFailure(
+            op, "op is clustered but pattern is configured to only match "
+                "non-clustered ops");
+    }
+
     auto ci = getAndValidateClusterInfo(op, subgroupSize);
     if (failed(ci))
       return failure();
@@ -343,6 +366,7 @@ struct VectorSubgroupReduceToShuffles final
 private:
   unsigned subgroupSize = 0;
   unsigned shuffleBitwidth = 0;
+  bool matchClustered;
 };
 } // namespace
 
@@ -358,5 +382,14 @@ void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
     RewritePatternSet &patterns, unsigned subgroupSize,
     unsigned shuffleBitwidth, PatternBenefit benefit) {
   patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
-      patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
+      patterns.getContext(), subgroupSize, shuffleBitwidth,
+      /*matchClustered=*/false, benefit);
+}
+
+void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+    RewritePatternSet &patterns, unsigned subgroupSize,
+    unsigned shuffleBitwidth, PatternBenefit benefit) {
+  patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
+      patterns.getContext(), subgroupSize, shuffleBitwidth,
+      /*matchClustered=*/true, benefit);
 }
diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
index 99a914506b011a..74d057c0b7b6cb 100644
--- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
@@ -78,9 +78,12 @@ struct TestGpuSubgroupReduceLoweringPass
     populateGpuBreakDownSubgroupReducePatterns(patterns,
                                                /*maxShuffleBitwidth=*/32,
                                                PatternBenefit(2));
-    if (expandToShuffles)
+    if (expandToShuffles) {
       populateGpuLowerSubgroupReduceToShufflePatterns(
           patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
+      populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
+          patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
+    }
 
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/109158


More information about the Mlir-commits mailing list