[Mlir-commits] [mlir] 2b529a3 - [mlir] Removed TanHOp lowering from ConvertStandardToLLVM since there is no reasonable TanH representation in LLVM.
Marcel Koester
llvmlistbot at llvm.org
Wed Mar 25 08:47:18 PDT 2020
Author: Marcel Koester
Date: 2020-03-25T16:43:45+01:00
New Revision: 2b529a396d73e793615ce2cca525d68f9f738566
URL: https://github.com/llvm/llvm-project/commit/2b529a396d73e793615ce2cca525d68f9f738566
DIFF: https://github.com/llvm/llvm-project/commit/2b529a396d73e793615ce2cca525d68f9f738566.diff
LOG: [mlir] Removed TanHOp lowering from ConvertStandardToLLVM since there is no reasonable TanH representation in LLVM.
Summary: The current ConvertStandardToLLVM phase lowers the standard TanHOp to function calls to external tanh symbols. However, this leads to misunderstandings since these external symbols are not defined anywhere. This commit removes the TanHOp lowering functionality from ConvertStandardToLLVM, adapts the LowerGpuOpsToNVVMOps and LowerGpuOpsToROCDLOps passes and adjusts the affected test cases.
Reviewers: mravishankar, herhut
Subscribers: jholewinski, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, csigg, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D75509
Added:
Modified:
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index eb5628d18e46..c7bbb6d989c8 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -95,24 +95,6 @@ struct OpToFuncCallLowering : public ConvertToLLVMPattern {
const std::string f64Func;
};
-namespace gpu {
-/// Returns a predicate to be used with addDynamicallyLegalOp. The predicate
-/// returns false for calls to the provided intrinsics and true otherwise.
-inline std::function<bool(Operation *)>
-filterIllegalLLVMIntrinsics(ArrayRef<StringRef> intrinsics, MLIRContext *ctx) {
- SmallVector<StringRef, 4> illegalIds(intrinsics.begin(), intrinsics.end());
- return [illegalIds](Operation *op) -> bool {
- LLVM::CallOp callOp = dyn_cast<LLVM::CallOp>(op);
- if (!callOp || !callOp.callee())
- return true;
- StringRef callee = callOp.callee().getValue();
- return !llvm::any_of(illegalIds, [callee](StringRef intrinsic) {
- return callee.equals(intrinsic);
- });
- };
-}
-} // namespace gpu
-
} // namespace mlir
#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index e929caac6133..18aeeb845b30 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -279,8 +279,6 @@ class LowerGpuOpsToNVVMOpsPass
LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op>();
target.addIllegalOp<FuncOp>();
target.addLegalDialect<NVVM::NVVMDialect>();
- target.addDynamicallyLegalOp<mlir::LLVM::CallOp>(
- gpu::filterIllegalLLVMIntrinsics({"tanh", "tanhf"}, m.getContext()));
// TODO(csigg): Remove once we support replacing non-root ops.
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
if (failed(applyPartialConversion(m, target, patterns, &converter)))
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 238821ec8dc3..79fb3771aff6 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -71,8 +71,6 @@ class LowerGpuOpsToROCDLOpsPass
target.addLegalDialect<LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::FAbsOp, LLVM::FCeilOp,
LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op>();
- target.addDynamicallyLegalOp<LLVM::CallOp>(
- gpu::filterIllegalLLVMIntrinsics({"tanh", "tanhf"}, m.getContext()));
target.addIllegalOp<FuncOp>();
if (failed(applyPartialConversion(m, target, patterns, &converter)))
signalPassFailure();
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index e353e933a8ae..d37a7733a713 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1737,56 +1737,6 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
}
};
-// A `tanh` is converted into a call to the `tanh` function.
-struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
- using LLVMLegalizationPattern<TanhOp>::LLVMLegalizationPattern;
-
- LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
-
- using LLVMFuncOpT = LLVM::LLVMFuncOp;
- using LLVMTypeT = LLVM::LLVMType;
-
- OperandAdaptor<TanhOp> transformed(operands);
- LLVMTypeT operandType =
- transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
-
- if (!operandType)
- return failure();
-
- std::string functionName;
- if (operandType.isFloatTy())
- functionName = "tanhf";
- else if (operandType.isDoubleTy())
- functionName = "tanh";
- else
- return failure();
-
- // Get a reference to the tanh function, inserting it if necessary.
- Operation *tanhFunc =
- SymbolTable::lookupNearestSymbolFrom(op, functionName);
-
- LLVMFuncOpT tanhLLVMFunc;
- if (tanhFunc) {
- tanhLLVMFunc = cast<LLVMFuncOpT>(tanhFunc);
- } else {
- PatternRewriter::InsertionGuard insertGuard(rewriter);
- auto module = op->getParentOfType<ModuleOp>();
- rewriter.setInsertionPointToStart(module.getBody());
- tanhLLVMFunc = rewriter.create<LLVMFuncOpT>(
- module.getLoc(), functionName,
- LLVMTypeT::getFunctionTy(operandType, operandType,
- /*isVarArg=*/false));
- }
-
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(
- op, operandType, rewriter.getSymbolRefAttr(tanhLLVMFunc),
- transformed.operand());
- return success();
- }
-};
-
struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
@@ -2833,7 +2783,6 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
SqrtOpLowering,
SubFOpLowering,
SubIOpLowering,
- TanhOpLowering,
TruncateIOpLowering,
UnsignedDivIOpLowering,
UnsignedRemIOpLowering,
@@ -3022,6 +2971,7 @@ mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
this->addLegalDialect<LLVM::LLVMDialect>();
this->addIllegalOp<LLVM::DialectCastOp>();
+ this->addIllegalOp<TanhOp>();
}
std::unique_ptr<OpPassBase<ModuleOp>>
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 699ea31836a5..9c072a6e9da0 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -407,43 +407,39 @@ func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
// CHECK-NEXT: %2 = llvm.icmp "slt" %arg2, %1 : !llvm.i32
%2 = cmpi "slt", %arg2, %1 : i32
// CHECK-NEXT: %3 = llvm.sdiv %arg2, %arg3 : !llvm.i32
- %4 = divi_signed %arg2, %arg3 : i32
+ %3 = divi_signed %arg2, %arg3 : i32
// CHECK-NEXT: %4 = llvm.udiv %arg2, %arg3 : !llvm.i32
- %5 = divi_unsigned %arg2, %arg3 : i32
+ %4 = divi_unsigned %arg2, %arg3 : i32
// CHECK-NEXT: %5 = llvm.srem %arg2, %arg3 : !llvm.i32
- %6 = remi_signed %arg2, %arg3 : i32
+ %5 = remi_signed %arg2, %arg3 : i32
// CHECK-NEXT: %6 = llvm.urem %arg2, %arg3 : !llvm.i32
- %7 = remi_unsigned %arg2, %arg3 : i32
+ %6 = remi_unsigned %arg2, %arg3 : i32
// CHECK-NEXT: %7 = llvm.select %2, %arg2, %arg3 : !llvm.i1, !llvm.i32
- %8 = select %2, %arg2, %arg3 : i32
+ %7 = select %2, %arg2, %arg3 : i32
// CHECK-NEXT: %8 = llvm.fdiv %arg0, %arg1 : !llvm.float
- %9 = divf %arg0, %arg1 : f32
+ %8 = divf %arg0, %arg1 : f32
// CHECK-NEXT: %9 = llvm.frem %arg0, %arg1 : !llvm.float
- %10 = remf %arg0, %arg1 : f32
+ %9 = remf %arg0, %arg1 : f32
// CHECK-NEXT: %10 = llvm.and %arg2, %arg3 : !llvm.i32
- %11 = and %arg2, %arg3 : i32
+ %10 = and %arg2, %arg3 : i32
// CHECK-NEXT: %11 = llvm.or %arg2, %arg3 : !llvm.i32
- %12 = or %arg2, %arg3 : i32
+ %11 = or %arg2, %arg3 : i32
// CHECK-NEXT: %12 = llvm.xor %arg2, %arg3 : !llvm.i32
- %13 = xor %arg2, %arg3 : i32
+ %12 = xor %arg2, %arg3 : i32
// CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float
- %14 = std.exp %arg0 : f32
-// CHECK-NEXT: %14 = llvm.call @tanhf(%arg0) : (!llvm.float) -> !llvm.float
- %15 = std.tanh %arg0 : f32
-// CHECK-NEXT: %15 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double
- %16 = constant 7.9e-01 : f64
-// CHECK-NEXT: %16 = llvm.call @tanh(%15) : (!llvm.double) -> !llvm.double
- %17 = std.tanh %16 : f64
-// CHECK-NEXT: %17 = llvm.shl %arg2, %arg3 : !llvm.i32
- %18 = shift_left %arg2, %arg3 : i32
-// CHECK-NEXT: %18 = llvm.ashr %arg2, %arg3 : !llvm.i32
- %19 = shift_right_signed %arg2, %arg3 : i32
-// CHECK-NEXT: %19 = llvm.lshr %arg2, %arg3 : !llvm.i32
- %20 = shift_right_unsigned %arg2, %arg3 : i32
+ %13 = std.exp %arg0 : f32
+// CHECK-NEXT: %14 = llvm.mlir.constant(7.900000e-01 : f64) : !llvm.double
+ %14 = constant 7.9e-01 : f64
+// CHECK-NEXT: %15 = llvm.shl %arg2, %arg3 : !llvm.i32
+ %15 = shift_left %arg2, %arg3 : i32
+// CHECK-NEXT: %16 = llvm.ashr %arg2, %arg3 : !llvm.i32
+ %16 = shift_right_signed %arg2, %arg3 : i32
+// CHECK-NEXT: %17 = llvm.lshr %arg2, %arg3 : !llvm.i32
+ %17 = shift_right_unsigned %arg2, %arg3 : i32
// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
- %21 = std.sqrt %arg0 : f32
+ %18 = std.sqrt %arg0 : f32
// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (!llvm.double) -> !llvm.double
- %22 = std.sqrt %arg4 : f64
+ %19 = std.sqrt %arg4 : f64
return %0, %4 : f32, i32
}
@@ -853,22 +849,6 @@ func @subview_const_stride_and_offset(%0 : memref<64x4xf32, affine_map<(d0, d1)
// -----
-module {
- func @check_tanh_func_added_only_once_to_symbol_table(%f: f32, %lf: f64) -> () {
- %f0 = std.tanh %f : f32
- %f1 = std.tanh %f0 : f32
- %lf0 = std.tanh %lf : f64
- %lf1 = std.tanh %lf0 : f64
- return
- }
-// CHECK: module {
-// CHECK: llvm.func @tanh(!llvm.double) -> !llvm.double
-// CHECK: llvm.func @tanhf(!llvm.float) -> !llvm.float
-// CHECK-LABEL: func @check_tanh_func_added_only_once_to_symbol_table
-}
-
-// -----
-
// CHECK-LABEL: func @atomic_rmw
func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) {
atomic_rmw "assign" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
More information about the Mlir-commits
mailing list