[Mlir-commits] [mlir] c0345b4 - [mlir][gpu] Add subgroup_reduce to shuffle lowering (#76530)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 2 13:14:27 PST 2024


Author: Jakub Kuderski
Date: 2024-01-02T16:14:22-05:00
New Revision: c0345b4648608b44071bbb6b012b24bd07c3d29e

URL: https://github.com/llvm/llvm-project/commit/c0345b4648608b44071bbb6b012b24bd07c3d29e
DIFF: https://github.com/llvm/llvm-project/commit/c0345b4648608b44071bbb6b012b24bd07c3d29e.diff

LOG: [mlir][gpu] Add subgroup_reduce to shuffle lowering (#76530)

This supports both the scalar and the vector multi-reduction cases.

Added: 
    mlir/lib/Dialect/GPU/Transforms/Utils.cpp

Modified: 
    mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
    mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
    mlir/lib/Dialect/GPU/CMakeLists.txt
    mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
    mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
    mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
    mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 6c5bf75d212478..5885facd07541e 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -70,6 +70,13 @@ void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns,
                                                unsigned maxShuffleBitwidth = 32,
                                                PatternBenefit benefit = 1);
 
+/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `gpu.shuffle`
+/// ops over `shuffleBitwidth` scalar types. Assumes that the subgroup has
+/// `subgroupSize` lanes. Uses the butterfly shuffle algorithm.
+void populateGpuLowerSubgroupReduceToShufflePattenrs(
+    RewritePatternSet &patterns, unsigned subgroupSize,
+    unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
+
 /// Collect all patterns to rewrite ops within the GPU dialect.
 inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
   populateGpuAllReducePatterns(patterns);

diff  --git a/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
index a426bee7686dbc..f25c506fd638d8 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
@@ -13,6 +13,8 @@
 #ifndef MLIR_DIALECT_GPU_TRANSFORMS_UTILS_H_
 #define MLIR_DIALECT_GPU_TRANSFORMS_UTILS_H_
 
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Support/LLVM.h"
 
 #include <string>
@@ -28,6 +30,9 @@ class LaunchOp;
 
 /// Returns the default annotation name for GPU binary blobs.
 std::string getDefaultGpuBinaryAnnotation();
+
+/// Returns the matching vector combining kind.
+vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode);
 } // namespace gpu
 
 /// Get a gpu.func created from outlining the region of a gpu.launch op with the

diff  --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 8383e06e6d2478..8f289ce9452e80 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -64,6 +64,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/ShuffleRewriter.cpp
   Transforms/SPIRVAttachTarget.cpp
   Transforms/SubgroupReduceLowering.cpp
+  Transforms/Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU

diff  --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index 608d801ee9bbbe..a75598afe8c72d 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -27,33 +27,6 @@ using namespace mlir;
 
 namespace {
 
-static vector::CombiningKind
-convertReductionKind(gpu::AllReduceOperation mode) {
-  switch (mode) {
-#define MAP_CASE(X)                                                            \
-  case gpu::AllReduceOperation::X:                                             \
-    return vector::CombiningKind::X
-
-    MAP_CASE(ADD);
-    MAP_CASE(MUL);
-    MAP_CASE(MINUI);
-    MAP_CASE(MINSI);
-    MAP_CASE(MINNUMF);
-    MAP_CASE(MAXSI);
-    MAP_CASE(MAXUI);
-    MAP_CASE(MAXNUMF);
-    MAP_CASE(AND);
-    MAP_CASE(OR);
-    MAP_CASE(XOR);
-    MAP_CASE(MINIMUMF);
-    MAP_CASE(MAXIMUMF);
-
-#undef MAP_CASE
-  }
-
-  llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
-}
-
 struct GpuAllReduceRewriter {
   using AccumulatorFactory = std::function<Value(Value, Value)>;
 

diff  --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 61edce5e2a0862..b00c65c4c62483 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -13,13 +13,17 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/Transforms/Utils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
 #include <cassert>
+#include <cstdint>
 
 using namespace mlir;
 
@@ -58,7 +62,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
     unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
     if (elemBitwidth >= maxShuffleBitwidth)
       return rewriter.notifyMatchFailure(
-          op, llvm::formatv("element type too large {0}, cannot break down "
+          op, llvm::formatv("element type too large ({0}), cannot break down "
                             "into vectors of bitwidth {1} or less",
                             elemBitwidth, maxShuffleBitwidth));
 
@@ -139,6 +143,167 @@ struct ScalarizeSingleElementReduce final
   }
 };
 
+/// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
+/// and `unpackFn` to convert to the native shuffle type and to the reduction
+/// type, respectively. For example, with `input` of type `f16`, `packFn` could
+/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
+/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
+/// the subgroup is `subgroupSize` lanes wide and reduces across all of them.
+static Value createSubgroupShuffleReduction(
+    OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
+    unsigned subgroupSize, function_ref<Value(Value)> packFn,
+    function_ref<Value(Value)> unpackFn) {
+  assert(llvm::isPowerOf2_32(subgroupSize));
+  // Lane value always stays in the original type. We use it to perform arith
+  // reductions.
+  Value laneVal = input;
+  // Parallel reduction using butterfly shuffles.
+  for (unsigned i = 1; i < subgroupSize; i <<= 1) {
+    Value shuffled = builder
+                         .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
+                                                 /*width=*/subgroupSize,
+                                                 /*mode=*/gpu::ShuffleMode::XOR)
+                         .getShuffleResult();
+    laneVal = vector::makeArithReduction(builder, loc,
+                                         gpu::convertReductionKind(mode),
+                                         laneVal, unpackFn(shuffled));
+    assert(laneVal.getType() == input.getType());
+  }
+
+  return laneVal;
+}
+
+/// Lowers scalar gpu subgroup reductions to a series of shuffles.
+struct ScalarSubgroupReduceToShuffles final
+    : OpRewritePattern<gpu::SubgroupReduceOp> {
+  ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
+                                 unsigned shuffleBitwidth,
+                                 PatternBenefit benefit)
+      : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
+        shuffleBitwidth(shuffleBitwidth) {}
+
+  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+                                PatternRewriter &rewriter) const override {
+    Type valueTy = op.getType();
+    unsigned elemBitwidth =
+        getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth();
+    if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
+      return rewriter.notifyMatchFailure(
+          op, "value type is not a compatible scalar");
+
+    Location loc = op.getLoc();
+    // Since this is already a native shuffle scalar, no packing is necessary.
+    if (elemBitwidth == shuffleBitwidth) {
+      auto identityFn = [](Value v) { return v; };
+      rewriter.replaceOp(op, createSubgroupShuffleReduction(
+                                 rewriter, loc, op.getValue(), op.getOp(),
+                                 subgroupSize, identityFn, identityFn));
+      return success();
+    }
+
+    auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
+    auto equivIntType = rewriter.getIntegerType(elemBitwidth);
+    auto packFn = [loc, &rewriter, equivIntType,
+                   shuffleIntType](Value unpackedVal) -> Value {
+      auto asInt =
+          rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
+      return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
+    };
+    auto unpackFn = [loc, &rewriter, equivIntType,
+                     valueTy](Value packedVal) -> Value {
+      auto asInt =
+          rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
+      return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
+    };
+
+    rewriter.replaceOp(op, createSubgroupShuffleReduction(
+                               rewriter, loc, op.getValue(), op.getOp(),
+                               subgroupSize, packFn, unpackFn));
+    return success();
+  }
+
+private:
+  unsigned subgroupSize = 0;
+  unsigned shuffleBitwidth = 0;
+};
+
+/// Lowers vector gpu subgroup reductions to a series of shuffles.
+struct VectorSubgroupReduceToShuffles final
+    : OpRewritePattern<gpu::SubgroupReduceOp> {
+  VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
+                                 unsigned shuffleBitwidth,
+                                 PatternBenefit benefit)
+      : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
+        shuffleBitwidth(shuffleBitwidth) {}
+
+  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto vecTy = dyn_cast<VectorType>(op.getType());
+    if (!vecTy)
+      return rewriter.notifyMatchFailure(op, "value type is not a vector");
+
+    unsigned vecBitwidth =
+        vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
+    if (vecBitwidth > shuffleBitwidth)
+      return rewriter.notifyMatchFailure(
+          op,
+          llvm::formatv("vector type bitwidth too large ({0}), cannot lower "
+                        "to shuffles of size {1}",
+                        vecBitwidth, shuffleBitwidth));
+
+    unsigned elementsPerShuffle =
+        shuffleBitwidth / vecTy.getElementTypeBitWidth();
+    if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
+      return rewriter.notifyMatchFailure(
+          op, "shuffle bitwidth is not a multiple of the element bitwidth");
+
+    Location loc = op.getLoc();
+
+    // If the reduced type is smaller than the native shuffle size, extend it,
+    // perform the shuffles, and extract at the end.
+    auto extendedVecTy = VectorType::get(
+        static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
+    Value extendedInput = op.getValue();
+    if (vecBitwidth < shuffleBitwidth) {
+      auto zero = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getZeroAttr(extendedVecTy));
+      extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
+    }
+
+    auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
+    auto shuffleVecType = VectorType::get(1, shuffleIntType);
+
+    auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
+      auto asIntVec =
+          rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
+      return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
+    };
+    auto unpackFn = [loc, &rewriter, shuffleVecType,
+                     extendedVecTy](Value packedVal) -> Value {
+      auto asIntVec =
+          rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
+      return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
+    };
+
+    Value res =
+        createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
+                                       subgroupSize, packFn, unpackFn);
+
+    if (vecBitwidth < shuffleBitwidth) {
+      res = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
+          /*strides=*/1);
+    }
+
+    rewriter.replaceOp(op, res);
+    return success();
+  }
+
+private:
+  unsigned subgroupSize = 0;
+  unsigned shuffleBitwidth = 0;
+};
 } // namespace
 
 void mlir::populateGpuBreakDownSubgrupReducePatterns(
@@ -148,3 +313,10 @@ void mlir::populateGpuBreakDownSubgrupReducePatterns(
                                         maxShuffleBitwidth, benefit);
   patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
 }
+
+void mlir::populateGpuLowerSubgroupReduceToShufflePattenrs(
+    RewritePatternSet &patterns, unsigned subgroupSize,
+    unsigned shuffleBitwidth, PatternBenefit benefit) {
+  patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
+      patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
+}

diff  --git a/mlir/lib/Dialect/GPU/Transforms/Utils.cpp b/mlir/lib/Dialect/GPU/Transforms/Utils.cpp
new file mode 100644
index 00000000000000..e91aa18128c7b9
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/Utils.cpp
@@ -0,0 +1,44 @@
+//===- Utils.cpp - GPU transforms utils -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements GPU dialect transforms utils.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Transforms/Utils.h"
+#include "llvm/Support/ErrorHandling.h"
+
+namespace mlir::gpu {
+
+vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode) {
+  switch (mode) {
+#define MAP_CASE(X)                                                            \
+  case gpu::AllReduceOperation::X:                                             \
+    return vector::CombiningKind::X
+
+    MAP_CASE(ADD);
+    MAP_CASE(MUL);
+    MAP_CASE(MINUI);
+    MAP_CASE(MINSI);
+    MAP_CASE(MINNUMF);
+    MAP_CASE(MAXSI);
+    MAP_CASE(MAXUI);
+    MAP_CASE(MAXNUMF);
+    MAP_CASE(AND);
+    MAP_CASE(OR);
+    MAP_CASE(XOR);
+    MAP_CASE(MINIMUMF);
+    MAP_CASE(MAXIMUMF);
+
+#undef MAP_CASE
+  }
+
+  llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
+}
+
+} // namespace mlir::gpu

diff  --git a/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
index b7146071bf2fd8..f04a01ffe75d3c 100644
--- a/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
+++ b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
@@ -1,71 +1,191 @@
-// RUN: mlir-opt --allow-unregistered-dialect --test-gpu-subgroup-reduce-lowering %s | FileCheck %s
+// RUN: mlir-opt  --allow-unregistered-dialect \
+// RUN:   --test-gpu-subgroup-reduce-lowering %s \
+// RUN:   | FileCheck %s --check-prefix=CHECK-SUB
 
-// CHECK: gpu.module @kernels {
+// RUN: mlir-opt --allow-unregistered-dialect \
+// RUN:   --test-gpu-subgroup-reduce-lowering="expand-to-shuffles" %s \
+// RUN:   | FileCheck %s --check-prefix=CHECK-SHFL
+
+// CHECK-SUB:  gpu.module @kernels {
+// CHECK-SHFL: gpu.module @kernels {
 gpu.module @kernels {
 
-  // CHECK-LABEL: gpu.func @kernel0(
-  // CHECK-SAME: %[[ARG0:.+]]: vector<5xf16>)
+  // CHECK-SUB-LABEL:  gpu.func @kernel0(
+  // CHECK-SUB-SAME:     %[[ARG0:.+]]: vector<5xf16>)
+  //
+  // CHECK-SHFL-LABEL: gpu.func @kernel0(
   gpu.func @kernel0(%arg0: vector<5xf16>) kernel {
-    // CHECK: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
-    // CHECK: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
-    // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16>
-    // CHECK: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16>
-    // CHECK: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
-    // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16>
-    // CHECK: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16>
-    // CHECK: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16>
-    // CHECK: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16
-    // CHECK: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16>
-    // CHECK: "test.consume"(%[[V2]]) : (vector<5xf16>) -> ()
+    // CHECK-SUB: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
+    // CHECK-SUB: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+    // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16>
+    // CHECK-SUB: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16>
+    // CHECK-SUB: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+    // CHECK-SUB: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16>
+    // CHECK-SUB: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16>
+    // CHECK-SUB: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16>
+    // CHECK-SUB: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16
+    // CHECK-SUB: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16>
+    // CHECK-SUB: "test.consume"(%[[V2]]) : (vector<5xf16>) -> ()
     %sum0 = gpu.subgroup_reduce add %arg0 : (vector<5xf16>) -> (vector<5xf16>)
     "test.consume"(%sum0) : (vector<5xf16>) -> ()
 
-
-    // CHECK-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform
-    // CHECK: "test.consume"
+    // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform
+    // CHECK-SUB: "test.consume"
     %sum1 = gpu.subgroup_reduce mul %arg0 uniform : (vector<5xf16>) -> (vector<5xf16>)
     "test.consume"(%sum1) : (vector<5xf16>) -> ()
 
-    // CHECK: gpu.return
+    // CHECK-SUB: gpu.return
     gpu.return
   }
 
-  // CHECK-LABEL: gpu.func @kernel1(
-  // CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>)
+  // CHECK-SUB-LABEL:  gpu.func @kernel1(
+  // CHECK-SUB-SAME:     %[[ARG0:.+]]: vector<1xf32>)
+  //
+  // CHECK-SHFL-LABEL: gpu.func @kernel1(
   gpu.func @kernel1(%arg0: vector<1xf32>) kernel {
-    // CHECK: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
-    // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
-    // CHECK: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32>
-    // CHECK: "test.consume"(%[[V0]]) : (vector<1xf32>) -> ()
+    // CHECK-SUB: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+    // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
+    // CHECK-SUB: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32>
+    // CHECK-SUB: "test.consume"(%[[V0]]) : (vector<1xf32>) -> ()
     %sum0 = gpu.subgroup_reduce add %arg0 : (vector<1xf32>) -> (vector<1xf32>)
     "test.consume"(%sum0) : (vector<1xf32>) -> ()
 
-    // CHECK: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32
-    // CHECK: "test.consume"
+    // CHECK-SUB: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32
+    // CHECK-SUB: "test.consume"
     %sum1 = gpu.subgroup_reduce add %arg0 uniform : (vector<1xf32>) -> (vector<1xf32>)
     "test.consume"(%sum1) : (vector<1xf32>) -> ()
 
-    // CHECK: gpu.return
+    // CHECK-SUB: gpu.return
     gpu.return
   }
 
   // These vectors fit the native shuffle size and should not be broken down.
   //
-  // CHECK-LABEL: gpu.func @kernel2(
-  // CHECK-SAME: %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>)
+  // CHECK-SUB-LABEL:  gpu.func @kernel2(
+  // CHECK-SUB-SAME:     %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>)
+  //
+  // CHECK-SHFL-LABEL: gpu.func @kernel2(
   gpu.func @kernel2(%arg0: vector<3xi8>, %arg1: vector<4xi8>) kernel {
-    // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8>
-    // CHECK: "test.consume"(%[[R0]]) : (vector<3xi8>) -> ()
+    // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8>
+    // CHECK-SUB: "test.consume"(%[[R0]]) : (vector<3xi8>) -> ()
     %sum0 = gpu.subgroup_reduce add %arg0 : (vector<3xi8>) -> (vector<3xi8>)
     "test.consume"(%sum0) : (vector<3xi8>) -> ()
 
-    // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[ARG1]] : (vector<4xi8>) -> vector<4xi8>
-    // CHECK: "test.consume"(%[[R1]]) : (vector<4xi8>) -> ()
+    // CHECK-SUB: %[[R1:.+]] = gpu.subgroup_reduce add %[[ARG1]] : (vector<4xi8>) -> vector<4xi8>
+    // CHECK-SUB: "test.consume"(%[[R1]]) : (vector<4xi8>) -> ()
     %sum1 = gpu.subgroup_reduce add %arg1 : (vector<4xi8>) -> (vector<4xi8>)
     "test.consume"(%sum1) : (vector<4xi8>) -> ()
 
-    // CHECK: gpu.return
+    // CHECK-SUB: gpu.return
+    gpu.return
+  }
+
+  // CHECK-SHFL-LABEL: gpu.func @kernel3(
+  // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i32)
+  gpu.func @kernel3(%arg0: i32) kernel {
+    // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
+    // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
+    // CHECK-SHFL-DAG: %[[C4:.+]] = arith.constant 4 : i32
+    // CHECK-SHFL-DAG: %[[C8:.+]] = arith.constant 8 : i32
+    // CHECK-SHFL-DAG: %[[C16:.+]] = arith.constant 16 : i32
+    // CHECK-SHFL-DAG: %[[C32:.+]] = arith.constant 32 : i32
+
+    // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[ARG0]], %[[C1]], %[[C32]] : i32
+    // CHECK-SHFL: %[[A0:.+]] = arith.addi %[[ARG0]], %[[S0]] : i32
+    // CHECK-SHFL: %[[S1:.+]], %{{.+}} = gpu.shuffle xor %[[A0]], %[[C2]], %[[C32]] : i32
+    // CHECK-SHFL: %[[A1:.+]] = arith.addi %[[A0]], %[[S1]] : i32
+    // CHECK-SHFL: %[[S2:.+]], %{{.+}} = gpu.shuffle xor %[[A1]], %[[C4]], %[[C32]] : i32
+    // CHECK-SHFL: %[[A2:.+]] = arith.addi %[[A1]], %[[S2]] : i32
+    // CHECK-SHFL: %[[S3:.+]], %{{.+}} = gpu.shuffle xor %[[A2]], %[[C8]], %[[C32]] : i32
+    // CHECK-SHFL: %[[A3:.+]] = arith.addi %[[A2]], %[[S3]] : i32
+    // CHECK-SHFL: %[[S4:.+]], %{{.+}} = gpu.shuffle xor %[[A3]], %[[C16]], %[[C32]] : i32
+    // CHECK-SHFL: %[[A4:.+]] = arith.addi %[[A3]], %[[S4]] : i32
+    // CHECK-SHFL: "test.consume"(%[[A4]]) : (i32) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 : (i32) -> i32
+    "test.consume"(%sum0) : (i32) -> ()
+
+    // CHECK-SHFL: gpu.return
+    gpu.return
+  }
+
+  // CHECK-SHFL-LABEL: gpu.func @kernel4(
+  // CHECK-SHFL-SAME:    %[[ARG0:.+]]: vector<2xf16>)
+  gpu.func @kernel4(%arg0: vector<2xf16>) kernel {
+    // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
+    // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
+    // CHECK-SHFL-DAG: %[[C4:.+]] = arith.constant 4 : i32
+    // CHECK-SHFL-DAG: %[[C8:.+]] = arith.constant 8 : i32
+    // CHECK-SHFL-DAG: %[[C16:.+]] = arith.constant 16 : i32
+    // CHECK-SHFL-DAG: %[[C32:.+]] = arith.constant 32 : i32
+
+    // CHECK-SHFL: %[[V0:.+]] = vector.bitcast %[[ARG0]] : vector<2xf16> to vector<1xi32>
+    // CHECK-SHFL: %[[I0:.+]] = vector.extract %[[V0]][0] : i32 from vector<1xi32>
+    // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[I0]], %[[C1]], %[[C32]] : i32
+    // CHECK-SHFL: %[[BR0:.+]] = vector.broadcast %[[S0]] : i32 to vector<1xi32>
+    // CHECK-SHFL: %[[BC0:.+]] = vector.bitcast %[[BR0]] : vector<1xi32> to vector<2xf16>
+    // CHECK-SHFL: %[[ADD0:.+]] = arith.addf %[[ARG0]], %[[BC0]] : vector<2xf16>
+    // CHECK-SHFL: %[[BC1:.+]] = vector.bitcast %[[ADD0]] : vector<2xf16> to vector<1xi32>
+    // CHECK-SHFL: %[[I1:.+]] = vector.extract %[[BC1]][0] : i32 from vector<1xi32>
+    // CHECK-SHFL: gpu.shuffle xor %[[I1]], %[[C2]], %[[C32]] : i32
+    // CHECK-SHFL: arith.addf {{.+}} : vector<2xf16>
+    // CHECK-SHFL: gpu.shuffle xor %{{.+}}, %[[C4]], %[[C32]] : i32
+    // CHECK-SHFL: arith.addf {{.+}} : vector<2xf16>
+    // CHECK-SHFL: gpu.shuffle xor %{{.+}}, %[[C8]], %[[C32]] : i32
+    // CHECK-SHFL: arith.addf {{.+}} : vector<2xf16>
+    // CHECK-SHFL: %[[SL:.+]], %{{.+}} = gpu.shuffle xor %{{.+}}, %[[C16]], %[[C32]] : i32
+    // CHECK-SHFL: %[[BRL:.+]] = vector.broadcast %[[SL]] : i32 to vector<1xi32>
+    // CHECK-SHFL: %[[BCL:.+]] = vector.bitcast %[[BRL]] : vector<1xi32> to vector<2xf16>
+    // CHECK-SHFL: %[[ADDL:.+]] = arith.addf %{{.+}}, %[[BCL]] : vector<2xf16>
+    // CHECK-SHFL: "test.consume"(%[[ADDL]]) : (vector<2xf16>) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 : (vector<2xf16>) -> (vector<2xf16>)
+    "test.consume"(%sum0) : (vector<2xf16>) -> ()
+
+    // CHECK-SHFL: gpu.return
+    gpu.return
+  }
+
+  // CHECK-SHFL-LABEL: gpu.func @kernel5(
+  // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i16)
+  gpu.func @kernel5(%arg0: i16) kernel {
+    // CHECK-SHFL: %[[E0:.+]] = arith.extui %[[ARG0]] : i16 to i32
+    // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[E0]], {{.+}} : i32
+    // CHECK-SHFL: %[[T0:.+]] = arith.trunci %[[S0]] : i32 to i16
+    // CHECK-SHFL: %[[A0:.+]] = arith.addi %[[ARG0]], %[[T0]] : i16
+    // CHECK-SHFL: %[[E1:.+]] = arith.extui %[[A0]] : i16 to i32
+    // CHECK-SHFL: %{{.+}}, %{{.+}} = gpu.shuffle xor %[[E1]], {{.+}} : i32
+    // CHECK-SHFL-COUNT-3: gpu.shuffle xor
+    // CHECK-SHFL: arith.trunci {{.+}} : i32 to i16
+    // CHECK-SHFL: %[[AL:.+]] = arith.addi {{.+}} : i16
+    // CHECK-SHFL: "test.consume"(%[[AL]]) : (i16) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 : (i16) -> i16
+    "test.consume"(%sum0) : (i16) -> ()
+
+    // CHECK-SHFL: gpu.return
+    gpu.return
+  }
+
+  // CHECK-SHFL-LABEL: gpu.func @kernel6(
+  // CHECK-SHFL-SAME:    %[[ARG0:.+]]: vector<3xi8>)
+  gpu.func @kernel6(%arg0: vector<3xi8>) kernel {
+    // CHECK-SHFL: %[[CZ:.+]] = arith.constant dense<0> : vector<4xi8>
+    // CHECK-SHFL: %[[V0:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CZ]] {offsets = [0], strides = [1]} : vector<3xi8> into vector<4xi8>
+    // CHECK-SHFL: %[[BC0:.+]] = vector.bitcast %[[V0]] : vector<4xi8> to vector<1xi32>
+    // CHECK-SHFL: %[[I0:.+]] = vector.extract %[[BC0]][0] : i32 from vector<1xi32>
+    // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[I0]], {{.+}} : i32
+    // CHECK-SHFL: %[[BR0:.+]] = vector.broadcast %[[S0]] : i32 to vector<1xi32>
+    // CHECK-SHFL: %[[BC1:.+]] = vector.bitcast %[[BR0]] : vector<1xi32> to vector<4xi8>
+    // CHECK-SHFL: %[[ADD0:.+]] = arith.addi %[[V0]], %[[BC1]] : vector<4xi8>
+    // CHECK-SHFL: %[[BC2:.+]] = vector.bitcast %[[ADD0]] : vector<4xi8> to vector<1xi32>
+    // CHECK-SHFL: %[[I1:.+]] = vector.extract %[[BC2]][0] : i32 from vector<1xi32>
+    // CHECK-SHFL-COUNT-4: gpu.shuffle xor
+    // CHECK-SHFL: %[[ESS:.+]] = vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [3], strides = [1]} : vector<4xi8> to vector<3xi8>
+    // CHECK-SHFL: "test.consume"(%[[ESS]]) : (vector<3xi8>) -> ()
+    %sum0 = gpu.subgroup_reduce add %arg0 : (vector<3xi8>) -> (vector<3xi8>)
+    "test.consume"(%sum0) : (vector<3xi8>) -> ()
+
+    // CHECK-SHFL: gpu.return
     gpu.return
   }
 
 }
+

diff  --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
index 21cc89c0d89b0f..8c13c0e2575c9c 100644
--- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -47,19 +48,40 @@ struct TestGpuSubgroupReduceLoweringPass
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
       TestGpuSubgroupReduceLoweringPass)
 
+  TestGpuSubgroupReduceLoweringPass() = default;
+  TestGpuSubgroupReduceLoweringPass(
+      const TestGpuSubgroupReduceLoweringPass &pass)
+      : PassWrapper(pass) {}
+
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<arith::ArithDialect, vector::VectorDialect>();
   }
+
   StringRef getArgument() const final {
     return "test-gpu-subgroup-reduce-lowering";
   }
+
   StringRef getDescription() const final {
     return "Applies gpu.subgroup_reduce lowering patterns.";
   }
+
+  Option<bool> expandToShuffles{
+      *this, "expand-to-shuffles",
+      llvm::cl::desc("Expand subgroup_reduce ops to shuffle ops."),
+      llvm::cl::init(false)};
+
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
+
+    // Since both pattern sets match on the same ops, set higher benefit to
+    // perform fewer failing matches.
     populateGpuBreakDownSubgrupReducePatterns(patterns,
-                                              /*maxShuffleBitwidth=*/32);
+                                              /*maxShuffleBitwidth=*/32,
+                                              PatternBenefit(2));
+    if (expandToShuffles)
+      populateGpuLowerSubgroupReduceToShufflePattenrs(
+          patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
+
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };


        


More information about the Mlir-commits mailing list