[Mlir-commits] [mlir] 330a232 - [mlir][gpu] Add i64 & f64 support to gpu.shuffle

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 25 14:41:39 PDT 2023


Author: Fabian Mora
Date: 2023-05-25T21:40:25Z
New Revision: 330a232ae76139c3970df5ccaf1b51640cbd4d66

URL: https://github.com/llvm/llvm-project/commit/330a232ae76139c3970df5ccaf1b51640cbd4d66
DIFF: https://github.com/llvm/llvm-project/commit/330a232ae76139c3970df5ccaf1b51640cbd4d66.diff

LOG: [mlir][gpu] Add i64 & f64 support to gpu.shuffle

This patch adds support for i64, f64 values in `gpu.shuffle`, rewriting 64bit shuffles into two 32bit shuffles.
The reason behind this change is that both CUDA & HIP support this kind of shuffling.
The implementation provided by this patch is based on the LLVM IR emitted by clang for 64bit shuffles when using `-O3`.

Reviewed By: makslevental

Differential Revision: https://reviews.llvm.org/D148974

Added: 
    mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
    mlir/test/Dialect/GPU/shuffle-rewrite.mlir

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
    mlir/lib/Dialect/GPU/CMakeLists.txt
    mlir/test/Dialect/GPU/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 5160b6886817..fdcbf4d139bc 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -936,14 +936,17 @@ def GPU_ShuffleMode : I32EnumAttr<"ShuffleMode",
 def GPU_ShuffleModeAttr : EnumAttr<GPU_Dialect, GPU_ShuffleMode,
                                    "shuffle_mode">;
 
-def I32OrF32 : TypeConstraint<Or<[I32.predicate, F32.predicate]>,
-                                 "i32 or f32">;
+def I32I64F32OrF64 : TypeConstraint<Or<[I32.predicate,
+                                        I64.predicate,
+                                        F32.predicate,
+                                        F64.predicate]>,
+                                       "i32, i64, f32 or f64">;
 
 def GPU_ShuffleOp : GPU_Op<
     "shuffle", [Pure, AllTypesMatch<["value", "shuffleResult"]>]>,
-    Arguments<(ins I32OrF32:$value, I32:$offset, I32:$width,
+    Arguments<(ins I32I64F32OrF64:$value, I32:$offset, I32:$width,
                GPU_ShuffleModeAttr:$mode)>,
-    Results<(outs I32OrF32:$shuffleResult, I1:$valid)> {
+    Results<(outs I32I64F32OrF64:$shuffleResult, I1:$valid)> {
   let summary = "Shuffles values within a subgroup.";
   let description = [{
     The "shuffle" op moves values to a 
diff erent invocation within the same

diff  --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 5e2ff6d646ce..89a45a4e4993 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -56,6 +56,9 @@ std::unique_ptr<OperationPass<func::FuncOp>> createGpuMapParallelLoopsPass();
 /// Collect a set of patterns to rewrite GlobalIdOp op within the GPU dialect.
 void populateGpuGlobalIdPatterns(RewritePatternSet &patterns);
 
+/// Collect a set of patterns to rewrite shuffle ops within the GPU dialect.
+void populateGpuShufflePatterns(RewritePatternSet &patterns);
+
 /// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
 void populateGpuAllReducePatterns(RewritePatternSet &patterns);
 
@@ -63,6 +66,7 @@ void populateGpuAllReducePatterns(RewritePatternSet &patterns);
 inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
   populateGpuAllReducePatterns(patterns);
   populateGpuGlobalIdPatterns(patterns);
+  populateGpuShufflePatterns(patterns);
 }
 
 namespace gpu {

diff  --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 2211e15a5d4b..31790490828f 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -50,6 +50,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/KernelOutlining.cpp
   Transforms/MemoryPromotion.cpp
   Transforms/ParallelLoopMapper.cpp
+  Transforms/ShuffleRewriter.cpp
   Transforms/SerializeToBlob.cpp
   Transforms/SerializeToCubin.cpp
   Transforms/SerializeToHsaco.cpp

diff  --git a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
new file mode 100644
index 000000000000..4bd4da25f6e5
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
@@ -0,0 +1,99 @@
+//===- ShuffleRewriter.cpp - Implementation of shuffle rewriting  ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements in-dialect rewriting of the shuffle op for types i64 and
+// f64, rewriting 64bit shuffles into two 32bit shuffles. This particular
+// implementation using shifts and truncations can be obtained using clang: by
+// emitting IR for shuffle operations with `-O3`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
+  using OpRewritePattern<gpu::ShuffleOp>::OpRewritePattern;
+
+  void initialize() {
+    // Required as the pattern will replace the Op with 2 additional ShuffleOps.
+    setHasBoundedRewriteRecursion();
+  }
+  LogicalResult matchAndRewrite(gpu::ShuffleOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto value = op.getValue();
+    auto valueType = value.getType();
+    auto valueLoc = value.getLoc();
+    auto i32 = rewriter.getI32Type();
+    auto i64 = rewriter.getI64Type();
+
+    // If the type of the value is either i32 or f32, the op is already valid.
+    if (valueType.getIntOrFloatBitWidth() == 32)
+      return failure();
+
+    Value lo, hi;
+
+    // Float types must be converted to i64 to extract the bits.
+    if (isa<FloatType>(valueType))
+      value = rewriter.create<arith::BitcastOp>(valueLoc, i64, value);
+
+    // Get the low bits by trunc(value).
+    lo = rewriter.create<arith::TruncIOp>(valueLoc, i32, value);
+
+    // Get the high bits by trunc(value >> 32).
+    auto c32 = rewriter.create<arith::ConstantOp>(
+        valueLoc, rewriter.getIntegerAttr(i64, 32));
+    hi = rewriter.create<arith::ShRUIOp>(valueLoc, value, c32);
+    hi = rewriter.create<arith::TruncIOp>(valueLoc, i32, hi);
+
+    // Shuffle the values.
+    ValueRange loRes =
+        rewriter
+            .create<gpu::ShuffleOp>(op.getLoc(), lo, op.getOffset(),
+                                    op.getWidth(), op.getMode())
+            .getResults();
+    ValueRange hiRes =
+        rewriter
+            .create<gpu::ShuffleOp>(op.getLoc(), hi, op.getOffset(),
+                                    op.getWidth(), op.getMode())
+            .getResults();
+
+    // Convert lo back to i64.
+    lo = rewriter.create<arith::ExtUIOp>(valueLoc, i64, loRes[0]);
+
+    // Convert hi back to i64.
+    hi = rewriter.create<arith::ExtUIOp>(valueLoc, i64, hiRes[0]);
+    hi = rewriter.create<arith::ShLIOp>(valueLoc, hi, c32);
+
+    // Obtain the shuffled bits hi | lo.
+    value = rewriter.create<arith::OrIOp>(loc, hi, lo);
+
+    // Convert the value back to float.
+    if (isa<FloatType>(valueType))
+      value = rewriter.create<arith::BitcastOp>(valueLoc, valueType, value);
+
+    // Obtain the shuffle validity by combining both validities.
+    auto validity = rewriter.create<arith::AndIOp>(loc, loRes[1], hiRes[1]);
+
+    // Replace the op.
+    rewriter.replaceOp(op, {value, validity});
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateGpuShufflePatterns(RewritePatternSet &patterns) {
+  patterns.add<GpuShuffleRewriter>(patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 4a52455ad0b3..e280cd65811d 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -318,7 +318,7 @@ func.func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) {
 // -----
 
 func.func @shuffle_unsupported_type(%arg0 : index, %arg1 : i32, %arg2 : i32) {
-  // expected-error at +1 {{operand #0 must be i32 or f32}}
+  // expected-error at +1 {{operand #0 must be i32, i64, f32 or f64}}
   %shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : index
   return
 }

diff  --git a/mlir/test/Dialect/GPU/shuffle-rewrite.mlir b/mlir/test/Dialect/GPU/shuffle-rewrite.mlir
new file mode 100644
index 000000000000..461825820153
--- /dev/null
+++ b/mlir/test/Dialect/GPU/shuffle-rewrite.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt --test-gpu-rewrite -split-input-file %s | FileCheck %s
+
+module {
+  // CHECK-LABEL: func.func @shuffleF64
+  // CHECK-SAME: (%[[SZ:.*]]: index, %[[VALUE:.*]]: f64, %[[OFF:.*]]: i32, %[[WIDTH:.*]]: i32, %[[MEM:.*]]: memref<f64, 1>) {
+  func.func @shuffleF64(%sz : index, %value: f64, %offset: i32, %width: i32, %mem: memref<f64, 1>) {
+    gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
+               threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
+      // CHECK: %[[INTVAL:.*]] = arith.bitcast %[[VALUE]] : f64 to i64
+      // CHECK-NEXT: %[[LO:.*]] = arith.trunci %[[INTVAL]] : i64 to i32
+      // CHECK-NEXT: %[[HI64:.*]] = arith.shrui %[[INTVAL]], %[[C32:.*]] : i64
+      // CHECK-NEXT: %[[HI:.*]] = arith.trunci %[[HI64]] : i64 to i32
+      // CHECK-NEXT: %[[SH1:.*]], %[[V1:.*]] = gpu.shuffle  xor %[[LO]], %[[OFF]], %[[WIDTH]] : i32
+      // CHECK-NEXT: %[[SH2:.*]], %[[V2:.*]] = gpu.shuffle  xor %[[HI]], %[[OFF]], %[[WIDTH]] : i32
+      // CHECK-NEXT: %[[LOSH:.*]] = arith.extui %[[SH1]] : i32 to i64
+      // CHECK-NEXT: %[[HISHTMP:.*]] = arith.extui %[[SH2]] : i32 to i64
+      // CHECK-NEXT: %[[HISH:.*]] = arith.shli %[[HISHTMP]], %[[C32]] : i64
+      // CHECK-NEXT: %[[SHFLINT:.*]] = arith.ori %[[HISH]], %[[LOSH]] : i64
+      // CHECK-NEXT:  = arith.bitcast %[[SHFLINT]] : i64 to f64
+      %shfl, %pred = gpu.shuffle xor %value, %offset, %width : f64
+      memref.store %shfl, %mem[]  : memref<f64, 1>
+      gpu.terminator
+    }
+    return
+  }
+}
+
+// -----
+
+module {
+  // CHECK-LABEL: func.func @shuffleI64
+  // CHECK-SAME: (%[[SZ:.*]]: index, %[[VALUE:.*]]: i64, %[[OFF:.*]]: i32, %[[WIDTH:.*]]: i32, %[[MEM:.*]]: memref<i64, 1>) {
+  func.func @shuffleI64(%sz : index, %value: i64, %offset: i32, %width: i32, %mem: memref<i64, 1>) {
+    gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
+               threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
+      // CHECK: %[[LO:.*]] = arith.trunci %[[VALUE]] : i64 to i32
+      // CHECK-NEXT: %[[HI64:.*]] = arith.shrui %[[VALUE]], %[[C32:.*]] : i64
+      // CHECK-NEXT: %[[HI:.*]] = arith.trunci %[[HI64]] : i64 to i32
+      // CHECK-NEXT: %[[SH1:.*]], %[[V1:.*]] = gpu.shuffle  xor %[[LO]], %[[OFF]], %[[WIDTH]] : i32
+      // CHECK-NEXT: %[[SH2:.*]], %[[V2:.*]] = gpu.shuffle  xor %[[HI]], %[[OFF]], %[[WIDTH]] : i32
+      // CHECK-NEXT: %[[LOSH:.*]] = arith.extui %[[SH1]] : i32 to i64
+      // CHECK-NEXT: %[[HISHTMP:.*]] = arith.extui %[[SH2]] : i32 to i64
+      // CHECK-NEXT: %[[HISH:.*]] = arith.shli %[[HISHTMP]], %[[C32]] : i64
+      // CHECK-NEXT: %[[SHFLINT:.*]] = arith.ori %[[HISH]], %[[LOSH]] : i64
+      %shfl, %pred = gpu.shuffle xor %value, %offset, %width : i64
+      memref.store %shfl, %mem[]  : memref<i64, 1>
+      gpu.terminator
+    }
+    return
+  }
+}


        


More information about the Mlir-commits mailing list