[flang-commits] [flang] f1dfc02 - [fir] Add fir.cmpc conversion
Diana Picus via flang-commits
flang-commits at lists.llvm.org
Tue Nov 16 04:29:05 PST 2021
Author: Diana Picus
Date: 2021-11-16T12:26:27Z
New Revision: f1dfc0275c543db14b1723ee2e33f0910dcc7f24
URL: https://github.com/llvm/llvm-project/commit/f1dfc0275c543db14b1723ee2e33f0910dcc7f24
DIFF: https://github.com/llvm/llvm-project/commit/f1dfc0275c543db14b1723ee2e33f0910dcc7f24.diff
LOG: [fir] Add fir.cmpc conversion
This patch adds the codegen for fir.cmpc. The real and imaginary parts
are extracted and compared separately. For the .EQ. predicate the
results are AND'd, for the .NE. predicate the results are OR'd, and for
other predicates we keep only the result on the real parts.
This patch is part of the upstreaming effort from fir-dev.
Differential Revision: https://reviews.llvm.org/D113976
Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>
Co-authored-by: Jean Perier <jperier at nvidia.com>
Added:
Modified:
flang/lib/Optimizer/CodeGen/CodeGen.cpp
flang/test/Fir/convert-to-llvm.fir
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 43019f80935e..87827dcf2795 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -487,6 +487,52 @@ static mlir::Type getComplexEleTy(mlir::Type complex) {
return complex.cast<fir::ComplexType>().getElementType();
}
+/// Compare complex values
+///
+/// Per 10.1, the only comparisons available are .EQ. (oeq) and .NE. (une).
+///
+/// For completeness, all other comparison are done on the real component only.
+struct CmpcOpConversion : public FIROpConversion<fir::CmpcOp> {
+ using FIROpConversion::FIROpConversion;
+
+ mlir::LogicalResult
+ matchAndRewrite(fir::CmpcOp cmp, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ mlir::ValueRange operands = adaptor.getOperands();
+ mlir::MLIRContext *ctxt = cmp.getContext();
+ mlir::Type eleTy = convertType(getComplexEleTy(cmp.lhs().getType()));
+ mlir::Type resTy = convertType(cmp.getType());
+ mlir::Location loc = cmp.getLoc();
+ auto pos0 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(0));
+ SmallVector<mlir::Value, 2> rp{rewriter.create<mlir::LLVM::ExtractValueOp>(
+ loc, eleTy, operands[0], pos0),
+ rewriter.create<mlir::LLVM::ExtractValueOp>(
+ loc, eleTy, operands[1], pos0)};
+ auto rcp =
+ rewriter.create<mlir::LLVM::FCmpOp>(loc, resTy, rp, cmp->getAttrs());
+ auto pos1 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(1));
+ SmallVector<mlir::Value, 2> ip{rewriter.create<mlir::LLVM::ExtractValueOp>(
+ loc, eleTy, operands[0], pos1),
+ rewriter.create<mlir::LLVM::ExtractValueOp>(
+ loc, eleTy, operands[1], pos1)};
+ auto icp =
+ rewriter.create<mlir::LLVM::FCmpOp>(loc, resTy, ip, cmp->getAttrs());
+ SmallVector<mlir::Value, 2> cp{rcp, icp};
+ switch (cmp.getPredicate()) {
+ case mlir::arith::CmpFPredicate::OEQ: // .EQ.
+ rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmp, resTy, cp);
+ break;
+ case mlir::arith::CmpFPredicate::UNE: // .NE.
+ rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmp, resTy, cp);
+ break;
+ default:
+ rewriter.replaceOp(cmp, rcp.getResult());
+ break;
+ }
+ return success();
+ }
+};
+
/// convert value of from-type to value of to-type
struct ConvertOpConversion : public FIROpConversion<fir::ConvertOp> {
using FIROpConversion::FIROpConversion;
@@ -1514,15 +1560,17 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
AllocaOpConversion, BoxAddrOpConversion, BoxDimsOpConversion,
BoxEleSizeOpConversion, BoxIsAllocOpConversion, BoxIsArrayOpConversion,
BoxIsPtrOpConversion, BoxRankOpConversion, CallOpConversion,
- ConvertOpConversion, DispatchOpConversion, DispatchTableOpConversion,
- DTEntryOpConversion, DivcOpConversion, EmboxCharOpConversion,
- ExtractValueOpConversion, HasValueOpConversion, GlobalLenOpConversion,
- GlobalOpConversion, InsertOnRangeOpConversion, InsertValueOpConversion,
- IsPresentOpConversion, LoadOpConversion, NegcOpConversion,
- MulcOpConversion, SelectCaseOpConversion, SelectOpConversion,
- SelectRankOpConversion, SelectTypeOpConversion, StoreOpConversion,
- SubcOpConversion, UnboxCharOpConversion, UndefOpConversion,
- UnreachableOpConversion, ZeroOpConversion>(typeConverter);
+ CmpcOpConversion, ConvertOpConversion, DispatchOpConversion,
+ DispatchTableOpConversion, DTEntryOpConversion, DivcOpConversion,
+ EmboxCharOpConversion, ExtractValueOpConversion, HasValueOpConversion,
+ GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion,
+ InsertValueOpConversion, IsPresentOpConversion, LoadOpConversion,
+ NegcOpConversion, MulcOpConversion, SelectCaseOpConversion,
+ SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
+ StoreOpConversion, SubcOpConversion, UnboxCharOpConversion,
+ UndefOpConversion, UnreachableOpConversion, ZeroOpConversion>(
+ typeConverter);
+
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
pattern);
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index a804825eb75f..7cfd73c01bf4 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -521,6 +521,57 @@ func @fir_complex_neg(%a: !fir.complex<16>) -> !fir.complex<16> {
// -----
+// Test FIR complex compare conversion
+
+func @compare_complex_eq(%a : !fir.complex<8>, %b : !fir.complex<8>) -> i1 {
+ %r = fir.cmpc "oeq", %a, %b : !fir.complex<8>
+ return %r : i1
+}
+
+// CHECK-LABEL: llvm.func @compare_complex_eq
+// CHECK-SAME: [[A:%.*]]: !llvm.struct<(f64, f64)>,
+// CHECK-SAME: [[B:%.*]]: !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RA:%.*]] = llvm.extractvalue [[A]][0 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[IA:%.*]] = llvm.extractvalue [[A]][1 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RB:%.*]] = llvm.extractvalue [[B]][0 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[IB:%.*]] = llvm.extractvalue [[B]][1 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RESR:%.*]] = llvm.fcmp "oeq" [[RA]], [[RB]] : f64
+// CHECK-DAG: [[RESI:%.*]] = llvm.fcmp "oeq" [[IA]], [[IB]] : f64
+// CHECK: [[RES:%.*]] = llvm.and [[RESR]], [[RESI]] : i1
+// CHECK: return [[RES]] : i1
+
+func @compare_complex_ne(%a : !fir.complex<8>, %b : !fir.complex<8>) -> i1 {
+ %r = fir.cmpc "une", %a, %b : !fir.complex<8>
+ return %r : i1
+}
+
+// CHECK-LABEL: llvm.func @compare_complex_ne
+// CHECK-SAME: [[A:%.*]]: !llvm.struct<(f64, f64)>,
+// CHECK-SAME: [[B:%.*]]: !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RA:%.*]] = llvm.extractvalue [[A]][0 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[IA:%.*]] = llvm.extractvalue [[A]][1 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RB:%.*]] = llvm.extractvalue [[B]][0 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[IB:%.*]] = llvm.extractvalue [[B]][1 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RESR:%.*]] = llvm.fcmp "une" [[RA]], [[RB]] : f64
+// CHECK-DAG: [[RESI:%.*]] = llvm.fcmp "une" [[IA]], [[IB]] : f64
+// CHECK: [[RES:%.*]] = llvm.or [[RESR]], [[RESI]] : i1
+// CHECK: return [[RES]] : i1
+
+func @compare_complex_other(%a : !fir.complex<8>, %b : !fir.complex<8>) -> i1 {
+ %r = fir.cmpc "ogt", %a, %b : !fir.complex<8>
+ return %r : i1
+}
+
+// CHECK-LABEL: llvm.func @compare_complex_other
+// CHECK-SAME: [[A:%.*]]: !llvm.struct<(f64, f64)>,
+// CHECK-SAME: [[B:%.*]]: !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RA:%.*]] = llvm.extractvalue [[A]][0 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RB:%.*]] = llvm.extractvalue [[B]][0 : i32] : !llvm.struct<(f64, f64)>
+// CHECK: [[RESR:%.*]] = llvm.fcmp "ogt" [[RA]], [[RB]] : f64
+// CHECK: return [[RESR]] : i1
+
+// -----
+
// Test `fir.convert` operation conversion from Float type.
func @convert_from_float(%arg0 : f32) {
More information about the flang-commits
mailing list