[Mlir-commits] [mlir] [MLIR][XeVM] XeVM to LLVM: Update xevm.truncf handling (PR #194491)
Charitha Saumya
llvmlistbot at llvm.org
Thu Apr 30 13:02:26 PDT 2026
================
@@ -1129,51 +1129,165 @@ class TruncfToOCLPattern : public OpConversionPattern<TruncfOp> {
// Supported source and result types are resticted for now.
auto srcEtype = op.getSrcEtype().getEtype();
auto dstEtype = op.getDstEtype().getEtype();
- if (auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType())) {
- if (vecSrcTy.getNumElements() != 16)
- return rewriter.notifyMatchFailure(
- op, "Only vector src of 16 elements is supported");
- } else {
+ // Currently only 16 input elements are supported as
+ // - Any vector beyond 16 elements not a valid OpenCL vector.
+ // - 2D block load can only load up to 16 16bit elements per lane.
+ // Widest load is 8x16xi32 with 16 lanes, which is 16 16bit
+ // elements per lane.
+ // - mma_mx A and B operands need more than 16 elements per lane
+ //
+ // Conversion is done in batches depending on the dst type.
+ // batch_size =
+ // 16 if dst type == fp8
+ // 8 if dst type == fp4
+ // For num_elem > batch_size
+ // convert batch of batch_size
+ // cast batch to i32 elem type vector
+ // concat batches by shufflevector
+ // For num_elem = batch_size
+ // use API for conversion
+ // Scalar case is not supported until usage case become clear.
+ auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType());
+ if (!vecSrcTy) {
return rewriter.notifyMatchFailure(op, "Scalar src is not supported.");
}
- if (auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType())) {
- if (vecDstTy.getNumElements() != 16)
- return rewriter.notifyMatchFailure(
- op, "Only vector dst of 16 elements is supported");
- } else {
+ if (vecSrcTy.getNumElements() != 16)
+ return rewriter.notifyMatchFailure(
+ op, "Only vector src of 16 elements is supported");
+ auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType());
+ if (!vecDstTy)
return rewriter.notifyMatchFailure(op, "Scalar dst is not supported.");
+ Value src = op.getSrc();
+ auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+ /*other=*/LLVM::ModRefInfo::NoModRef,
+ /*argMem=*/LLVM::ModRefInfo::NoModRef,
+ /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
+ /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
+ /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
+ auto funcAttrs = convergentNoUnwindWillReturnAttrs;
+ funcAttrs.memEffectsAttr = memAttr;
+
+ // Handle the case where dst type is fp4 first.
+ if (dstEtype == TruncfDstElemTypes::E2M1) {
+ // Convert 8 elements at a time.
+ // To convert 8 elements, vector<8xf16>:
+ // Use:
+ // uint __builtin_IB_dnscl_hf16(uint, uint, 1, 0)
+ // uint __builtin_IB_dnscl_hf16(uint, uint, 1, 3)
+ // llvm.or
+ Value cast = LLVM::BitcastOp::create(
+ rewriter, op.getLoc(), VectorType::get(8, rewriter.getI32Type()),
+ src);
+
+ std::string fnName = "__builtin_IB_dnscl_";
+ fnName += (srcEtype == TruncfSrcElemTypes::F16) ? "hf16" : "bf16";
+ auto genDnscl = [&](Value input, Value idx0, Value idx1, Value dstTy,
+ Value mode) -> Value {
+ Value arg1 =
+ LLVM::ExtractElementOp::create(rewriter, op.getLoc(), input, idx0)
+ ->getResult(0);
+ Value arg2 =
+ LLVM::ExtractElementOp::create(rewriter, op.getLoc(), input, idx1)
+ ->getResult(0);
+ SmallVector<Type> argTypes{arg1.getType(), arg2.getType(),
+ dstTy.getType(), mode.getType()};
+ SmallVector<Value> args{arg1, arg2, dstTy, mode};
+ Value dnscl = createDeviceFunctionCall(
+ rewriter, fnName, rewriter.getI32Type(), argTypes,
+ args, {}, funcAttrs, op.getOperation())
+ ->getResult(0);
+ return dnscl;
+ };
+
+ Value zero = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 0);
+ Value one = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 1);
+ Value two = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 2);
+ Value three = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 3);
+ Value even = genDnscl(cast, zero, two, one, zero);
+ Value odd = genDnscl(cast, one, three, one, two);
+ Value firstHalf = LLVM::OrOp::create(rewriter, op.getLoc(), even, odd);
+ Value four = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 4);
+ Value five = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 5);
+ Value six = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 6);
+ Value seven = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+ rewriter.getI32Type(), 7);
----------------
charithaintc wrote:
nit: could be in a loop?
https://github.com/llvm/llvm-project/pull/194491
More information about the Mlir-commits
mailing list