[Mlir-commits] [mlir] 2b529a3 - [mlir] Removed TanHOp lowering from ConvertStandardToLLVM since there is no reasonable TanH representation in LLVM.

Marcel Koester llvmlistbot at llvm.org
Wed Mar 25 08:47:18 PDT 2020


Author: Marcel Koester
Date: 2020-03-25T16:43:45+01:00
New Revision: 2b529a396d73e793615ce2cca525d68f9f738566

URL: https://github.com/llvm/llvm-project/commit/2b529a396d73e793615ce2cca525d68f9f738566
DIFF: https://github.com/llvm/llvm-project/commit/2b529a396d73e793615ce2cca525d68f9f738566.diff

LOG: [mlir] Removed TanHOp lowering from ConvertStandardToLLVM since there is no reasonable TanH representation in LLVM.

Summary: The current ConvertStandardToLLVM phase lowers the standard TanHOp to function calls to external tanh symbols. However, this leads to misunderstandings since these external symbols are not defined anywhere. This commit removes the TanHOp lowering functionality from ConvertStandardToLLVM, adapts the LowerGpuOpsToNVVMOps and LowerGpuOpsToROCDLOps passes and adjusts the affected test cases.

Reviewers: mravishankar, herhut

Subscribers: jholewinski, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, csigg, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D75509

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index eb5628d18e46..c7bbb6d989c8 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -95,24 +95,6 @@ struct OpToFuncCallLowering : public ConvertToLLVMPattern {
   const std::string f64Func;
 };
 
-namespace gpu {
-/// Returns a predicate to be used with addDynamicallyLegalOp. The predicate
-/// returns false for calls to the provided intrinsics and true otherwise.
-inline std::function<bool(Operation *)>
-filterIllegalLLVMIntrinsics(ArrayRef<StringRef> intrinsics, MLIRContext *ctx) {
-  SmallVector<StringRef, 4> illegalIds(intrinsics.begin(), intrinsics.end());
-  return [illegalIds](Operation *op) -> bool {
-    LLVM::CallOp callOp = dyn_cast<LLVM::CallOp>(op);
-    if (!callOp || !callOp.callee())
-      return true;
-    StringRef callee = callOp.callee().getValue();
-    return !llvm::any_of(illegalIds, [callee](StringRef intrinsic) {
-      return callee.equals(intrinsic);
-    });
-  };
-}
-} // namespace gpu
-
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index e929caac6133..18aeeb845b30 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -279,8 +279,6 @@ class LowerGpuOpsToNVVMOpsPass
                         LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op>();
     target.addIllegalOp<FuncOp>();
     target.addLegalDialect<NVVM::NVVMDialect>();
-    target.addDynamicallyLegalOp<mlir::LLVM::CallOp>(
-        gpu::filterIllegalLLVMIntrinsics({"tanh", "tanhf"}, m.getContext()));
     // TODO(csigg): Remove once we support replacing non-root ops.
     target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
     if (failed(applyPartialConversion(m, target, patterns, &converter)))

diff  --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 238821ec8dc3..79fb3771aff6 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -71,8 +71,6 @@ class LowerGpuOpsToROCDLOpsPass
     target.addLegalDialect<LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
     target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::FAbsOp, LLVM::FCeilOp,
                         LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op>();
-    target.addDynamicallyLegalOp<LLVM::CallOp>(
-        gpu::filterIllegalLLVMIntrinsics({"tanh", "tanhf"}, m.getContext()));
     target.addIllegalOp<FuncOp>();
     if (failed(applyPartialConversion(m, target, patterns, &converter)))
       signalPassFailure();

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index e353e933a8ae..d37a7733a713 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1737,56 +1737,6 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
   }
 };
 
-// A `tanh` is converted into a call to the `tanh` function.
-struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
-  using LLVMLegalizationPattern<TanhOp>::LLVMLegalizationPattern;
-
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-
-    using LLVMFuncOpT = LLVM::LLVMFuncOp;
-    using LLVMTypeT = LLVM::LLVMType;
-
-    OperandAdaptor<TanhOp> transformed(operands);
-    LLVMTypeT operandType =
-        transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
-
-    if (!operandType)
-      return failure();
-
-    std::string functionName;
-    if (operandType.isFloatTy())
-      functionName = "tanhf";
-    else if (operandType.isDoubleTy())
-      functionName = "tanh";
-    else
-      return failure();
-
-    // Get a reference to the tanh function, inserting it if necessary.
-    Operation *tanhFunc =
-        SymbolTable::lookupNearestSymbolFrom(op, functionName);
-
-    LLVMFuncOpT tanhLLVMFunc;
-    if (tanhFunc) {
-      tanhLLVMFunc = cast<LLVMFuncOpT>(tanhFunc);
-    } else {
-      PatternRewriter::InsertionGuard insertGuard(rewriter);
-      auto module = op->getParentOfType<ModuleOp>();
-      rewriter.setInsertionPointToStart(module.getBody());
-      tanhLLVMFunc = rewriter.create<LLVMFuncOpT>(
-          module.getLoc(), functionName,
-          LLVMTypeT::getFunctionTy(operandType, operandType,
-                                   /*isVarArg=*/false));
-    }
-
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(
-        op, operandType, rewriter.getSymbolRefAttr(tanhLLVMFunc),
-        transformed.operand());
-    return success();
-  }
-};
-
 struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
   using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
 
@@ -2833,7 +2783,6 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       SqrtOpLowering,
       SubFOpLowering,
       SubIOpLowering,
-      TanhOpLowering,
       TruncateIOpLowering,
       UnsignedDivIOpLowering,
       UnsignedRemIOpLowering,
@@ -3022,6 +2971,7 @@ mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx)
     : ConversionTarget(ctx) {
   this->addLegalDialect<LLVM::LLVMDialect>();
   this->addIllegalOp<LLVM::DialectCastOp>();
+  this->addIllegalOp<TanhOp>();
 }
 
 std::unique_ptr<OpPassBase<ModuleOp>>

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 699ea31836a5..9c072a6e9da0 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -407,43 +407,39 @@ func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
 // CHECK-NEXT:  %2 = llvm.icmp "slt" %arg2, %1 : !llvm.i32
   %2 = cmpi "slt", %arg2, %1 : i32
 // CHECK-NEXT:  %3 = llvm.sdiv %arg2, %arg3 : !llvm.i32
-  %4 = divi_signed %arg2, %arg3 : i32
+  %3 = divi_signed %arg2, %arg3 : i32
 // CHECK-NEXT:  %4 = llvm.udiv %arg2, %arg3 : !llvm.i32
-  %5 = divi_unsigned %arg2, %arg3 : i32
+  %4 = divi_unsigned %arg2, %arg3 : i32
 // CHECK-NEXT:  %5 = llvm.srem %arg2, %arg3 : !llvm.i32
-  %6 = remi_signed %arg2, %arg3 : i32
+  %5 = remi_signed %arg2, %arg3 : i32
 // CHECK-NEXT:  %6 = llvm.urem %arg2, %arg3 : !llvm.i32
-  %7 = remi_unsigned %arg2, %arg3 : i32
+  %6 = remi_unsigned %arg2, %arg3 : i32
 // CHECK-NEXT:  %7 = llvm.select %2, %arg2, %arg3 : !llvm.i1, !llvm.i32
-  %8 = select %2, %arg2, %arg3 : i32
+  %7 = select %2, %arg2, %arg3 : i32
 // CHECK-NEXT:  %8 = llvm.fdiv %arg0, %arg1 : !llvm.float
-  %9 = divf %arg0, %arg1 : f32
+  %8 = divf %arg0, %arg1 : f32
 // CHECK-NEXT:  %9 = llvm.frem %arg0, %arg1 : !llvm.float
-  %10 = remf %arg0, %arg1 : f32
+  %9 = remf %arg0, %arg1 : f32
 // CHECK-NEXT: %10 = llvm.and %arg2, %arg3 : !llvm.i32
-  %11 = and %arg2, %arg3 : i32
+  %10 = and %arg2, %arg3 : i32
 // CHECK-NEXT: %11 = llvm.or %arg2, %arg3 : !llvm.i32
-  %12 = or %arg2, %arg3 : i32
+  %11 = or %arg2, %arg3 : i32
 // CHECK-NEXT: %12 = llvm.xor %arg2, %arg3 : !llvm.i32
-  %13 = xor %arg2, %arg3 : i32
+  %12 = xor %arg2, %arg3 : i32
 // CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float
-  %14 = std.exp %arg0 : f32
-// CHECK-NEXT: %14 = llvm.call @tanhf(%arg0) : (!llvm.float) -> !llvm.float
-  %15 = std.tanh %arg0 : f32
-// CHECK-NEXT: %15 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double
-  %16 = constant 7.9e-01 : f64
-// CHECK-NEXT: %16 = llvm.call @tanh(%15) : (!llvm.double) -> !llvm.double
-  %17 = std.tanh %16 : f64
-// CHECK-NEXT: %17 = llvm.shl %arg2, %arg3 : !llvm.i32
-  %18 = shift_left %arg2, %arg3 : i32
-// CHECK-NEXT: %18 = llvm.ashr %arg2, %arg3 : !llvm.i32
-  %19 = shift_right_signed %arg2, %arg3 : i32
-// CHECK-NEXT: %19 = llvm.lshr %arg2, %arg3 : !llvm.i32
-  %20 = shift_right_unsigned %arg2, %arg3 : i32
+  %13 = std.exp %arg0 : f32
+// CHECK-NEXT: %14 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double
+  %14 = constant 7.9e-01 : f64
+// CHECK-NEXT: %15 = llvm.shl %arg2, %arg3 : !llvm.i32
+  %15 = shift_left %arg2, %arg3 : i32
+// CHECK-NEXT: %16 = llvm.ashr %arg2, %arg3 : !llvm.i32
+  %16 = shift_right_signed %arg2, %arg3 : i32
+// CHECK-NEXT: %17 = llvm.lshr %arg2, %arg3 : !llvm.i32
+  %17 = shift_right_unsigned %arg2, %arg3 : i32
 // CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
-  %21 = std.sqrt %arg0 : f32
+  %18 = std.sqrt %arg0 : f32
 // CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (!llvm.double) -> !llvm.double
-  %22 = std.sqrt %arg4 : f64
+  %19 = std.sqrt %arg4 : f64
   return %0, %4 : f32, i32
 }
 
@@ -853,22 +849,6 @@ func @subview_const_stride_and_offset(%0 : memref<64x4xf32, affine_map<(d0, d1)
 
 // -----
 
-module {
-  func @check_tanh_func_added_only_once_to_symbol_table(%f: f32, %lf: f64) -> () {
-    %f0 = std.tanh %f : f32
-    %f1 = std.tanh %f0 : f32
-    %lf0 = std.tanh %lf : f64
-    %lf1 = std.tanh %lf0 : f64
-    return
-  }
-// CHECK: module {
-// CHECK: llvm.func @tanh(!llvm.double) -> !llvm.double
-// CHECK: llvm.func @tanhf(!llvm.float) -> !llvm.float
-// CHECK-LABEL: func @check_tanh_func_added_only_once_to_symbol_table
-}
-
-// -----
-
 // CHECK-LABEL: func @atomic_rmw
 func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) {
   atomic_rmw "assign" %fval, %F[%i] : (f32, memref<10xf32>) -> f32


        


More information about the Mlir-commits mailing list