[Mlir-commits] [mlir] [mlir][gpu] Add patterns to break down subgroup reduce (PR #76271)
Kunwar Grover
llvmlistbot at llvm.org
Thu Dec 28 10:38:36 PST 2023
================
@@ -0,0 +1,139 @@
+//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements gradual lowering of `gpu.subgroup_reduce` ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+namespace {
+
+/// Example:
+/// ```
+/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
+/// ==>
+/// %v0 = arith.constant dense<0.0> : vector<3xf16>
+/// %e0 = vector.extract_strided_slice %x
+/// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
+/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
+/// %v1 = vector.insert_strided_slice %r0, %v0
+/// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
+/// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
+/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
+/// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
+/// ```
+struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
+ BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
+ PatternBenefit benefit)
+ : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
+ }
+
+ LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+ PatternRewriter &rewriter) const override {
+ auto vecTy = dyn_cast<VectorType>(op.getType());
+ if (!vecTy || vecTy.getNumElements() < 2)
+ return rewriter.notifyMatchFailure(op, "not a multireduction");
+
+ assert(vecTy.getRank() == 1 && "Unexpected vector type");
+ assert(!vecTy.isScalable() && "Unexpected vector type");
+
+ Type elemTy = vecTy.getElementType();
+ unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
+ if (elemBitwidth >= maxShuffleBitwidth)
+ return rewriter.notifyMatchFailure(
+ op, "large element type, nothing to break down");
----------------
Groverkss wrote:
Can we give a more descriptive error message like "element type is larger than maxShuffleBitwidth, nothing to break down"
https://github.com/llvm/llvm-project/pull/76271
More information about the Mlir-commits
mailing list