[Mlir-commits] [mlir] d893d12 - [mlir] GPUToROCDL: Fix crashes with unsupported shuffle datatypes (#135504)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Apr 13 11:26:23 PDT 2025
Author: Ivan Butygin
Date: 2025-04-13T20:26:19+02:00
New Revision: d893d129e6ee8b4dead1532cd8420750908acca6
URL: https://github.com/llvm/llvm-project/commit/d893d129e6ee8b4dead1532cd8420750908acca6
DIFF: https://github.com/llvm/llvm-project/commit/d893d129e6ee8b4dead1532cd8420750908acca6.diff
LOG: [mlir] GPUToROCDL: Fix crashes with unsupported shuffle datatypes (#135504)
Calling `getIntOrFloatBitWidth` on non-int/float types (`gpu.shuffle`
also accepts vectors) will crash.
Added:
mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir
Modified:
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
mlir/test/Dialect/GPU/shuffle-rewrite.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 4891dab3aa1d0..c6c695b442b4f 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -136,9 +136,13 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
+ Value initShflValue = adaptor.getValue();
+ Type shflType = initShflValue.getType();
// TODO: Add support for non 32-bit shuffle values.
- if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32)
- return failure();
+ if (!shflType.isIntOrFloat() || shflType.getIntOrFloatBitWidth() != 32)
+ return rewriter.notifyMatchFailure(
+ op, "only 32-bit int/float types are supported");
+
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
@@ -175,16 +179,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
Value dwordAlignedDstLane =
rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
- Value initShflValue = adaptor.getValue();
- if (adaptor.getValue().getType().isF32()) {
+ if (shflType.isF32()) {
initShflValue =
rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
}
Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>(
loc, int32Type, dwordAlignedDstLane, initShflValue);
- if (adaptor.getValue().getType().isF32()) {
- shflValue = rewriter.create<LLVM::BitcastOp>(
- loc, adaptor.getValue().getType(), shflValue);
+ if (shflType.isF32()) {
+ shflValue = rewriter.create<LLVM::BitcastOp>(loc, shflType, shflValue);
}
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
return success();
diff --git a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
index 4bd4da25f6e52..9f2900214e8b1 100644
--- a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
@@ -40,8 +40,9 @@ struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
auto i64 = rewriter.getI64Type();
// If the type of the value is either i32 or f32, the op is already valid.
- if (valueType.getIntOrFloatBitWidth() == 32)
- return failure();
+ if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 64)
+ return rewriter.notifyMatchFailure(
+ op, "only 64-bit int/float types are supported");
Value lo, hi;
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir
new file mode 100644
index 0000000000000..90f2e5f047cd9
--- /dev/null
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -convert-gpu-to-rocdl -verify-diagnostics
+
+gpu.module @test_module {
+ // ROCDL lowering only suport shuffles for 32bit ints/floats, but they
+ // shouldn't crash on unsupported types.
+ func.func @gpu_shuffle_unsupported(%arg0 : vector<4xf16>) -> vector<4xf16> {
+ %offset = arith.constant 4 : i32
+ %width = arith.constant 64 : i32
+ // expected-error @+1 {{failed to legalize operation 'gpu.shuffle'}}
+ %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : vector<4xf16>
+ return %shfl : vector<4xf16>
+ }
+}
diff --git a/mlir/test/Dialect/GPU/shuffle-rewrite.mlir b/mlir/test/Dialect/GPU/shuffle-rewrite.mlir
index 4618258201532..c0ccae05a0572 100644
--- a/mlir/test/Dialect/GPU/shuffle-rewrite.mlir
+++ b/mlir/test/Dialect/GPU/shuffle-rewrite.mlir
@@ -49,3 +49,14 @@ module {
return
}
}
+
+// -----
+
+// CHECK-LABEL: @gpu_shuffle_unsupported
+func.func @gpu_shuffle_unsupported(%arg0 : vector<4xf16>) -> vector<4xf16> {
+ %offset = arith.constant 4 : i32
+ %width = arith.constant 64 : i32
+ // CHECK: gpu.shuffle xor %{{.*}}, %{{.*}}, %{{.*}} : vector<4xf16>
+ %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : vector<4xf16>
+ return %shfl : vector<4xf16>
+}
More information about the Mlir-commits
mailing list