[Mlir-commits] [mlir] [MLIR][XeVM] XeVM to LLVM: Update xevm.truncf handling (PR #194491)
Md Abdullah Shahneous Bari
llvmlistbot at llvm.org
Thu Apr 30 17:05:47 PDT 2026
================
@@ -1129,51 +1129,165 @@ class TruncfToOCLPattern : public OpConversionPattern<TruncfOp> {
// Supported source and result types are resticted for now.
auto srcEtype = op.getSrcEtype().getEtype();
auto dstEtype = op.getDstEtype().getEtype();
- if (auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType())) {
- if (vecSrcTy.getNumElements() != 16)
- return rewriter.notifyMatchFailure(
- op, "Only vector src of 16 elements is supported");
- } else {
+ // Currently only 16 input elements are supported as
+ // - Any vector beyond 16 elements not a valid OpenCL vector.
+ // - 2D block load can only load up to 16 16bit elements per lane.
+ // Widest load is 8x16xi32 with 16 lanes, which is 16 16bit
+ // elements per lane.
+ // - mma_mx A and B operands need more than 16 elements per lane
+ //
+ // Conversion is done in batches depending on the dst type.
+ // batch_size =
+ // 16 if dst type == fp8
+ // 8 if dst type == fp4
+ // For num_elem > batch_size
+ // convert batch of batch_size
+ // cast batch to i32 elem type vector
+ // concat batches by shufflevector
+ // For num_elem = batch_size
+ // use API for conversion
+ // Scalar case is not supported until usage case become clear.
+ auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType());
+ if (!vecSrcTy) {
return rewriter.notifyMatchFailure(op, "Scalar src is not supported.");
}
- if (auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType())) {
- if (vecDstTy.getNumElements() != 16)
- return rewriter.notifyMatchFailure(
- op, "Only vector dst of 16 elements is supported");
- } else {
+ if (vecSrcTy.getNumElements() != 16)
+ return rewriter.notifyMatchFailure(
+ op, "Only vector src of 16 elements is supported");
+ auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType());
+ if (!vecDstTy)
return rewriter.notifyMatchFailure(op, "Scalar dst is not supported.");
+ Value src = op.getSrc();
+ auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+ /*other=*/LLVM::ModRefInfo::NoModRef,
+ /*argMem=*/LLVM::ModRefInfo::NoModRef,
+ /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
+ /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
+ auto funcAttrs = convergentNoUnwindWillReturnAttrs;
+ funcAttrs.memEffectsAttr = memAttr;
+
+ // Handle the case where dst type is fp4 first.
+ if (dstEtype == TruncfDstElemTypes::E2M1) {
+ // Convert 8 elements at a time.
+ // To convert 8 elements, vector<8xf16>:
+ // Use:
+ // uint __builtin_IB_dnscl_hf16(uint, uint, 1, 0)
+ // uint __builtin_IB_dnscl_hf16(uint, uint, 1, 3)
+ // llvm.or
+ Value cast = LLVM::BitcastOp::create(
+ rewriter, op.getLoc(), VectorType::get(8, rewriter.getI32Type()),
+ src);
+
+ std::string fnName = "__builtin_IB_dnscl_";
+ fnName += (srcEtype == TruncfSrcElemTypes::F16) ? "hf16" : "bf16";
+ auto genDnscl = [&](Value input, Value idx0, Value idx1, Value dstTy,
+ Value mode) -> Value {
+ Value arg1 =
+ LLVM::ExtractElementOp::create(rewriter, op.getLoc(), input, idx0)
+ ->getResult(0);
+ Value arg2 =
+ LLVM::ExtractElementOp::create(rewriter, op.getLoc(), input, idx1)
+ ->getResult(0);
+ SmallVector<Type> argTypes{arg1.getType(), arg2.getType(),
+ dstTy.getType(), mode.getType()};
+ SmallVector<Value> args{arg1, arg2, dstTy, mode};
+ Value dnscl = createDeviceFunctionCall(
+ rewriter, fnName, rewriter.getI32Type(), argTypes,
+ args, {}, funcAttrs, op.getOperation())
+ ->getResult(0);
+ return dnscl;
+ };
+
+ Value zero = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 0);
+ Value one = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 1);
+ Value two = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 2);
+ Value three = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 3);
+ Value even = genDnscl(cast, zero, two, one, zero);
+ Value odd = genDnscl(cast, one, three, one, two);
+ Value firstHalf = LLVM::OrOp::create(rewriter, op.getLoc(), even, odd);
+ Value four = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 4);
+ Value five = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 5);
+ Value six = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 6);
+ Value seven = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 7);
+ even = genDnscl(cast, four, six, one, zero);
+ odd = genDnscl(cast, five, seven, one, two);
+ Value secondHalf = LLVM::OrOp::create(rewriter, op.getLoc(), even, odd);
+ // Create vector<2xi32> from two i32 values and then bitcast to
+ // vector<8xi8> to match the dst type.
+ Value combined = LLVM::UndefOp::create(
+ rewriter, op.getLoc(), VectorType::get(2, rewriter.getI32Type()));
+ combined = LLVM::InsertElementOp::create(rewriter, op.getLoc(), combined,
+ firstHalf, zero)
+ ->getResult(0);
+ combined = LLVM::InsertElementOp::create(rewriter, op.getLoc(), combined,
+ secondHalf, one)
+ ->getResult(0);
+ Value result =
+ LLVM::BitcastOp::create(rewriter, op.getLoc(), vecDstTy, combined);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+
+ // Handle the case where dst type is fp8.
+ // BF16 type needs some preprocessing before conversion,
+ // First extended to F32 and then truncated to F16.
+ if (srcEtype == TruncfSrcElemTypes::BF16) {
+ // Step 1: Extend to F32
+ // Use float16 __builtin_IB_bftof_16(short16)
+ src = LLVM::BitcastOp::create(
+ rewriter, op.getLoc(),
+ VectorType::get(vecSrcTy.getShape(), rewriter.getI16Type()), src);
+ std::string fnName = "__builtin_IB_bftof_16";
+ SmallVector<Type> argTypes{src.getType()};
+ SmallVector<Value> args{src};
+ Type resTy = VectorType::get(vecSrcTy.getShape(), rewriter.getF32Type());
+ src = createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args,
+ {}, funcAttrs, op.getOperation())
+ ->getResult(0);
+ // Step 2: Truncf to F16
+ // Use half16 convert_half16(float16)
+ std::string truncFnName = "convert_half16";
+ SmallVector<Type> truncArgTypes{src.getType()};
+ SmallVector<Value> truncArgs{src};
+ truncFnName = mangle(truncFnName, truncArgTypes);
+ resTy = VectorType::get(vecSrcTy.getShape(), rewriter.getF16Type());
+ src =
+ createDeviceFunctionCall(rewriter, truncFnName, resTy, truncArgTypes,
+ truncArgs, {}, funcAttrs, op.getOperation())
+ ->getResult(0);
}
- if (srcEtype == TruncfSrcElemTypes::F16 &&
- dstEtype == TruncfDstElemTypes::BF8) {
- // BF8 is just F16 with lower 8 bits of mantessa discard.
- // Signbit Exponent Mantessa
- // BF8 1 5 2
- // F16 1 5 10
- // Xe arch is Little Endian so BF8 is just the second byte of the two
- // byte representation used for F16
- auto firstHalf =
- LLVM::ShuffleVectorOp::create(rewriter, op.getLoc(), op.getSrc(),
- op.getSrc(), {0, 1, 2, 3, 4, 5, 6, 7});
- auto secondHalf = LLVM::ShuffleVectorOp::create(
- rewriter, op.getLoc(), op.getSrc(), op.getSrc(),
- {8, 9, 10, 11, 12, 13, 14, 15});
- auto firstHalfCasted = LLVM::BitcastOp::create(
- rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
- firstHalf);
- auto secondHalfCasted = LLVM::BitcastOp::create(
- rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
- secondHalf);
- // Gather just the second bytes from every two byte F16 values
- auto resFirstHalf = LLVM::ShuffleVectorOp::create(
- rewriter, op.getLoc(), firstHalfCasted, firstHalfCasted,
- {1, 3, 5, 7, 9, 11, 13, 15});
- auto resSecondHalf = LLVM::ShuffleVectorOp::create(
- rewriter, op.getLoc(), secondHalfCasted, secondHalfCasted,
- {1, 3, 5, 7, 9, 11, 13, 15});
- auto res = LLVM::ShuffleVectorOp::create(
- rewriter, op.getLoc(), resFirstHalf, resSecondHalf,
- {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
- rewriter.replaceOp(op, res);
+ if (dstEtype == TruncfDstElemTypes::BF8) {
+ // Use char16 __builtin_IB_hftobf8_16(half16)
+ std::string fnName = "__builtin_IB_hftobf8_16";
+ SmallVector<Type> argTypes{src.getType()};
+ SmallVector<Value> args{src};
+ Value result =
+ createDeviceFunctionCall(rewriter, fnName, vecDstTy, argTypes, args,
+ {}, funcAttrs, op.getOperation())
+ ->getResult(0);
+
+ rewriter.replaceOp(op, result);
+ } else if (dstEtype == TruncfDstElemTypes::F8) {
----------------
mshahneo wrote:
Nit. Maybe a comment about what F8 represents in MLIR type?
https://github.com/llvm/llvm-project/pull/194491
More information about the Mlir-commits
mailing list