[flang-commits] [flang] 092cee5 - [fir] Add fir.convert op conversion from FIR to LLVM IR

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Tue Nov 9 06:39:07 PST 2021


Author: Valentin Clement
Date: 2021-11-09T15:38:50+01:00
New Revision: 092cee5f7fe991cc0cfef6a6d52ae150723ce7a0

URL: https://github.com/llvm/llvm-project/commit/092cee5f7fe991cc0cfef6a6d52ae150723ce7a0
DIFF: https://github.com/llvm/llvm-project/commit/092cee5f7fe991cc0cfef6a6d52ae150723ce7a0.diff

LOG: [fir] Add fir.convert op conversion from FIR to LLVM IR

Add conversion pattern for the `fir.convert` operation.

This patch is part of the upstreaming effort from fir-dev branch.

This patch was previously landed with a truncated version that
was failing the windows buildbot.

Reviewed By: rovka, awarzynski

Differential Revision: https://reviews.llvm.org/D113469

Co-authored-by: Jean Perier <jperier at nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>

Added: 
    

Modified: 
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    flang/test/Fir/convert-to-llvm.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index db95017f0ba3..c81f385f21df 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -96,6 +96,121 @@ struct CallOpConversion : public FIROpConversion<fir::CallOp> {
   }
 };
 
+static mlir::Type getComplexEleTy(mlir::Type complex) {
+  if (auto cc = complex.dyn_cast<mlir::ComplexType>())
+    return cc.getElementType();
+  return complex.cast<fir::ComplexType>().getElementType();
+}
+
+/// convert value of from-type to value of to-type
+struct ConvertOpConversion : public FIROpConversion<fir::ConvertOp> {
+  using FIROpConversion::FIROpConversion;
+
+  static bool isFloatingPointTy(mlir::Type ty) {
+    return ty.isa<mlir::FloatType>();
+  }
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::ConvertOp convert, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    auto fromTy = convertType(convert.value().getType());
+    auto toTy = convertType(convert.res().getType());
+    mlir::Value op0 = adaptor.getOperands()[0];
+    if (fromTy == toTy) {
+      rewriter.replaceOp(convert, op0);
+      return success();
+    }
+    auto loc = convert.getLoc();
+    auto convertFpToFp = [&](mlir::Value val, unsigned fromBits,
+                             unsigned toBits, mlir::Type toTy) -> mlir::Value {
+      if (fromBits == toBits) {
+        // TODO: Converting between two floating-point representations with the
+        // same bitwidth is not allowed for now.
+        mlir::emitError(loc,
+                        "cannot implicitly convert between two floating-point "
+                        "representations of the same bitwidth");
+        return {};
+      }
+      if (fromBits > toBits)
+        return rewriter.create<mlir::LLVM::FPTruncOp>(loc, toTy, val);
+      return rewriter.create<mlir::LLVM::FPExtOp>(loc, toTy, val);
+    };
+    // Complex to complex conversion.
+    if (fir::isa_complex(convert.value().getType()) &&
+        fir::isa_complex(convert.res().getType())) {
+      // Special case: handle the conversion of a complex such that both the
+      // real and imaginary parts are converted together.
+      auto zero = mlir::ArrayAttr::get(convert.getContext(),
+                                       rewriter.getI32IntegerAttr(0));
+      auto one = mlir::ArrayAttr::get(convert.getContext(),
+                                      rewriter.getI32IntegerAttr(1));
+      auto ty = convertType(getComplexEleTy(convert.value().getType()));
+      auto rp = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, ty, op0, zero);
+      auto ip = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, ty, op0, one);
+      auto nt = convertType(getComplexEleTy(convert.res().getType()));
+      auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
+      auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(nt);
+      auto rc = convertFpToFp(rp, fromBits, toBits, nt);
+      auto ic = convertFpToFp(ip, fromBits, toBits, nt);
+      auto un = rewriter.create<mlir::LLVM::UndefOp>(loc, toTy);
+      auto i1 =
+          rewriter.create<mlir::LLVM::InsertValueOp>(loc, toTy, un, rc, zero);
+      rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(convert, toTy, i1,
+                                                             ic, one);
+      return mlir::success();
+    }
+    // Floating point to floating point conversion.
+    if (isFloatingPointTy(fromTy)) {
+      if (isFloatingPointTy(toTy)) {
+        auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy);
+        auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy);
+        auto v = convertFpToFp(op0, fromBits, toBits, toTy);
+        rewriter.replaceOp(convert, v);
+        return mlir::success();
+      }
+      if (toTy.isa<mlir::IntegerType>()) {
+        rewriter.replaceOpWithNewOp<mlir::LLVM::FPToSIOp>(convert, toTy, op0);
+        return mlir::success();
+      }
+    } else if (fromTy.isa<mlir::IntegerType>()) {
+      // Integer to integer conversion.
+      if (toTy.isa<mlir::IntegerType>()) {
+        auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy);
+        auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy);
+        assert(fromBits != toBits);
+        if (fromBits > toBits) {
+          rewriter.replaceOpWithNewOp<mlir::LLVM::TruncOp>(convert, toTy, op0);
+          return mlir::success();
+        }
+        rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>(convert, toTy, op0);
+        return mlir::success();
+      }
+      // Integer to floating point conversion.
+      if (isFloatingPointTy(toTy)) {
+        rewriter.replaceOpWithNewOp<mlir::LLVM::SIToFPOp>(convert, toTy, op0);
+        return mlir::success();
+      }
+      // Integer to pointer conversion.
+      if (toTy.isa<mlir::LLVM::LLVMPointerType>()) {
+        rewriter.replaceOpWithNewOp<mlir::LLVM::IntToPtrOp>(convert, toTy, op0);
+        return mlir::success();
+      }
+    } else if (fromTy.isa<mlir::LLVM::LLVMPointerType>()) {
+      // Pointer to integer conversion.
+      if (toTy.isa<mlir::IntegerType>()) {
+        rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(convert, toTy, op0);
+        return mlir::success();
+      }
+      // Pointer to pointer conversion.
+      if (toTy.isa<mlir::LLVM::LLVMPointerType>()) {
+        rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(convert, toTy, op0);
+        return mlir::success();
+      }
+    }
+    return emitError(loc) << "cannot convert " << fromTy << " to " << toTy;
+  }
+};
+
 /// Lower `fir.has_value` operation to `llvm.return` operation.
 struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> {
   using FIROpConversion::FIROpConversion;
@@ -489,12 +604,6 @@ struct InsertOnRangeOpConversion
   }
 };
 
-static mlir::Type getComplexEleTy(mlir::Type complex) {
-  if (auto cc = complex.dyn_cast<mlir::ComplexType>())
-    return cc.getElementType();
-  return complex.cast<fir::ComplexType>().getElementType();
-}
-
 //
 // Primitive operations on Complex types
 //
@@ -679,13 +788,14 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
     auto *context = getModule().getContext();
     fir::LLVMTypeConverter typeConverter{getModule()};
     mlir::OwningRewritePatternList pattern(context);
-    pattern.insert<AddcOpConversion, AddrOfOpConversion, CallOpConversion,
-                   DivcOpConversion, ExtractValueOpConversion,
-                   HasValueOpConversion, GlobalOpConversion,
-                   InsertOnRangeOpConversion, InsertValueOpConversion,
-                   NegcOpConversion, MulcOpConversion, SelectOpConversion,
-                   SelectRankOpConversion, SubcOpConversion, UndefOpConversion,
-                   UnreachableOpConversion, ZeroOpConversion>(typeConverter);
+    pattern
+        .insert<AddcOpConversion, AddrOfOpConversion, CallOpConversion,
+                ConvertOpConversion, DivcOpConversion, ExtractValueOpConversion,
+                HasValueOpConversion, GlobalOpConversion,
+                InsertOnRangeOpConversion, InsertValueOpConversion,
+                NegcOpConversion, MulcOpConversion, SelectOpConversion,
+                SelectRankOpConversion, SubcOpConversion, UndefOpConversion,
+                UnreachableOpConversion, ZeroOpConversion>(typeConverter);
     mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
     mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
                                                             pattern);

diff  --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index cafac1de4e13..b33a2294b3e3 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -514,3 +514,121 @@ func @fir_complex_neg(%a: !fir.complex<16>) -> !fir.complex<16> {
 // CHECK:         %{{.*}} = llvm.insertvalue %[[NEGX]], %{{.*}}[0 : i32] : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[NEGY]], %{{.*}}[1 : i32] : !llvm.struct<(f128, f128)>
 // CHECK:         llvm.return %{{.*}} : !llvm.struct<(f128, f128)>
+
+// -----
+
+// Test `fir.convert` operation conversion from Float type.
+
+func @convert_from_float(%arg0 : f32) {
+  %0 = fir.convert %arg0 : (f32) -> f16
+  %1 = fir.convert %arg0 : (f32) -> f32
+  %2 = fir.convert %arg0 : (f32) -> f64
+  %3 = fir.convert %arg0 : (f32) -> f80
+  %4 = fir.convert %arg0 : (f32) -> f128
+  %5 = fir.convert %arg0 : (f32) -> i1
+  %6 = fir.convert %arg0 : (f32) -> i8
+  %7 = fir.convert %arg0 : (f32) -> i16
+  %8 = fir.convert %arg0 : (f32) -> i32
+  %9 = fir.convert %arg0 : (f32) -> i64
+  return
+}
+
+// CHECK-LABEL: convert_from_float(
+// CHECK-SAME:                     %[[ARG0:.*]]: f32
+// CHECK:         %{{.*}} = llvm.fptrunc %[[ARG0]] : f32 to f16
+// CHECK-NOT:     f32 to f32
+// CHECK:         %{{.*}} = llvm.fpext %[[ARG0]] : f32 to f64
+// CHECK:         %{{.*}} = llvm.fpext %[[ARG0]] : f32 to f80
+// CHECK:         %{{.*}} = llvm.fpext %[[ARG0]] : f32 to f128
+// CHECK:         %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i1
+// CHECK:         %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i8
+// CHECK:         %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i16
+// CHECK:         %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i32
+// CHECK:         %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i64
+
+// -----
+
+// Test `fir.convert` operation conversion from Integer type.
+
+func @convert_from_int(%arg0 : i32) {
+  %0 = fir.convert %arg0 : (i32) -> f16
+  %1 = fir.convert %arg0 : (i32) -> f32
+  %2 = fir.convert %arg0 : (i32) -> f64
+  %3 = fir.convert %arg0 : (i32) -> f80
+  %4 = fir.convert %arg0 : (i32) -> f128
+  %5 = fir.convert %arg0 : (i32) -> i1
+  %6 = fir.convert %arg0 : (i32) -> i8
+  %7 = fir.convert %arg0 : (i32) -> i16
+  %8 = fir.convert %arg0 : (i32) -> i32
+  %9 = fir.convert %arg0 : (i32) -> i64
+  %10 = fir.convert %arg0 : (i32) -> i64
+  %ptr = fir.convert %10 : (i64) -> !fir.ref<i64>
+  return
+}
+
+// CHECK-LABEL: convert_from_int(
+// CHECK-SAME:                   %[[ARG0:.*]]: i32
+// CHECK:         %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f16
+// CHECK:         %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f32
+// CHECK:         %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f64
+// CHECK:         %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f80
+// CHECK:         %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f128
+// CHECK:         %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i1
+// CHECK:         %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i8
+// CHECK:         %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i16
+// CHECK-NOT:     %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i32
+// CHECK:         %{{.*}} = llvm.sext %[[ARG0]] : i32 to i64
+// CHECK:         %{{.*}} = llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<i64>
+
+// -----
+
+// Test `fir.convert` operation conversion from !fir.ref<> type.
+
+func @convert_from_ref(%arg0 : !fir.ref<i32>) {
+  %0 = fir.convert %arg0 : (!fir.ref<i32>) -> !fir.ref<i8>
+  %1 = fir.convert %arg0 : (!fir.ref<i32>) -> i32
+  return
+}
+
+// CHECK-LABEL: convert_from_ref(
+// CHECK-SAME:                   %[[ARG0:.*]]: !llvm.ptr<i32>
+// CHECK:         %{{.*}} = llvm.bitcast %[[ARG0]] : !llvm.ptr<i32> to !llvm.ptr<i8>
+// CHECK:         %{{.*}} = llvm.ptrtoint %[[ARG0]] : !llvm.ptr<i32> to i32
+
+// -----
+
+// Test `fir.convert` operation conversion between fir.complex types.
+
+func @convert_complex4(%arg0 : !fir.complex<4>) -> !fir.complex<8> {
+  %0 = fir.convert %arg0 : (!fir.complex<4>) -> !fir.complex<8>
+  return %0 : !fir.complex<8>
+}
+
+// CHECK-LABEL: func @convert_complex4(
+// CHECK-SAME:                         %[[ARG0:.*]]: !llvm.struct<(f32, f32)>) -> !llvm.struct<(f64, f64)>
+// CHECK:         %[[X:.*]] = llvm.extractvalue %[[ARG0]][0 : i32] : !llvm.struct<(f32, f32)>
+// CHECK:         %[[Y:.*]] = llvm.extractvalue %[[ARG0]][1 : i32] : !llvm.struct<(f32, f32)>
+// CHECK:         %[[CONVERTX:.*]] = llvm.fpext %[[X]] : f32 to f64
+// CHECK:         %[[CONVERTY:.*]] = llvm.fpext %[[Y]] : f32 to f64
+// CHECK:         %[[STRUCT0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)>
+// CHECK:         %[[STRUCT1:.*]] = llvm.insertvalue %[[CONVERTX]], %[[STRUCT0]][0 : i32] : !llvm.struct<(f64, f64)>
+// CHECK:         %[[STRUCT2:.*]] = llvm.insertvalue %[[CONVERTY]], %[[STRUCT1]][1 : i32] : !llvm.struct<(f64, f64)>
+// CHECK:         llvm.return %[[STRUCT2]] : !llvm.struct<(f64, f64)>
+
+// Test `fir.convert` operation conversion between fir.complex types.
+
+func @convert_complex16(%arg0 : !fir.complex<16>) -> !fir.complex<2> {
+  %0 = fir.convert %arg0 : (!fir.complex<16>) -> !fir.complex<2>
+  return %0 : !fir.complex<2>
+}
+
+// CHECK-LABEL: func @convert_complex16(
+// CHECK-SAME:                          %[[ARG0:.*]]: !llvm.struct<(f128, f128)>) -> !llvm.struct<(f16, f16)>
+// CHECK:         %[[X:.*]] = llvm.extractvalue %[[ARG0]][0 : i32] : !llvm.struct<(f128, f128)>
+// CHECK:         %[[Y:.*]] = llvm.extractvalue %[[ARG0]][1 : i32] : !llvm.struct<(f128, f128)>
+// CHECK:         %[[CONVERTX:.*]] = llvm.fptrunc %[[X]] : f128 to f16
+// CHECK:         %[[CONVERTY:.*]] = llvm.fptrunc %[[Y]] : f128 to f16
+// CHECK:         %[[STRUCT0:.*]] = llvm.mlir.undef : !llvm.struct<(f16, f16)>
+// CHECK:         %[[STRUCT1:.*]] = llvm.insertvalue %[[CONVERTX]], %[[STRUCT0]][0 : i32] : !llvm.struct<(f16, f16)>
+// CHECK:         %[[STRUCT2:.*]] = llvm.insertvalue %[[CONVERTY]], %[[STRUCT1]][1 : i32] : !llvm.struct<(f16, f16)>
+// CHECK:         llvm.return %[[STRUCT2]] : !llvm.struct<(f16, f16)>


        


More information about the flang-commits mailing list