[Mlir-commits] [mlir] fb978f0 - [mlir][Complex]: Add lowerings for AddOp and SubOp from Complex dialect to

Adrian Kuegel llvmlistbot at llvm.org
Fri Jul 23 03:44:10 PDT 2021


Author: Adrian Kuegel
Date: 2021-07-23T12:43:45+02:00
New Revision: fb978f092c9c1eff56906c65123944140c89f9cd

URL: https://github.com/llvm/llvm-project/commit/fb978f092c9c1eff56906c65123944140c89f9cd
DIFF: https://github.com/llvm/llvm-project/commit/fb978f092c9c1eff56906c65123944140c89f9cd.diff

LOG: [mlir][Complex]: Add lowerings for AddOp and SubOp from Complex dialect to
Standard.

Differential Revision: https://reviews.llvm.org/D106429

Added: 
    

Modified: 
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 4d3d52213e55..f651eedd77f1 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -79,6 +79,35 @@ struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
   }
 };
 
+// Default conversion which applies the BinaryStandardOp separately on the real
+// and imaginary parts. Can for example be used for complex::AddOp and
+// complex::SubOp.
+template <typename BinaryComplexOp, typename BinaryStandardOp>
+struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
+  using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(BinaryComplexOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    typename BinaryComplexOp::Adaptor transformed(operands);
+    auto type = transformed.lhs().getType().template cast<ComplexType>();
+    auto elementType = type.getElementType().template cast<FloatType>();
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+    Value realLhs = b.create<complex::ReOp>(elementType, transformed.lhs());
+    Value realRhs = b.create<complex::ReOp>(elementType, transformed.rhs());
+    Value resultReal =
+        b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
+    Value imagLhs = b.create<complex::ImOp>(elementType, transformed.lhs());
+    Value imagRhs = b.create<complex::ImOp>(elementType, transformed.rhs());
+    Value resultImag =
+        b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
+    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
+                                                   resultImag);
+    return success();
+  }
+};
+
 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
   using OpConversionPattern<complex::DivOp>::OpConversionPattern;
 
@@ -554,6 +583,8 @@ void mlir::populateComplexToStandardConversionPatterns(
       AbsOpConversion,
       ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
       ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
+      BinaryComplexOpConversion<complex::AddOp, AddFOp>,
+      BinaryComplexOpConversion<complex::SubOp, SubFOp>,
       DivOpConversion,
       ExpOpConversion,
       LogOpConversion,
@@ -578,12 +609,8 @@ void ConvertComplexToStandardPass::runOnFunction() {
   populateComplexToStandardConversionPatterns(patterns);
 
   ConversionTarget target(getContext());
-  target.addLegalDialect<StandardOpsDialect, math::MathDialect,
-                         complex::ComplexDialect>();
-  target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
-                      complex::ExpOp, complex::LogOp, complex::Log1pOp,
-                      complex::MulOp, complex::NegOp, complex::NotEqualOp,
-                      complex::SignOp>();
+  target.addLegalDialect<StandardOpsDialect, math::MathDialect>();
+  target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
   if (failed(applyPartialConversion(function, target, std::move(patterns))))
     signalPassFailure();
 }

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 765d79c0bb8c..9d9593aed73c 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -14,6 +14,21 @@ func @complex_abs(%arg: complex<f32>) -> f32 {
 // CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
 // CHECK: return %[[NORM]] : f32
 
+// CHECK-LABEL: func @complex_add
+// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
+func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %add = complex.add %lhs, %rhs: complex<f32>
+  return %add : complex<f32>
+}
+// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_REAL:.*]] = addf %[[REAL_LHS]], %[[REAL_RHS]] : f32
+// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_IMAG:.*]] = addf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
 // CHECK-LABEL: func @complex_div
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -366,3 +381,18 @@ func @complex_sign(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
 // CHECK: %[[RESULT:.*]] = select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
+
+// CHECK-LABEL: func @complex_sub
+// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
+func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %sub = complex.sub %lhs, %rhs: complex<f32>
+  return %sub : complex<f32>
+}
+// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_REAL:.*]] = subf %[[REAL_LHS]], %[[REAL_RHS]] : f32
+// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
+// CHECK: %[[RESULT_IMAG:.*]] = subf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>


        


More information about the Mlir-commits mailing list