[flang-commits] [flang] f1dfc02 - [fir] Add fir.cmpc conversion

Diana Picus via flang-commits flang-commits at lists.llvm.org
Tue Nov 16 04:29:05 PST 2021


Author: Diana Picus
Date: 2021-11-16T12:26:27Z
New Revision: f1dfc0275c543db14b1723ee2e33f0910dcc7f24

URL: https://github.com/llvm/llvm-project/commit/f1dfc0275c543db14b1723ee2e33f0910dcc7f24
DIFF: https://github.com/llvm/llvm-project/commit/f1dfc0275c543db14b1723ee2e33f0910dcc7f24.diff

LOG: [fir] Add fir.cmpc conversion

This patch adds the codegen for fir.cmpc. The real and imaginary parts
are extracted and compared separately. For the .EQ. predicate the
results are AND'd, for the .NE. predicate the results are OR'd, and for
other predicates we keep only the result on the real parts.

This patch is part of the upstreaming effort from fir-dev.

Differential Revision: https://reviews.llvm.org/D113976

Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>
Co-authored-by: Jean Perier <jperier at nvidia.com>

Added: 
    

Modified: 
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    flang/test/Fir/convert-to-llvm.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 43019f80935e..87827dcf2795 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -487,6 +487,52 @@ static mlir::Type getComplexEleTy(mlir::Type complex) {
   return complex.cast<fir::ComplexType>().getElementType();
 }
 
+/// Compare complex values
+///
+/// Per 10.1, the only comparisons available are .EQ. (oeq) and .NE. (une).
+///
+/// For completeness, all other comparison are done on the real component only.
+struct CmpcOpConversion : public FIROpConversion<fir::CmpcOp> {
+  using FIROpConversion::FIROpConversion;
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::CmpcOp cmp, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    mlir::ValueRange operands = adaptor.getOperands();
+    mlir::MLIRContext *ctxt = cmp.getContext();
+    mlir::Type eleTy = convertType(getComplexEleTy(cmp.lhs().getType()));
+    mlir::Type resTy = convertType(cmp.getType());
+    mlir::Location loc = cmp.getLoc();
+    auto pos0 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(0));
+    SmallVector<mlir::Value, 2> rp{rewriter.create<mlir::LLVM::ExtractValueOp>(
+                                       loc, eleTy, operands[0], pos0),
+                                   rewriter.create<mlir::LLVM::ExtractValueOp>(
+                                       loc, eleTy, operands[1], pos0)};
+    auto rcp =
+        rewriter.create<mlir::LLVM::FCmpOp>(loc, resTy, rp, cmp->getAttrs());
+    auto pos1 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(1));
+    SmallVector<mlir::Value, 2> ip{rewriter.create<mlir::LLVM::ExtractValueOp>(
+                                       loc, eleTy, operands[0], pos1),
+                                   rewriter.create<mlir::LLVM::ExtractValueOp>(
+                                       loc, eleTy, operands[1], pos1)};
+    auto icp =
+        rewriter.create<mlir::LLVM::FCmpOp>(loc, resTy, ip, cmp->getAttrs());
+    SmallVector<mlir::Value, 2> cp{rcp, icp};
+    switch (cmp.getPredicate()) {
+    case mlir::arith::CmpFPredicate::OEQ: // .EQ.
+      rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmp, resTy, cp);
+      break;
+    case mlir::arith::CmpFPredicate::UNE: // .NE.
+      rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmp, resTy, cp);
+      break;
+    default:
+      rewriter.replaceOp(cmp, rcp.getResult());
+      break;
+    }
+    return success();
+  }
+};
+
 /// convert value of from-type to value of to-type
 struct ConvertOpConversion : public FIROpConversion<fir::ConvertOp> {
   using FIROpConversion::FIROpConversion;
@@ -1514,15 +1560,17 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
         AllocaOpConversion, BoxAddrOpConversion, BoxDimsOpConversion,
         BoxEleSizeOpConversion, BoxIsAllocOpConversion, BoxIsArrayOpConversion,
         BoxIsPtrOpConversion, BoxRankOpConversion, CallOpConversion,
-        ConvertOpConversion, DispatchOpConversion, DispatchTableOpConversion,
-        DTEntryOpConversion, DivcOpConversion, EmboxCharOpConversion,
-        ExtractValueOpConversion, HasValueOpConversion, GlobalLenOpConversion,
-        GlobalOpConversion, InsertOnRangeOpConversion, InsertValueOpConversion,
-        IsPresentOpConversion, LoadOpConversion, NegcOpConversion,
-        MulcOpConversion, SelectCaseOpConversion, SelectOpConversion,
-        SelectRankOpConversion, SelectTypeOpConversion, StoreOpConversion,
-        SubcOpConversion, UnboxCharOpConversion, UndefOpConversion,
-        UnreachableOpConversion, ZeroOpConversion>(typeConverter);
+        CmpcOpConversion, ConvertOpConversion, DispatchOpConversion,
+        DispatchTableOpConversion, DTEntryOpConversion, DivcOpConversion,
+        EmboxCharOpConversion, ExtractValueOpConversion, HasValueOpConversion,
+        GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion,
+        InsertValueOpConversion, IsPresentOpConversion, LoadOpConversion,
+        NegcOpConversion, MulcOpConversion, SelectCaseOpConversion,
+        SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
+        StoreOpConversion, SubcOpConversion, UnboxCharOpConversion,
+        UndefOpConversion, UnreachableOpConversion, ZeroOpConversion>(
+        typeConverter);
+
     mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
     mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
                                                             pattern);

diff  --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index a804825eb75f..7cfd73c01bf4 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -521,6 +521,57 @@ func @fir_complex_neg(%a: !fir.complex<16>) -> !fir.complex<16> {
 
 // -----
 
+// Test FIR complex compare conversion
+
+func @compare_complex_eq(%a : !fir.complex<8>, %b : !fir.complex<8>) -> i1 {
+  %r = fir.cmpc "oeq", %a, %b : !fir.complex<8>
+  return %r : i1
+}
+
+// CHECK-LABEL: llvm.func @compare_complex_eq
+// CHECK-SAME: [[A:%.*]]: !llvm.struct<(f64, f64)>,
+// CHECK-SAME: [[B:%.*]]: !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RA:%.*]] = llvm.extractvalue [[A]][0 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[IA:%.*]] = llvm.extractvalue [[A]][1 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RB:%.*]] = llvm.extractvalue [[B]][0 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[IB:%.*]] = llvm.extractvalue [[B]][1 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RESR:%.*]] = llvm.fcmp "oeq" [[RA]], [[RB]] : f64
+// CHECK-DAG: [[RESI:%.*]] = llvm.fcmp "oeq" [[IA]], [[IB]] : f64
+// CHECK: [[RES:%.*]] = llvm.and [[RESR]], [[RESI]] : i1
+// CHECK: return [[RES]] : i1
+
+func @compare_complex_ne(%a : !fir.complex<8>, %b : !fir.complex<8>) -> i1 {
+  %r = fir.cmpc "une", %a, %b : !fir.complex<8>
+  return %r : i1
+}
+
+// CHECK-LABEL: llvm.func @compare_complex_ne
+// CHECK-SAME: [[A:%.*]]: !llvm.struct<(f64, f64)>,
+// CHECK-SAME: [[B:%.*]]: !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RA:%.*]] = llvm.extractvalue [[A]][0 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[IA:%.*]] = llvm.extractvalue [[A]][1 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RB:%.*]] = llvm.extractvalue [[B]][0 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[IB:%.*]] = llvm.extractvalue [[B]][1 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RESR:%.*]] = llvm.fcmp "une" [[RA]], [[RB]] : f64
+// CHECK-DAG: [[RESI:%.*]] = llvm.fcmp "une" [[IA]], [[IB]] : f64
+// CHECK: [[RES:%.*]] = llvm.or [[RESR]], [[RESI]] : i1
+// CHECK: return [[RES]] : i1
+
+func @compare_complex_other(%a : !fir.complex<8>, %b : !fir.complex<8>) -> i1 {
+  %r = fir.cmpc "ogt", %a, %b : !fir.complex<8>
+  return %r : i1
+}
+
+// CHECK-LABEL: llvm.func @compare_complex_other
+// CHECK-SAME: [[A:%.*]]: !llvm.struct<(f64, f64)>,
+// CHECK-SAME: [[B:%.*]]: !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RA:%.*]] = llvm.extractvalue [[A]][0 : i32] : !llvm.struct<(f64, f64)>
+// CHECK-DAG: [[RB:%.*]] = llvm.extractvalue [[B]][0 : i32] : !llvm.struct<(f64, f64)>
+// CHECK: [[RESR:%.*]] = llvm.fcmp "ogt" [[RA]], [[RB]] : f64
+// CHECK: return [[RESR]] : i1
+
+// -----
+
 // Test `fir.convert` operation conversion from Float type.
 
 func @convert_from_float(%arg0 : f32) {


        


More information about the flang-commits mailing list