[Mlir-commits] [mlir] [mlir][gpu] Add patterns to break down subgroup reduce (PR #76271)

Jakub Kuderski llvmlistbot at llvm.org
Thu Dec 28 11:14:06 PST 2023

@@ -0,0 +1,139 @@
+//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Implements gradual lowering of `gpu.subgroup_reduce` ops.
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+using namespace mlir;
+namespace {
+/// Example:
+/// ```
+/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
+///  ==>
+/// %v0 = arith.constant dense<0.0> : vector<3xf16>
+/// %e0 = vector.extract_strided_slice %x
+///   {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
+/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
+/// %v1 = vector.insert_strided_slice %r0, %v0
+///   {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
+/// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
+/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
+/// %a  = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
+/// ```
+struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
+  BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
+                          PatternBenefit benefit)
+      : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
+  }
+  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto vecTy = dyn_cast<VectorType>(op.getType());
+    if (!vecTy || vecTy.getNumElements() < 2)
+      return rewriter.notifyMatchFailure(op, "not a multireduction");
+    assert(vecTy.getRank() == 1 && "Unexpected vector type");
+    assert(!vecTy.isScalable() && "Unexpected vector type");
+    Type elemTy = vecTy.getElementType();
+    unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
+    if (elemBitwidth >= maxShuffleBitwidth)
+      return rewriter.notifyMatchFailure(
+          op, "large element type, nothing to break down");
+    unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
+    assert(elementsPerShuffle >= 1);
+    unsigned numNewReductions =
+        llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
+    assert(numNewReductions >= 1);
+    if (numNewReductions == 1)
+      return rewriter.notifyMatchFailure(op, "nothing to break down");
+    Location loc = op.getLoc();
+    Value res =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
+    for (unsigned i = 0; i != numNewReductions; ++i) {
+      int64_t startIdx = i * elementsPerShuffle;
+      int64_t endIdx =
+          std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
+      int64_t numElems = endIdx - startIdx;
+      Value extracted;
+      if (numElems == 1) {
+        extracted =
+            rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
+      } else {
+        extracted = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
+            /*strides=*/1);
+      }
+      Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
+          loc, extracted, op.getOp(), op.getUniform());
+      if (numElems == 1) {
+        res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
+        continue;
+      }
+      res = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
kuhar wrote:

the llvm coding style prefers early exits: https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code. doesn't make much difference here IMO.


More information about the Mlir-commits mailing list