[Mlir-commits] [mlir] [MLIR][GPU] Support bf16 and i1 gpu::shuffles to LLVMSPIRV conversion (PR #119675)

Pietro Ghiglio llvmlistbot at llvm.org
Thu Dec 12 00:49:17 PST 2024


https://github.com/PietroGhg created https://github.com/llvm/llvm-project/pull/119675

This PR adds support to the `bf16` and `i1` data types when converting `gpu::shuffle` to the `LLVMSPV` dialect, by inserting `bitcast` to/from `i16` (for `bf16`) and extending/truncating to `i8` (for `i1`).

>From 9e3af90821eeec58d715a0559ceaa61bee83b999 Mon Sep 17 00:00:00 2001
From: PietroGhg <pietro.ghiglio at codeplay.com>
Date: Mon, 9 Dec 2024 11:24:28 +0000
Subject: [PATCH] [MLIR][GPU] Support bf16 and i1 gpu::shuffles to LLVMSPIRV
 conversion

---
 .../Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp  | 83 +++++++++++++++++--
 .../GPUToLLVMSPV/gpu-to-llvm-spv.mlir         | 19 ++++-
 2 files changed, 91 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 03745f4537e99e..415e67aebab978 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -262,15 +262,20 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
         .Default([](auto) { return std::nullopt; });
   }
 
-  static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
-    StringRef baseName = getBaseName(op.getMode());
-    std::optional<StringRef> typeMangling = getTypeMangling(op.getType(0));
+  static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
+                                                Type type) {
+    StringRef baseName = getBaseName(mode);
+    std::optional<StringRef> typeMangling = getTypeMangling(type);
     if (!typeMangling)
       return std::nullopt;
     return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
                          typeMangling.value());
   }
 
+  static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
+    return getFuncName(op.getMode(), op.getType(0));
+  }
+
   /// Get the subgroup size from the target or return a default.
   static std::optional<int> getSubgroupSize(Operation *op) {
     auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
@@ -286,6 +291,51 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
            val == getSubgroupSize(op);
   }
 
+  static bool needsBitCastOrExt(gpu::ShuffleOp op) {
+    Type type = op.getType(0);
+    return isa<BFloat16Type>(type) || type.isInteger(1);
+  }
+
+  static Type getBitCastOrExtTy(Type oldTy,
+                                ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Type>(oldTy)
+        .Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
+        .Case<IntegerType>([&](auto intTy) -> Type {
+          if (intTy.getWidth() == 1)
+            return rewriter.getIntegerType(8);
+          return Type{};
+        })
+        .Default([](auto) { return Type{}; });
+  }
+
+  static Value doBitcastOrExt(Value oldVal, Type newTy, Location loc,
+                              ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Value>(oldVal.getType())
+        .Case<BFloat16Type>([&](auto) {
+          return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+        })
+        .Case<IntegerType>([&](auto intTy) -> Value {
+          if (intTy.getWidth() == 1)
+            return rewriter.create<LLVM::ZExtOp>(loc, newTy, oldVal);
+          return Value{};
+        })
+        .Default([](auto) { return Value{}; });
+  }
+
+  static Value doBitcastOrTrunc(Value oldVal, Type newTy, Location loc,
+                                ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Value>(newTy)
+        .Case<BFloat16Type>([&](auto) {
+          return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+        })
+        .Case<IntegerType>([&](auto intTy) -> Value {
+          if (intTy.getWidth() == 1)
+            return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+          return Value{};
+        })
+        .Default([](auto) { return Value{}; });
+  }
+
   LogicalResult
   matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
@@ -293,23 +343,42 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
       return rewriter.notifyMatchFailure(
           op, "shuffle width and subgroup size mismatch");
 
-    std::optional<std::string> funcName = getFuncName(op);
+    Location loc = op->getLoc();
+    Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
+    std::optional<std::string> funcName;
+    Value inValue;
+    if (bitcastOrExtDestTy) {
+      Value newVal =
+          doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
+      assert(newVal && "Unhandled op type in bitcastorext");
+      funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
+      inValue = newVal;
+    } else {
+      funcName = getFuncName(op);
+      inValue = adaptor.getValue();
+    }
     if (!funcName)
       return rewriter.notifyMatchFailure(op, "unsupported value type");
 
     Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
     assert(moduleOp && "Expecting module");
-    Type valueType = adaptor.getValue().getType();
+    Type valueType = inValue.getType();
     Type offsetType = adaptor.getOffset().getType();
     Type resultType = valueType;
     LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
         moduleOp, funcName.value(), {valueType, offsetType}, resultType,
         /*isMemNone=*/false, /*isConvergent=*/true);
 
-    Location loc = op->getLoc();
-    std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
+    std::array<Value, 2> args{inValue, adaptor.getOffset()};
     Value result =
         createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
+    if (bitcastOrExtDestTy) {
+      Value newVal =
+          doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter);
+      assert(newVal && "Unhandled op type in bitcastortrunc");
+      result = newVal;
+    }
+
     Value trueVal =
         rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
     rewriter.replaceOp(op, {result, trueVal});
diff --git a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
index 16b692b9689398..6fab647cb35681 100644
--- a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
+++ b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
@@ -277,7 +277,8 @@ gpu.module @shuffles {
   // CHECK-SAME:              (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16,
   // CHECK-SAME:               %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64,
   // CHECK-SAME:               %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32,
-  // CHECK-SAME:               %[[F64_VAL:.*]]: f64,  %[[OFFSET:.*]]: i32)
+  // CHECK-SAME:               %[[F64_VAL:.*]]: f64, %[[BF16_VAL:.*]]: bf16,
+  // CHECK-SAME:               %[[I1_VAL:.*]]: i1, %[[OFFSET:.*]]: i32)
   llvm.func @gpu_shuffles(%i8_val: i8,
                           %i16_val: i16,
                           %i32_val: i32,
@@ -285,6 +286,8 @@ gpu.module @shuffles {
                           %f16_val: f16,
                           %f32_val: f32,
                           %f64_val: f64,
+                          %bf16_val: bf16,
+                          %i1_val: i1,
                           %offset: i32) attributes {intel_reqd_sub_group_size = 16 : i32} {
     %width = arith.constant 16 : i32
     // CHECK:         llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]])
@@ -301,6 +304,14 @@ gpu.module @shuffles {
     // CHECK:         llvm.mlir.constant(true) : i1
     // CHECK:         llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]])
     // CHECK:         llvm.mlir.constant(true) : i1
+    // CHECK:         %[[BF16_INBC:.*]] = llvm.bitcast %[[BF16_VAL]] : bf16 to i16
+    // CHECK:         %[[BF16_CALL:.*]] = llvm.call spir_funccc @_Z22sub_group_shuffle_downsj(%[[BF16_INBC]], %[[OFFSET]])
+    // CHECK:         llvm.bitcast %[[BF16_CALL]] : i16 to bf16
+    // CHECK:         llvm.mlir.constant(true) : i1
+    // CHECK:         %[[I1_ZEXT:.*]] = llvm.zext %[[I1_VAL]] : i1 to i8
+    // CHECK:         %[[I1_CALL:.*]] = llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj(%18, %arg9)
+    // CHECK:         llvm.trunc %[[I1_CALL:.*]] : i8 to i1
+    // CHECK:         llvm.mlir.constant(true) : i1
     %shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8
     %shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16
     %shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32
@@ -308,6 +319,8 @@ gpu.module @shuffles {
     %shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16
     %shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32
     %shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64
+    %shuffleResult7, %valid7 = gpu.shuffle down %bf16_val, %offset, %width : bf16
+    %shuffleResult8, %valid8 = gpu.shuffle xor %i1_val, %offset, %width : i1
     llvm.return
   }
 }
@@ -342,10 +355,8 @@ gpu.module @shuffles_mismatch {
 // Cannot convert due to value type not being supported by the conversion
 
 gpu.module @not_supported_lowering {
-  llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
+  llvm.func @gpu_shuffles(%id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
     %width = arith.constant 32 : i32
-    // expected-error at below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}}
-    %shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1
     llvm.return
   }
 }



More information about the Mlir-commits mailing list