[Mlir-commits] [mlir] [mlir][math] Add vector support for math-to-apfloat (PR #172715)
Maksim Levental
llvmlistbot at llvm.org
Thu Jan 15 18:36:11 PST 2026
================
@@ -151,29 +150,63 @@ struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> {
Location loc = op.getLoc();
rewriter.setInsertionPoint(op);
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- auto int64Type = rewriter.getI64Type();
- Value operand = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getA()));
- Value multiplicand = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getB()));
- Value addend = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getC()));
-
- // Call APFloat function.
- Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, operand, multiplicand, addend};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
-
- // Truncate result to the original width.
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
- resultOp->getResult(0));
- rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, floatTy, truncatedBits);
+ IntegerType intWType = rewriter.getIntegerType(floatTy.getWidth());
+ IntegerType int64Type = rewriter.getI64Type();
+
+ auto scalarFMA = [&rewriter, &loc, &floatTy, &fn, &intWType,
+ &int64Type](Value a, Value b, Value c) {
+ Value operand = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, a));
+ Value multiplicand = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, b));
+ Value addend = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, c));
+ // Call APFloat function.
+ Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operand, multiplicand, addend};
+ auto resultOp =
+ func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
+
+ // Truncate result to the original width.
+ auto trunc = arith::TruncIOp::create(rewriter, loc, intWType,
+ resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, floatTy, trunc);
+ };
+
+ if (VectorType vecTy1 = dyn_cast<VectorType>(op.getA().getType())) {
----------------
makslevental wrote:
done
https://github.com/llvm/llvm-project/pull/172715
More information about the Mlir-commits
mailing list