[Mlir-commits] [mlir] 74f6138 - [mlir] Add lowering from math::Log1p to LLVM

Eugene Zhulenev llvmlistbot at llvm.org
Tue Mar 16 15:59:16 PDT 2021


Author: Eugene Zhulenev
Date: 2021-03-16T15:59:09-07:00
New Revision: 74f6138bd98f480be2bd39d8ecc2cf66089739c3

URL: https://github.com/llvm/llvm-project/commit/74f6138bd98f480be2bd39d8ecc2cf66089739c3
DIFF: https://github.com/llvm/llvm-project/commit/74f6138bd98f480be2bd39d8ecc2cf66089739c3.diff

LOG: [mlir] Add lowering from math::Log1p to LLVM

[mlir] Add lowering from math::Log1p to LLVM

Reviewed By: cota

Differential Revision: https://reviews.llvm.org/D98662

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index b3a2bb634f39..de1df34eaa5d 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2303,6 +2303,61 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
   }
 };
 
+// A `log1p` is converted into `log(1 + ...)`.
+struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
+  using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(math::Log1pOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    math::Log1pOp::Adaptor transformed(operands);
+    auto operandType = transformed.operand().getType();
+
+    if (!operandType || !LLVM::isCompatibleType(operandType))
+      return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+    auto loc = op.getLoc();
+    auto resultType = op.getResult().getType();
+    auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
+    auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+
+    if (!operandType.isa<LLVM::LLVMArrayType>()) {
+      LLVM::ConstantOp one =
+          LLVM::isCompatibleVectorType(operandType)
+              ? rewriter.create<LLVM::ConstantOp>(
+                    loc, operandType,
+                    SplatElementsAttr::get(resultType.cast<ShapedType>(),
+                                           floatOne))
+              : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
+
+      auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
+                                               transformed.operand());
+      rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
+      return success();
+    }
+
+    auto vectorType = resultType.dyn_cast<VectorType>();
+    if (!vectorType)
+      return rewriter.notifyMatchFailure(op, "expected vector result type");
+
+    return handleMultidimensionalVectors(
+        op.getOperation(), operands, *getTypeConverter(),
+        [&](Type llvm1DVectorTy, ValueRange operands) {
+          auto splatAttr = SplatElementsAttr::get(
+              mlir::VectorType::get(
+                  {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
+                  floatType),
+              floatOne);
+          auto one =
+              rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
+          auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
+                                                   transformed.operand());
+          return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
+        },
+        rewriter);
+  }
+};
+
 // A `rsqrt` is converted into `1 / sqrt`.
 struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
   using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
@@ -3788,6 +3843,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       GenericAtomicRMWOpLowering,
       LogOpLowering,
       Log10OpLowering,
+      Log1pOpLowering,
       Log2OpLowering,
       FPExtLowering,
       FPToSILowering,

diff  --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index fcb5b1c8a5a2..5eca81dcad00 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -12,6 +12,18 @@ func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
 
 // -----
 
+// CHECK-LABEL: func @log1p(
+// CHECK-SAME: f32
+func @log1p(%arg0 : f32) {
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+  // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %arg0 : f32
+  // CHECK: %[[LOG:.*]] = "llvm.intr.log"(%[[ADD]]) : (f32) -> f32
+  %0 = math.log1p %arg0 : f32
+  std.return
+}
+
+// -----
+
 // CHECK-LABEL: func @rsqrt(
 // CHECK-SAME: f32
 func @rsqrt(%arg0 : f32) {


        


More information about the Mlir-commits mailing list