[flang-commits] [flang] 165879e - [fir] Add fir.convert op conversion from FIR to LLVM IR
Valentin Clement via flang-commits
flang-commits at lists.llvm.org
Tue Nov 9 05:55:36 PST 2021
Author: Valentin Clement
Date: 2021-11-09T14:55:24+01:00
New Revision: 165879ec31ed5cc6e4e1a2524c86fc80b81ebbda
URL: https://github.com/llvm/llvm-project/commit/165879ec31ed5cc6e4e1a2524c86fc80b81ebbda
DIFF: https://github.com/llvm/llvm-project/commit/165879ec31ed5cc6e4e1a2524c86fc80b81ebbda.diff
LOG: [fir] Add fir.convert op conversion from FIR to LLVM IR
Add conversion pattern for the `fir.convert` operation.
This patch is part of the upstreaming effort from fir-dev branch.
Reviewed By: rovka, awarzynski
Differential Revision: https://reviews.llvm.org/D113469
Co-authored-by: Jean Perier <jperier at nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>
Added:
Modified:
flang/lib/Optimizer/CodeGen/CodeGen.cpp
flang/lib/Optimizer/CodeGen/TypeConverter.h
flang/test/Fir/convert-to-llvm.fir
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index db95017f0ba3..c81f385f21df 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -96,6 +96,121 @@ struct CallOpConversion : public FIROpConversion<fir::CallOp> {
}
};
+static mlir::Type getComplexEleTy(mlir::Type complex) {
+ if (auto cc = complex.dyn_cast<mlir::ComplexType>())
+ return cc.getElementType();
+ return complex.cast<fir::ComplexType>().getElementType();
+}
+
+/// convert value of from-type to value of to-type
+struct ConvertOpConversion : public FIROpConversion<fir::ConvertOp> {
+ using FIROpConversion::FIROpConversion;
+
+ static bool isFloatingPointTy(mlir::Type ty) {
+ return ty.isa<mlir::FloatType>();
+ }
+
+ mlir::LogicalResult
+ matchAndRewrite(fir::ConvertOp convert, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ auto fromTy = convertType(convert.value().getType());
+ auto toTy = convertType(convert.res().getType());
+ mlir::Value op0 = adaptor.getOperands()[0];
+ if (fromTy == toTy) {
+ rewriter.replaceOp(convert, op0);
+ return success();
+ }
+ auto loc = convert.getLoc();
+ auto convertFpToFp = [&](mlir::Value val, unsigned fromBits,
+ unsigned toBits, mlir::Type toTy) -> mlir::Value {
+ if (fromBits == toBits) {
+ // TODO: Converting between two floating-point representations with the
+ // same bitwidth is not allowed for now.
+ mlir::emitError(loc,
+ "cannot implicitly convert between two floating-point "
+ "representations of the same bitwidth");
+ return {};
+ }
+ if (fromBits > toBits)
+ return rewriter.create<mlir::LLVM::FPTruncOp>(loc, toTy, val);
+ return rewriter.create<mlir::LLVM::FPExtOp>(loc, toTy, val);
+ };
+ // Complex to complex conversion.
+ if (fir::isa_complex(convert.value().getType()) &&
+ fir::isa_complex(convert.res().getType())) {
+ // Special case: handle the conversion of a complex such that both the
+ // real and imaginary parts are converted together.
+ auto zero = mlir::ArrayAttr::get(convert.getContext(),
+ rewriter.getI32IntegerAttr(0));
+ auto one = mlir::ArrayAttr::get(convert.getContext(),
+ rewriter.getI32IntegerAttr(1));
+ auto ty = convertType(getComplexEleTy(convert.value().getType()));
+ auto rp = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, ty, op0, zero);
+ auto ip = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, ty, op0, one);
+ auto nt = convertType(getComplexEleTy(convert.res().getType()));
+ auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(ty);
+ auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(nt);
+ auto rc = convertFpToFp(rp, fromBits, toBits, nt);
+ auto ic = convertFpToFp(ip, fromBits, toBits, nt);
+ auto un = rewriter.create<mlir::LLVM::UndefOp>(loc, toTy);
+ auto i1 =
+ rewriter.create<mlir::LLVM::InsertValueOp>(loc, toTy, un, rc, zero);
+ rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(convert, toTy, i1,
+ ic, one);
+ return mlir::success();
+ }
+ // Floating point to floating point conversion.
+ if (isFloatingPointTy(fromTy)) {
+ if (isFloatingPointTy(toTy)) {
+ auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy);
+ auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy);
+ auto v = convertFpToFp(op0, fromBits, toBits, toTy);
+ rewriter.replaceOp(convert, v);
+ return mlir::success();
+ }
+ if (toTy.isa<mlir::IntegerType>()) {
+ rewriter.replaceOpWithNewOp<mlir::LLVM::FPToSIOp>(convert, toTy, op0);
+ return mlir::success();
+ }
+ } else if (fromTy.isa<mlir::IntegerType>()) {
+ // Integer to integer conversion.
+ if (toTy.isa<mlir::IntegerType>()) {
+ auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy);
+ auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy);
+ assert(fromBits != toBits);
+ if (fromBits > toBits) {
+ rewriter.replaceOpWithNewOp<mlir::LLVM::TruncOp>(convert, toTy, op0);
+ return mlir::success();
+ }
+ rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>(convert, toTy, op0);
+ return mlir::success();
+ }
+ // Integer to floating point conversion.
+ if (isFloatingPointTy(toTy)) {
+ rewriter.replaceOpWithNewOp<mlir::LLVM::SIToFPOp>(convert, toTy, op0);
+ return mlir::success();
+ }
+ // Integer to pointer conversion.
+ if (toTy.isa<mlir::LLVM::LLVMPointerType>()) {
+ rewriter.replaceOpWithNewOp<mlir::LLVM::IntToPtrOp>(convert, toTy, op0);
+ return mlir::success();
+ }
+ } else if (fromTy.isa<mlir::LLVM::LLVMPointerType>()) {
+ // Pointer to integer conversion.
+ if (toTy.isa<mlir::IntegerType>()) {
+ rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(convert, toTy, op0);
+ return mlir::success();
+ }
+ // Pointer to pointer conversion.
+ if (toTy.isa<mlir::LLVM::LLVMPointerType>()) {
+ rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(convert, toTy, op0);
+ return mlir::success();
+ }
+ }
+ return emitError(loc) << "cannot convert " << fromTy << " to " << toTy;
+ }
+};
+
/// Lower `fir.has_value` operation to `llvm.return` operation.
struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> {
using FIROpConversion::FIROpConversion;
@@ -489,12 +604,6 @@ struct InsertOnRangeOpConversion
}
};
-static mlir::Type getComplexEleTy(mlir::Type complex) {
- if (auto cc = complex.dyn_cast<mlir::ComplexType>())
- return cc.getElementType();
- return complex.cast<fir::ComplexType>().getElementType();
-}
-
//
// Primitive operations on Complex types
//
@@ -679,13 +788,14 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
auto *context = getModule().getContext();
fir::LLVMTypeConverter typeConverter{getModule()};
mlir::OwningRewritePatternList pattern(context);
- pattern.insert<AddcOpConversion, AddrOfOpConversion, CallOpConversion,
- DivcOpConversion, ExtractValueOpConversion,
- HasValueOpConversion, GlobalOpConversion,
- InsertOnRangeOpConversion, InsertValueOpConversion,
- NegcOpConversion, MulcOpConversion, SelectOpConversion,
- SelectRankOpConversion, SubcOpConversion, UndefOpConversion,
- UnreachableOpConversion, ZeroOpConversion>(typeConverter);
+ pattern
+ .insert<AddcOpConversion, AddrOfOpConversion, CallOpConversion,
+ ConvertOpConversion, DivcOpConversion, ExtractValueOpConversion,
+ HasValueOpConversion, GlobalOpConversion,
+ InsertOnRangeOpConversion, InsertValueOpConversion,
+ NegcOpConversion, MulcOpConversion, SelectOpConversion,
+ SelectRankOpConversion, SubcOpConversion, UndefOpConversion,
+ UnreachableOpConversion, ZeroOpConversion>(typeConverter);
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
pattern);
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.h b/flang/lib/Optimizer/CodeGen/TypeConverter.h
index f4d252dfc7ee..3f8dda99e97c 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.h
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.h
@@ -148,24 +148,6 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
/*isPacked=*/false));
}
- // Use the target specifics to figure out how to map complex to LLVM IR. The
- // use of complex values in function signatures is handled before conversion
- // to LLVM IR dialect here.
- //
- // fir.complex<T> | std.complex<T> --> llvm<"{t,t}">
- template <typename C>
- mlir::Type convertComplexType(C cmplx) {
- LLVM_DEBUG(llvm::dbgs() << "type convert: " << cmplx << '\n');
- auto eleTy = cmplx.getElementType();
- return convertType(specifics->complexMemoryType(eleTy));
- }
-
- // convert a front-end kind value to either a std or LLVM IR dialect type
- // fir.real<n> --> llvm.anyfloat where anyfloat is a kind mapping
- mlir::Type convertRealType(fir::KindTy kind) {
- return fromRealTypeID(kindMapping.getRealTypeID(kind), kind);
- }
-
template <typename A>
mlir::Type convertPointerLike(A &ty) {
mlir::Type eleTy = ty.getEleTy();
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index cafac1de4e13..b33a2294b3e3 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -514,3 +514,121 @@ func @fir_complex_neg(%a: !fir.complex<16>) -> !fir.complex<16> {
// CHECK: %{{.*}} = llvm.insertvalue %[[NEGX]], %{{.*}}[0 : i32] : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[NEGY]], %{{.*}}[1 : i32] : !llvm.struct<(f128, f128)>
// CHECK: llvm.return %{{.*}} : !llvm.struct<(f128, f128)>
+
+// -----
+
+// Test `fir.convert` operation conversion from Float type.
+
+func @convert_from_float(%arg0 : f32) {
+ %0 = fir.convert %arg0 : (f32) -> f16
+ %1 = fir.convert %arg0 : (f32) -> f32
+ %2 = fir.convert %arg0 : (f32) -> f64
+ %3 = fir.convert %arg0 : (f32) -> f80
+ %4 = fir.convert %arg0 : (f32) -> f128
+ %5 = fir.convert %arg0 : (f32) -> i1
+ %6 = fir.convert %arg0 : (f32) -> i8
+ %7 = fir.convert %arg0 : (f32) -> i16
+ %8 = fir.convert %arg0 : (f32) -> i32
+ %9 = fir.convert %arg0 : (f32) -> i64
+ return
+}
+
+// CHECK-LABEL: convert_from_float(
+// CHECK-SAME: %[[ARG0:.*]]: f32
+// CHECK: %{{.*}} = llvm.fptrunc %[[ARG0]] : f32 to f16
+// CHECK-NOT: f32 to f32
+// CHECK: %{{.*}} = llvm.fpext %[[ARG0]] : f32 to f64
+// CHECK: %{{.*}} = llvm.fpext %[[ARG0]] : f32 to f80
+// CHECK: %{{.*}} = llvm.fpext %[[ARG0]] : f32 to f128
+// CHECK: %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i1
+// CHECK: %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i8
+// CHECK: %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i16
+// CHECK: %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i32
+// CHECK: %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i64
+
+// -----
+
+// Test `fir.convert` operation conversion from Integer type.
+
+func @convert_from_int(%arg0 : i32) {
+ %0 = fir.convert %arg0 : (i32) -> f16
+ %1 = fir.convert %arg0 : (i32) -> f32
+ %2 = fir.convert %arg0 : (i32) -> f64
+ %3 = fir.convert %arg0 : (i32) -> f80
+ %4 = fir.convert %arg0 : (i32) -> f128
+ %5 = fir.convert %arg0 : (i32) -> i1
+ %6 = fir.convert %arg0 : (i32) -> i8
+ %7 = fir.convert %arg0 : (i32) -> i16
+ %8 = fir.convert %arg0 : (i32) -> i32
+ %9 = fir.convert %arg0 : (i32) -> i64
+ %10 = fir.convert %arg0 : (i32) -> i64
+ %ptr = fir.convert %10 : (i64) -> !fir.ref<i64>
+ return
+}
+
+// CHECK-LABEL: convert_from_int(
+// CHECK-SAME: %[[ARG0:.*]]: i32
+// CHECK: %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f16
+// CHECK: %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f32
+// CHECK: %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f64
+// CHECK: %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f80
+// CHECK: %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f128
+// CHECK: %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i1
+// CHECK: %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i8
+// CHECK: %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i16
+// CHECK-NOT: %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i32
+// CHECK: %{{.*}} = llvm.sext %[[ARG0]] : i32 to i64
+// CHECK: %{{.*}} = llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<i64>
+
+// -----
+
+// Test `fir.convert` operation conversion from !fir.ref<> type.
+
+func @convert_from_ref(%arg0 : !fir.ref<i32>) {
+ %0 = fir.convert %arg0 : (!fir.ref<i32>) -> !fir.ref<i8>
+ %1 = fir.convert %arg0 : (!fir.ref<i32>) -> i32
+ return
+}
+
+// CHECK-LABEL: convert_from_ref(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<i32>
+// CHECK: %{{.*}} = llvm.bitcast %[[ARG0]] : !llvm.ptr<i32> to !llvm.ptr<i8>
+// CHECK: %{{.*}} = llvm.ptrtoint %[[ARG0]] : !llvm.ptr<i32> to i32
+
+// -----
+
+// Test `fir.convert` operation conversion between fir.complex types.
+
+func @convert_complex4(%arg0 : !fir.complex<4>) -> !fir.complex<8> {
+ %0 = fir.convert %arg0 : (!fir.complex<4>) -> !fir.complex<8>
+ return %0 : !fir.complex<8>
+}
+
+// CHECK-LABEL: func @convert_complex4(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(f32, f32)>) -> !llvm.struct<(f64, f64)>
+// CHECK: %[[X:.*]] = llvm.extractvalue %[[ARG0]][0 : i32] : !llvm.struct<(f32, f32)>
+// CHECK: %[[Y:.*]] = llvm.extractvalue %[[ARG0]][1 : i32] : !llvm.struct<(f32, f32)>
+// CHECK: %[[CONVERTX:.*]] = llvm.fpext %[[X]] : f32 to f64
+// CHECK: %[[CONVERTY:.*]] = llvm.fpext %[[Y]] : f32 to f64
+// CHECK: %[[STRUCT0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)>
+// CHECK: %[[STRUCT1:.*]] = llvm.insertvalue %[[CONVERTX]], %[[STRUCT0]][0 : i32] : !llvm.struct<(f64, f64)>
+// CHECK: %[[STRUCT2:.*]] = llvm.insertvalue %[[CONVERTY]], %[[STRUCT1]][1 : i32] : !llvm.struct<(f64, f64)>
+// CHECK: llvm.return %[[STRUCT2]] : !llvm.struct<(f64, f64)>
+
+// Test `fir.convert` operation conversion between fir.complex types.
+
+func @convert_complex16(%arg0 : !fir.complex<16>) -> !fir.complex<2> {
+ %0 = fir.convert %arg0 : (!fir.complex<16>) -> !fir.complex<2>
+ return %0 : !fir.complex<2>
+}
+
+// CHECK-LABEL: func @convert_complex16(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(f128, f128)>) -> !llvm.struct<(f16, f16)>
+// CHECK: %[[X:.*]] = llvm.extractvalue %[[ARG0]][0 : i32] : !llvm.struct<(f128, f128)>
+// CHECK: %[[Y:.*]] = llvm.extractvalue %[[ARG0]][1 : i32] : !llvm.struct<(f128, f128)>
+// CHECK: %[[CONVERTX:.*]] = llvm.fptrunc %[[X]] : f128 to f16
+// CHECK: %[[CONVERTY:.*]] = llvm.fptrunc %[[Y]] : f128 to f16
+// CHECK: %[[STRUCT0:.*]] = llvm.mlir.undef : !llvm.struct<(f16, f16)>
+// CHECK: %[[STRUCT1:.*]] = llvm.insertvalue %[[CONVERTX]], %[[STRUCT0]][0 : i32] : !llvm.struct<(f16, f16)>
+// CHECK: %[[STRUCT2:.*]] = llvm.insertvalue %[[CONVERTY]], %[[STRUCT1]][1 : i32] : !llvm.struct<(f16, f16)>
+// CHECK: llvm.return %[[STRUCT2]] : !llvm.struct<(f16, f16)>
More information about the flang-commits
mailing list