[Mlir-commits] [mlir] [MLIR][NVGPU] Add convert.fpext and convert.fptrunc Ops (PR #199700)

Durgadoss R llvmlistbot at llvm.org
Wed May 27 01:28:09 PDT 2026


================
@@ -1709,6 +1711,669 @@ struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
         rewriter);
   }
 };
+
+//===----------------------------------------------------------------------===//
+// FPTruncOp Lowering
+//===----------------------------------------------------------------------===//
+
+/// Conversion op identifier for nvgpu.convert.fptrunc lowering dispatch table.
+enum class TruncConvOp {
+  F32x2_TO_F16x2,
+  F32x2_TO_BF16x2,
+  F32x2_TO_F8x2,
+  F32x2_TO_F6x2,
+  F32x2_TO_F4x2,
+  F16x2_TO_F8x2,
+  F16x2_TO_F4x2,
+  BF16x2_TO_F8x2,
+  BF16x2_TO_F4x2,
+};
+
+enum class TruncSrcKind { F32, F16, BF16 };
+
+enum class TruncDstKind { F16, BF16, F8, F6, F4 };
+
+struct TruncTableEntry {
+  TruncSrcKind src;
+  TruncDstKind dst;
+  TruncConvOp convOp;
+  int srcStepDecrement; // 2 for f32 pairs, 1 for f16x2/bf16x2
+};
+
+static constexpr TruncTableEntry kTruncTable[] = {
+    // f32 source
+    {TruncSrcKind::F32, TruncDstKind::F16, TruncConvOp::F32x2_TO_F16x2, 2},
+    {TruncSrcKind::F32, TruncDstKind::BF16, TruncConvOp::F32x2_TO_BF16x2, 2},
+    {TruncSrcKind::F32, TruncDstKind::F8, TruncConvOp::F32x2_TO_F8x2, 2},
+    {TruncSrcKind::F32, TruncDstKind::F6, TruncConvOp::F32x2_TO_F6x2, 2},
+    {TruncSrcKind::F32, TruncDstKind::F4, TruncConvOp::F32x2_TO_F4x2, 2},
+    // f16 source
+    {TruncSrcKind::F16, TruncDstKind::F8, TruncConvOp::F16x2_TO_F8x2, 1},
+    {TruncSrcKind::F16, TruncDstKind::F4, TruncConvOp::F16x2_TO_F4x2, 1},
+    // bf16 source
+    {TruncSrcKind::BF16, TruncDstKind::F8, TruncConvOp::BF16x2_TO_F8x2, 1},
+    {TruncSrcKind::BF16, TruncDstKind::F4, TruncConvOp::BF16x2_TO_F4x2, 1},
+};
+
+static bool isConvertibleF8Type(Type t) {
+  return isa<Float8E4M3FNType, Float8E5M2Type, Float8E8M0FNUType>(t);
+}
+
+static std::optional<TruncSrcKind> classifySrcType(Type t) {
+  if (t.isF32())
+    return TruncSrcKind::F32;
+  if (t.isF16())
+    return TruncSrcKind::F16;
+  if (t.isBF16())
+    return TruncSrcKind::BF16;
+  return std::nullopt;
+}
+
+static std::optional<TruncDstKind> classifyDstType(Type t) {
+  if (t.isF16())
+    return TruncDstKind::F16;
+  if (t.isBF16())
+    return TruncDstKind::BF16;
+  if (isConvertibleF8Type(t))
+    return TruncDstKind::F8;
+  int bitWidth = t.getIntOrFloatBitWidth();
+  if (isa<IntegerType>(t) && bitWidth == 8)
+    return TruncDstKind::F6;
+  if (bitWidth == 4)
+    return TruncDstKind::F4;
+  return std::nullopt;
+}
+
+static std::optional<std::pair<TruncConvOp, int>>
+lookupTruncConvOp(Type srcElemType, Type dstElemType) {
+  auto srcKind = classifySrcType(srcElemType);
+  auto dstKind = classifyDstType(dstElemType);
+  if (!srcKind || !dstKind)
+    return std::nullopt;
+  for (const auto &entry : kTruncTable) {
+    if (entry.src == *srcKind && entry.dst == *dstKind)
+      return {{entry.convOp, entry.srcStepDecrement}};
+  }
+  return std::nullopt;
+}
+
+static Value extractElement(ImplicitLocOpBuilder &b, Value srcVec, int idx) {
+  IntegerType i64Ty = b.getI64Type();
+  return b.create<LLVM::ExtractElementOp>(
+      srcVec, b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx)));
+}
+
+/// Extract a pair of f32 values from an i32 vector at the given base index.
+/// Returns {f32_lo (lower index), f32_hi (higher index)}.
+static std::pair<Value, Value> extractF32Pair(ImplicitLocOpBuilder &b,
+                                              Value srcI32Vec, int baseIdx) {
+  FloatType f32Ty = b.getF32Type();
+  Value elem0 = extractElement(b, srcI32Vec, baseIdx);
+  Value elem1 = extractElement(b, srcI32Vec, baseIdx + 1);
----------------
durga4github wrote:

Do we assume `baseIdx + 1` will not access OOB?
(If that's checked in the callers, let us explicitly state this in the comment too)

https://github.com/llvm/llvm-project/pull/199700


More information about the Mlir-commits mailing list