[Mlir-commits] [mlir] [mlir][complex] Emit fma for contracted complex.mul lowering (PR #196248)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 6 23:45:51 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: hanbeom (ParkHanbum)
<details>
<summary>Changes</summary>
When complex.mul has fastmath<contract>, lower it using explicit fused multiply-add operations for the real and imaginary components.
The lowering changes from:
real = ar * br - ai * bi
imag = ai * br + ar * bi
expressed as mul/sub/add, to:
real = fma(ar, br, -(ai * bi))
imag = fma(ar, bi, ai * br)
This is only applied when contraction is allowed. Non-contracted complex.mul continues to lower to separate fmul/fsub/fadd operations.
Fixed: #<!-- -->196246
---
Full diff: https://github.com/llvm/llvm-project/pull/196248.diff
4 Files Affected:
- (modified) mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp (+24-7)
- (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+29-12)
- (modified) mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir (+5-6)
- (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+17-22)
``````````diff
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index cdcb3cba55752..327a6678f9aed 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -279,13 +279,30 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
Value lhsRe = arg.lhs.real();
Value lhsIm = arg.lhs.imag();
- Value real = LLVM::FSubOp::create(
- rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
- LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
-
- Value imag = LLVM::FAddOp::create(
- rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
- LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
+ Value real;
+ Value imag;
+ if (arith::bitEnumContainsAll(complexFMFAttr.getValue(),
+ arith::FastMathFlags::contract)) {
+ Value lhsImagTimesRhsImag =
+ LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsIm, fmf);
+ Value negLhsImagTimesRhsImag =
+ LLVM::FNegOp::create(rewriter, loc, lhsImagTimesRhsImag, fmf);
+ real = LLVM::FMAOp::create(rewriter, loc, lhsRe, rhsRe,
+ negLhsImagTimesRhsImag, fmf);
+
+ Value lhsImagTimesRhsReal =
+ LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf);
+ imag = LLVM::FMAOp::create(rewriter, loc, lhsRe, rhsIm,
+ lhsImagTimesRhsReal, fmf);
+ } else {
+ real = LLVM::FSubOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
+ LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
+
+ imag = LLVM::FAddOp::create(
+ rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
+ LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
+ }
result.setReal(rewriter, loc, real);
result.setImaginary(rewriter, loc, imag);
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9e46b7d78baca..4dcc9a2c23d77 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -549,18 +549,35 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs());
Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs());
Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs());
- Value lhsRealTimesRhsReal =
- arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
- Value lhsImagTimesRhsImag =
- arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
- Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
- lhsImagTimesRhsImag, fmfValue);
- Value lhsImagTimesRhsReal =
- arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
- Value lhsRealTimesRhsImag =
- arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
- Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
- lhsRealTimesRhsImag, fmfValue);
+ Value real;
+ Value imag;
+ if (arith::bitEnumContainsAll(fmfValue, arith::FastMathFlags::contract)) {
+ Value lhsImagTimesRhsImag =
+ arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
+ Value negLhsImagTimesRhsImag =
+ arith::NegFOp::create(b, lhsImagTimesRhsImag, fmfValue);
+ real = math::FmaOp::create(b, lhsReal, rhsReal,
+ negLhsImagTimesRhsImag, fmfValue);
+
+ Value lhsImagTimesRhsReal =
+ arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
+ imag =
+ math::FmaOp::create(b, lhsReal, rhsImag, lhsImagTimesRhsReal,
+ fmfValue);
+ } else {
+ Value lhsRealTimesRhsReal =
+ arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
+ Value lhsImagTimesRhsImag =
+ arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
+ real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
+ lhsImagTimesRhsImag, fmfValue);
+ Value lhsImagTimesRhsReal =
+ arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
+ Value lhsRealTimesRhsImag =
+ arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
+ imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
+ lhsRealTimesRhsImag, fmfValue);
+ }
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
return success();
}
diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
index 4d2c12a56eaca..ccbe075d8afc5 100644
--- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
@@ -245,13 +245,12 @@ func.func @complex_div_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]]
// CHECK: %[[RESULT_0:.*]] = llvm.mlir.poison : ![[C_TY]]
-// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[RHS_RE]], %[[LHS_RE]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[RHS_IM]], %[[LHS_IM]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK: %[[REAL:.*]] = llvm.fsub %[[REAL_TMP_0]], %[[REAL_TMP_1]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[REAL_TMP:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[NEG_REAL_TMP:.*]] = llvm.fneg %[[REAL_TMP]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[REAL:.*]] = llvm.intr.fma(%[[LHS_RE]], %[[RHS_RE]], %[[NEG_REAL_TMP]]) {fastmathFlags = #llvm.fastmath<contract, afn>} : (f32, f32, f32) -> f32
-// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK: %[[IMAG:.*]] = llvm.fadd %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[IMAG_TMP:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[IMAG:.*]] = llvm.intr.fma(%[[LHS_RE]], %[[RHS_IM]], %[[IMAG_TMP]]) {fastmathFlags = #llvm.fastmath<contract, afn>} : (f32, f32, f32) -> f32
// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0]
// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1]
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 7a82236b0656e..5f8838f8433b7 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -942,12 +942,11 @@ func.func @complex_mul_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
// CHECK: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex<f32>
// CHECK: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex<f32>
// CHECK: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex<f32>
-// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG:.*]] = arith.addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_TMP:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_REAL_TMP:.*]] = arith.negf %[[REAL_TMP]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL:.*]] = math.fma %[[LHS_REAL]], %[[RHS_REAL]], %[[NEG_REAL_TMP]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_TMP:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG:.*]] = math.fma %[[LHS_REAL]], %[[RHS_IMAG]], %[[IMAG_TMP]] fastmath<nnan,contract> : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[REAL]], %[[IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
@@ -964,23 +963,21 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
// CHECK: %[[VAR2:.*]] = complex.im %arg1 : complex<f32>
// CHECK: %[[VAR4:.*]] = complex.re %arg1 : complex<f32>
// CHECK: %[[VAR6:.*]] = complex.im %arg1 : complex<f32>
-// CHECK: %[[VAR8:.*]] = arith.mulf %[[VAR0]], %[[VAR4]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR10:.*]] = arith.mulf %[[VAR2]], %[[VAR6]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR12:.*]] = arith.subf %[[VAR8]], %[[VAR10]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR10:.*]] = arith.negf %[[VAR10]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR12:.*]] = math.fma %[[VAR0]], %[[VAR4]], %[[NEG_VAR10]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR13:.*]] = arith.mulf %[[VAR2]], %[[VAR4]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR15:.*]] = arith.mulf %[[VAR0]], %[[VAR6]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR17:.*]] = arith.addf %[[VAR13]], %[[VAR15]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR17:.*]] = math.fma %[[VAR0]], %[[VAR6]], %[[VAR13]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR89:.*]] = complex.create %[[VAR12]], %[[VAR17]] : complex<f32>
// CHECK: %[[VAR90:.*]] = complex.re %arg0 : complex<f32>
// CHECK: %[[VAR92:.*]] = complex.im %arg0 : complex<f32>
// CHECK: %[[VAR94:.*]] = complex.re %arg0 : complex<f32>
// CHECK: %[[VAR96:.*]] = complex.im %arg0 : complex<f32>
-// CHECK: %[[VAR98:.*]] = arith.mulf %[[VAR90]], %[[VAR94]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR100:.*]] = arith.mulf %[[VAR92]], %[[VAR96]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR102:.*]] = arith.subf %[[VAR98]], %[[VAR100]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR100:.*]] = arith.negf %[[VAR100]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR102:.*]] = math.fma %[[VAR90]], %[[VAR94]], %[[NEG_VAR100]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR103:.*]] = arith.mulf %[[VAR92]], %[[VAR94]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR105:.*]] = arith.mulf %[[VAR90]], %[[VAR96]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR107:.*]] = arith.addf %[[VAR103]], %[[VAR105]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR107:.*]] = math.fma %[[VAR90]], %[[VAR96]], %[[VAR103]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR179:.*]] = complex.create %[[VAR102]], %[[VAR107]] : complex<f32>
// CHECK: %[[VAR180:.*]] = complex.re %[[VAR89]] : complex<f32>
// CHECK: %[[VAR181:.*]] = complex.re %[[VAR179]] : complex<f32>
@@ -1043,12 +1040,11 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
// CHECK: %[[VAR232:.*]] = complex.im %[[VAR229]] : complex<f32>
// CHECK: %[[VAR234:.*]] = complex.re %arg0 : complex<f32>
// CHECK: %[[VAR236:.*]] = complex.im %arg0 : complex<f32>
-// CHECK: %[[VAR238:.*]] = arith.mulf %[[VAR230]], %[[VAR234]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR240:.*]] = arith.mulf %[[VAR232]], %[[VAR236]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR242:.*]] = arith.subf %[[VAR238]], %[[VAR240]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR240:.*]] = arith.negf %[[VAR240]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR242:.*]] = math.fma %[[VAR230]], %[[VAR234]], %[[NEG_VAR240]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR243:.*]] = arith.mulf %[[VAR232]], %[[VAR234]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR245:.*]] = arith.mulf %[[VAR230]], %[[VAR236]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR247:.*]] = arith.addf %[[VAR243]], %[[VAR245]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR247:.*]] = math.fma %[[VAR230]], %[[VAR236]], %[[VAR243]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR319:.*]] = complex.create %[[VAR242]], %[[VAR247]] : complex<f32>
// CHECK: %[[VAR320:.*]] = complex.re %arg1 : complex<f32>
// CHECK: %[[VAR321:.*]] = complex.re %[[VAR319]] : complex<f32>
@@ -1174,12 +1170,11 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
// CHECK: %[[VAR444:.*]] = complex.im %[[VAR441]] : complex<f32>
// CHECK: %[[VAR446:.*]] = complex.re %[[VAR440]] : complex<f32>
// CHECK: %[[VAR448:.*]] = complex.im %[[VAR440]] : complex<f32>
-// CHECK: %[[VAR450:.*]] = arith.mulf %[[VAR442]], %[[VAR446]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR452:.*]] = arith.mulf %[[VAR444]], %[[VAR448]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR454:.*]] = arith.subf %[[VAR450]], %[[VAR452]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR452:.*]] = arith.negf %[[VAR452]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR454:.*]] = math.fma %[[VAR442]], %[[VAR446]], %[[NEG_VAR452]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR455:.*]] = arith.mulf %[[VAR444]], %[[VAR446]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR457:.*]] = arith.mulf %[[VAR442]], %[[VAR448]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR459:.*]] = arith.addf %[[VAR455]], %[[VAR457]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR459:.*]] = math.fma %[[VAR442]], %[[VAR448]], %[[VAR455]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR531:.*]] = complex.create %[[VAR454]], %[[VAR459]] : complex<f32>
// CHECK: return %[[VAR531]] : complex<f32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/196248
More information about the Mlir-commits
mailing list