[Mlir-commits] [mlir] [MLIR][NVGPU] Add convert.fpext and convert.fptrunc Ops (PR #199700)
Durgadoss R
llvmlistbot at llvm.org
Wed May 27 01:28:09 PDT 2026
================
@@ -1709,6 +1711,669 @@ struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
rewriter);
}
};
+
+//===----------------------------------------------------------------------===//
+// FPTruncOp Lowering
+//===----------------------------------------------------------------------===//
+
+/// Conversion op identifier for nvgpu.convert.fptrunc lowering dispatch table.
+enum class TruncConvOp {
+ F32x2_TO_F16x2,
+ F32x2_TO_BF16x2,
+ F32x2_TO_F8x2,
+ F32x2_TO_F6x2,
+ F32x2_TO_F4x2,
+ F16x2_TO_F8x2,
+ F16x2_TO_F4x2,
+ BF16x2_TO_F8x2,
+ BF16x2_TO_F4x2,
+};
+
+enum class TruncSrcKind { F32, F16, BF16 };
+
+enum class TruncDstKind { F16, BF16, F8, F6, F4 };
+
+struct TruncTableEntry {
+ TruncSrcKind src;
+ TruncDstKind dst;
+ TruncConvOp convOp;
+ int srcStepDecrement; // 2 for f32 pairs, 1 for f16x2/bf16x2
+};
+
+static constexpr TruncTableEntry kTruncTable[] = {
+ // f32 source
+ {TruncSrcKind::F32, TruncDstKind::F16, TruncConvOp::F32x2_TO_F16x2, 2},
+ {TruncSrcKind::F32, TruncDstKind::BF16, TruncConvOp::F32x2_TO_BF16x2, 2},
+ {TruncSrcKind::F32, TruncDstKind::F8, TruncConvOp::F32x2_TO_F8x2, 2},
+ {TruncSrcKind::F32, TruncDstKind::F6, TruncConvOp::F32x2_TO_F6x2, 2},
+ {TruncSrcKind::F32, TruncDstKind::F4, TruncConvOp::F32x2_TO_F4x2, 2},
+ // f16 source
+ {TruncSrcKind::F16, TruncDstKind::F8, TruncConvOp::F16x2_TO_F8x2, 1},
+ {TruncSrcKind::F16, TruncDstKind::F4, TruncConvOp::F16x2_TO_F4x2, 1},
+ // bf16 source
+ {TruncSrcKind::BF16, TruncDstKind::F8, TruncConvOp::BF16x2_TO_F8x2, 1},
+ {TruncSrcKind::BF16, TruncDstKind::F4, TruncConvOp::BF16x2_TO_F4x2, 1},
+};
+
+static bool isConvertibleF8Type(Type t) {
+ return isa<Float8E4M3FNType, Float8E5M2Type, Float8E8M0FNUType>(t);
+}
+
+static std::optional<TruncSrcKind> classifySrcType(Type t) {
+ if (t.isF32())
+ return TruncSrcKind::F32;
+ if (t.isF16())
+ return TruncSrcKind::F16;
+ if (t.isBF16())
+ return TruncSrcKind::BF16;
+ return std::nullopt;
+}
+
+static std::optional<TruncDstKind> classifyDstType(Type t) {
+ if (t.isF16())
+ return TruncDstKind::F16;
+ if (t.isBF16())
+ return TruncDstKind::BF16;
+ if (isConvertibleF8Type(t))
+ return TruncDstKind::F8;
+ int bitWidth = t.getIntOrFloatBitWidth();
+ if (isa<IntegerType>(t) && bitWidth == 8)
+ return TruncDstKind::F6;
+ if (bitWidth == 4)
+ return TruncDstKind::F4;
+ return std::nullopt;
+}
+
+static std::optional<std::pair<TruncConvOp, int>>
+lookupTruncConvOp(Type srcElemType, Type dstElemType) {
+ auto srcKind = classifySrcType(srcElemType);
+ auto dstKind = classifyDstType(dstElemType);
+ if (!srcKind || !dstKind)
+ return std::nullopt;
+ for (const auto &entry : kTruncTable) {
+ if (entry.src == *srcKind && entry.dst == *dstKind)
+ return {{entry.convOp, entry.srcStepDecrement}};
+ }
+ return std::nullopt;
+}
+
+static Value extractElement(ImplicitLocOpBuilder &b, Value srcVec, int idx) {
+ IntegerType i64Ty = b.getI64Type();
+ return b.create<LLVM::ExtractElementOp>(
+ srcVec, b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx)));
+}
+
+/// Extract a pair of f32 values from an i32 vector at the given base index.
+/// Returns {f32_lo (lower index), f32_hi (higher index)}.
+static std::pair<Value, Value> extractF32Pair(ImplicitLocOpBuilder &b,
+ Value srcI32Vec, int baseIdx) {
+ FloatType f32Ty = b.getF32Type();
+ Value elem0 = extractElement(b, srcI32Vec, baseIdx);
+ Value elem1 = extractElement(b, srcI32Vec, baseIdx + 1);
----------------
durga4github wrote:
Do we assume `baseIdx + 1` will not access OOB?
(If that's checked in the callers, let us explicitly state this in the comment too)
https://github.com/llvm/llvm-project/pull/199700
More information about the Mlir-commits
mailing list