[Mlir-commits] [mlir] [mlir][gpu] Add patterns to break down subgroup reduce (PR #76271)

Kunwar Grover llvmlistbot at llvm.org
Thu Dec 28 10:38:36 PST 2023


================
@@ -0,0 +1,139 @@
+//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements gradual lowering of `gpu.subgroup_reduce` ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+namespace {
+
+/// Example:
+/// ```
+/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
+///  ==>
+/// %v0 = arith.constant dense<0.0> : vector<3xf16>
+/// %e0 = vector.extract_strided_slice %x
+///   {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
+/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
+/// %v1 = vector.insert_strided_slice %r0, %v0
+///   {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
+/// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
+/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
+/// %a  = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
+/// ```
+struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
+  BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
+                          PatternBenefit benefit)
+      : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
+  }
+
+  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto vecTy = dyn_cast<VectorType>(op.getType());
+    if (!vecTy || vecTy.getNumElements() < 2)
+      return rewriter.notifyMatchFailure(op, "not a multireduction");
+
+    assert(vecTy.getRank() == 1 && "Unexpected vector type");
+    assert(!vecTy.isScalable() && "Unexpected vector type");
+
+    Type elemTy = vecTy.getElementType();
+    unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
+    if (elemBitwidth >= maxShuffleBitwidth)
+      return rewriter.notifyMatchFailure(
+          op, "large element type, nothing to break down");
----------------
Groverkss wrote:

Can we give a more descriptive error message like "element type is larger than maxShuffleBitwidth, nothing to break down"

https://github.com/llvm/llvm-project/pull/76271


More information about the Mlir-commits mailing list