[Mlir-commits] [mlir] [mlir][gpu] Add patterns to break down subgroup reduce (PR #76271)

Kunwar Grover llvmlistbot at llvm.org
Thu Dec 28 10:38:36 PST 2023


================
@@ -0,0 +1,139 @@
+//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements gradual lowering of `gpu.subgroup_reduce` ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+namespace {
+
+/// Example:
+/// ```
+/// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
+///  ==>
+/// %v0 = arith.constant dense<0.0> : vector<3xf16>
+/// %e0 = vector.extract_strided_slice %x
+///   {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
+/// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
+/// %v1 = vector.insert_strided_slice %r0, %v0
+///   {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
+/// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
+/// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
+/// %a  = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
+/// ```
+struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
+  BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
+                          PatternBenefit benefit)
+      : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
+  }
+
+  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto vecTy = dyn_cast<VectorType>(op.getType());
+    if (!vecTy || vecTy.getNumElements() < 2)
+      return rewriter.notifyMatchFailure(op, "not a multireduction");
+
+    assert(vecTy.getRank() == 1 && "Unexpected vector type");
+    assert(!vecTy.isScalable() && "Unexpected vector type");
+
+    Type elemTy = vecTy.getElementType();
+    unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
+    if (elemBitwidth >= maxShuffleBitwidth)
+      return rewriter.notifyMatchFailure(
+          op, "large element type, nothing to break down");
+
+    unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
+    assert(elementsPerShuffle >= 1);
+
+    unsigned numNewReductions =
+        llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
+    assert(numNewReductions >= 1);
+    if (numNewReductions == 1)
+      return rewriter.notifyMatchFailure(op, "nothing to break down");
+
+    Location loc = op.getLoc();
+    Value res =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
+
+    for (unsigned i = 0; i != numNewReductions; ++i) {
+      int64_t startIdx = i * elementsPerShuffle;
+      int64_t endIdx =
+          std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
+      int64_t numElems = endIdx - startIdx;
+
+      Value extracted;
+      if (numElems == 1) {
+        extracted =
+            rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
+      } else {
+        extracted = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
+            /*strides=*/1);
+      }
+
+      Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
+          loc, extracted, op.getOp(), op.getUniform());
+      if (numElems == 1) {
+        res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
+        continue;
+      }
+
+      res = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
+    }
+
+    rewriter.replaceOp(op, res);
+    return success();
+  }
+
+private:
+  unsigned maxShuffleBitwidth = 0;
+};
+
+struct ScalarizeSignleElementReduce final
----------------
Groverkss wrote:

nit: Single?

https://github.com/llvm/llvm-project/pull/76271


More information about the Mlir-commits mailing list