[Mlir-commits] [mlir] [mlir][amdgpu] Promote gpu.shuffle to amdgpu.permlane_swap (PR #154933)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 22 05:10:03 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Tim Gymnich (tgymnich)
<details>
<summary>Changes</summary>
- promote `gpu.shuffle %src xor {15,31} 64` to `amdgpu.permlane_swap %src {16,32}`
---
Full diff: https://github.com/llvm/llvm-project/pull/154933.diff
9 Files Affected:
- (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+2-2)
- (modified) mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td (+4-1)
- (modified) mlir/include/mlir/Dialect/GPU/Transforms/Passes.h (+2-1)
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+3-3)
- (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+1-1)
- (modified) mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp (+12-1)
- (modified) mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp (+47-2)
- (modified) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir (+31-5)
- (modified) mlir/test/Dialect/GPU/promote-shuffle-amdgpu.mlir (+13-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 72aca2938e029..2ccf350a359a8 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -670,8 +670,8 @@ def AMDGPU_PermlaneSwapOp : AMDGPU_Op<"permlane_swap", [Pure, AllTypesMatch<["re
Example:
```mlir
- %0 = amdgpu.permlane %src 16 : f16
- %1 = amdgpu.permlane %src 32 { fetch_inactive = true, bound_ctrl = true } : f16
+ %0 = amdgpu.permlane_swap %src 16 : f16
+ %1 = amdgpu.permlane_swap %src 32 { fetch_inactive = true, bound_ctrl = true } : f16
```
Operands:
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
index 87423c639945f..3a8caf8aa42e2 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
@@ -331,7 +331,10 @@ def ApplyGPUPromoteShuffleToAMDGPUPatternsOp : Op<Transform_Dialect,
Collects patterns that are tryin to promote `gpu.shuffle`s to specialized
AMDGPU intrinsics.
}];
- let assemblyFormat = "attr-dict";
+ let arguments = (ins OptionalAttr<StrAttr>:$chipset);
+ let assemblyFormat = [{
+ (`chipset` `=` $chipset^)? attr-dict
+ }];
}
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index b4fd55e6c22e6..d5c253d6c9c08 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -114,7 +114,8 @@ void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns);
void populateGpuEliminateBarriersPatterns(RewritePatternSet &patterns);
/// Tries to promote `gpu.shuffle`s to specialized AMDGPU intrinsics.
-void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns);
+void populateGpuPromoteShuffleToAMDGPUPatterns(
+ RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset);
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b44d647cf7632..203790ed95153 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1893,7 +1893,7 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
Location loc = op.getLoc();
Type i32 = rewriter.getI32Type();
Value src = adaptor.getSrc();
- unsigned row_length = op.getRowLength();
+ unsigned rowLength = op.getRowLength();
bool fi = op.getFetchInactive();
bool boundctrl = op.getBoundCtrl();
@@ -1906,10 +1906,10 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
Type i32pair = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), {v.getType(), v.getType()});
- if (row_length == 16)
+ if (rowLength == 16)
res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
boundctrl);
- else if (row_length == 32)
+ else if (rowLength == 32)
res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
boundctrl);
else
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index e6fbcf98950a4..e39fcd579b890 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -327,7 +327,7 @@ struct LowerGpuOpsToROCDLOpsPass final
{
RewritePatternSet patterns(ctx);
populateGpuRewritePatterns(patterns);
- populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
+ populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset);
(void)applyPatternsGreedily(m, std::move(patterns));
}
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 21cb2f6d0a8e0..c766539f9d91a 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -13,6 +13,7 @@
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/TransformOps/Utils.h"
@@ -43,6 +44,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/LogicalResult.h"
+#include <optional>
#include <type_traits>
using namespace mlir;
@@ -170,7 +172,16 @@ void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
+ std::optional<StringRef> chipsetName = getChipset();
+ std::optional<amdgpu::Chipset> maybeChipset;
+ if (chipsetName) {
+ FailureOr<amdgpu::Chipset> parsedChipset =
+ amdgpu::Chipset::parse(*chipsetName);
+ assert(llvm::succeeded(parsedChipset) && "expected valid chipset");
+ maybeChipset = parsedChipset;
+ }
+
+ populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
index 18c69f5f30e5d..c106d7bfcbf90 100644
--- a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
@@ -11,16 +11,21 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/PatternMatch.h"
+#include <optional>
using namespace mlir;
namespace {
+
+constexpr amdgpu::Chipset kGfx950 = amdgpu::Chipset(9, 5, 0);
+
/// Try to promote `gpu.shuffle` to `amdgpu.swizzle_bitmode`, width must be 64
/// and offset must be a constant integer in the range [0, 31].
struct PromoteShuffleToSwizzlePattern
@@ -56,9 +61,49 @@ struct PromoteShuffleToSwizzlePattern
return success();
}
};
+
+/// Try to promote `gpu.shuffle` to `amdgpu.permlane_swap`, width must be 64
+/// and offset must be a constant integer in the set {15, 31}.
+struct PromoteShuffleToPermlanePattern
+ : public OpRewritePattern<gpu::ShuffleOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(gpu::ShuffleOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getMode() != gpu::ShuffleMode::XOR)
+ return rewriter.notifyMatchFailure(op,
+ "only xor shuffle mode is supported");
+
+ if (!isConstantIntValue(op.getWidth(), 64))
+ return rewriter.notifyMatchFailure(op,
+ "only 64 width shuffle is supported");
+
+ std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
+ if (!offset)
+ return rewriter.notifyMatchFailure(op,
+ "offset must be a constant integer");
+
+ int64_t offsetValue = *offset;
+ if (offsetValue != 15 && offsetValue != 31)
+ return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31");
+
+ Location loc = op.getLoc();
+ Value res =
+ amdgpu::PermlaneSwapOp::create(rewriter, loc, op.getResult(0).getType(),
+ op.getValue(), offsetValue + 1);
+ Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1);
+ rewriter.replaceOp(op, {res, valid});
+ return success();
+ }
+};
+
} // namespace
void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
- RewritePatternSet &patterns) {
- patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext());
+ RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) {
+ patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext(),
+ /*benefit*/ 1);
+ if (maybeChipset && *maybeChipset >= kGfx950)
+ patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext(),
+ /*benefit*/ 2);
}
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index fa4a9749f6a9b..1456589fea440 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -convert-gpu-to-rocdl='allowed-dialects=func,arith,math' -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -convert-gpu-to-rocdl='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
+// RUN: mlir-opt %s -convert-gpu-to-rocdl='chipset=gfx950' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-gpu-to-rocdl='chipset=gfx950 allowed-dialects=func,arith,math' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-gpu-to-rocdl='chipset=gfx950 index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
// CHECK-LABEL: @test_module
// CHECK-SAME: llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"
@@ -734,14 +734,40 @@ gpu.module @test_module {
func.return %shfl, %shfli, %shflu, %shfld : f32, f32, f32, f32
}
+ // CHECK-LABEL: func @gpu_shuffle_promote()
+ func.func @gpu_shuffle_promote() -> (f32, f32, f32) {
+ // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+ %arg0 = arith.constant 1.0 : f32
+ %arg1 = arith.constant 4 : i32
+ %arg2 = arith.constant 15 : i32
+ %arg3 = arith.constant 31 : i32
+ %arg4 = arith.constant 64 : i32
+ // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+ // CHECK: %[[#MASK:]] = llvm.mlir.constant(4127 : i32) : i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
+ %shfl1, %pred1 = gpu.shuffle xor %arg0, %arg1, %arg4 : f32
+ // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.permlane16.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
+ // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
+ %shfl2, %pred2 = gpu.shuffle xor %arg0, %arg2, %arg4 : f32
+ // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.permlane32.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
+ // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
+ %shfl3, %pred3 = gpu.shuffle xor %arg0, %arg3, %arg4 : f32
+ func.return %shfl1, %shfl2, %shfl3 : f32, f32, f32
+ }
+
// CHECK-LABEL: func @gpu_shuffle_vec
// CHECK-SAME: (%[[ARG:.*]]: vector<4xf16>, %{{.*}}: i32, %{{.*}}: i32)
func.func @gpu_shuffle_vec(%arg0: vector<4xf16>, %arg1: i32, %arg2: i32) -> vector<4xf16> {
// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG]] : vector<4xf16> to vector<2xi32>
// CHECK: %[[IDX0:.*]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: %[[ELEM0:.*]] = llvm.extractelement %13[%[[IDX0]] : i32] : vector<2xi32>
+ // CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[IDX0]] : i32] : vector<2xi32>
// CHECK: %[[IDX1:.*]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[ELEM1:.*]] = llvm.extractelement %13[%[[IDX1]] : i32] : vector<2xi32>
+ // CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[IDX1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0:.*]] = rocdl.ds_bpermute %{{.*}}, %[[ELEM0]] : (i32, i32) -> i32
// CHECK: %[[PERM1:.*]] = rocdl.ds_bpermute %{{.*}}, %[[ELEM1]] : (i32, i32) -> i32
// CHECK: %[[V0:.*]] = llvm.mlir.poison : vector<2xi32>
diff --git a/mlir/test/Dialect/GPU/promote-shuffle-amdgpu.mlir b/mlir/test/Dialect/GPU/promote-shuffle-amdgpu.mlir
index 4293b430f71f7..75a3828536bc6 100644
--- a/mlir/test/Dialect/GPU/promote-shuffle-amdgpu.mlir
+++ b/mlir/test/Dialect/GPU/promote-shuffle-amdgpu.mlir
@@ -4,7 +4,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
- transform.apply_patterns.gpu.gpu_shuffle_to_amdgpu
+ transform.apply_patterns.gpu.gpu_shuffle_to_amdgpu chipset = "gfx950"
} : !transform.any_op
transform.yield
}
@@ -21,3 +21,15 @@ func.func @gpu_shuffle_swizzle(%arg0: i32) -> (i32, i1) {
%shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : i32
func.return %shfl, %pred : i32, i1
}
+
+ // CHECK-LABEL: func @gpu_shuffle_permlane_swap
+ // CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @gpu_shuffle_permlane_swap(%arg0: i32) -> (i32, i1) {
+ // CHECK: %[[TRUE:.*]] = arith.constant true
+ // CHECK: %[[RES:.*]] = amdgpu.permlane_swap %[[ARG]] 32 : i32
+ // CHECK: return %[[RES]], %[[TRUE]] : i32, i1
+ %width = arith.constant 64 : i32
+ %offset = arith.constant 31 : i32
+ %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : i32
+ func.return %shfl, %pred : i32, i1
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/154933
More information about the Mlir-commits
mailing list