[Mlir-commits] [mlir] eee71ed - [mlir][complex] Support Fastmath flag for complex.mulf (#74554)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 8 16:29:30 PST 2024


Author: Kai Sasaki
Date: 2024-01-09T09:29:27+09:00
New Revision: eee71ed3f7d0abe40f7c54166421421362a8ac46

URL: https://github.com/llvm/llvm-project/commit/eee71ed3f7d0abe40f7c54166421421362a8ac46
DIFF: https://github.com/llvm/llvm-project/commit/eee71ed3f7d0abe40f7c54166421421362a8ac46.diff

LOG: [mlir][complex] Support Fastmath flag for complex.mulf (#74554)

Support fast math flag in the conversion of `complex.mulf` op to
standard dialect.

See:
https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981

Added: 
    

Modified: 
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index bf753c7062f366..4c9dad9e2c1731 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -569,29 +569,39 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     auto type = cast<ComplexType>(adaptor.getLhs().getType());
     auto elementType = cast<FloatType>(type.getElementType());
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
+    auto fmfValue = fmf.getValue();
 
     Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
-    Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal);
+    Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal, fmfValue);
     Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
-    Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag);
+    Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag, fmfValue);
     Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
-    Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal);
+    Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal, fmfValue);
     Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
-    Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag);
-
-    Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
-    Value lhsRealTimesRhsRealAbs = b.create<math::AbsFOp>(lhsRealTimesRhsReal);
-    Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
-    Value lhsImagTimesRhsImagAbs = b.create<math::AbsFOp>(lhsImagTimesRhsImag);
-    Value real =
-        b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
-
-    Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
-    Value lhsImagTimesRhsRealAbs = b.create<math::AbsFOp>(lhsImagTimesRhsReal);
-    Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
-    Value lhsRealTimesRhsImagAbs = b.create<math::AbsFOp>(lhsRealTimesRhsImag);
-    Value imag =
-        b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
+    Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag, fmfValue);
+
+    Value lhsRealTimesRhsReal =
+        b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
+    Value lhsRealTimesRhsRealAbs =
+        b.create<math::AbsFOp>(lhsRealTimesRhsReal, fmfValue);
+    Value lhsImagTimesRhsImag =
+        b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
+    Value lhsImagTimesRhsImagAbs =
+        b.create<math::AbsFOp>(lhsImagTimesRhsImag, fmfValue);
+    Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
+                                         lhsImagTimesRhsImag, fmfValue);
+
+    Value lhsImagTimesRhsReal =
+        b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
+    Value lhsImagTimesRhsRealAbs =
+        b.create<math::AbsFOp>(lhsImagTimesRhsReal, fmfValue);
+    Value lhsRealTimesRhsImag =
+        b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
+    Value lhsRealTimesRhsImagAbs =
+        b.create<math::AbsFOp>(lhsRealTimesRhsImag, fmfValue);
+    Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
+                                         lhsRealTimesRhsImag, fmfValue);
 
     // Handle cases where the "naive" calculation results in NaN values.
     Value realIsNan =
@@ -717,20 +727,20 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
     recalc = b.create<arith::AndIOp>(isNan, recalc);
 
     // Recalculate real part.
-    lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
-    lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
-    Value newReal =
-        b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
+    lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
+    lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
+    Value newReal = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
+                                            lhsImagTimesRhsImag, fmfValue);
     real = b.create<arith::SelectOp>(
-        recalc, b.create<arith::MulFOp>(inf, newReal), real);
+        recalc, b.create<arith::MulFOp>(inf, newReal, fmfValue), real);
 
     // Recalculate imag part.
-    lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
-    lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
-    Value newImag =
-        b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
+    lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
+    lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
+    Value newImag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
+                                            lhsRealTimesRhsImag, fmfValue);
     imag = b.create<arith::SelectOp>(
-        recalc, b.create<arith::MulFOp>(inf, newImag), imag);
+        recalc, b.create<arith::MulFOp>(inf, newImag, fmfValue), imag);
 
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
     return success();

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 3af28150fd5c3f..8fa29ea43854a4 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -845,3 +845,123 @@ func.func @complex_log1p_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] fastmath<nnan,contract> : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_mul_with_fmf
+// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
+func.func @complex_mul_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %mul = complex.mul %lhs, %rhs fastmath<nnan,contract> : complex<f32>
+  return %mul : complex<f32>
+}
+// CHECK: %[[LHS_REAL:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[LHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[LHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[RHS_REAL_ABS:.*]] = math.absf %[[RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex<f32>
+// CHECK: %[[RHS_IMAG_ABS:.*]] = math.absf %[[RHS_IMAG]] fastmath<nnan,contract> : f32
+
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_ABS:.*]] = math.absf %[[LHS_REAL_TIMES_RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_ABS:.*]] = math.absf %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG]], %[[RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_ABS:.*]] = math.absf %[[LHS_IMAG_TIMES_RHS_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL]], %[[RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG_ABS:.*]] = math.absf %[[LHS_REAL_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG:.*]] = arith.addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+
+// Handle cases where the "naive" calculation results in NaN values.
+// CHECK: %[[REAL_IS_NAN:.*]] = arith.cmpf uno, %[[REAL]], %[[REAL]] : f32
+// CHECK: %[[IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[IMAG]], %[[IMAG]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.andi %[[REAL_IS_NAN]], %[[IMAG_IS_NAN]] : i1
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
+
+// Case 1. LHS_REAL or LHS_IMAG are infinite.
+// CHECK: %[[LHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IS_INF:.*]] = arith.ori %[[LHS_REAL_IS_INF]], %[[LHS_IMAG_IS_INF]] : i1
+// CHECK:  %[[RHS_REAL_IS_NAN:.*]] = arith.cmpf uno, %[[RHS_REAL]], %[[RHS_REAL]] : f32
+// CHECK: %[[RHS_IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[RHS_IMAG]], %[[RHS_IMAG]] : f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[LHS_REAL_IS_INF_FLOAT:.*]] = arith.select %[[LHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = math.copysign %[[LHS_REAL_IS_INF_FLOAT]], %[[LHS_REAL]] : f32
+// CHECK: %[[LHS_REAL1:.*]] = arith.select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_REAL]] : f32
+// CHECK: %[[LHS_IMAG_IS_INF_FLOAT:.*]] = arith.select %[[LHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = math.copysign %[[LHS_IMAG_IS_INF_FLOAT]], %[[LHS_IMAG]] : f32
+// CHECK: %[[LHS_IMAG1:.*]] = arith.select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_IMAG]] : f32
+// CHECK: %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN:.*]] = arith.andi %[[LHS_IS_INF]], %[[RHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_REAL]] : f32
+// CHECK: %[[RHS_REAL1:.*]] = arith.select %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL]] : f32
+// CHECK: %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN:.*]] = arith.andi %[[LHS_IS_INF]], %[[RHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_IMAG]] : f32
+// CHECK: %[[RHS_IMAG1:.*]] = arith.select %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG]] : f32
+
+// Case 2. RHS_REAL or RHS_IMAG are infinite.
+// CHECK: %[[RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[RHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[RHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[RHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[RHS_IS_INF:.*]] = arith.ori %[[RHS_REAL_IS_INF]], %[[RHS_IMAG_IS_INF]] : i1
+// CHECK: %[[LHS_REAL_IS_NAN:.*]] = arith.cmpf uno, %[[LHS_REAL1]], %[[LHS_REAL1]] : f32
+// CHECK: %[[LHS_IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[LHS_IMAG1]], %[[LHS_IMAG1]] : f32
+// CHECK: %[[RHS_REAL_IS_INF_FLOAT:.*]] = arith.select %[[RHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = math.copysign %[[RHS_REAL_IS_INF_FLOAT]], %[[RHS_REAL1]] : f32
+// CHECK: %[[RHS_REAL2:.*]] = arith.select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_REAL1]] : f32
+// CHECK: %[[RHS_IMAG_IS_INF_FLOAT:.*]] = arith.select %[[RHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = math.copysign %[[RHS_IMAG_IS_INF_FLOAT]], %[[RHS_IMAG1]] : f32
+// CHECK: %[[RHS_IMAG2:.*]] = arith.select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_IMAG1]] : f32
+// CHECK: %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN:.*]] = arith.andi %[[RHS_IS_INF]], %[[LHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_REAL1]] : f32
+// CHECK: %[[LHS_REAL2:.*]] = arith.select %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL1]] : f32
+// CHECK: %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN:.*]] = arith.andi %[[RHS_IS_INF]], %[[LHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_IMAG1]] : f32
+// CHECK: %[[LHS_IMAG2:.*]] = arith.select %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG1]] : f32
+// CHECK: %[[RECALC:.*]] = arith.ori %[[LHS_IS_INF]], %[[RHS_IS_INF]] : i1
+
+// Case 3. One of the pairwise products of left hand side with right hand side
+// is infinite.
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_REAL_TIMES_RHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_IMAG_TIMES_RHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[IS_SPECIAL_CASE:.*]] = arith.ori %[[LHS_REAL_TIMES_RHS_REAL_IS_INF]], %[[LHS_IMAG_TIMES_RHS_IMAG_IS_INF]] : i1
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_REAL_TIMES_RHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[IS_SPECIAL_CASE1:.*]] = arith.ori %[[IS_SPECIAL_CASE]], %[[LHS_REAL_TIMES_RHS_IMAG_IS_INF]] : i1
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_IS_INF:.*]] = arith.cmpf oeq, %[[LHS_IMAG_TIMES_RHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[IS_SPECIAL_CASE2:.*]] = arith.ori %[[IS_SPECIAL_CASE1]], %[[LHS_IMAG_TIMES_RHS_REAL_IS_INF]] : i1
+// CHECK: %[[TRUE:.*]] = arith.constant true
+// CHECK: %[[NOT_RECALC:.*]] = arith.xori %[[RECALC]], %[[TRUE]] : i1
+// CHECK: %[[IS_SPECIAL_CASE3:.*]] = arith.andi %[[IS_SPECIAL_CASE2]], %[[NOT_RECALC]] : i1
+// CHECK: %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[LHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_REAL2]] : f32
+// CHECK: %[[LHS_REAL3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL2]] : f32
+// CHECK: %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[LHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[LHS_IMAG2]] : f32
+// CHECK: %[[LHS_IMAG3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG2]] : f32
+// CHECK: %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[RHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_REAL2]] : f32
+// CHECK: %[[RHS_REAL3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL2]] : f32
+// CHECK: %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN:.*]] = arith.andi %[[IS_SPECIAL_CASE3]], %[[RHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = math.copysign %[[ZERO]], %[[RHS_IMAG2]] : f32
+// CHECK: %[[RHS_IMAG3:.*]] = arith.select %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG2]] : f32
+// CHECK: %[[RECALC2:.*]] = arith.ori %[[RECALC]], %[[IS_SPECIAL_CASE3]] : i1
+// CHECK: %[[RECALC3:.*]] = arith.andi %[[IS_NAN]], %[[RECALC2]] : i1
+
+ // Recalculate real part.
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_REAL3]], %[[RHS_REAL3]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_IMAG3]], %[[RHS_IMAG3]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEW_REAL:.*]] = arith.subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEW_REAL_TIMES_INF:.*]] = arith.mulf %[[INF]], %[[NEW_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[FINAL_REAL:.*]] = arith.select %[[RECALC3]], %[[NEW_REAL_TIMES_INF]], %[[REAL]] : f32
+
+// Recalculate imag part.
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = arith.mulf %[[LHS_IMAG3]], %[[RHS_REAL3]] fastmath<nnan,contract> : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = arith.mulf %[[LHS_REAL3]], %[[RHS_IMAG3]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEW_IMAG:.*]] = arith.addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEW_IMAG_TIMES_INF:.*]] = arith.mulf %[[INF]], %[[NEW_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[FINAL_IMAG:.*]] = arith.select %[[RECALC3]], %[[NEW_IMAG_TIMES_INF]], %[[IMAG]] : f32
+
+// CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
\ No newline at end of file


        


More information about the Mlir-commits mailing list