[Mlir-commits] [mlir] [mlir][gpu] Add patterns to break down subgroup reduce (PR #76271)
Kunwar Grover
llvmlistbot at llvm.org
Thu Dec 28 10:38:36 PST 2023
================
@@ -0,0 +1,139 @@
+//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements gradual lowering of `gpu.subgroup_reduce` ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+namespace {
+
+/// Example:
+/// ```
+/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
+/// ==>
+/// %v0 = arith.constant dense<0.0> : vector<3xf16>
+/// %e0 = vector.extract_strided_slice %x
+/// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
+/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
+/// %v1 = vector.insert_strided_slice %r0, %v0
+/// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
+/// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
+/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
+/// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
+/// ```
+struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
+ BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
+ PatternBenefit benefit)
+ : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
+ }
+
+ LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+ PatternRewriter &rewriter) const override {
+ auto vecTy = dyn_cast<VectorType>(op.getType());
+ if (!vecTy || vecTy.getNumElements() < 2)
+ return rewriter.notifyMatchFailure(op, "not a multireduction");
+
+ assert(vecTy.getRank() == 1 && "Unexpected vector type");
+ assert(!vecTy.isScalable() && "Unexpected vector type");
+
+ Type elemTy = vecTy.getElementType();
+ unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
+ if (elemBitwidth >= maxShuffleBitwidth)
+ return rewriter.notifyMatchFailure(
+ op, "large element type, nothing to break down");
+
+ unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
+ assert(elementsPerShuffle >= 1);
+
+ unsigned numNewReductions =
+ llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
+ assert(numNewReductions >= 1);
+ if (numNewReductions == 1)
+ return rewriter.notifyMatchFailure(op, "nothing to break down");
+
+ Location loc = op.getLoc();
+ Value res =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
+
+ for (unsigned i = 0; i != numNewReductions; ++i) {
+ int64_t startIdx = i * elementsPerShuffle;
+ int64_t endIdx =
+ std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
+ int64_t numElems = endIdx - startIdx;
+
+ Value extracted;
+ if (numElems == 1) {
+ extracted =
+ rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
+ } else {
+ extracted = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
+ /*strides=*/1);
+ }
+
+ Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
+ loc, extracted, op.getOp(), op.getUniform());
+ if (numElems == 1) {
+ res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
+ continue;
+ }
+
+ res = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
+ }
+
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+private:
+ unsigned maxShuffleBitwidth = 0;
+};
+
+struct ScalarizeSignleElementReduce final
----------------
Groverkss wrote:
nit: Single?
https://github.com/llvm/llvm-project/pull/76271
More information about the Mlir-commits
mailing list