[Mlir-commits] [mlir] [MLIR][GPU] Support bf16 and i1 gpu::shuffles to LLVMSPIRV conversion (PR #119675)

Pietro Ghiglio llvmlistbot at llvm.org
Tue Jan 7 08:01:33 PST 2025


================
@@ -286,30 +291,94 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
            val == getSubgroupSize(op);
   }
 
+  static bool needsBitCastOrExt(gpu::ShuffleOp op) {
+    Type type = op.getType(0);
+    return isa<BFloat16Type>(type) || type.isInteger(1);
+  }
+
+  static Type getBitCastOrExtTy(Type oldTy,
+                                ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Type>(oldTy)
+        .Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
+        .Case<IntegerType>([&](auto intTy) -> Type {
+          if (intTy.getWidth() == 1)
+            return rewriter.getIntegerType(8);
+          return Type{};
+        })
+        .Default([](auto) { return Type{}; });
+  }
+
+  static Value doBitcastOrExt(Value oldVal, Type newTy, Location loc,
+                              ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Value>(oldVal.getType())
+        .Case<BFloat16Type>([&](auto) {
+          return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+        })
+        .Case<IntegerType>([&](auto intTy) -> Value {
+          if (intTy.getWidth() == 1)
+            return rewriter.create<LLVM::ZExtOp>(loc, newTy, oldVal);
+          return Value{};
+        })
+        .Default([](auto) { return Value{}; });
+  }
+
+  static Value doBitcastOrTrunc(Value oldVal, Type newTy, Location loc,
+                                ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Value>(newTy)
+        .Case<BFloat16Type>([&](auto) {
+          return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+        })
+        .Case<IntegerType>([&](auto intTy) -> Value {
+          if (intTy.getWidth() == 1)
+            return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+          return Value{};
+        })
+        .Default([](auto) { return Value{}; });
+  }
+
   LogicalResult
   matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     if (!hasValidWidth(op))
       return rewriter.notifyMatchFailure(
           op, "shuffle width and subgroup size mismatch");
 
-    std::optional<std::string> funcName = getFuncName(op);
+    Location loc = op->getLoc();
+    Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
+    std::optional<std::string> funcName;
+    Value inValue;
+    if (bitcastOrExtDestTy) {
+      Value newVal =
+          doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
+      assert(newVal && "Unhandled op type in bitcastorext");
+      funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
+      inValue = newVal;
+    } else {
+      funcName = getFuncName(op);
+      inValue = adaptor.getValue();
+    }
----------------
PietroGhg wrote:

Yes, thank you

https://github.com/llvm/llvm-project/pull/119675


More information about the Mlir-commits mailing list