[Mlir-commits] [mlir] fb8b2b8 - [mlir] Add conversion from Complex to Standard dialect for NotEqualOp.

Adrian Kuegel llvmlistbot at llvm.org
Fri May 21 01:47:09 PDT 2021


Author: Adrian Kuegel
Date: 2021-05-21T10:46:50+02:00
New Revision: fb8b2b86d3d1ba6e26a5d9296e2b235eb36d10b8

URL: https://github.com/llvm/llvm-project/commit/fb8b2b86d3d1ba6e26a5d9296e2b235eb36d10b8
DIFF: https://github.com/llvm/llvm-project/commit/fb8b2b86d3d1ba6e26a5d9296e2b235eb36d10b8.diff

LOG: [mlir] Add conversion from Complex to Standard dialect for NotEqualOp.

Differential Revision: https://reviews.llvm.org/D102902

Added: 
    

Modified: 
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
    mlir/test/Conversion/ComplexToStandard/full-conversion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 2aa97a681d527..5b1765407e690 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
 
 #include <memory>
+#include <type_traits>
 
 #include "../PassDetail.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
@@ -43,16 +44,22 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
   }
 };
 
-struct EqualOpConversion : public OpConversionPattern<complex::EqualOp> {
-  using OpConversionPattern<complex::EqualOp>::OpConversionPattern;
+template <typename ComparisonOp, CmpFPredicate p>
+struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
+  using OpConversionPattern<ComparisonOp>::OpConversionPattern;
+  using ResultCombiner =
+      std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
+                         AndOp, OrOp>;
 
   LogicalResult
-  matchAndRewrite(complex::EqualOp op, ArrayRef<Value> operands,
+  matchAndRewrite(ComparisonOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    complex::EqualOp::Adaptor transformed(operands);
+    typename ComparisonOp::Adaptor transformed(operands);
     auto loc = op.getLoc();
-    auto type =
-        transformed.lhs().getType().cast<ComplexType>().getElementType();
+    auto type = transformed.lhs()
+                    .getType()
+                    .template cast<ComplexType>()
+                    .getElementType();
 
     Value realLhs =
         rewriter.create<complex::ReOp>(loc, type, transformed.lhs());
@@ -62,12 +69,11 @@ struct EqualOpConversion : public OpConversionPattern<complex::EqualOp> {
         rewriter.create<complex::ReOp>(loc, type, transformed.rhs());
     Value imagRhs =
         rewriter.create<complex::ImOp>(loc, type, transformed.rhs());
-    Value realEqual =
-        rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, realLhs, realRhs);
-    Value imagEqual =
-        rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, imagLhs, imagRhs);
+    Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs);
+    Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs);
 
-    rewriter.replaceOpWithNewOp<AndOp>(op, realEqual, imagEqual);
+    rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
+                                                imagComparison);
     return success();
   }
 };
@@ -75,7 +81,10 @@ struct EqualOpConversion : public OpConversionPattern<complex::EqualOp> {
 
 void mlir::populateComplexToStandardConversionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<AbsOpConversion, EqualOpConversion>(patterns.getContext());
+  patterns.add<AbsOpConversion,
+               ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
+               ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>>(
+      patterns.getContext());
 }
 
 namespace {
@@ -94,7 +103,7 @@ void ConvertComplexToStandardPass::runOnFunction() {
   ConversionTarget target(getContext());
   target.addLegalDialect<StandardOpsDialect, math::MathDialect,
                          complex::ComplexDialect>();
-  target.addIllegalOp<complex::AbsOp, complex::EqualOp>();
+  target.addIllegalOp<complex::AbsOp, complex::EqualOp, complex::NotEqualOp>();
   if (failed(applyPartialConversion(function, target, std::move(patterns))))
     signalPassFailure();
 }

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 1a43ea7df8612..d09f1a0cb708d 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -28,3 +28,18 @@ func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
 // CHECK-DAG: %[[IMAG_EQUAL:.*]] = cmpf oeq, %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
 // CHECK: %[[EQUAL:.*]] = and %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
 // CHECK: return %[[EQUAL]] : i1
+
+// CHECK-LABEL: func @complex_neq
+// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
+func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
+  %neq = complex.neq %lhs, %rhs: complex<f32>
+  return %neq : i1
+}
+// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
+// CHECK-DAG: %[[REAL_NOT_EQUAL:.*]] = cmpf une, %[[REAL_LHS]], %[[REAL_RHS]] : f32
+// CHECK-DAG: %[[IMAG_NOT_EQUAL:.*]] = cmpf une, %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
+// CHECK: %[[NOT_EQUAL:.*]] = or %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
+// CHECK: return %[[NOT_EQUAL]] : i1

diff  --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
index a453fcac2b2a1..6fa090674e6fc 100644
--- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
@@ -28,3 +28,18 @@ func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
 // CHECK-DAG: %[[IMAG_EQUAL:.*]] = llvm.fcmp "oeq" %[[IMAG_LHS]], %[[IMAG_RHS]]  : f32
 // CHECK: %[[EQUAL:.*]] = llvm.and %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
 // CHECK: llvm.return %[[EQUAL]] : i1
+
+// CHECK-LABEL: llvm.func @complex_neq
+// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*]], %[[RHS:.*]]: ![[C_TY:.*]])
+func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
+  %neq = complex.neq %lhs, %rhs: complex<f32>
+  return %neq : i1
+}
+// CHECK: %[[REAL_LHS:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]]
+// CHECK: %[[IMAG_LHS:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]]
+// CHECK: %[[REAL_RHS:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]]
+// CHECK: %[[IMAG_RHS:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]]
+// CHECK-DAG: %[[REAL_NOT_EQUAL:.*]] = llvm.fcmp "une" %[[REAL_LHS]], %[[REAL_RHS]]  : f32
+// CHECK-DAG: %[[IMAG_NOT_EQUAL:.*]] = llvm.fcmp "une" %[[IMAG_LHS]], %[[IMAG_RHS]]  : f32
+// CHECK: %[[NOT_EQUAL:.*]] = llvm.or %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
+// CHECK: llvm.return %[[NOT_EQUAL]] : i1


        


More information about the Mlir-commits mailing list