[Mlir-commits] [mlir] a48adc5 - [mlir][math] Promote (b)f16 to f32 when lowering to libm calls

Benjamin Kramer llvmlistbot at llvm.org
Mon May 9 03:03:31 PDT 2022


Author: Benjamin Kramer
Date: 2022-05-09T11:59:55+02:00
New Revision: a48adc565864e0ce10becf301de5455308bd7d6c

URL: https://github.com/llvm/llvm-project/commit/a48adc565864e0ce10becf301de5455308bd7d6c
DIFF: https://github.com/llvm/llvm-project/commit/a48adc565864e0ce10becf301de5455308bd7d6c.diff

LOG: [mlir][math] Promote (b)f16 to f32 when lowering to libm calls

libm doesn't have overloads for the small types, so promote them to a
bigger type and use the f32 function.

Differential Revision: https://reviews.llvm.org/D125093

Added: 
    

Modified: 
    mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
    mlir/test/Conversion/MathToLibm/convert-to-libm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 9bea594d87df6..6c9d02c273e5b 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -30,6 +30,14 @@ struct VecOpToScalarOp : public OpRewritePattern<Op> {
 
   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
 };
+// Pattern to promote an op of a smaller floating point type to F32.
+template <typename Op>
+struct PromoteOpToF32 : public OpRewritePattern<Op> {
+public:
+  using OpRewritePattern<Op>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
+};
 // Pattern to convert scalar math operations to calls to libm functions.
 // Additionally the libm function signatures are declared.
 template <typename Op>
@@ -82,13 +90,30 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
   return success();
 }
 
+template <typename Op>
+LogicalResult
+PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
+  auto opType = op.getType();
+  if (!opType.template isa<Float16Type, BFloat16Type>())
+    return failure();
+
+  auto loc = op.getLoc();
+  auto f32 = rewriter.getF32Type();
+  auto extendedOperands = llvm::to_vector(
+      llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
+        return rewriter.create<arith::ExtFOp>(loc, f32, operand);
+      }));
+  auto newOp = rewriter.create<Op>(loc, f32, extendedOperands);
+  rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, opType, newOp);
+  return success();
+}
+
 template <typename Op>
 LogicalResult
 ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
                                         PatternRewriter &rewriter) const {
   auto module = SymbolTable::getNearestSymbolTable(op);
   auto type = op.getType();
-  // TODO: Support Float16 by upcasting to Float32
   if (!type.template isa<Float32Type, Float64Type>())
     return failure();
 
@@ -117,6 +142,8 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
                                                 PatternBenefit benefit) {
   patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
                VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
+  patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
+               PromoteOpToF32<math::TanhOp>>(patterns.getContext(), benefit);
   patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
                                                   "atan2f", "atan2", benefit);
   patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",

diff  --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index 57af89badd635..7cdb56e783e74 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -25,13 +25,25 @@ func.func @tanh_caller(%float: f32, %double: f64) -> (f32, f64)  {
 // CHECK-LABEL: func @atan2_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32
 // CHECK-SAME: %[[DOUBLE:.*]]: f64
-func.func @atan2_caller(%float: f32, %double: f64) -> (f32, f64) {
-  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @atan2f(%[[FLOAT]], %[[FLOAT]]) : (f32, f32) -> f32
+// CHECK-SAME: %[[HALF:.*]]: f16
+// CHECK-SAME: %[[BFLOAT:.*]]: bf16
+func.func @atan2_caller(%float: f32, %double: f64, %half: f16, %bfloat: bf16) -> (f32, f64, f16, bf16) {
+  // CHECK: %[[FLOAT_RESULT:.*]] = call @atan2f(%[[FLOAT]], %[[FLOAT]]) : (f32, f32) -> f32
   %float_result = math.atan2 %float, %float : f32
-  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @atan2(%[[DOUBLE]], %[[DOUBLE]]) : (f64, f64) -> f64
+  // CHECK: %[[DOUBLE_RESULT:.*]] = call @atan2(%[[DOUBLE]], %[[DOUBLE]]) : (f64, f64) -> f64
   %double_result = math.atan2 %double, %double : f64
-  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
-  return %float_result, %double_result : f32, f64
+  // CHECK: %[[HALF_PROMOTED1:.*]] = arith.extf %[[HALF]] : f16 to f32
+  // CHECK: %[[HALF_PROMOTED2:.*]] = arith.extf %[[HALF]] : f16 to f32
+  // CHECK: %[[HALF_CALL:.*]] = call @atan2f(%[[HALF_PROMOTED1]], %[[HALF_PROMOTED2]]) : (f32, f32) -> f32
+  // CHECK: %[[HALF_RESULT:.*]] = arith.truncf %[[HALF_CALL]] : f32 to f16
+  %half_result = math.atan2 %half, %half : f16
+  // CHECK: %[[BFLOAT_PROMOTED1:.*]] = arith.extf %[[BFLOAT]] : bf16 to f32
+  // CHECK: %[[BFLOAT_PROMOTED2:.*]] = arith.extf %[[BFLOAT]] : bf16 to f32
+  // CHECK: %[[BFLOAT_CALL:.*]] = call @atan2f(%[[BFLOAT_PROMOTED1]], %[[BFLOAT_PROMOTED2]]) : (f32, f32) -> f32
+  // CHECK: %[[BFLOAT_RESULT:.*]] = arith.truncf %[[BFLOAT_CALL]] : f32 to bf16
+  %bfloat_result = math.atan2 %bfloat, %bfloat : bf16
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]], %[[HALF_RESULT]], %[[BFLOAT_RESULT]]
+  return %float_result, %double_result, %half_result, %bfloat_result : f32, f64, f16, bf16
 }
 
 // CHECK-LABEL: func @erf_caller


        


More information about the Mlir-commits mailing list