[Mlir-commits] [mlir] 93537fa - [mlir] Add lowering from math.expm1 to LLVM.
Adrian Kuegel
llvmlistbot at llvm.org
Tue May 4 05:22:40 PDT 2021
Author: Adrian Kuegel
Date: 2021-05-04T14:22:10+02:00
New Revision: 93537fabcee8fcfa3316d7abd5e935f7fe9b468f
URL: https://github.com/llvm/llvm-project/commit/93537fabcee8fcfa3316d7abd5e935f7fe9b468f
DIFF: https://github.com/llvm/llvm-project/commit/93537fabcee8fcfa3316d7abd5e935f7fe9b468f.diff
LOG: [mlir] Add lowering from math.expm1 to LLVM.
Differential Revision: https://reviews.llvm.org/D96776
Added:
Modified:
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 8df65042e154..5f94804b6252 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2352,6 +2352,60 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
}
};
+// A `expm1` is converted into `exp - 1`.
+struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
+ using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::ExpM1Op op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ math::ExpM1Op::Adaptor transformed(operands);
+ auto operandType = transformed.operand().getType();
+
+ if (!operandType || !LLVM::isCompatibleType(operandType))
+ return failure();
+
+ auto loc = op.getLoc();
+ auto resultType = op.getResult().getType();
+ auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
+ auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+
+ if (!operandType.isa<LLVM::LLVMArrayType>()) {
+ LLVM::ConstantOp one;
+ if (LLVM::isCompatibleVectorType(operandType)) {
+ one = rewriter.create<LLVM::ConstantOp>(
+ loc, operandType,
+ SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
+ } else {
+ one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
+ }
+ auto exp = rewriter.create<LLVM::ExpOp>(loc, transformed.operand());
+ rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
+ return success();
+ }
+
+ auto vectorType = resultType.dyn_cast<VectorType>();
+ if (!vectorType)
+ return rewriter.notifyMatchFailure(op, "expected vector result type");
+
+ return handleMultidimensionalVectors(
+ op.getOperation(), operands, *getTypeConverter(),
+ [&](Type llvm1DVectorTy, ValueRange operands) {
+ auto splatAttr = SplatElementsAttr::get(
+ mlir::VectorType::get(
+ {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
+ floatType),
+ floatOne);
+ auto one =
+ rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
+ auto exp =
+ rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
+ return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
+ },
+ rewriter);
+ }
+};
+
// A `log1p` is converted into `log(1 + ...)`.
struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
@@ -3924,6 +3978,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
DivFOpLowering,
ExpOpLowering,
Exp2OpLowering,
+ ExpM1OpLowering,
FloorFOpLowering,
FmaFOpLowering,
GenericAtomicRMWOpLowering,
diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index 4a2983b44eb9..a743e31971a3 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -37,6 +37,18 @@ func @log1p_2dvector(%arg0 : vector<4x3xf32>) {
// -----
+// CHECK-LABEL: func @expm1(
+// CHECK-SAME: f32
+func @expm1(%arg0 : f32) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+ // CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%arg0) : (f32) -> f32
+ // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : f32
+ %0 = math.expm1 %arg0 : f32
+ std.return
+}
+
+// -----
+
// CHECK-LABEL: func @rsqrt(
// CHECK-SAME: f32
func @rsqrt(%arg0 : f32) {
More information about the Mlir-commits
mailing list