[Mlir-commits] [mlir] be5b666 - [mlir][complex] Support fastmath in the binary op conversion. (#65702)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Sep 10 23:57:28 PDT 2023


Author: Kai Sasaki
Date: 2023-09-11T15:57:24+09:00
New Revision: be5b66670d869f5f9401feb71c70f2ca519ce9e0

URL: https://github.com/llvm/llvm-project/commit/be5b66670d869f5f9401feb71c70f2ca519ce9e0
DIFF: https://github.com/llvm/llvm-project/commit/be5b66670d869f5f9401feb71c70f2ca519ce9e0.diff

LOG: [mlir][complex] Support fastmath in the binary op conversion. (#65702)

Complex dialect arithmetic operations are now able to recognize the
given fastmath flags. This PR lets the conversion from complex to
standard keep the fastmath flag passed to arith dialect ops.

See:

https://discourse.llvm.org/t/rfc-fastmath-flags-support-in-complex-dialect/71981

Added: 
    

Modified: 
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 2bcec4ea10f92c5..28e490da330f3c3 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -137,15 +137,16 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
     auto type = cast<ComplexType>(adaptor.getLhs().getType());
     auto elementType = cast<FloatType>(type.getElementType());
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
     Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
     Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
-    Value resultReal =
-        b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
+    Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
+                                                  fmf.getValue());
     Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
     Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
-    Value resultImag =
-        b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
+    Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
+                                                  fmf.getValue());
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
     return success();

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index bc2ea0dd7a5847a..9b2eef82541952a 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -723,3 +723,37 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
 // CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath<nnan,contract> : f32
 // CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
 // CHECK: return %[[NORM]] : f32
+
+// -----
+
+// CHECK-LABEL: func @complex_add_with_fmf
+// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
+func.func @complex_add_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %add = complex.add %lhs, %rhs fastmath<nnan,contract> : complex<f32>
+  return %add : complex<f32>
+}
+// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_REAL:.*]] = arith.addf %[[REAL_LHS]], %[[REAL_RHS]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_IMAG:.*]] = arith.addf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_sub_with_fmf
+// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
+func.func @complex_sub_with_fmf(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %sub = complex.sub %lhs, %rhs fastmath<nnan,contract> : complex<f32>
+  return %sub : complex<f32>
+}
+// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_REAL:.*]] = arith.subf %[[REAL_LHS]], %[[REAL_RHS]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>


        


More information about the Mlir-commits mailing list