[Mlir-commits] [mlir] [mlir] GPUToROCDL: Fix crashes with unsupported shuffle datatypes (PR #135504)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Apr 12 14:57:24 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Ivan Butygin (Hardcode84)
<details>
<summary>Changes</summary>
Calling `getIntOrFloatBitWidth` on non-int/float types (`gpu.shuffle` also accepts vectors) will crash.
---
Full diff: https://github.com/llvm/llvm-project/pull/135504.diff
4 Files Affected:
- (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+9-7)
- (modified) mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp (+3-2)
- (added) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir (+13)
- (modified) mlir/test/Dialect/GPU/shuffle-rewrite.mlir (+11)
``````````diff
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 4891dab3aa1d0..c6c695b442b4f 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -136,9 +136,13 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
+ Value initShflValue = adaptor.getValue();
+ Type shflType = initShflValue.getType();
// TODO: Add support for non 32-bit shuffle values.
- if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32)
- return failure();
+ if (!shflType.isIntOrFloat() || shflType.getIntOrFloatBitWidth() != 32)
+ return rewriter.notifyMatchFailure(
+ op, "only 32-bit int/float types are supported");
+
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
@@ -175,16 +179,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
Value dwordAlignedDstLane =
rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
- Value initShflValue = adaptor.getValue();
- if (adaptor.getValue().getType().isF32()) {
+ if (shflType.isF32()) {
initShflValue =
rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
}
Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>(
loc, int32Type, dwordAlignedDstLane, initShflValue);
- if (adaptor.getValue().getType().isF32()) {
- shflValue = rewriter.create<LLVM::BitcastOp>(
- loc, adaptor.getValue().getType(), shflValue);
+ if (shflType.isF32()) {
+ shflValue = rewriter.create<LLVM::BitcastOp>(loc, shflType, shflValue);
}
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
return success();
diff --git a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
index 4bd4da25f6e52..9f2900214e8b1 100644
--- a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
@@ -40,8 +40,9 @@ struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
auto i64 = rewriter.getI64Type();
// If the type of the value is either i32 or f32, the op is already valid.
- if (valueType.getIntOrFloatBitWidth() == 32)
- return failure();
+ if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 64)
+ return rewriter.notifyMatchFailure(
+ op, "only 64-bit int/float types are supported");
Value lo, hi;
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir
new file mode 100644
index 0000000000000..90f2e5f047cd9
--- /dev/null
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -convert-gpu-to-rocdl -verify-diagnostics
+
+gpu.module @test_module {
+ // ROCDL lowering only suport shuffles for 32bit ints/floats, but they
+ // shouldn't crash on unsupported types.
+ func.func @gpu_shuffle_unsupported(%arg0 : vector<4xf16>) -> vector<4xf16> {
+ %offset = arith.constant 4 : i32
+ %width = arith.constant 64 : i32
+ // expected-error @+1 {{failed to legalize operation 'gpu.shuffle'}}
+ %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : vector<4xf16>
+ return %shfl : vector<4xf16>
+ }
+}
diff --git a/mlir/test/Dialect/GPU/shuffle-rewrite.mlir b/mlir/test/Dialect/GPU/shuffle-rewrite.mlir
index 4618258201532..c0ccae05a0572 100644
--- a/mlir/test/Dialect/GPU/shuffle-rewrite.mlir
+++ b/mlir/test/Dialect/GPU/shuffle-rewrite.mlir
@@ -49,3 +49,14 @@ module {
return
}
}
+
+// -----
+
+// CHECK-LABEL: @gpu_shuffle_unsupported
+func.func @gpu_shuffle_unsupported(%arg0 : vector<4xf16>) -> vector<4xf16> {
+ %offset = arith.constant 4 : i32
+ %width = arith.constant 64 : i32
+ // CHECK: gpu.shuffle xor %{{.*}}, %{{.*}}, %{{.*}} : vector<4xf16>
+ %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : vector<4xf16>
+ return %shfl : vector<4xf16>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/135504
More information about the Mlir-commits
mailing list