[Mlir-commits] [mlir] 34810e1 - [mlir] Add patterns to lower Math operations to LLVM based libm calls.

Tres Popp llvmlistbot at llvm.org
Tue Apr 20 02:39:25 PDT 2021


Author: Tres Popp
Date: 2021-04-20T11:38:55+02:00
New Revision: 34810e1b9c4554976d9d8249b18f48ff083b55fa

URL: https://github.com/llvm/llvm-project/commit/34810e1b9c4554976d9d8249b18f48ff083b55fa
DIFF: https://github.com/llvm/llvm-project/commit/34810e1b9c4554976d9d8249b18f48ff083b55fa.diff

LOG: [mlir] Add patterns to lower Math operations to LLVM based libm calls.

Some Math operations do not have an equivalent in LLVM. In these cases,
allow a low priority fallback of calling the libm functions. This is to
give functionality and is not a performant option.

Differential Revision: https://reviews.llvm.org/D100367

Added: 
    mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
    mlir/lib/Conversion/MathToLibm/CMakeLists.txt
    mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
    mlir/test/Conversion/MathToLLVM/convert-to-libm.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
new file mode 100644
index 0000000000000..9e7aa1a0f52ac
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
@@ -0,0 +1,26 @@
+//===- MathToLibm.h - Utils to convert from the complex dialect --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_
+#define MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+template <typename T>
+class OperationPass;
+
+/// Populate the given list with patterns that convert from Math to Libm calls.
+void populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit);
+
+/// Create a pass to convert Math operations to libm calls.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 21e604eabecd3..64de7c962beed 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -20,6 +20,7 @@
 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
 #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MathToLibm/MathToLibm.h"
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
 #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 6eb5abdefe552..eb940d3414049 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -228,6 +228,19 @@ def ConvertLinalgToSPIRV : Pass<"convert-linalg-to-spirv", "ModuleOp"> {
   let dependentDialects = ["spirv::SPIRVDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// MathToLibm
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToLibm : Pass<"convert-math-to-libm", "ModuleOp"> {
+  let summary = "Convert Math dialect to libm calls";
+  let description = [{
+    This pass converts supported Math ops to libm calls.
+  }];
+  let constructor = "mlir::createConvertMathToLibmPass()";
+  let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // OpenMPToLLVM
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 4f6d4a27ecca3..60dbab0a04432 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -9,6 +9,7 @@ add_subdirectory(GPUToVulkan)
 add_subdirectory(LinalgToLLVM)
 add_subdirectory(LinalgToSPIRV)
 add_subdirectory(LinalgToStandard)
+add_subdirectory(MathToLibm)
 add_subdirectory(OpenMPToLLVM)
 add_subdirectory(PDLToPDLInterp)
 add_subdirectory(SCFToGPU)

diff  --git a/mlir/lib/Conversion/MathToLibm/CMakeLists.txt b/mlir/lib/Conversion/MathToLibm/CMakeLists.txt
new file mode 100644
index 0000000000000..cd43a11d30d54
--- /dev/null
+++ b/mlir/lib/Conversion/MathToLibm/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRMathToLibm
+  MathToLibm.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLibm
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRMath
+  MLIRStandardOpsTransforms
+  )

diff  --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
new file mode 100644
index 0000000000000..8512432681c24
--- /dev/null
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -0,0 +1,147 @@
+//===-- MathToLibm.cpp - conversion from Math to libm calls ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToLibm/MathToLibm.h"
+
+#include "../PassDetail.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+namespace {
+// Pattern to convert vector operations to scalar operations. This is needed as
+// libm calls require scalars.
+template <typename Op>
+struct VecOpToScalarOp : public OpRewritePattern<Op> {
+public:
+  using OpRewritePattern<Op>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
+};
+// Pattern to convert scalar math operations to calls to libm functions.
+// Additionally the libm function signatures are declared.
+template <typename Op>
+struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
+public:
+  using OpRewritePattern<Op>::OpRewritePattern;
+  ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
+                         StringRef doubleFunc, PatternBenefit benefit)
+      : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+        doubleFunc(doubleFunc){};
+
+  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
+
+private:
+  std::string floatFunc, doubleFunc;
+};
+} // namespace
+
+template <typename Op>
+LogicalResult
+VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
+  auto opType = op.getType();
+  auto loc = op.getLoc();
+  auto vecType = opType.template dyn_cast<VectorType>();
+
+  if (!vecType)
+    return failure();
+  if (!vecType.hasRank())
+    return failure();
+  auto shape = vecType.getShape();
+  // TODO: support multidimensional vectors
+  if (shape.size() != 1)
+    return failure();
+
+  Value result = rewriter.create<ConstantOp>(
+      loc, DenseElementsAttr::get(
+               vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
+  for (auto i = 0; i < shape.front(); ++i) {
+    SmallVector<Value> operands;
+    for (auto input : op->getOperands())
+      operands.push_back(
+          rewriter.create<vector::ExtractElementOp>(loc, input, i));
+    Value scalarOp =
+        rewriter.create<Op>(loc, vecType.getElementType(), operands);
+    result = rewriter.create<vector::InsertElementOp>(loc, scalarOp, result, i);
+  }
+  rewriter.replaceOp(op, {result});
+  return success();
+}
+
+template <typename Op>
+LogicalResult
+ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
+                                        PatternRewriter &rewriter) const {
+  auto module = op->template getParentOfType<ModuleOp>();
+  auto type = op.getType();
+  // TODO: Support Float16 by upcasting to Float32
+  if (!type.template isa<Float32Type, Float64Type>())
+    return failure();
+
+  auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
+  auto opFunc = module.template lookupSymbol<FuncOp>(name);
+  // Forward declare function if it hasn't already been
+  if (!opFunc) {
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(module.getBody());
+    auto opFunctionTy = FunctionType::get(
+        rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
+    opFunc =
+        rewriter.create<FuncOp>(rewriter.getUnknownLoc(), name, opFunctionTy);
+    opFunc.setPrivate();
+  }
+  assert(opFunc.getType().template cast<FunctionType>().getResults() ==
+         op->getResultTypes());
+  assert(opFunc.getType().template cast<FunctionType>().getInputs() ==
+         op->getOperandTypes());
+
+  rewriter.replaceOpWithNewOp<CallOp>(op, opFunc, op->getOperands());
+
+  return success();
+}
+
+void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
+                                                PatternBenefit benefit) {
+  patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
+               VecOpToScalarOp<math::TanhOp>>(patterns.getContext(), benefit);
+  patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
+                                                  "atan2f", "atan2", benefit);
+  patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
+                                                  "expm1f", "expm1", benefit);
+  patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
+                                                 "tanh", benefit);
+}
+
+namespace {
+struct ConvertMathToLibmPass
+    : public ConvertMathToLibmBase<ConvertMathToLibmPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToLibmPass::runOnOperation() {
+  auto module = getOperation();
+
+  RewritePatternSet patterns(&getContext());
+  populateMathToLibmConversionPatterns(patterns, /*benefit=*/1);
+
+  ConversionTarget target(getContext());
+  target.addLegalDialect<BuiltinDialect, StandardOpsDialect,
+                         vector::VectorDialect>();
+  target.addIllegalDialect<math::MathDialect>();
+  if (failed(applyPartialConversion(module, target, std::move(patterns))))
+    signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToLibmPass() {
+  return std::make_unique<ConvertMathToLibmPass>();
+}

diff  --git a/mlir/test/Conversion/MathToLLVM/convert-to-libm.mlir b/mlir/test/Conversion/MathToLLVM/convert-to-libm.mlir
new file mode 100644
index 0000000000000..7c8d8e7136bb6
--- /dev/null
+++ b/mlir/test/Conversion/MathToLLVM/convert-to-libm.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -convert-math-to-libm -canonicalize | FileCheck %s
+
+// CHECK-DAG: @expm1(f64) -> f64
+// CHECK-DAG: @expm1f(f32) -> f32
+// CHECK-DAG: @atan2(f64, f64) -> f64
+// CHECK-DAG: @atan2f(f32, f32) -> f32
+// CHECK-DAG: @tanh(f64) -> f64
+// CHECK-DAG: @tanhf(f32) -> f32
+
+// CHECK-LABEL: func @tanh_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func @tanh_caller(%float: f32, %double: f64) -> (f32, f64)  {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @tanhf(%[[FLOAT]]) : (f32) -> f32
+  %float_result = math.tanh %float : f32
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @tanh(%[[DOUBLE]]) : (f64) -> f64
+  %double_result = math.tanh %double : f64
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : f32, f64
+}
+
+
+// CHECK-LABEL: func @atan2_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func @atan2_caller(%float: f32, %double: f64) -> (f32, f64) {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @atan2f(%[[FLOAT]], %[[FLOAT]]) : (f32, f32) -> f32
+  %float_result = math.atan2 %float, %float : f32
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @atan2(%[[DOUBLE]], %[[DOUBLE]]) : (f64, f64) -> f64
+  %double_result = math.atan2 %double, %double : f64
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : f32, f64
+}
+
+// CHECK-LABEL: func @expm1_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func @expm1_caller(%float: f32, %double: f64) -> (f32, f64) {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @expm1f(%[[FLOAT]]) : (f32) -> f32
+  %float_result = math.expm1 %float : f32
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @expm1(%[[DOUBLE]]) : (f64) -> f64
+  %double_result = math.expm1 %double : f64
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : f32, f64
+}
+
+func @expm1_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+  %float_result = math.expm1 %float : vector<2xf32>
+  %double_result = math.expm1 %double : vector<2xf64>
+  return %float_result, %double_result : vector<2xf32>, vector<2xf64>
+}
+// CHECK-LABEL:   func @expm1_vec_caller(
+// CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
+// CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+// CHECK:           %[[CVF:.*]] = constant dense<0.000000e+00> : vector<2xf32>
+// CHECK:           %[[CVD:.*]] = constant dense<0.000000e+00> : vector<2xf64>
+// CHECK:           %[[C0:.*]] = constant 0 : i32
+// CHECK:           %[[C1:.*]] = constant 1 : i32
+// CHECK:           %[[IN0_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C0]] : i32] : vector<2xf32>
+// CHECK:           %[[OUT0_F32:.*]] = call @expm1f(%[[IN0_F32]]) : (f32) -> f32
+// CHECK:           %[[VAL_8:.*]] = vector.insertelement %[[OUT0_F32]], %[[CVF]]{{\[}}%[[C0]] : i32] : vector<2xf32>
+// CHECK:           %[[IN1_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C1]] : i32] : vector<2xf32>
+// CHECK:           %[[OUT1_F32:.*]] = call @expm1f(%[[IN1_F32]]) : (f32) -> f32
+// CHECK:           %[[VAL_11:.*]] = vector.insertelement %[[OUT1_F32]], %[[VAL_8]]{{\[}}%[[C1]] : i32] : vector<2xf32>
+// CHECK:           %[[IN0_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C0]] : i32] : vector<2xf64>
+// CHECK:           %[[OUT0_F64:.*]] = call @expm1(%[[IN0_F64]]) : (f64) -> f64
+// CHECK:           %[[VAL_14:.*]] = vector.insertelement %[[OUT0_F64]], %[[CVD]]{{\[}}%[[C0]] : i32] : vector<2xf64>
+// CHECK:           %[[IN1_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C1]] : i32] : vector<2xf64>
+// CHECK:           %[[OUT1_F64:.*]] = call @expm1(%[[IN1_F64]]) : (f64) -> f64
+// CHECK:           %[[VAL_17:.*]] = vector.insertelement %[[OUT1_F64]], %[[VAL_14]]{{\[}}%[[C1]] : i32] : vector<2xf64>
+// CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+// CHECK:         }
+


        


More information about the Mlir-commits mailing list