[Mlir-commits] [mlir] fb8b2b8 - [mlir] Add conversion from Complex to Standard dialect for NotEqualOp.
Adrian Kuegel
llvmlistbot at llvm.org
Fri May 21 01:47:09 PDT 2021
Author: Adrian Kuegel
Date: 2021-05-21T10:46:50+02:00
New Revision: fb8b2b86d3d1ba6e26a5d9296e2b235eb36d10b8
URL: https://github.com/llvm/llvm-project/commit/fb8b2b86d3d1ba6e26a5d9296e2b235eb36d10b8
DIFF: https://github.com/llvm/llvm-project/commit/fb8b2b86d3d1ba6e26a5d9296e2b235eb36d10b8.diff
LOG: [mlir] Add conversion from Complex to Standard dialect for NotEqualOp.
Differential Revision: https://reviews.llvm.org/D102902
Added:
Modified:
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 2aa97a681d527..5b1765407e690 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -9,6 +9,7 @@
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include <memory>
+#include <type_traits>
#include "../PassDetail.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
@@ -43,16 +44,22 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
}
};
-struct EqualOpConversion : public OpConversionPattern<complex::EqualOp> {
- using OpConversionPattern<complex::EqualOp>::OpConversionPattern;
+template <typename ComparisonOp, CmpFPredicate p>
+struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
+ using OpConversionPattern<ComparisonOp>::OpConversionPattern;
+ using ResultCombiner =
+ std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
+ AndOp, OrOp>;
LogicalResult
- matchAndRewrite(complex::EqualOp op, ArrayRef<Value> operands,
+ matchAndRewrite(ComparisonOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- complex::EqualOp::Adaptor transformed(operands);
+ typename ComparisonOp::Adaptor transformed(operands);
auto loc = op.getLoc();
- auto type =
- transformed.lhs().getType().cast<ComplexType>().getElementType();
+ auto type = transformed.lhs()
+ .getType()
+ .template cast<ComplexType>()
+ .getElementType();
Value realLhs =
rewriter.create<complex::ReOp>(loc, type, transformed.lhs());
@@ -62,12 +69,11 @@ struct EqualOpConversion : public OpConversionPattern<complex::EqualOp> {
rewriter.create<complex::ReOp>(loc, type, transformed.rhs());
Value imagRhs =
rewriter.create<complex::ImOp>(loc, type, transformed.rhs());
- Value realEqual =
- rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, realLhs, realRhs);
- Value imagEqual =
- rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, imagLhs, imagRhs);
+ Value realComparison = rewriter.create<CmpFOp>(loc, p, realLhs, realRhs);
+ Value imagComparison = rewriter.create<CmpFOp>(loc, p, imagLhs, imagRhs);
- rewriter.replaceOpWithNewOp<AndOp>(op, realEqual, imagEqual);
+ rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
+ imagComparison);
return success();
}
};
@@ -75,7 +81,10 @@ struct EqualOpConversion : public OpConversionPattern<complex::EqualOp> {
void mlir::populateComplexToStandardConversionPatterns(
RewritePatternSet &patterns) {
- patterns.add<AbsOpConversion, EqualOpConversion>(patterns.getContext());
+ patterns.add<AbsOpConversion,
+ ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
+ ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>>(
+ patterns.getContext());
}
namespace {
@@ -94,7 +103,7 @@ void ConvertComplexToStandardPass::runOnFunction() {
ConversionTarget target(getContext());
target.addLegalDialect<StandardOpsDialect, math::MathDialect,
complex::ComplexDialect>();
- target.addIllegalOp<complex::AbsOp, complex::EqualOp>();
+ target.addIllegalOp<complex::AbsOp, complex::EqualOp, complex::NotEqualOp>();
if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 1a43ea7df8612..d09f1a0cb708d 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -28,3 +28,18 @@ func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
// CHECK-DAG: %[[IMAG_EQUAL:.*]] = cmpf oeq, %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
// CHECK: %[[EQUAL:.*]] = and %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
// CHECK: return %[[EQUAL]] : i1
+
+// CHECK-LABEL: func @complex_neq
+// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
+func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
+ %neq = complex.neq %lhs, %rhs: complex<f32>
+ return %neq : i1
+}
+// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
+// CHECK-DAG: %[[REAL_NOT_EQUAL:.*]] = cmpf une, %[[REAL_LHS]], %[[REAL_RHS]] : f32
+// CHECK-DAG: %[[IMAG_NOT_EQUAL:.*]] = cmpf une, %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
+// CHECK: %[[NOT_EQUAL:.*]] = or %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
+// CHECK: return %[[NOT_EQUAL]] : i1
diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
index a453fcac2b2a1..6fa090674e6fc 100644
--- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
@@ -28,3 +28,18 @@ func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
// CHECK-DAG: %[[IMAG_EQUAL:.*]] = llvm.fcmp "oeq" %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
// CHECK: %[[EQUAL:.*]] = llvm.and %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
// CHECK: llvm.return %[[EQUAL]] : i1
+
+// CHECK-LABEL: llvm.func @complex_neq
+// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*]], %[[RHS:.*]]: ![[C_TY:.*]])
+func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
+ %neq = complex.neq %lhs, %rhs: complex<f32>
+ return %neq : i1
+}
+// CHECK: %[[REAL_LHS:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]]
+// CHECK: %[[IMAG_LHS:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]]
+// CHECK: %[[REAL_RHS:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]]
+// CHECK: %[[IMAG_RHS:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]]
+// CHECK-DAG: %[[REAL_NOT_EQUAL:.*]] = llvm.fcmp "une" %[[REAL_LHS]], %[[REAL_RHS]] : f32
+// CHECK-DAG: %[[IMAG_NOT_EQUAL:.*]] = llvm.fcmp "une" %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
+// CHECK: %[[NOT_EQUAL:.*]] = llvm.or %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
+// CHECK: llvm.return %[[NOT_EQUAL]] : i1
More information about the Mlir-commits
mailing list