[Mlir-commits] [mlir] d893d12 - [mlir] GPUToROCDL: Fix crashes with unsupported shuffle datatypes (#135504)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Apr 13 11:26:23 PDT 2025


Author: Ivan Butygin
Date: 2025-04-13T20:26:19+02:00
New Revision: d893d129e6ee8b4dead1532cd8420750908acca6

URL: https://github.com/llvm/llvm-project/commit/d893d129e6ee8b4dead1532cd8420750908acca6
DIFF: https://github.com/llvm/llvm-project/commit/d893d129e6ee8b4dead1532cd8420750908acca6.diff

LOG: [mlir] GPUToROCDL: Fix crashes with unsupported shuffle datatypes (#135504)

Calling `getIntOrFloatBitWidth` on non-int/float types (`gpu.shuffle`
also accepts vectors) will crash.

Added: 
    mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir

Modified: 
    mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
    mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
    mlir/test/Dialect/GPU/shuffle-rewrite.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 4891dab3aa1d0..c6c695b442b4f 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -136,9 +136,13 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
   matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
+    Value initShflValue = adaptor.getValue();
+    Type shflType = initShflValue.getType();
     // TODO: Add support for non 32-bit shuffle values.
-    if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32)
-      return failure();
+    if (!shflType.isIntOrFloat() || shflType.getIntOrFloatBitWidth() != 32)
+      return rewriter.notifyMatchFailure(
+          op, "only 32-bit int/float types are supported");
+
     const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
     Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
 
@@ -175,16 +179,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
     Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
     Value dwordAlignedDstLane =
         rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
-    Value initShflValue = adaptor.getValue();
-    if (adaptor.getValue().getType().isF32()) {
+    if (shflType.isF32()) {
       initShflValue =
           rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
     }
     Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>(
         loc, int32Type, dwordAlignedDstLane, initShflValue);
-    if (adaptor.getValue().getType().isF32()) {
-      shflValue = rewriter.create<LLVM::BitcastOp>(
-          loc, adaptor.getValue().getType(), shflValue);
+    if (shflType.isF32()) {
+      shflValue = rewriter.create<LLVM::BitcastOp>(loc, shflType, shflValue);
     }
     rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
     return success();

diff  --git a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
index 4bd4da25f6e52..9f2900214e8b1 100644
--- a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp
@@ -40,8 +40,9 @@ struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> {
     auto i64 = rewriter.getI64Type();
 
     // If the type of the value is either i32 or f32, the op is already valid.
-    if (valueType.getIntOrFloatBitWidth() == 32)
-      return failure();
+    if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 64)
+      return rewriter.notifyMatchFailure(
+          op, "only 64-bit int/float types are supported");
 
     Value lo, hi;
 

diff  --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir
new file mode 100644
index 0000000000000..90f2e5f047cd9
--- /dev/null
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s -convert-gpu-to-rocdl -verify-diagnostics
+
+gpu.module @test_module {
+  // ROCDL lowering only suport shuffles for 32bit ints/floats, but they
+  // shouldn't crash on unsupported types.
+  func.func @gpu_shuffle_unsupported(%arg0 : vector<4xf16>) -> vector<4xf16> {
+    %offset = arith.constant 4 : i32
+    %width = arith.constant 64 : i32
+    // expected-error @+1 {{failed to legalize operation 'gpu.shuffle'}}
+    %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : vector<4xf16>
+    return %shfl : vector<4xf16>
+  }
+}

diff  --git a/mlir/test/Dialect/GPU/shuffle-rewrite.mlir b/mlir/test/Dialect/GPU/shuffle-rewrite.mlir
index 4618258201532..c0ccae05a0572 100644
--- a/mlir/test/Dialect/GPU/shuffle-rewrite.mlir
+++ b/mlir/test/Dialect/GPU/shuffle-rewrite.mlir
@@ -49,3 +49,14 @@ module {
     return
   }
 }
+
+// -----
+
+// CHECK-LABEL: @gpu_shuffle_unsupported
+func.func @gpu_shuffle_unsupported(%arg0 : vector<4xf16>) -> vector<4xf16> {
+  %offset = arith.constant 4 : i32
+  %width = arith.constant 64 : i32
+  // CHECK: gpu.shuffle xor %{{.*}}, %{{.*}}, %{{.*}} : vector<4xf16>
+  %shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : vector<4xf16>
+  return %shfl : vector<4xf16>
+}


        


More information about the Mlir-commits mailing list