[Mlir-commits] [mlir] [mlir][gpu] Add patterns to break down subgroup reduce (PR #76271)

Jakub Kuderski llvmlistbot at llvm.org
Thu Dec 28 11:34:32 PST 2023


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/76271

>From b6148e207468f304d674daa30fac418c1f3b75f1 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 22 Dec 2023 17:03:03 -0500
Subject: [PATCH 1/4] [mlir][gpu] Add patterns to break down subgroup reduce

The new patterns break down subgroup reduce ops with vector values into
a sequence of subgroup reductions that fit the native shuffle size.
The maximum/native shuffle size is parametrized.

The overall goal is to be able to perform multi-element reductions with
a sequence of `gpu.shuffle` ops.
---
 .../mlir/Dialect/GPU/Transforms/Passes.h      |   7 +
 mlir/lib/Dialect/GPU/CMakeLists.txt           |   5 +-
 .../GPU/Transforms/SubgroupReduceLowering.cpp | 139 ++++++++++++++++++
 .../Dialect/GPU/subgroup-redule-lowering.mlir |  71 +++++++++
 mlir/test/lib/Dialect/GPU/CMakeLists.txt      |   1 +
 mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp  |  28 +++-
 mlir/tools/mlir-opt/mlir-opt.cpp              |   4 +-
 7 files changed, 250 insertions(+), 5 deletions(-)
 create mode 100644 mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
 create mode 100644 mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir

diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index c6c02ccaafbcf4..b905ef2e02aee0 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -15,6 +15,7 @@
 
 #include "Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include <optional>
 
@@ -62,6 +63,12 @@ void populateGpuShufflePatterns(RewritePatternSet &patterns);
 /// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
 void populateGpuAllReducePatterns(RewritePatternSet &patterns);
 
+/// Collect a set of patterns to break down subgroup_reduce ops into smaller
+/// ones supported by the target of size <= `maxShuffleBitwidth`.
+void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns,
+                                               unsigned maxShuffleBitwidth = 32,
+                                               PatternBenefit benefit = 1);
+
 /// Collect all patterns to rewrite ops within the GPU dialect.
 inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
   populateGpuAllReducePatterns(patterns);
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index ab6834cb262fb5..8383e06e6d2478 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -50,19 +50,20 @@ add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/AsyncRegionRewriter.cpp
   Transforms/BufferDeallocationOpInterfaceImpl.cpp
   Transforms/DecomposeMemrefs.cpp
+  Transforms/EliminateBarriers.cpp
   Transforms/GlobalIdRewriter.cpp
   Transforms/KernelOutlining.cpp
   Transforms/MemoryPromotion.cpp
   Transforms/ModuleToBinary.cpp
   Transforms/NVVMAttachTarget.cpp
   Transforms/ParallelLoopMapper.cpp
+  Transforms/ROCDLAttachTarget.cpp
   Transforms/SerializeToBlob.cpp
   Transforms/SerializeToCubin.cpp
   Transforms/SerializeToHsaco.cpp
   Transforms/ShuffleRewriter.cpp
   Transforms/SPIRVAttachTarget.cpp
-  Transforms/ROCDLAttachTarget.cpp
-  Transforms/EliminateBarriers.cpp
+  Transforms/SubgroupReduceLowering.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
new file mode 100644
index 00000000000000..07700cfa3c2a22
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -0,0 +1,139 @@
+//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements gradual lowering of `gpu.subgroup_reduce` ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+namespace {
+
+/// Example:
+/// ```
+/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
+///  ==>
+/// %v0 = arith.constant dense<0.0> : vector<3xf16>
+/// %e0 = vector.extract_strided_slice %x
+///   {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
+/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
+/// %v1 = vector.insert_strided_slice %r0, %v0
+///   {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
+/// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
+/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
+/// %a  = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
+/// ```
+struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
+  BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
+                          PatternBenefit benefit)
+      : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
+  }
+
+  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto vecTy = dyn_cast<VectorType>(op.getType());
+    if (!vecTy || vecTy.getNumElements() < 2)
+      return rewriter.notifyMatchFailure(op, "not a multireduction");
+
+    assert(vecTy.getRank() == 1 && "Unexpected vector type");
+    assert(!vecTy.isScalable() && "Unexpected vector type");
+
+    Type elemTy = vecTy.getElementType();
+    unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
+    if (elemBitwidth >= maxShuffleBitwidth)
+      return rewriter.notifyMatchFailure(
+          op, "large element type, nothing to break down");
+
+    unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
+    assert(elementsPerShuffle >= 1);
+
+    unsigned numNewReductions =
+        llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
+    assert(numNewReductions >= 1);
+    if (numNewReductions == 1)
+      return rewriter.notifyMatchFailure(op, "nothing to break down");
+
+    Location loc = op.getLoc();
+    Value res =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
+
+    for (unsigned i = 0; i != numNewReductions; ++i) {
+      int64_t startIdx = i * elementsPerShuffle;
+      int64_t endIdx =
+          std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
+      int64_t numElems = endIdx - startIdx;
+
+      Value extracted;
+      if (numElems == 1) {
+        extracted =
+            rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
+      } else {
+        extracted = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
+            /*strides=*/1);
+      }
+
+      Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
+          loc, extracted, op.getOp(), op.getUniform());
+      if (numElems == 1) {
+        res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
+        continue;
+      }
+
+      res = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
+    }
+
+    rewriter.replaceOp(op, res);
+    return success();
+  }
+
+  private:
+  unsigned maxShuffleBitwidth = 0;
+};
+
+struct ScalarizeSignleElementReduce final
+    : OpRewritePattern<gpu::SubgroupReduceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto vecTy = dyn_cast<VectorType>(op.getType());
+    if (!vecTy || vecTy.getNumElements() != 1)
+      return rewriter.notifyMatchFailure(op, "not a single-element reduction");
+
+    assert(vecTy.getRank() == 1 && "Unexpected vector type");
+    assert(!vecTy.isScalable() && "Unexpected vector type");
+    Location loc = op.getLoc();
+    Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
+    Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
+        loc, extracted, op.getOp(), op.getUniform());
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::populateGpuBreakDownSubgrupReducePatterns(
+    RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
+    PatternBenefit benefit) {
+  patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
+                                        maxShuffleBitwidth, benefit);
+  patterns.add<ScalarizeSignleElementReduce>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
new file mode 100644
index 00000000000000..b7146071bf2fd8
--- /dev/null
+++ b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt --allow-unregistered-dialect --test-gpu-subgroup-reduce-lowering %s | FileCheck %s
+
+// CHECK: gpu.module @kernels {
+gpu.module @kernels {
+
+  // CHECK-LABEL: gpu.func @kernel0(
+  // CHECK-SAME: %[[ARG0:.+]]: vector<5xf16>)
+  gpu.func @kernel0(%arg0: vector<5xf16>) kernel {
+    // CHECK: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
+    // CHECK: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+    // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16>
+    // CHECK: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16>
+    // CHECK: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+    // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16>
+    // CHECK: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16>
+    // CHECK: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16>
+    // CHECK: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16
+    // CHECK: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16>
+    // CHECK: "test.consume"(%[[V2]]) : (vector<5xf16>) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 : (vector<5xf16>) -> (vector<5xf16>)
+    "test.consume"(%sum0) : (vector<5xf16>) -> ()
+
+
+    // CHECK-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform
+    // CHECK: "test.consume"
+    %sum1 = gpu.subgroup_reduce mul %arg0 uniform : (vector<5xf16>) -> (vector<5xf16>)
+    "test.consume"(%sum1) : (vector<5xf16>) -> ()
+
+    // CHECK: gpu.return
+    gpu.return
+  }
+
+  // CHECK-LABEL: gpu.func @kernel1(
+  // CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>)
+  gpu.func @kernel1(%arg0: vector<1xf32>) kernel {
+    // CHECK: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+    // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
+    // CHECK: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32>
+    // CHECK: "test.consume"(%[[V0]]) : (vector<1xf32>) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 : (vector<1xf32>) -> (vector<1xf32>)
+    "test.consume"(%sum0) : (vector<1xf32>) -> ()
+
+    // CHECK: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32
+    // CHECK: "test.consume"
+    %sum1 = gpu.subgroup_reduce add %arg0 uniform : (vector<1xf32>) -> (vector<1xf32>)
+    "test.consume"(%sum1) : (vector<1xf32>) -> ()
+
+    // CHECK: gpu.return
+    gpu.return
+  }
+
+  // These vectors fit the native shuffle size and should not be broken down.
+  //
+  // CHECK-LABEL: gpu.func @kernel2(
+  // CHECK-SAME: %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>)
+  gpu.func @kernel2(%arg0: vector<3xi8>, %arg1: vector<4xi8>) kernel {
+    // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8>
+    // CHECK: "test.consume"(%[[R0]]) : (vector<3xi8>) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 : (vector<3xi8>) -> (vector<3xi8>)
+    "test.consume"(%sum0) : (vector<3xi8>) -> ()
+
+    // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[ARG1]] : (vector<4xi8>) -> vector<4xi8>
+    // CHECK: "test.consume"(%[[R1]]) : (vector<4xi8>) -> ()
+    %sum1 = gpu.subgroup_reduce add %arg1 : (vector<4xi8>) -> (vector<4xi8>)
+    "test.consume"(%sum1) : (vector<4xi8>) -> ()
+
+    // CHECK: gpu.return
+    gpu.return
+  }
+
+}
diff --git a/mlir/test/lib/Dialect/GPU/CMakeLists.txt b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
index aa94bce275eafb..48cbc4ad5505b0 100644
--- a/mlir/test/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
@@ -27,6 +27,7 @@ set(LIBS
   MLIRTransforms
   MLIRTransformUtils
   MLIRTranslateLib
+  MLIRVectorDialect
   MLIRVectorToLLVMPass
   )
 
diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
index db65f3bccec52d..4e8f0cc6667524 100644
--- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -39,10 +40,35 @@ struct TestGpuRewritePass
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
+
+struct TestGpuSubgroupReduceLoweringPass
+    : public PassWrapper<TestGpuSubgroupReduceLoweringPass,
+                         OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestGpuSubgroupReduceLoweringPass)
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, func::FuncDialect, index::IndexDialect,
+                    memref::MemRefDialect, vector::VectorDialect>();
+  }
+  StringRef getArgument() const final {
+    return "test-gpu-subgroup-reduce-lowering";
+  }
+  StringRef getDescription() const final {
+    return "Applies gpu.subgroup_reduce lowering patterns.";
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateGpuBreakDownSubgrupReducePatterns(patterns,
+                                              /*maxShuffleBitwidth=*/32);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
 } // namespace
 
 namespace mlir {
-void registerTestAllReduceLoweringPass() {
+void registerTestGpuLoweringPasses() {
   PassRegistration<TestGpuRewritePass>();
+  PassRegistration<TestGpuSubgroupReduceLoweringPass>();
 }
 } // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index eedade691c6c39..dc4121dc46bb9b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -47,7 +47,7 @@ void registerTestAffineReifyValueBoundsPass();
 void registerTestBytecodeRoundtripPasses();
 void registerTestDecomposeAffineOpPass();
 void registerTestAffineLoopUnswitchingPass();
-void registerTestAllReduceLoweringPass();
+void registerTestGpuLoweringPasses();
 void registerTestFunc();
 void registerTestGpuMemoryPromotionPass();
 void registerTestLoopPermutationPass();
@@ -167,7 +167,7 @@ void registerTestPasses() {
   registerTestAffineReifyValueBoundsPass();
   registerTestDecomposeAffineOpPass();
   registerTestAffineLoopUnswitchingPass();
-  registerTestAllReduceLoweringPass();
+  registerTestGpuLoweringPasses();
   registerTestBytecodeRoundtripPasses();
   registerTestFunc();
   registerTestGpuMemoryPromotionPass();

>From 93aa0a13128ccc094be13a34eb4e6a4aec755091 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 22 Dec 2023 17:27:34 -0500
Subject: [PATCH 2/4] Format

---
 mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 07700cfa3c2a22..de17a05eb15a43 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -103,7 +103,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
     return success();
   }
 
-  private:
+private:
   unsigned maxShuffleBitwidth = 0;
 };
 

>From ae8478190b948ffb650a1202c180ecc657d13add Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Thu, 28 Dec 2023 14:29:37 -0500
Subject: [PATCH 3/4] Improve documentation

---
 .../mlir/Dialect/GPU/Transforms/Passes.h      |  3 ++-
 .../GPU/Transforms/SubgroupReduceLowering.cpp | 21 ++++++++++++++-----
 2 files changed, 18 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index b905ef2e02aee0..6c5bf75d212478 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -64,7 +64,8 @@ void populateGpuShufflePatterns(RewritePatternSet &patterns);
 void populateGpuAllReducePatterns(RewritePatternSet &patterns);
 
 /// Collect a set of patterns to break down subgroup_reduce ops into smaller
-/// ones supported by the target of size <= `maxShuffleBitwidth`.
+/// ones supported by the target of `size <= maxShuffleBitwidth`, where `size`
+/// is the subgroup_reduce value bitwidth.
 void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns,
                                                unsigned maxShuffleBitwidth = 32,
                                                PatternBenefit benefit = 1);
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index de17a05eb15a43..61edce5e2a0862 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
 #include <cassert>
 
@@ -24,7 +25,7 @@ using namespace mlir;
 
 namespace {
 
-/// Example:
+/// Example, assumes `maxShuffleBitwidth` equal to 32:
 /// ```
 /// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
 ///  ==>
@@ -48,7 +49,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
                                 PatternRewriter &rewriter) const override {
     auto vecTy = dyn_cast<VectorType>(op.getType());
     if (!vecTy || vecTy.getNumElements() < 2)
-      return rewriter.notifyMatchFailure(op, "not a multireduction");
+      return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
 
     assert(vecTy.getRank() == 1 && "Unexpected vector type");
     assert(!vecTy.isScalable() && "Unexpected vector type");
@@ -57,7 +58,9 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
     unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
     if (elemBitwidth >= maxShuffleBitwidth)
       return rewriter.notifyMatchFailure(
-          op, "large element type, nothing to break down");
+          op, llvm::formatv("element type too large {0}, cannot break down "
+                            "into vectors of bitwidth {1} or less",
+                            elemBitwidth, maxShuffleBitwidth));
 
     unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
     assert(elementsPerShuffle >= 1);
@@ -107,7 +110,15 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
   unsigned maxShuffleBitwidth = 0;
 };
 
-struct ScalarizeSignleElementReduce final
+/// Example:
+/// ```
+/// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32>
+///  ==>
+/// %e0 = vector.extract %x[0] : f32 from vector<1xf32>
+/// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32
+/// %a = vector.broadcast %r0 : f32 to vector<1xf32>
+/// ```
+struct ScalarizeSingleElementReduce final
     : OpRewritePattern<gpu::SubgroupReduceOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -135,5 +146,5 @@ void mlir::populateGpuBreakDownSubgrupReducePatterns(
     PatternBenefit benefit) {
   patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
                                         maxShuffleBitwidth, benefit);
-  patterns.add<ScalarizeSignleElementReduce>(patterns.getContext(), benefit);
+  patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
 }

>From 63fe83c9ea75091330524ad7be35d5cc43a97cc1 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Thu, 28 Dec 2023 14:31:21 -0500
Subject: [PATCH 4/4] Remove unnecessary dependent dialects from the pass

---
 mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
index 4e8f0cc6667524..21cc89c0d89b0f 100644
--- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
@@ -48,8 +48,7 @@ struct TestGpuSubgroupReduceLoweringPass
       TestGpuSubgroupReduceLoweringPass)
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<arith::ArithDialect, func::FuncDialect, index::IndexDialect,
-                    memref::MemRefDialect, vector::VectorDialect>();
+    registry.insert<arith::ArithDialect, vector::VectorDialect>();
   }
   StringRef getArgument() const final {
     return "test-gpu-subgroup-reduce-lowering";



More information about the Mlir-commits mailing list