[flang-commits] [flang] a7bb8e2 - [Flang] Change fir.divc to perform library call rather than generate inline operations.

Sacha Ballantyne via flang-commits flang-commits at lists.llvm.org
Tue Apr 4 09:09:27 PDT 2023


Author: Sacha Ballantyne
Date: 2023-04-04T16:09:21Z
New Revision: a7bb8e273f433cceeb547e87d04114178573496a

URL: https://github.com/llvm/llvm-project/commit/a7bb8e273f433cceeb547e87d04114178573496a
DIFF: https://github.com/llvm/llvm-project/commit/a7bb8e273f433cceeb547e87d04114178573496a.diff

LOG: [Flang] Change fir.divc to perform library call rather than generate inline operations.

Currently `fir.divc` is always lowered to a sequence of llvm operations to perform complex division, however this causes issues for extreme values when the calculations overflow. While this behaviour would be fine at -Ofast, this is currently the default at all levels.

This patch changes `fir.divc` to lower to a library call instead, except for when KIND=3 as there is no appropriate library call for this case.

Reviewed By: vzakhari

Differential Revision: https://reviews.llvm.org/D145808

Added: 
    

Modified: 
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    flang/test/Fir/convert-to-llvm.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index ef08a6cb1171..a9efba470863 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -41,6 +41,7 @@
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <mlir/IR/ValueRange.h>
 
 namespace fir {
 #define GEN_PASS_DEF_FIRTOLLVMLOWERING
@@ -3512,42 +3513,87 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
   }
 };
 
-/// Inlined complex division
+static mlir::LogicalResult getDivc3(fir::DivcOp op,
+                                    mlir::ConversionPatternRewriter &rewriter,
+                                    std::string funcName, mlir::Type returnType,
+                                    llvm::SmallVector<mlir::Type> argType,
+                                    llvm::SmallVector<mlir::Value> args) {
+  auto module = op->getParentOfType<mlir::ModuleOp>();
+  auto loc = op.getLoc();
+  if (mlir::LLVM::LLVMFuncOp divideFunc =
+          module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(funcName)) {
+    auto call = rewriter.create<mlir::LLVM::CallOp>(
+        loc, returnType, mlir::SymbolRefAttr::get(divideFunc), args);
+    rewriter.replaceOp(op, call->getResults());
+    return mlir::success();
+  }
+  mlir::OpBuilder moduleBuilder(
+      op->getParentOfType<mlir::ModuleOp>().getBodyRegion());
+  auto divideFunc = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
+      rewriter.getUnknownLoc(), funcName,
+      mlir::LLVM::LLVMFunctionType::get(returnType, argType,
+                                        /*isVarArg=*/false));
+  auto call = rewriter.create<mlir::LLVM::CallOp>(
+      loc, returnType, mlir::SymbolRefAttr::get(divideFunc), args);
+  rewriter.replaceOp(op, call->getResults());
+  return mlir::success();
+}
+
+///  complex division
 struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
   using FIROpConversion::FIROpConversion;
 
   mlir::LogicalResult
   matchAndRewrite(fir::DivcOp divc, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
-    // TODO: Can we use a call to __divdc3 instead?
-    // Just generate inline code for now.
     // given: (x + iy) / (x' + iy')
     // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
     mlir::Value a = adaptor.getOperands()[0];
     mlir::Value b = adaptor.getOperands()[1];
     auto loc = divc.getLoc();
     mlir::Type eleTy = convertType(getComplexEleTy(divc.getType()));
-    mlir::Type ty = convertType(divc.getType());
+    llvm::SmallVector<mlir::Type> argTy = {eleTy, eleTy, eleTy, eleTy};
+    mlir::Type firReturnTy = divc.getType();
+    mlir::Type ty = convertType(firReturnTy);
     auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 0);
     auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1);
     auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
     auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
-    auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
-    auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
-    auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
-    auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
-    auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
-    auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
-    auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
-    auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
-    auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
-    auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
-    auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
-    auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
-    auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
-    auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
-    rewriter.replaceOp(divc, r0.getResult());
-    return mlir::success();
+
+    fir::KindTy kind = (firReturnTy.dyn_cast<fir::ComplexType>()).getFKind();
+    mlir::SmallVector<mlir::Value> args = {x0, y0, x1, y1};
+    switch (kind) {
+    default:
+      llvm_unreachable("Unsupported complex type");
+    case 4:
+      return getDivc3(divc, rewriter, "__divsc3", ty, argTy, args);
+    case 8:
+      return getDivc3(divc, rewriter, "__divdc3", ty, argTy, args);
+    case 10:
+      return getDivc3(divc, rewriter, "__divxc3", ty, argTy, args);
+    case 16:
+      return getDivc3(divc, rewriter, "__divtc3", ty, argTy, args);
+    case 3:
+    case 2:
+      // No library function for bfloat or half in compiler_rt, generate
+      // inline instead
+      auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
+      auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
+      auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
+      auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
+      auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
+      auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
+      auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
+      auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
+      auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
+      auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
+      auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
+      auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
+      auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
+      auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
+      rewriter.replaceOp(divc, r0.getResult());
+      return mlir::success();
+    }
   }
 };
 

diff  --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 75f6a6c659d4..6eac945cc548 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -586,22 +586,42 @@ func.func @fir_complex_div(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
 // CHECK:         %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK:         %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]]  : f128
-// CHECK:         %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]]  : f128
-// CHECK:         %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]]  : f128
-// CHECK:         %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]]  : f128
-// CHECK:         %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]]  : f128
-// CHECK:         %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]]  : f128
-// CHECK:         %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]]  : f128
-// CHECK:         %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]]  : f128
-// CHECK:         %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]]  : f128
-// CHECK:         %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]]  : f128
-// CHECK:         %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]]  : f128
-// CHECK:         %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
-// CHECK:         %{{.*}} = llvm.insertvalue %[[DIV0]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
-// CHECK:         %{{.*}} = llvm.insertvalue %[[DIV1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
+// CHECK:         %[[CALL:.*]] = llvm.call @__divtc3(%[[X0]], %[[Y0]], %[[X1]], %[[Y1]]) : (f128, f128, f128, f128) -> !llvm.struct<(f128, f128)>
 // CHECK:         llvm.return %{{.*}} : !llvm.struct<(f128, f128)>
 
+// -----
+
+// Test FIR complex division inlines for KIND=3
+
+func.func @fir_complex_div(%a: !fir.complex<3>, %b: !fir.complex<3>) -> !fir.complex<3> {
+  %c = fir.divc %a, %b : !fir.complex<3>
+  return %c : !fir.complex<3>
+}
+
+// CHECK-LABEL: llvm.func @fir_complex_div(
+// CHECK-SAME:                             %[[ARG0:.*]]: !llvm.struct<(bf16, bf16)>,
+// CHECK-SAME:                             %[[ARG1:.*]]: !llvm.struct<(bf16, bf16)>) -> !llvm.struct<(bf16, bf16)> {
+// CHECK:         %[[X0:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.struct<(bf16, bf16)>
+// CHECK:         %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(bf16, bf16)>
+// CHECK:         %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(bf16, bf16)>
+// CHECK:         %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(bf16, bf16)>
+// CHECK:         %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]]  : bf16
+// CHECK:         %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]]  : bf16
+// CHECK:         %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]]  : bf16
+// CHECK:         %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]]  : bf16
+// CHECK:         %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]]  : bf16
+// CHECK:         %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]]  : bf16
+// CHECK:         %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]]  : bf16
+// CHECK:         %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]]  : bf16
+// CHECK:         %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]]  : bf16
+// CHECK:         %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]]  : bf16
+// CHECK:         %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]]  : bf16
+// CHECK:         %{{.*}} = llvm.mlir.undef : !llvm.struct<(bf16, bf16)>
+// CHECK:         %{{.*}} = llvm.insertvalue %[[DIV0]], %{{.*}}[0] : !llvm.struct<(bf16, bf16)>
+// CHECK:         %{{.*}} = llvm.insertvalue %[[DIV1]], %{{.*}}[1] : !llvm.struct<(bf16, bf16)>
+// CHECK:         llvm.return %{{.*}} : !llvm.struct<(bf16, bf16)>
+
+
 // -----
 
 // Test FIR complex negation conversion


        


More information about the flang-commits mailing list