[Mlir-commits] [mlir] [mlir][complex] Emit fma for contracted complex.mul lowering (PR #196248)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 6 23:45:51 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: hanbeom (ParkHanbum)

<details>
<summary>Changes</summary>

When complex.mul has fastmath<contract>, lower it using explicit fused multiply-add operations for the real and imaginary components.

The lowering changes from:

  real = ar * br - ai * bi
  imag = ai * br + ar * bi

expressed as mul/sub/add, to:

  real = fma(ar, br, -(ai * bi))
  imag = fma(ar, bi,  ai * br)

This is only applied when contraction is allowed. Non-contracted complex.mul continues to lower to separate fmul/fsub/fadd operations.

Fixed: #<!-- -->196246

---
Full diff: https://github.com/llvm/llvm-project/pull/196248.diff


4 Files Affected:

- (modified) mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp (+24-7) 
- (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+29-12) 
- (modified) mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir (+5-6) 
- (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+17-22) 


``````````diff
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index cdcb3cba55752..327a6678f9aed 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -279,13 +279,30 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
     Value lhsRe = arg.lhs.real();
     Value lhsIm = arg.lhs.imag();
 
-    Value real = LLVM::FSubOp::create(
-        rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
-        LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
-
-    Value imag = LLVM::FAddOp::create(
-        rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
-        LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
+    Value real;
+    Value imag;
+    if (arith::bitEnumContainsAll(complexFMFAttr.getValue(),
+                                  arith::FastMathFlags::contract)) {
+      Value lhsImagTimesRhsImag =
+          LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsIm, fmf);
+      Value negLhsImagTimesRhsImag =
+          LLVM::FNegOp::create(rewriter, loc, lhsImagTimesRhsImag, fmf);
+      real = LLVM::FMAOp::create(rewriter, loc, lhsRe, rhsRe,
+                                 negLhsImagTimesRhsImag, fmf);
+
+      Value lhsImagTimesRhsReal =
+          LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf);
+      imag = LLVM::FMAOp::create(rewriter, loc, lhsRe, rhsIm,
+                                 lhsImagTimesRhsReal, fmf);
+    } else {
+      real = LLVM::FSubOp::create(
+          rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
+          LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
+
+      imag = LLVM::FAddOp::create(
+          rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
+          LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
+    }
 
     result.setReal(rewriter, loc, real);
     result.setImaginary(rewriter, loc, imag);
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9e46b7d78baca..4dcc9a2c23d77 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -549,18 +549,35 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
     Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs());
     Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs());
     Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs());
-    Value lhsRealTimesRhsReal =
-        arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
-    Value lhsImagTimesRhsImag =
-        arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
-    Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
-                                       lhsImagTimesRhsImag, fmfValue);
-    Value lhsImagTimesRhsReal =
-        arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
-    Value lhsRealTimesRhsImag =
-        arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
-    Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
-                                       lhsRealTimesRhsImag, fmfValue);
+    Value real;
+    Value imag;
+    if (arith::bitEnumContainsAll(fmfValue, arith::FastMathFlags::contract)) {
+      Value lhsImagTimesRhsImag =
+          arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
+      Value negLhsImagTimesRhsImag =
+          arith::NegFOp::create(b, lhsImagTimesRhsImag, fmfValue);
+      real = math::FmaOp::create(b, lhsReal, rhsReal,
+                                 negLhsImagTimesRhsImag, fmfValue);
+
+      Value lhsImagTimesRhsReal =
+          arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
+      imag =
+          math::FmaOp::create(b, lhsReal, rhsImag, lhsImagTimesRhsReal,
+                              fmfValue);
+    } else {
+      Value lhsRealTimesRhsReal =
+          arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
+      Value lhsImagTimesRhsImag =
+          arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
+      real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
+                                   lhsImagTimesRhsImag, fmfValue);
+      Value lhsImagTimesRhsReal =
+          arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
+      Value lhsRealTimesRhsImag =
+          arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
+      imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
+                                   lhsRealTimesRhsImag, fmfValue);
+    }
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
     return success();
   }
diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
index 4d2c12a56eaca..ccbe075d8afc5 100644
--- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
@@ -245,13 +245,12 @@ func.func @complex_div_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
 // CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]]
 // CHECK: %[[RESULT_0:.*]] = llvm.mlir.poison : ![[C_TY]]
 
-// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[RHS_RE]], %[[LHS_RE]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[RHS_IM]], %[[LHS_IM]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK: %[[REAL:.*]] = llvm.fsub %[[REAL_TMP_0]], %[[REAL_TMP_1]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[REAL_TMP:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[NEG_REAL_TMP:.*]] = llvm.fneg %[[REAL_TMP]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[REAL:.*]] = llvm.intr.fma(%[[LHS_RE]], %[[RHS_RE]], %[[NEG_REAL_TMP]]) {fastmathFlags = #llvm.fastmath<contract, afn>} : (f32, f32, f32) -> f32
 
-// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
-// CHECK: %[[IMAG:.*]] = llvm.fadd %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[IMAG_TMP:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath<contract, afn>} : f32
+// CHECK: %[[IMAG:.*]] = llvm.intr.fma(%[[LHS_RE]], %[[RHS_IM]], %[[IMAG_TMP]]) {fastmathFlags = #llvm.fastmath<contract, afn>} : (f32, f32, f32) -> f32
 
 // CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0]
 // CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1]
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 7a82236b0656e..5f8838f8433b7 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -942,12 +942,11 @@ func.func @complex_mul_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> compl
 // CHECK: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex<f32>
 // CHECK: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex<f32>
 // CHECK: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex<f32>
-// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG:.*]] = arith.addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_TMP:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_REAL_TMP:.*]] = arith.negf %[[REAL_TMP]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL:.*]] = math.fma %[[LHS_REAL]], %[[RHS_REAL]], %[[NEG_REAL_TMP]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_TMP:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG:.*]] = math.fma %[[LHS_REAL]], %[[RHS_IMAG]], %[[IMAG_TMP]] fastmath<nnan,contract> : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[REAL]], %[[IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
@@ -964,23 +963,21 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
 // CHECK: %[[VAR2:.*]] = complex.im %arg1 : complex<f32>
 // CHECK: %[[VAR4:.*]] = complex.re %arg1 : complex<f32>
 // CHECK: %[[VAR6:.*]] = complex.im %arg1 : complex<f32>
-// CHECK: %[[VAR8:.*]] = arith.mulf %[[VAR0]], %[[VAR4]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR10:.*]] = arith.mulf %[[VAR2]], %[[VAR6]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR12:.*]] = arith.subf %[[VAR8]], %[[VAR10]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR10:.*]] = arith.negf %[[VAR10]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR12:.*]] = math.fma %[[VAR0]], %[[VAR4]], %[[NEG_VAR10]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR13:.*]] = arith.mulf %[[VAR2]], %[[VAR4]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR15:.*]] = arith.mulf %[[VAR0]], %[[VAR6]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR17:.*]] = arith.addf %[[VAR13]], %[[VAR15]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR17:.*]] = math.fma %[[VAR0]], %[[VAR6]], %[[VAR13]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR89:.*]] = complex.create %[[VAR12]], %[[VAR17]] : complex<f32>
 // CHECK: %[[VAR90:.*]] = complex.re %arg0 : complex<f32>
 // CHECK: %[[VAR92:.*]] = complex.im %arg0 : complex<f32>
 // CHECK: %[[VAR94:.*]] = complex.re %arg0 : complex<f32>
 // CHECK: %[[VAR96:.*]] = complex.im %arg0 : complex<f32>
-// CHECK: %[[VAR98:.*]] = arith.mulf %[[VAR90]], %[[VAR94]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR100:.*]] = arith.mulf %[[VAR92]], %[[VAR96]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR102:.*]] = arith.subf %[[VAR98]], %[[VAR100]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR100:.*]] = arith.negf %[[VAR100]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR102:.*]] = math.fma %[[VAR90]], %[[VAR94]], %[[NEG_VAR100]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR103:.*]] = arith.mulf %[[VAR92]], %[[VAR94]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR105:.*]] = arith.mulf %[[VAR90]], %[[VAR96]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR107:.*]] = arith.addf %[[VAR103]], %[[VAR105]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR107:.*]] = math.fma %[[VAR90]], %[[VAR96]], %[[VAR103]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR179:.*]] = complex.create %[[VAR102]], %[[VAR107]] : complex<f32>
 // CHECK: %[[VAR180:.*]] = complex.re %[[VAR89]] : complex<f32>
 // CHECK: %[[VAR181:.*]] = complex.re %[[VAR179]] : complex<f32>
@@ -1043,12 +1040,11 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
 // CHECK: %[[VAR232:.*]] = complex.im %[[VAR229]] : complex<f32>
 // CHECK: %[[VAR234:.*]] = complex.re %arg0 : complex<f32>
 // CHECK: %[[VAR236:.*]] = complex.im %arg0 : complex<f32>
-// CHECK: %[[VAR238:.*]] = arith.mulf %[[VAR230]], %[[VAR234]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR240:.*]] = arith.mulf %[[VAR232]], %[[VAR236]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR242:.*]] = arith.subf %[[VAR238]], %[[VAR240]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR240:.*]] = arith.negf %[[VAR240]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR242:.*]] = math.fma %[[VAR230]], %[[VAR234]], %[[NEG_VAR240]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR243:.*]] = arith.mulf %[[VAR232]], %[[VAR234]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR245:.*]] = arith.mulf %[[VAR230]], %[[VAR236]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR247:.*]] = arith.addf %[[VAR243]], %[[VAR245]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR247:.*]] = math.fma %[[VAR230]], %[[VAR236]], %[[VAR243]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR319:.*]] = complex.create %[[VAR242]], %[[VAR247]] : complex<f32>
 // CHECK: %[[VAR320:.*]] = complex.re %arg1 : complex<f32>
 // CHECK: %[[VAR321:.*]] = complex.re %[[VAR319]] : complex<f32>
@@ -1174,12 +1170,11 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
 // CHECK: %[[VAR444:.*]] = complex.im %[[VAR441]] : complex<f32>
 // CHECK: %[[VAR446:.*]] = complex.re %[[VAR440]] : complex<f32>
 // CHECK: %[[VAR448:.*]] = complex.im %[[VAR440]] : complex<f32>
-// CHECK: %[[VAR450:.*]] = arith.mulf %[[VAR442]], %[[VAR446]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR452:.*]] = arith.mulf %[[VAR444]], %[[VAR448]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR454:.*]] = arith.subf %[[VAR450]], %[[VAR452]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_VAR452:.*]] = arith.negf %[[VAR452]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR454:.*]] = math.fma %[[VAR442]], %[[VAR446]], %[[NEG_VAR452]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR455:.*]] = arith.mulf %[[VAR444]], %[[VAR446]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR457:.*]] = arith.mulf %[[VAR442]], %[[VAR448]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR459:.*]] = arith.addf %[[VAR455]], %[[VAR457]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAR459:.*]] = math.fma %[[VAR442]], %[[VAR448]], %[[VAR455]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR531:.*]] = complex.create %[[VAR454]], %[[VAR459]] : complex<f32>
 // CHECK: return %[[VAR531]] : complex<f32>
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/196248


More information about the Mlir-commits mailing list