[Mlir-commits] [mlir] [MLIR][XeVM] XeVM to LLVM: Update xevm.truncf handling (PR #194491)

Md Abdullah Shahneous Bari llvmlistbot at llvm.org
Thu Apr 30 17:05:47 PDT 2026


================
@@ -1129,51 +1129,165 @@ class TruncfToOCLPattern : public OpConversionPattern<TruncfOp> {
     // Supported source and result types are resticted for now.
     auto srcEtype = op.getSrcEtype().getEtype();
     auto dstEtype = op.getDstEtype().getEtype();
-    if (auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType())) {
-      if (vecSrcTy.getNumElements() != 16)
-        return rewriter.notifyMatchFailure(
-            op, "Only vector src of 16 elements is supported");
-    } else {
+    // Currently only 16 input elements are supported as
+    //  - Any vector beyond 16 elements not a valid OpenCL vector.
+    //  - 2D block load can only load up to 16 16bit elements per lane.
+    //      Widest load is 8x16xi32 with 16 lanes, which is 16 16bit
+    //      elements per lane.
+    //  - mma_mx A and B operands need more than 16 elements per lane
+    //
+    // Conversion is done in batches depending on the dst type.
+    // batch_size =
+    //   16 if dst type == fp8
+    //   8  if dst type == fp4
+    // For num_elem > batch_size
+    //   convert batch of batch_size
+    //   cast batch to i32 elem type vector
+    //   concat batches by shufflevector
+    // For num_elem = batch_size
+    //   use API for conversion
+    // Scalar case is not supported until usage case become clear.
+    auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType());
+    if (!vecSrcTy) {
       return rewriter.notifyMatchFailure(op, "Scalar src is not supported.");
     }
-    if (auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType())) {
-      if (vecDstTy.getNumElements() != 16)
-        return rewriter.notifyMatchFailure(
-            op, "Only vector dst of 16 elements is supported");
-    } else {
+    if (vecSrcTy.getNumElements() != 16)
+      return rewriter.notifyMatchFailure(
+          op, "Only vector src of 16 elements is supported");
+    auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType());
+    if (!vecDstTy)
       return rewriter.notifyMatchFailure(op, "Scalar dst is not supported.");
+    Value src = op.getSrc();
+    auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+        /*other=*/LLVM::ModRefInfo::NoModRef,
+        /*argMem=*/LLVM::ModRefInfo::NoModRef,
+        /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
+        /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
+        /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
+        /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
+    auto funcAttrs = convergentNoUnwindWillReturnAttrs;
+    funcAttrs.memEffectsAttr = memAttr;
+
+    // Handle the case where dst type is fp4 first.
+    if (dstEtype == TruncfDstElemTypes::E2M1) {
+      // Convert 8 elements at a time.
+      // To convert 8 elements, vector<8xf16>:
+      // Use:
+      // uint __builtin_IB_dnscl_hf16(uint, uint, 1, 0)
+      // uint __builtin_IB_dnscl_hf16(uint, uint, 1, 3)
+      // llvm.or
+      Value cast = LLVM::BitcastOp::create(
+          rewriter, op.getLoc(), VectorType::get(8, rewriter.getI32Type()),
+          src);
+
+      std::string fnName = "__builtin_IB_dnscl_";
+      fnName += (srcEtype == TruncfSrcElemTypes::F16) ? "hf16" : "bf16";
+      auto genDnscl = [&](Value input, Value idx0, Value idx1, Value dstTy,
+                          Value mode) -> Value {
+        Value arg1 =
+            LLVM::ExtractElementOp::create(rewriter, op.getLoc(), input, idx0)
+                ->getResult(0);
+        Value arg2 =
+            LLVM::ExtractElementOp::create(rewriter, op.getLoc(), input, idx1)
+                ->getResult(0);
+        SmallVector<Type> argTypes{arg1.getType(), arg2.getType(),
+                                   dstTy.getType(), mode.getType()};
+        SmallVector<Value> args{arg1, arg2, dstTy, mode};
+        Value dnscl = createDeviceFunctionCall(
+                          rewriter, fnName, rewriter.getI32Type(), argTypes,
+                          args, {}, funcAttrs, op.getOperation())
+                          ->getResult(0);
+        return dnscl;
+      };
+
+      Value zero = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+                                            rewriter.getI32Type(), 0);
+      Value one = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+                                           rewriter.getI32Type(), 1);
+      Value two = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+                                           rewriter.getI32Type(), 2);
+      Value three = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+                                             rewriter.getI32Type(), 3);
+      Value even = genDnscl(cast, zero, two, one, zero);
+      Value odd = genDnscl(cast, one, three, one, two);
+      Value firstHalf = LLVM::OrOp::create(rewriter, op.getLoc(), even, odd);
+      Value four = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+                                            rewriter.getI32Type(), 4);
+      Value five = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+                                            rewriter.getI32Type(), 5);
+      Value six = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+                                           rewriter.getI32Type(), 6);
+      Value seven = LLVM::ConstantOp::create(rewriter, op.getLoc(),
+                                             rewriter.getI32Type(), 7);
+      even = genDnscl(cast, four, six, one, zero);
+      odd = genDnscl(cast, five, seven, one, two);
+      Value secondHalf = LLVM::OrOp::create(rewriter, op.getLoc(), even, odd);
+      // Create vector<2xi32> from two i32 values and then bitcast to
+      // vector<8xi8> to match the dst type.
+      Value combined = LLVM::UndefOp::create(
+          rewriter, op.getLoc(), VectorType::get(2, rewriter.getI32Type()));
+      combined = LLVM::InsertElementOp::create(rewriter, op.getLoc(), combined,
+                                               firstHalf, zero)
+                     ->getResult(0);
+      combined = LLVM::InsertElementOp::create(rewriter, op.getLoc(), combined,
+                                               secondHalf, one)
+                     ->getResult(0);
+      Value result =
+          LLVM::BitcastOp::create(rewriter, op.getLoc(), vecDstTy, combined);
+      rewriter.replaceOp(op, result);
+      return success();
+    }
+
+    // Handle the case where dst type is fp8.
+    // BF16 type needs some preprocessing before conversion,
+    // First extended to F32 and then truncated to F16.
+    if (srcEtype == TruncfSrcElemTypes::BF16) {
+      // Step 1: Extend to F32
+      // Use float16 __builtin_IB_bftof_16(short16)
+      src = LLVM::BitcastOp::create(
+          rewriter, op.getLoc(),
+          VectorType::get(vecSrcTy.getShape(), rewriter.getI16Type()), src);
+      std::string fnName = "__builtin_IB_bftof_16";
+      SmallVector<Type> argTypes{src.getType()};
+      SmallVector<Value> args{src};
+      Type resTy = VectorType::get(vecSrcTy.getShape(), rewriter.getF32Type());
+      src = createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args,
+                                     {}, funcAttrs, op.getOperation())
+                ->getResult(0);
+      // Step 2: Truncf to F16
+      // Use half16 convert_half16(float16)
+      std::string truncFnName = "convert_half16";
+      SmallVector<Type> truncArgTypes{src.getType()};
+      SmallVector<Value> truncArgs{src};
+      truncFnName = mangle(truncFnName, truncArgTypes);
+      resTy = VectorType::get(vecSrcTy.getShape(), rewriter.getF16Type());
+      src =
+          createDeviceFunctionCall(rewriter, truncFnName, resTy, truncArgTypes,
+                                   truncArgs, {}, funcAttrs, op.getOperation())
+              ->getResult(0);
     }
-    if (srcEtype == TruncfSrcElemTypes::F16 &&
-        dstEtype == TruncfDstElemTypes::BF8) {
-      // BF8 is just F16 with lower 8 bits of mantessa discard.
-      //     Signbit Exponent Mantessa
-      // BF8 1       5        2
-      // F16 1       5        10
-      // Xe arch is Little Endian so BF8 is just the second byte of the two
-      // byte representation used for F16
-      auto firstHalf =
-          LLVM::ShuffleVectorOp::create(rewriter, op.getLoc(), op.getSrc(),
-                                        op.getSrc(), {0, 1, 2, 3, 4, 5, 6, 7});
-      auto secondHalf = LLVM::ShuffleVectorOp::create(
-          rewriter, op.getLoc(), op.getSrc(), op.getSrc(),
-          {8, 9, 10, 11, 12, 13, 14, 15});
-      auto firstHalfCasted = LLVM::BitcastOp::create(
-          rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
-          firstHalf);
-      auto secondHalfCasted = LLVM::BitcastOp::create(
-          rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
-          secondHalf);
-      // Gather just the second bytes from every two byte F16 values
-      auto resFirstHalf = LLVM::ShuffleVectorOp::create(
-          rewriter, op.getLoc(), firstHalfCasted, firstHalfCasted,
-          {1, 3, 5, 7, 9, 11, 13, 15});
-      auto resSecondHalf = LLVM::ShuffleVectorOp::create(
-          rewriter, op.getLoc(), secondHalfCasted, secondHalfCasted,
-          {1, 3, 5, 7, 9, 11, 13, 15});
-      auto res = LLVM::ShuffleVectorOp::create(
-          rewriter, op.getLoc(), resFirstHalf, resSecondHalf,
-          {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
-      rewriter.replaceOp(op, res);
+    if (dstEtype == TruncfDstElemTypes::BF8) {
+      // Use char16 __builtin_IB_hftobf8_16(half16)
+      std::string fnName = "__builtin_IB_hftobf8_16";
+      SmallVector<Type> argTypes{src.getType()};
+      SmallVector<Value> args{src};
+      Value result =
+          createDeviceFunctionCall(rewriter, fnName, vecDstTy, argTypes, args,
+                                   {}, funcAttrs, op.getOperation())
+              ->getResult(0);
+
+      rewriter.replaceOp(op, result);
+    } else if (dstEtype == TruncfDstElemTypes::F8) {
----------------
mshahneo wrote:

Nit. Maybe a comment about what F8 represents in MLIR type?

https://github.com/llvm/llvm-project/pull/194491


More information about the Mlir-commits mailing list