[Mlir-commits] [mlir] ac00cb0 - [mlir] Add conversion from complex to standard dialect for EqualOp.

Adrian Kuegel llvmlistbot at llvm.org
Thu May 20 05:26:21 PDT 2021


Author: Adrian Kuegel
Date: 2021-05-20T14:25:56+02:00
New Revision: ac00cb0d2ad58914dd1cf52087ed29cd9834601a

URL: https://github.com/llvm/llvm-project/commit/ac00cb0d2ad58914dd1cf52087ed29cd9834601a
DIFF: https://github.com/llvm/llvm-project/commit/ac00cb0d2ad58914dd1cf52087ed29cd9834601a.diff

LOG: [mlir] Add conversion from complex to standard dialect for EqualOp.

This adds the straightforward conversion for EqualOp
(two complex numbers are equal if both the real and the imaginary part are equal).

Differential Revision: https://reviews.llvm.org/D102840

Added: 
    

Modified: 
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
    mlir/test/Conversion/ComplexToStandard/full-conversion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 15fa25441db62..2aa97a681d527 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -42,11 +42,40 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
     return success();
   }
 };
+
+struct EqualOpConversion : public OpConversionPattern<complex::EqualOp> {
+  using OpConversionPattern<complex::EqualOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::EqualOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    complex::EqualOp::Adaptor transformed(operands);
+    auto loc = op.getLoc();
+    auto type =
+        transformed.lhs().getType().cast<ComplexType>().getElementType();
+
+    Value realLhs =
+        rewriter.create<complex::ReOp>(loc, type, transformed.lhs());
+    Value imagLhs =
+        rewriter.create<complex::ImOp>(loc, type, transformed.lhs());
+    Value realRhs =
+        rewriter.create<complex::ReOp>(loc, type, transformed.rhs());
+    Value imagRhs =
+        rewriter.create<complex::ImOp>(loc, type, transformed.rhs());
+    Value realEqual =
+        rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, realLhs, realRhs);
+    Value imagEqual =
+        rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, imagLhs, imagRhs);
+
+    rewriter.replaceOpWithNewOp<AndOp>(op, realEqual, imagEqual);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::populateComplexToStandardConversionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<AbsOpConversion>(patterns.getContext());
+  patterns.add<AbsOpConversion, EqualOpConversion>(patterns.getContext());
 }
 
 namespace {
@@ -65,7 +94,7 @@ void ConvertComplexToStandardPass::runOnFunction() {
   ConversionTarget target(getContext());
   target.addLegalDialect<StandardOpsDialect, math::MathDialect,
                          complex::ComplexDialect>();
-  target.addIllegalOp<complex::AbsOp>();
+  target.addIllegalOp<complex::AbsOp, complex::EqualOp>();
   if (failed(applyPartialConversion(function, target, std::move(patterns))))
     signalPassFailure();
 }

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 788d42557883b..1a43ea7df8612 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -14,3 +14,17 @@ func @complex_abs(%arg: complex<f32>) -> f32 {
 // CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
 // CHECK: return %[[NORM]] : f32
 
+// CHECK-LABEL: func @complex_eq
+// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
+func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
+  %eq = complex.eq %lhs, %rhs: complex<f32>
+  return %eq : i1
+}
+// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
+// CHECK-DAG: %[[REAL_EQUAL:.*]] = cmpf oeq, %[[REAL_LHS]], %[[REAL_RHS]] : f32
+// CHECK-DAG: %[[IMAG_EQUAL:.*]] = cmpf oeq, %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
+// CHECK: %[[EQUAL:.*]] = and %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
+// CHECK: return %[[EQUAL]] : i1

diff  --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
index 2fd46b4d02264..a453fcac2b2a1 100644
--- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
@@ -14,3 +14,17 @@ func @complex_abs(%arg: complex<f32>) -> f32 {
 // CHECK: %[[NORM:.*]] = "llvm.intr.sqrt"(%[[SQ_NORM]]) : (f32) -> f32
 // CHECK: llvm.return %[[NORM]] : f32
 
+// CHECK-LABEL: llvm.func @complex_eq
+// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*]], %[[RHS:.*]]: ![[C_TY:.*]])
+func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
+  %eq = complex.eq %lhs, %rhs: complex<f32>
+  return %eq : i1
+}
+// CHECK: %[[REAL_LHS:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]]
+// CHECK: %[[IMAG_LHS:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]]
+// CHECK: %[[REAL_RHS:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]]
+// CHECK: %[[IMAG_RHS:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]]
+// CHECK-DAG: %[[REAL_EQUAL:.*]] = llvm.fcmp "oeq" %[[REAL_LHS]], %[[REAL_RHS]]  : f32
+// CHECK-DAG: %[[IMAG_EQUAL:.*]] = llvm.fcmp "oeq" %[[IMAG_LHS]], %[[IMAG_RHS]]  : f32
+// CHECK: %[[EQUAL:.*]] = llvm.and %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
+// CHECK: llvm.return %[[EQUAL]] : i1


        


More information about the Mlir-commits mailing list