[Mlir-commits] [mlir] [mlir][gpu] Pattern to promote `gpu.shuffle` to specialized AMDGPU ops (PR #137109)

Ivan Butygin llvmlistbot at llvm.org
Wed Apr 23 18:59:55 PDT 2025


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/137109

Only swizzle promotion for now, may add DPP ops support later.

>From 8cc34740eaef0e16303fd28645e33c254c4255b2 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 24 Apr 2025 03:51:08 +0200
Subject: [PATCH] [mlir][gpu] Patterns to promote `gpu.shuffle` to specialized
 AMDGPU ops

Only swizzle promotion for now, may add DPP ops support later.
---
 .../GPU/TransformOps/GPUTransformOps.td       | 27 +++++---
 .../mlir/Dialect/GPU/Transforms/Passes.h      |  3 +
 .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp      |  2 -
 mlir/lib/Dialect/GPU/CMakeLists.txt           |  8 ++-
 .../GPU/TransformOps/GPUTransformOps.cpp      | 11 +++-
 .../GPU/Transforms/PromoteShuffleToAMDGPU.cpp | 64 +++++++++++++++++++
 .../Dialect/GPU/promote-shuffle-amdgpu.mlir   | 23 +++++++
 7 files changed, 123 insertions(+), 15 deletions(-)
 create mode 100644 mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
 create mode 100644 mlir/test/Dialect/GPU/promote-shuffle-amdgpu.mlir

diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
index 15d6e0a069e3e..36b579485fc04 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
@@ -132,24 +132,24 @@ def MapNestedForallToThreads :
      TransformEachOpTrait,
      TransformOpInterface]> {
   let description = [{
-      Target the `gpu.launch op` and rewrite all `scf.forall` nested in it to 
+      Target the `gpu.launch op` and rewrite all `scf.forall` nested in it to
       distributed `gpu.thread_id` attribute.
 
       The operation searches for `scf.forall` ops nested under `target` and maps
-      each such op to GPU threads. 
-      
+      each such op to GPU threads.
+
       `scf.forall` induction variables are rewritten to `gpu.thread_id` according
       to the `mapping` attribute.
 
       Different types of mappings attributes are supported:
         - the block_dims is a list of integers that specifies the number of
           threads in each dimension. This is a mandatory attribute that is used
-          to constrain the number of threads in each dimension. If an 
+          to constrain the number of threads in each dimension. If an
           `scf.forall` op is mapped to fewer threads, predication occurs.
         - the warp_dims is a list of integers that specifies the number of
           warps in each dimension. This is an optional attribute that is used
           to constrain the number of warps in each dimension. When present, this
-          attribute must be specified in a way that is compatible with the 
+          attribute must be specified in a way that is compatible with the
           block_dims attribute. If an `scf.forall` op is mapped to fewer warps,
           predication occurs.
 
@@ -164,7 +164,7 @@ def MapNestedForallToThreads :
       inserted after each scf.forall op. At this time, this is an all or nothing
       choice. This will need to be tightened in the future.
 
-      The operation alters the block size of the given gpu_launch using the 
+      The operation alters the block size of the given gpu_launch using the
       mandatory block_dims argument.
 
       #### Return modes:
@@ -268,7 +268,7 @@ def MapForallToBlocks :
     Only scf.forall distributed to **at most 3 dimensions** are
     currently supported.
 
-    The operation alters the block size of the given gpu_launch using the 
+    The operation alters the block size of the given gpu_launch using the
     grid_dims argument.
 
     #### Return modes:
@@ -300,7 +300,7 @@ def MapForallToBlocks :
     `:` functional-type($target, $result)
   }];
   let hasVerifier = 1;
-  
+
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
         ::mlir::transform::TransformRewriter &rewriter,
@@ -310,4 +310,15 @@ def MapForallToBlocks :
   }];
 }
 
+def ApplyGPUPromoteShuffleToAMDGPUPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.gpu.gpu_shuffle_to_amdgpu",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collects patterns that are tryin to promote `gpu.shuffle`s to specialized
+    AMDGPU intrinsics.
+  }];
+  let assemblyFormat = "attr-dict";
+}
+
+
 #endif // GPU_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 5cc65082a7e56..77959a1e9e357 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -94,6 +94,9 @@ void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns);
 /// Erase barriers that do not enforce conflicting memory side effects.
 void populateGpuEliminateBarriersPatterns(RewritePatternSet &patterns);
 
+/// Tries to promote `gpu.shuffle`s to specialized AMDGPU intrinsics.
+void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns);
+
 /// Generate the code for registering passes.
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index e6dd6f135884e..4758e99ea2c78 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -150,8 +150,6 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
         rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth);
     Value dstLane;
     // TODO: Add support for gpu::ShuffleMode::UP and gpu::ShuffleMode::DOWN.
-    // TODO: Use ds_swizzle for XOR when step/offsets are constants for better
-    // perf.
     switch (op.getMode()) {
     case gpu::ShuffleMode::DOWN:
       dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId,
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 013311ec027da..003ae2496b392 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -37,9 +37,10 @@ add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/ModuleToBinary.cpp
   Transforms/NVVMAttachTarget.cpp
   Transforms/ParallelLoopMapper.cpp
+  Transforms/PromoteShuffleToAMDGPU.cpp
   Transforms/ROCDLAttachTarget.cpp
-  Transforms/ShuffleRewriter.cpp
   Transforms/SPIRVAttachTarget.cpp
+  Transforms/ShuffleRewriter.cpp
   Transforms/SubgroupReduceLowering.cpp
 
   OBJECT
@@ -52,6 +53,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
   MLIRParallelLoopMapperEnumsGen
 
   LINK_LIBS PUBLIC
+  MLIRAMDGPUDialect
   MLIRAffineUtils
   MLIRArithDialect
   MLIRAsyncDialect
@@ -66,11 +68,11 @@ add_mlir_dialect_library(MLIRGPUTransforms
   MLIRMemRefDialect
   MLIRNVVMTarget
   MLIRPass
+  MLIRROCDLTarget
   MLIRSCFDialect
-  MLIRSideEffectInterfaces
   MLIRSPIRVTarget
+  MLIRSideEffectInterfaces
   MLIRSupport
-  MLIRROCDLTarget
   MLIRTransformUtils
   MLIRVectorDialect
   )
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 3970539db6675..6446235c06fb2 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -136,6 +137,11 @@ void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
   populateGpuRewritePatterns(patterns);
 }
 
+void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // ApplyUnrollVectorsSubgroupMmaOp
 //===----------------------------------------------------------------------===//
@@ -914,9 +920,10 @@ class GPUTransformDialectExtension
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension)
 
   GPUTransformDialectExtension() {
-    declareGeneratedDialect<scf::SCFDialect>();
-    declareGeneratedDialect<arith::ArithDialect>();
     declareGeneratedDialect<GPUDialect>();
+    declareGeneratedDialect<amdgpu::AMDGPUDialect>();
+    declareGeneratedDialect<arith::ArithDialect>();
+    declareGeneratedDialect<scf::SCFDialect>();
     registerTransformOps<
 #define GET_OP_LIST
 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
diff --git a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
new file mode 100644
index 0000000000000..171e64346f155
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
@@ -0,0 +1,64 @@
+//===- PromoteShuffleToAMDGPU.cpp - Promote shuffle to AMDGPU -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains patterns to try to promote `gpu.shuffle`s to specialized
+// AMDGPU intrinsics.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+namespace {
+/// Try to promote `gpu.shuffle` to `amdgpu.swizzle_bitmode`, width must be 64
+/// and offset must be a constant integer in the range [0, 31].
+struct PromoteShuffleToSwizzlePattern
+    : public OpRewritePattern<gpu::ShuffleOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(gpu::ShuffleOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getMode() != gpu::ShuffleMode::XOR)
+      return rewriter.notifyMatchFailure(op,
+                                         "only xor shuffle mode is supported");
+
+    if (!isConstantIntValue(op.getWidth(), 64))
+      return rewriter.notifyMatchFailure(op,
+                                         "only 64 width shuffle is supported");
+
+    std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
+    if (!offset)
+      return rewriter.notifyMatchFailure(op,
+                                         "offset must be a constant integer");
+
+    int64_t offsetValue = *offset;
+    if (offsetValue < 0 || offsetValue >= 32)
+      return rewriter.notifyMatchFailure(op,
+                                         "offset must be in the range [0, 31]");
+
+    Location loc = op.getLoc();
+    Value res = rewriter.create<amdgpu::SwizzleBitModeOp>(
+        loc, op.getResult(0).getType(), op.getValue(), /*andMask=*/31,
+        /*orMask=*/0, /*xorMask=*/offsetValue);
+    Value valid = rewriter.create<arith::ConstantIntOp>(loc, 1, /*width*/ 1);
+    rewriter.replaceOp(op, {res, valid});
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/GPU/promote-shuffle-amdgpu.mlir b/mlir/test/Dialect/GPU/promote-shuffle-amdgpu.mlir
new file mode 100644
index 0000000000000..4293b430f71f7
--- /dev/null
+++ b/mlir/test/Dialect/GPU/promote-shuffle-amdgpu.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file %s | FileCheck %s
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.gpu.gpu_shuffle_to_amdgpu
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+  // CHECK-LABEL: func @gpu_shuffle_swizzle
+  //  CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @gpu_shuffle_swizzle(%arg0: i32) -> (i32, i1) {
+  // CHECK:  %[[TRUE:.*]] = arith.constant true
+  // CHECK:  %[[RES:.*]] = amdgpu.swizzle_bitmode %[[ARG]] 31 0 23 : i32
+  // CHECK:  return %[[RES]], %[[TRUE]] : i32, i1
+  %width = arith.constant 64 : i32
+  %offset = arith.constant 23 : i32
+  %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : i32
+  func.return %shfl, %pred : i32, i1
+}



More information about the Mlir-commits mailing list