[flang-commits] [flang] a7bb8e2 - [Flang] Change fir.divc to perform library call rather than generate inline operations.
Sacha Ballantyne via flang-commits
flang-commits at lists.llvm.org
Tue Apr 4 09:09:27 PDT 2023
Author: Sacha Ballantyne
Date: 2023-04-04T16:09:21Z
New Revision: a7bb8e273f433cceeb547e87d04114178573496a
URL: https://github.com/llvm/llvm-project/commit/a7bb8e273f433cceeb547e87d04114178573496a
DIFF: https://github.com/llvm/llvm-project/commit/a7bb8e273f433cceeb547e87d04114178573496a.diff
LOG: [Flang] Change fir.divc to perform library call rather than generate inline operations.
Currently `fir.divc` is always lowered to a sequence of llvm operations to perform complex division, however this causes issues for extreme values when the calculations overflow. While this behaviour would be fine at -Ofast, this is currently the default at all levels.
This patch changes `fir.divc` to lower to a library call instead, except for when KIND=3 as there is no appropriate library call for this case.
Reviewed By: vzakhari
Differential Revision: https://reviews.llvm.org/D145808
Added:
Modified:
flang/lib/Optimizer/CodeGen/CodeGen.cpp
flang/test/Fir/convert-to-llvm.fir
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index ef08a6cb1171..a9efba470863 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -41,6 +41,7 @@
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <mlir/IR/ValueRange.h>
namespace fir {
#define GEN_PASS_DEF_FIRTOLLVMLOWERING
@@ -3512,42 +3513,87 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
}
};
-/// Inlined complex division
+static mlir::LogicalResult getDivc3(fir::DivcOp op,
+ mlir::ConversionPatternRewriter &rewriter,
+ std::string funcName, mlir::Type returnType,
+ llvm::SmallVector<mlir::Type> argType,
+ llvm::SmallVector<mlir::Value> args) {
+ auto module = op->getParentOfType<mlir::ModuleOp>();
+ auto loc = op.getLoc();
+ if (mlir::LLVM::LLVMFuncOp divideFunc =
+ module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(funcName)) {
+ auto call = rewriter.create<mlir::LLVM::CallOp>(
+ loc, returnType, mlir::SymbolRefAttr::get(divideFunc), args);
+ rewriter.replaceOp(op, call->getResults());
+ return mlir::success();
+ }
+ mlir::OpBuilder moduleBuilder(
+ op->getParentOfType<mlir::ModuleOp>().getBodyRegion());
+ auto divideFunc = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
+ rewriter.getUnknownLoc(), funcName,
+ mlir::LLVM::LLVMFunctionType::get(returnType, argType,
+ /*isVarArg=*/false));
+ auto call = rewriter.create<mlir::LLVM::CallOp>(
+ loc, returnType, mlir::SymbolRefAttr::get(divideFunc), args);
+ rewriter.replaceOp(op, call->getResults());
+ return mlir::success();
+}
+
+/// complex division
struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
using FIROpConversion::FIROpConversion;
mlir::LogicalResult
matchAndRewrite(fir::DivcOp divc, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
- // TODO: Can we use a call to __divdc3 instead?
- // Just generate inline code for now.
// given: (x + iy) / (x' + iy')
// result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
mlir::Value a = adaptor.getOperands()[0];
mlir::Value b = adaptor.getOperands()[1];
auto loc = divc.getLoc();
mlir::Type eleTy = convertType(getComplexEleTy(divc.getType()));
- mlir::Type ty = convertType(divc.getType());
+ llvm::SmallVector<mlir::Type> argTy = {eleTy, eleTy, eleTy, eleTy};
+ mlir::Type firReturnTy = divc.getType();
+ mlir::Type ty = convertType(firReturnTy);
auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 0);
auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1);
auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
- auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
- auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
- auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
- auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
- auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
- auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
- auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
- auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
- auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
- auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
- auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
- auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
- auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
- auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
- rewriter.replaceOp(divc, r0.getResult());
- return mlir::success();
+
+ fir::KindTy kind = (firReturnTy.dyn_cast<fir::ComplexType>()).getFKind();
+ mlir::SmallVector<mlir::Value> args = {x0, y0, x1, y1};
+ switch (kind) {
+ default:
+ llvm_unreachable("Unsupported complex type");
+ case 4:
+ return getDivc3(divc, rewriter, "__divsc3", ty, argTy, args);
+ case 8:
+ return getDivc3(divc, rewriter, "__divdc3", ty, argTy, args);
+ case 10:
+ return getDivc3(divc, rewriter, "__divxc3", ty, argTy, args);
+ case 16:
+ return getDivc3(divc, rewriter, "__divtc3", ty, argTy, args);
+ case 3:
+ case 2:
+ // No library function for bfloat or half in compiler_rt, generate
+ // inline instead
+ auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
+ auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
+ auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
+ auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
+ auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
+ auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
+ auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
+ auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
+ auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
+ auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
+ auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
+ auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
+ auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
+ auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
+ rewriter.replaceOp(divc, r0.getResult());
+ return mlir::success();
+ }
}
};
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 75f6a6c659d4..6eac945cc548 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -586,22 +586,42 @@ func.func @fir_complex_div(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] : f128
-// CHECK: %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]] : f128
-// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] : f128
-// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] : f128
-// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] : f128
-// CHECK: %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]] : f128
-// CHECK: %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]] : f128
-// CHECK: %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] : f128
-// CHECK: %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]] : f128
-// CHECK: %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]] : f128
-// CHECK: %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]] : f128
-// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
-// CHECK: %{{.*}} = llvm.insertvalue %[[DIV0]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
-// CHECK: %{{.*}} = llvm.insertvalue %[[DIV1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
+// CHECK: %[[CALL:.*]] = llvm.call @__divtc3(%[[X0]], %[[Y0]], %[[X1]], %[[Y1]]) : (f128, f128, f128, f128) -> !llvm.struct<(f128, f128)>
// CHECK: llvm.return %{{.*}} : !llvm.struct<(f128, f128)>
+// -----
+
+// Test FIR complex division inlines for KIND=3
+
+func.func @fir_complex_div(%a: !fir.complex<3>, %b: !fir.complex<3>) -> !fir.complex<3> {
+ %c = fir.divc %a, %b : !fir.complex<3>
+ return %c : !fir.complex<3>
+}
+
+// CHECK-LABEL: llvm.func @fir_complex_div(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(bf16, bf16)>,
+// CHECK-SAME: %[[ARG1:.*]]: !llvm.struct<(bf16, bf16)>) -> !llvm.struct<(bf16, bf16)> {
+// CHECK: %[[X0:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.struct<(bf16, bf16)>
+// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(bf16, bf16)>
+// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(bf16, bf16)>
+// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(bf16, bf16)>
+// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] : bf16
+// CHECK: %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]] : bf16
+// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] : bf16
+// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] : bf16
+// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] : bf16
+// CHECK: %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]] : bf16
+// CHECK: %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]] : bf16
+// CHECK: %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] : bf16
+// CHECK: %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]] : bf16
+// CHECK: %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]] : bf16
+// CHECK: %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]] : bf16
+// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(bf16, bf16)>
+// CHECK: %{{.*}} = llvm.insertvalue %[[DIV0]], %{{.*}}[0] : !llvm.struct<(bf16, bf16)>
+// CHECK: %{{.*}} = llvm.insertvalue %[[DIV1]], %{{.*}}[1] : !llvm.struct<(bf16, bf16)>
+// CHECK: llvm.return %{{.*}} : !llvm.struct<(bf16, bf16)>
+
+
// -----
// Test FIR complex negation conversion
More information about the flang-commits
mailing list