[Mlir-commits] [mlir] 9cb1029 - [mlir] Add support for lowering tanh to LLVMIR.

Hanhan Wang llvmlistbot at llvm.org
Thu Jun 18 10:42:26 PDT 2020


Author: Hanhan Wang
Date: 2020-06-18T10:42:13-07:00
New Revision: 9cb10296ecaa2d7131744375e8b14200674fa1e5

URL: https://github.com/llvm/llvm-project/commit/9cb10296ecaa2d7131744375e8b14200674fa1e5
DIFF: https://github.com/llvm/llvm-project/commit/9cb10296ecaa2d7131744375e8b14200674fa1e5.diff

LOG: [mlir] Add support for lowering tanh to LLVMIR.

Summary:
Fixed build of D81618

Add a pattern for expanding tanh op into exp form.
A `tanh` is expanded into:
   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0.

Differential Revision: https://reviews.llvm.org/D82040

Added: 
    mlir/lib/Dialect/StandardOps/Transforms/ExpandTanh.cpp
    mlir/test/Dialect/Standard/expand-tanh.mlir
    mlir/test/lib/Transforms/TestExpandTanh.cpp

Modified: 
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
    mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index c0622e529564..aadc41d2790d 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -20,10 +20,15 @@
 namespace mlir {
 
 class Pass;
+class MLIRContext;
+class OwningRewritePatternList;
 
 /// Creates an instance of the ExpandAtomic pass.
 std::unique_ptr<Pass> createExpandAtomicPass();
 
+void populateExpandTanhPattern(OwningRewritePatternList &patterns,
+                               MLIRContext *ctx);
+
 } // end namespace mlir
 
 #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
index 0e2ef2dcc36c..299fc2bd3ccd 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRStandardOpsTransforms
   ExpandAtomic.cpp
+  ExpandTanh.cpp
   FuncConversions.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandTanh.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandTanh.cpp
new file mode 100644
index 000000000000..48cfc4787541
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandTanh.cpp
@@ -0,0 +1,70 @@
+//===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements expansion of tanh op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Expands tanh op into
+///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
+///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
+struct TanhOpConverter : public OpRewritePattern<TanhOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TanhOp op,
+                                PatternRewriter &rewriter) const final {
+    auto floatType = op.operand().getType();
+    Location loc = op.getLoc();
+    auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+    auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
+    Value one = rewriter.create<ConstantOp>(loc, floatOne);
+    Value two = rewriter.create<ConstantOp>(loc, floatTwo);
+    Value doubledX = rewriter.create<MulFOp>(loc, op.operand(), two);
+
+    // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
+    Value negDoubledX = rewriter.create<NegFOp>(loc, doubledX);
+    Value exp2x = rewriter.create<ExpOp>(loc, negDoubledX);
+    Value dividend = rewriter.create<SubFOp>(loc, one, exp2x);
+    Value divisor = rewriter.create<AddFOp>(loc, one, exp2x);
+    Value positiveRes = rewriter.create<DivFOp>(loc, dividend, divisor);
+
+    // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
+    exp2x = rewriter.create<ExpOp>(loc, doubledX);
+    dividend = rewriter.create<SubFOp>(loc, exp2x, one);
+    divisor = rewriter.create<AddFOp>(loc, exp2x, one);
+    Value negativeRes = rewriter.create<DivFOp>(loc, dividend, divisor);
+
+    // tanh(x) = x >= 0 ? positiveRes : negativeRes
+    auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
+    Value zero = rewriter.create<ConstantOp>(loc, floatZero);
+    Value cmpRes =
+        rewriter.create<CmpFOp>(loc, CmpFPredicate::OGE, op.operand(), zero);
+    rewriter.replaceOpWithNewOp<SelectOp>(op, cmpRes, positiveRes, negativeRes);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateExpandTanhPattern(OwningRewritePatternList &patterns,
+                                     MLIRContext *ctx) {
+  patterns.insert<TanhOpConverter>(ctx);
+}

diff  --git a/mlir/test/Dialect/Standard/expand-tanh.mlir b/mlir/test/Dialect/Standard/expand-tanh.mlir
new file mode 100644
index 000000000000..557d1d0a808a
--- /dev/null
+++ b/mlir/test/Dialect/Standard/expand-tanh.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s -test-expand-tanh | FileCheck %s
+
+// CHECK-LABEL: func @tanh
+func @tanh(%arg: f32) -> f32 {
+  %res = tanh %arg : f32
+  return %res : f32
+}
+// CHECK-DAG: %[[ZERO:.+]] = constant 0.000000e+00 : f32
+// CHECK-DAG: %[[ONE:.+]] = constant 1.000000e+00 : f32
+// CHECK-DAG: %[[TWO:.+]] = constant 2.000000e+00 : f32
+// CHECK: %[[DOUBLEDX:.+]] = mulf %arg0, %[[TWO]] : f32
+// CHECK: %[[NEGDOUBLEDX:.+]] = negf %[[DOUBLEDX]] : f32
+// CHECK: %[[EXP1:.+]] = exp %[[NEGDOUBLEDX]] : f32
+// CHECK: %[[DIVIDEND1:.+]] = subf %[[ONE]], %[[EXP1]] : f32
+// CHECK: %[[DIVISOR1:.+]] = addf %[[ONE]], %[[EXP1]] : f32
+// CHECK: %[[RES1:.+]] = divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
+// CHECK: %[[EXP2:.+]] = exp %[[DOUBLEDX]] : f32
+// CHECK: %[[DIVIDEND2:.+]] = subf %[[EXP2]], %[[ONE]] : f32
+// CHECK: %[[DIVISOR2:.+]] = addf %[[EXP2]], %[[ONE]] : f32
+// CHECK: %[[RES2:.+]] = divf %[[DIVIDEND2]], %[[DIVISOR2]] : f32
+// CHECK: %[[COND:.+]] = cmpf "oge", %arg0, %[[ZERO]] : f32
+// CHECK: %[[RESULT:.+]] = select %[[COND]], %[[RES1]], %[[RES2]] : f32
+// CHECK: return %[[RESULT]]

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 044270276c11..db864999a440 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@
 add_mlir_library(MLIRTestTransforms
   TestAllReduceLowering.cpp
   TestBufferPlacement.cpp
+  TestExpandTanh.cpp
   TestCallGraph.cpp
   TestConstantFold.cpp
   TestConvertGPUKernelToCubin.cpp

diff  --git a/mlir/test/lib/Transforms/TestExpandTanh.cpp b/mlir/test/lib/Transforms/TestExpandTanh.cpp
new file mode 100644
index 000000000000..c5f6e3a5ce30
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestExpandTanh.cpp
@@ -0,0 +1,37 @@
+//===- TestExpandTanh.cpp - Test expand tanh op into exp form ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains test passes for expanding tanh.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+struct TestExpandTanhPass
+    : public PassWrapper<TestExpandTanhPass, FunctionPass> {
+  void runOnFunction() override;
+};
+} // end anonymous namespace
+
+void TestExpandTanhPass::runOnFunction() {
+  OwningRewritePatternList patterns;
+  populateExpandTanhPattern(patterns, &getContext());
+  applyPatternsAndFoldGreedily(getOperation(), patterns);
+}
+
+namespace mlir {
+void registerTestExpandTanhPass() {
+  PassRegistration<TestExpandTanhPass> pass("test-expand-tanh",
+                                            "Test expanding tanh");
+}
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 067a2156c6fb..4515094eb105 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -48,6 +48,7 @@ void registerTestConstantFold();
 void registerTestConvertGPUKernelToCubinPass();
 void registerTestConvertGPUKernelToHsacoPass();
 void registerTestDominancePass();
+void registerTestExpandTanhPass();
 void registerTestFunc();
 void registerTestGpuMemoryPromotionPass();
 void registerTestLinalgHoisting();
@@ -122,6 +123,7 @@ void registerTestPasses() {
   registerTestBufferPlacementPreparationPass();
   registerTestDominancePass();
   registerTestFunc();
+  registerTestExpandTanhPass();
   registerTestGpuMemoryPromotionPass();
   registerTestLinalgHoisting();
   registerTestLinalgTransforms();


        


More information about the Mlir-commits mailing list