[Mlir-commits] [mlir] 26e59cc - [mlir] factor math-to-llvm out of standard-to-llvm
Alex Zinenko
llvmlistbot at llvm.org
Mon Jul 12 02:09:49 PDT 2021
Author: Alex Zinenko
Date: 2021-07-12T11:09:42+02:00
New Revision: 26e59cc19f8646ca59e2d699882c611980f2b563
URL: https://github.com/llvm/llvm-project/commit/26e59cc19f8646ca59e2d699882c611980f2b563
DIFF: https://github.com/llvm/llvm-project/commit/26e59cc19f8646ca59e2d699882c611980f2b563.diff
LOG: [mlir] factor math-to-llvm out of standard-to-llvm
After the Math has been split out of the Standard dialect, the
conversion to the LLVM dialect remained as a huge monolithic pass.
This is undesirable for the same complexity management reasons as having
a huge Standard dialect itself, and is even more confusing given the
existence of a separate dialect. Extract the conversion of the Math
dialect operations to LLVM into a separate library and a separate
conversion pass.
Reviewed By: silvas
Differential Revision: https://reviews.llvm.org/D105702
Added:
mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
mlir/lib/Conversion/MathToLLVM/CMakeLists.txt
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
Modified:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
new file mode 100644
index 0000000000000..d03bc29292693
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
@@ -0,0 +1,26 @@
+//===- MathToLLVM.h - Math to LLVM dialect conversion -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H
+#define MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H
+
+#include <memory>
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+void populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createConvertMathToLLVMPass();
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 9b3b9701ce6c9..dc4f19036c940 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -22,6 +22,7 @@
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 9fa437a895b12..2aa92f31ab093 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -255,6 +255,19 @@ def ConvertMathToLibm : Pass<"convert-math-to-libm", "ModuleOp"> {
let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"];
}
+//===----------------------------------------------------------------------===//
+// MathToLLVM
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToLLVM : FunctionPass<"convert-math-to-llvm"> {
+ let summary = "Convert Math dialect to LLVM dialect";
+ let description = [{
+ This pass converts supported Math ops to LLVM dialect intrinsics.
+ }];
+ let constructor = "mlir::createConvertMathToLLVMPass()";
+ let dependentDialects = ["LLVM::LLVMDialect"];
+}
+
//===----------------------------------------------------------------------===//
// MemRefToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 18c4850fd5ac0..07908183cdc19 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -13,6 +13,7 @@ add_subdirectory(LinalgToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(LLVMCommon)
add_subdirectory(MathToLibm)
+add_subdirectory(MathToLLVM)
add_subdirectory(MemRefToLLVM)
add_subdirectory(OpenACCToLLVM)
add_subdirectory(OpenACCToSCF)
diff --git a/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..fe1c172f5e4bd
--- /dev/null
+++ b/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRMathToLLVM
+ MathToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRLLVMCommonConversion
+ MLIRLLVMIR
+ MLIRMath
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
new file mode 100644
index 0000000000000..d15377e5c7fda
--- /dev/null
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -0,0 +1,234 @@
+//===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "../PassDetail.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+namespace {
+using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
+using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
+using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
+using Log10OpLowering =
+ VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
+using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
+using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
+using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
+using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
+using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
+
+// A `expm1` is converted into `exp - 1`.
+struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
+ using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::ExpM1Op op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ math::ExpM1Op::Adaptor transformed(operands);
+ auto operandType = transformed.operand().getType();
+
+ if (!operandType || !LLVM::isCompatibleType(operandType))
+ return failure();
+
+ auto loc = op.getLoc();
+ auto resultType = op.getResult().getType();
+ auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
+ auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+
+ if (!operandType.isa<LLVM::LLVMArrayType>()) {
+ LLVM::ConstantOp one;
+ if (LLVM::isCompatibleVectorType(operandType)) {
+ one = rewriter.create<LLVM::ConstantOp>(
+ loc, operandType,
+ SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
+ } else {
+ one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
+ }
+ auto exp = rewriter.create<LLVM::ExpOp>(loc, transformed.operand());
+ rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
+ return success();
+ }
+
+ auto vectorType = resultType.dyn_cast<VectorType>();
+ if (!vectorType)
+ return rewriter.notifyMatchFailure(op, "expected vector result type");
+
+ return LLVM::detail::handleMultidimensionalVectors(
+ op.getOperation(), operands, *getTypeConverter(),
+ [&](Type llvm1DVectorTy, ValueRange operands) {
+ auto splatAttr = SplatElementsAttr::get(
+ mlir::VectorType::get(
+ {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
+ floatType),
+ floatOne);
+ auto one =
+ rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
+ auto exp =
+ rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
+ return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
+ },
+ rewriter);
+ }
+};
+
+// A `log1p` is converted into `log(1 + ...)`.
+struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
+ using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::Log1pOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ math::Log1pOp::Adaptor transformed(operands);
+ auto operandType = transformed.operand().getType();
+
+ if (!operandType || !LLVM::isCompatibleType(operandType))
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ auto loc = op.getLoc();
+ auto resultType = op.getResult().getType();
+ auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
+ auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+
+ if (!operandType.isa<LLVM::LLVMArrayType>()) {
+ LLVM::ConstantOp one =
+ LLVM::isCompatibleVectorType(operandType)
+ ? rewriter.create<LLVM::ConstantOp>(
+ loc, operandType,
+ SplatElementsAttr::get(resultType.cast<ShapedType>(),
+ floatOne))
+ : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
+
+ auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
+ transformed.operand());
+ rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
+ return success();
+ }
+
+ auto vectorType = resultType.dyn_cast<VectorType>();
+ if (!vectorType)
+ return rewriter.notifyMatchFailure(op, "expected vector result type");
+
+ return LLVM::detail::handleMultidimensionalVectors(
+ op.getOperation(), operands, *getTypeConverter(),
+ [&](Type llvm1DVectorTy, ValueRange operands) {
+ auto splatAttr = SplatElementsAttr::get(
+ mlir::VectorType::get(
+ {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
+ floatType),
+ floatOne);
+ auto one =
+ rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
+ auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
+ operands[0]);
+ return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
+ },
+ rewriter);
+ }
+};
+
+// A `rsqrt` is converted into `1 / sqrt`.
+struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
+ using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::RsqrtOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ math::RsqrtOp::Adaptor transformed(operands);
+ auto operandType = transformed.operand().getType();
+
+ if (!operandType || !LLVM::isCompatibleType(operandType))
+ return failure();
+
+ auto loc = op.getLoc();
+ auto resultType = op.getResult().getType();
+ auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
+ auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+
+ if (!operandType.isa<LLVM::LLVMArrayType>()) {
+ LLVM::ConstantOp one;
+ if (LLVM::isCompatibleVectorType(operandType)) {
+ one = rewriter.create<LLVM::ConstantOp>(
+ loc, operandType,
+ SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
+ } else {
+ one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
+ }
+ auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
+ rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
+ return success();
+ }
+
+ auto vectorType = resultType.dyn_cast<VectorType>();
+ if (!vectorType)
+ return failure();
+
+ return LLVM::detail::handleMultidimensionalVectors(
+ op.getOperation(), operands, *getTypeConverter(),
+ [&](Type llvm1DVectorTy, ValueRange operands) {
+ auto splatAttr = SplatElementsAttr::get(
+ mlir::VectorType::get(
+ {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
+ floatType),
+ floatOne);
+ auto one =
+ rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
+ auto sqrt =
+ rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
+ return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
+ },
+ rewriter);
+ }
+};
+
+struct ConvertMathToLLVMPass
+ : public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
+ ConvertMathToLLVMPass() = default;
+
+ void runOnFunction() override {
+ RewritePatternSet patterns(&getContext());
+ LLVMTypeConverter converter(&getContext());
+ populateMathToLLVMConversionPatterns(converter, patterns);
+ LLVMConversionTarget target(getContext());
+ target.addLegalOp<LLVM::DialectCastOp>();
+ if (failed(
+ applyPartialConversion(getFunction(), target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
+ // clang-format off
+ patterns.add<
+ CosOpLowering,
+ ExpOpLowering,
+ Exp2OpLowering,
+ ExpM1OpLowering,
+ Log10OpLowering,
+ Log1pOpLowering,
+ Log2OpLowering,
+ LogOpLowering,
+ PowFOpLowering,
+ RsqrtOpLowering,
+ SinOpLowering,
+ SqrtOpLowering
+ >(converter);
+ // clang-format on
+}
+
+std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
+ return std::make_unique<ConvertMathToLLVMPass>();
+}
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 70b7adb9c7ddb..6be363a683b90 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -373,25 +373,17 @@ using AndOpLowering = VectorConvertToLLVMPattern<AndOp, LLVM::AndOp>;
using CeilFOpLowering = VectorConvertToLLVMPattern<CeilFOp, LLVM::FCeilOp>;
using CopySignOpLowering =
VectorConvertToLLVMPattern<CopySignOp, LLVM::CopySignOp>;
-using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
-using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
-using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
using FPExtOpLowering = VectorConvertToLLVMPattern<FPExtOp, LLVM::FPExtOp>;
using FPToSIOpLowering = VectorConvertToLLVMPattern<FPToSIOp, LLVM::FPToSIOp>;
using FPToUIOpLowering = VectorConvertToLLVMPattern<FPToUIOp, LLVM::FPToUIOp>;
using FPTruncOpLowering = VectorConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp>;
using FloorFOpLowering = VectorConvertToLLVMPattern<FloorFOp, LLVM::FFloorOp>;
using FmaFOpLowering = VectorConvertToLLVMPattern<FmaFOp, LLVM::FMAOp>;
-using Log10OpLowering =
- VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
-using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
-using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
using MulFOpLowering = VectorConvertToLLVMPattern<MulFOp, LLVM::FMulOp>;
using MulIOpLowering = VectorConvertToLLVMPattern<MulIOp, LLVM::MulOp>;
using NegFOpLowering = VectorConvertToLLVMPattern<NegFOp, LLVM::FNegOp>;
using OrOpLowering = VectorConvertToLLVMPattern<OrOp, LLVM::OrOp>;
-using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
using RemFOpLowering = VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp>;
using SIToFPOpLowering = VectorConvertToLLVMPattern<SIToFPOp, LLVM::SIToFPOp>;
using SelectOpLowering = VectorConvertToLLVMPattern<SelectOp, LLVM::SelectOp>;
@@ -405,8 +397,6 @@ using SignedRemIOpLowering =
VectorConvertToLLVMPattern<SignedRemIOp, LLVM::SRemOp>;
using SignedShiftRightOpLowering =
OneToOneConvertToLLVMPattern<SignedShiftRightOp, LLVM::AShrOp>;
-using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
-using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
using SubFOpLowering = VectorConvertToLLVMPattern<SubFOp, LLVM::FSubOp>;
using SubIOpLowering = VectorConvertToLLVMPattern<SubIOp, LLVM::SubOp>;
using TruncateIOpLowering = VectorConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp>;
@@ -656,169 +646,6 @@ struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
using Super::Super;
};
-// A `expm1` is converted into `exp - 1`.
-struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
- using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(math::ExpM1Op op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- math::ExpM1Op::Adaptor transformed(operands);
- auto operandType = transformed.operand().getType();
-
- if (!operandType || !LLVM::isCompatibleType(operandType))
- return failure();
-
- auto loc = op.getLoc();
- auto resultType = op.getResult().getType();
- auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
- auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
-
- if (!operandType.isa<LLVM::LLVMArrayType>()) {
- LLVM::ConstantOp one;
- if (LLVM::isCompatibleVectorType(operandType)) {
- one = rewriter.create<LLVM::ConstantOp>(
- loc, operandType,
- SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
- } else {
- one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
- }
- auto exp = rewriter.create<LLVM::ExpOp>(loc, transformed.operand());
- rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
- return success();
- }
-
- auto vectorType = resultType.dyn_cast<VectorType>();
- if (!vectorType)
- return rewriter.notifyMatchFailure(op, "expected vector result type");
-
- return LLVM::detail::handleMultidimensionalVectors(
- op.getOperation(), operands, *getTypeConverter(),
- [&](Type llvm1DVectorTy, ValueRange operands) {
- auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get(
- {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
- floatType),
- floatOne);
- auto one =
- rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
- auto exp =
- rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
- return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
- },
- rewriter);
- }
-};
-
-// A `log1p` is converted into `log(1 + ...)`.
-struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
- using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(math::Log1pOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- math::Log1pOp::Adaptor transformed(operands);
- auto operandType = transformed.operand().getType();
-
- if (!operandType || !LLVM::isCompatibleType(operandType))
- return rewriter.notifyMatchFailure(op, "unsupported operand type");
-
- auto loc = op.getLoc();
- auto resultType = op.getResult().getType();
- auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
- auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
-
- if (!operandType.isa<LLVM::LLVMArrayType>()) {
- LLVM::ConstantOp one =
- LLVM::isCompatibleVectorType(operandType)
- ? rewriter.create<LLVM::ConstantOp>(
- loc, operandType,
- SplatElementsAttr::get(resultType.cast<ShapedType>(),
- floatOne))
- : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
-
- auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
- transformed.operand());
- rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
- return success();
- }
-
- auto vectorType = resultType.dyn_cast<VectorType>();
- if (!vectorType)
- return rewriter.notifyMatchFailure(op, "expected vector result type");
-
- return LLVM::detail::handleMultidimensionalVectors(
- op.getOperation(), operands, *getTypeConverter(),
- [&](Type llvm1DVectorTy, ValueRange operands) {
- auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get(
- {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
- floatType),
- floatOne);
- auto one =
- rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
- auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
- operands[0]);
- return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
- },
- rewriter);
- }
-};
-
-// A `rsqrt` is converted into `1 / sqrt`.
-struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
- using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(math::RsqrtOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- math::RsqrtOp::Adaptor transformed(operands);
- auto operandType = transformed.operand().getType();
-
- if (!operandType || !LLVM::isCompatibleType(operandType))
- return failure();
-
- auto loc = op.getLoc();
- auto resultType = op.getResult().getType();
- auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
- auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
-
- if (!operandType.isa<LLVM::LLVMArrayType>()) {
- LLVM::ConstantOp one;
- if (LLVM::isCompatibleVectorType(operandType)) {
- one = rewriter.create<LLVM::ConstantOp>(
- loc, operandType,
- SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
- } else {
- one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
- }
- auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
- rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
- return success();
- }
-
- auto vectorType = resultType.dyn_cast<VectorType>();
- if (!vectorType)
- return failure();
-
- return LLVM::detail::handleMultidimensionalVectors(
- op.getOperation(), operands, *getTypeConverter(),
- [&](Type llvm1DVectorTy, ValueRange operands) {
- auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get(
- {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
- floatType),
- floatOne);
- auto one =
- rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
- auto sqrt =
- rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
- return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
- },
- rewriter);
- }
-};
-
struct DialectCastOpLowering
: public ConvertOpToLLVMPattern<LLVM::DialectCastOp> {
using ConvertOpToLLVMPattern<LLVM::DialectCastOp>::ConvertOpToLLVMPattern;
@@ -1375,20 +1202,12 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
CmpIOpLowering,
CondBranchOpLowering,
CopySignOpLowering,
- CosOpLowering,
ConstantOpLowering,
DialectCastOpLowering,
DivFOpLowering,
- ExpOpLowering,
- Exp2OpLowering,
- ExpM1OpLowering,
FloorFOpLowering,
FmaFOpLowering,
GenericAtomicRMWOpLowering,
- LogOpLowering,
- Log10OpLowering,
- Log1pOpLowering,
- Log2OpLowering,
FPExtOpLowering,
FPToSIOpLowering,
FPToUIOpLowering,
@@ -1398,11 +1217,9 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
MulIOpLowering,
NegFOpLowering,
OrOpLowering,
- PowFOpLowering,
RemFOpLowering,
RankOpLowering,
ReturnOpLowering,
- RsqrtOpLowering,
SIToFPOpLowering,
SelectOpLowering,
ShiftLeftOpLowering,
@@ -1410,10 +1227,8 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
SignedDivIOpLowering,
SignedRemIOpLowering,
SignedShiftRightOpLowering,
- SinOpLowering,
SplatOpLowering,
SplatNdOpLowering,
- SqrtOpLowering,
SubFOpLowering,
SubIOpLowering,
SwitchOpLowering,
diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
index 6fa090674e6fc..15f6a314e5668 100644
--- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-complex-to-standard -convert-complex-to-llvm -convert-std-to-llvm | FileCheck %s
+// RUN: mlir-opt %s -convert-complex-to-standard -convert-complex-to-llvm -convert-math-to-llvm -convert-std-to-llvm | FileCheck %s
// CHECK-LABEL: llvm.func @complex_abs
// CHECK-SAME: %[[ARG:.*]]: ![[C_TY:.*]])
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
new file mode 100644
index 0000000000000..3eeaea236bc5c
--- /dev/null
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -0,0 +1,121 @@
+// RUN: mlir-opt %s -split-input-file -convert-math-to-llvm | FileCheck %s
+
+// CHECK-LABEL: @ops
+func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64) {
+// CHECK: = "llvm.intr.exp"(%{{.*}}) : (f32) -> f32
+ %13 = math.exp %arg0 : f32
+// CHECK: = "llvm.intr.exp2"(%{{.*}}) : (f32) -> f32
+ %14 = math.exp2 %arg0 : f32
+// CHECK: = "llvm.intr.sqrt"(%{{.*}}) : (f32) -> f32
+ %19 = math.sqrt %arg0 : f32
+// CHECK: = "llvm.intr.sqrt"(%{{.*}}) : (f64) -> f64
+ %20 = math.sqrt %arg4 : f64
+ std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @log1p(
+// CHECK-SAME: f32
+func @log1p(%arg0 : f32) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+ // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %arg0 : f32
+ // CHECK: %[[LOG:.*]] = "llvm.intr.log"(%[[ADD]]) : (f32) -> f32
+ %0 = math.log1p %arg0 : f32
+ std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @log1p_2dvector(
+func @log1p_2dvector(%arg0 : vector<4x3xf32>) {
+ // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<3xf32>) : vector<3xf32>
+ // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[EXTRACT]] : vector<3xf32>
+ // CHECK: %[[LOG:.*]] = "llvm.intr.log"(%[[ADD]]) : (vector<3xf32>) -> vector<3xf32>
+ // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[LOG]], %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+ %0 = math.log1p %arg0 : vector<4x3xf32>
+ std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @expm1(
+// CHECK-SAME: f32
+func @expm1(%arg0 : f32) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+ // CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%arg0) : (f32) -> f32
+ // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : f32
+ %0 = math.expm1 %arg0 : f32
+ std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt(
+// CHECK-SAME: f32
+func @rsqrt(%arg0 : f32) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+ // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (f32) -> f32
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : f32
+ %0 = math.rsqrt %arg0 : f32
+ std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @sine(
+// CHECK-SAME: f32
+func @sine(%arg0 : f32) {
+ // CHECK: "llvm.intr.sin"(%arg0) : (f32) -> f32
+ %0 = math.sin %arg0 : f32
+ std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_double(
+// CHECK-SAME: f64
+func @rsqrt_double(%arg0 : f64) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f64) : f64
+ // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (f64) -> f64
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : f64
+ %0 = math.rsqrt %arg0 : f64
+ std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_vector(
+// CHECK-SAME: vector<4xf32>
+func @rsqrt_vector(%arg0 : vector<4xf32>) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32>
+ // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (vector<4xf32>) -> vector<4xf32>
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<4xf32>
+ %0 = math.rsqrt %arg0 : vector<4xf32>
+ std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_multidim_vector(
+func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
+ // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<3xf32>) : vector<3xf32>
+ // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%[[EXTRACT]]) : (vector<3xf32>) -> vector<3xf32>
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<3xf32>
+ // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[DIV]], %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+ %0 = math.rsqrt %arg0 : vector<4x3xf32>
+ std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @powf(
+// CHECK-SAME: f64
+func @powf(%arg0 : f64) {
+ // CHECK: %[[POWF:.*]] = "llvm.intr.pow"(%arg0, %arg0) : (f64, f64) -> f64
+ %0 = math.powf %arg0, %arg0 : f64
+ std.return
+}
+
diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index 281fe0c2b24c2..d68ce8cb353c7 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -463,48 +463,40 @@ func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4xi64>
// CHECK-LABEL: @ops
func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64):
-// CHECK-NEXT: %0 = llvm.fsub %arg0, %arg1 : f32
+// CHECK: = llvm.fsub %arg0, %arg1 : f32
%0 = subf %arg0, %arg1: f32
-// CHECK-NEXT: %1 = llvm.sub %arg2, %arg3 : i32
+// CHECK: = llvm.sub %arg2, %arg3 : i32
%1 = subi %arg2, %arg3: i32
-// CHECK-NEXT: %2 = llvm.icmp "slt" %arg2, %1 : i32
+// CHECK: = llvm.icmp "slt" %arg2, %1 : i32
%2 = cmpi slt, %arg2, %1 : i32
-// CHECK-NEXT: %3 = llvm.sdiv %arg2, %arg3 : i32
+// CHECK: = llvm.sdiv %arg2, %arg3 : i32
%3 = divi_signed %arg2, %arg3 : i32
-// CHECK-NEXT: %4 = llvm.udiv %arg2, %arg3 : i32
+// CHECK: = llvm.udiv %arg2, %arg3 : i32
%4 = divi_unsigned %arg2, %arg3 : i32
-// CHECK-NEXT: %5 = llvm.srem %arg2, %arg3 : i32
+// CHECK: = llvm.srem %arg2, %arg3 : i32
%5 = remi_signed %arg2, %arg3 : i32
-// CHECK-NEXT: %6 = llvm.urem %arg2, %arg3 : i32
+// CHECK: = llvm.urem %arg2, %arg3 : i32
%6 = remi_unsigned %arg2, %arg3 : i32
-// CHECK-NEXT: %7 = llvm.select %2, %arg2, %arg3 : i1, i32
+// CHECK: = llvm.select %2, %arg2, %arg3 : i1, i32
%7 = select %2, %arg2, %arg3 : i32
-// CHECK-NEXT: %8 = llvm.fdiv %arg0, %arg1 : f32
+// CHECK: = llvm.fdiv %arg0, %arg1 : f32
%8 = divf %arg0, %arg1 : f32
-// CHECK-NEXT: %9 = llvm.frem %arg0, %arg1 : f32
+// CHECK: = llvm.frem %arg0, %arg1 : f32
%9 = remf %arg0, %arg1 : f32
-// CHECK-NEXT: %10 = llvm.and %arg2, %arg3 : i32
+// CHECK: = llvm.and %arg2, %arg3 : i32
%10 = and %arg2, %arg3 : i32
-// CHECK-NEXT: %11 = llvm.or %arg2, %arg3 : i32
+// CHECK: = llvm.or %arg2, %arg3 : i32
%11 = or %arg2, %arg3 : i32
-// CHECK-NEXT: %12 = llvm.xor %arg2, %arg3 : i32
+// CHECK: = llvm.xor %arg2, %arg3 : i32
%12 = xor %arg2, %arg3 : i32
-// CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (f32) -> f32
- %13 = math.exp %arg0 : f32
-// CHECK-NEXT: %14 = "llvm.intr.exp2"(%arg0) : (f32) -> f32
- %14 = math.exp2 %arg0 : f32
-// CHECK-NEXT: %15 = llvm.mlir.constant(7.900000e-01 : f64) : f64
+// CHECK: = llvm.mlir.constant(7.900000e-01 : f64) : f64
%15 = constant 7.9e-01 : f64
-// CHECK-NEXT: %16 = llvm.shl %arg2, %arg3 : i32
+// CHECK: = llvm.shl %arg2, %arg3 : i32
%16 = shift_left %arg2, %arg3 : i32
-// CHECK-NEXT: %17 = llvm.ashr %arg2, %arg3 : i32
+// CHECK: = llvm.ashr %arg2, %arg3 : i32
%17 = shift_right_signed %arg2, %arg3 : i32
-// CHECK-NEXT: %18 = llvm.lshr %arg2, %arg3 : i32
+// CHECK: = llvm.lshr %arg2, %arg3 : i32
%18 = shift_right_unsigned %arg2, %arg3 : i32
-// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg0) : (f32) -> f32
- %19 = math.sqrt %arg0 : f32
-// CHECK-NEXT: %{{[0-9]+}} = "llvm.intr.sqrt"(%arg4) : (f64) -> f64
- %20 = math.sqrt %arg4 : f64
return %0, %4 : f32, i32
}
@@ -859,66 +851,6 @@ func @rank_of_ranked(%ranked: memref<?xi32>) {
// CHECK: llvm.mlir.constant(1 : index) : i64
// CHECK32: llvm.mlir.constant(1 : index) : i32
-
-// -----
-
-// CHECK-LABEL: func @log1p(
-// CHECK-SAME: f32
-func @log1p(%arg0 : f32) {
- // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
- // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %arg0 : f32
- // CHECK: %[[LOG:.*]] = "llvm.intr.log"(%[[ADD]]) : (f32) -> f32
- %0 = math.log1p %arg0 : f32
- std.return
-}
-
-// -----
-
-// CHECK-LABEL: func @log1p_2dvector(
-func @log1p_2dvector(%arg0 : vector<4x3xf32>) {
- // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<3xf32>>
- // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<3xf32>) : vector<3xf32>
- // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[EXTRACT]] : vector<3xf32>
- // CHECK: %[[LOG:.*]] = "llvm.intr.log"(%[[ADD]]) : (vector<3xf32>) -> vector<3xf32>
- // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[LOG]], %0[0] : !llvm.array<4 x vector<3xf32>>
- %0 = math.log1p %arg0 : vector<4x3xf32>
- std.return
-}
-
-// -----
-
-// CHECK-LABEL: func @expm1(
-// CHECK-SAME: f32
-func @expm1(%arg0 : f32) {
- // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
- // CHECK: %[[EXP:.*]] = "llvm.intr.exp"(%arg0) : (f32) -> f32
- // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : f32
- %0 = math.expm1 %arg0 : f32
- std.return
-}
-
-// -----
-
-// CHECK-LABEL: func @rsqrt(
-// CHECK-SAME: f32
-func @rsqrt(%arg0 : f32) {
- // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
- // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (f32) -> f32
- // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : f32
- %0 = math.rsqrt %arg0 : f32
- std.return
-}
-
-// -----
-
-// CHECK-LABEL: func @sine(
-// CHECK-SAME: f32
-func @sine(%arg0 : f32) {
- // CHECK: "llvm.intr.sin"(%arg0) : (f32) -> f32
- %0 = math.sin %arg0 : f32
- std.return
-}
-
// -----
// CHECK-LABEL: func @ceilf(
@@ -941,45 +873,6 @@ func @floorf(%arg0 : f32) {
// -----
-
-// CHECK-LABEL: func @rsqrt_double(
-// CHECK-SAME: f64
-func @rsqrt_double(%arg0 : f64) {
- // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f64) : f64
- // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (f64) -> f64
- // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : f64
- %0 = math.rsqrt %arg0 : f64
- std.return
-}
-
-// -----
-
-// CHECK-LABEL: func @rsqrt_vector(
-// CHECK-SAME: vector<4xf32>
-func @rsqrt_vector(%arg0 : vector<4xf32>) {
- // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32>
- // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (vector<4xf32>) -> vector<4xf32>
- // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<4xf32>
- %0 = math.rsqrt %arg0 : vector<4xf32>
- std.return
-}
-
-// -----
-
-// CHECK-LABEL: func @rsqrt_multidim_vector(
-// CHECK-SAME: !llvm.array<4 x vector<3xf32>>
-func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
- // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<3xf32>>
- // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<3xf32>) : vector<3xf32>
- // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%[[EXTRACT]]) : (vector<3xf32>) -> vector<3xf32>
- // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<3xf32>
- // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[DIV]], %0[0] : !llvm.array<4 x vector<3xf32>>
- %0 = math.rsqrt %arg0 : vector<4x3xf32>
- std.return
-}
-
-// -----
-
// Lowers `assert` to a function call to `abort` if the assertion is violated.
// CHECK: llvm.func @abort()
// CHECK-LABEL: @assert_test_function
@@ -1010,16 +903,6 @@ func private @zero_result_func()
// -----
-// CHECK-LABEL: func @powf(
-// CHECK-SAME: f64
-func @powf(%arg0 : f64) {
- // CHECK: %[[POWF:.*]] = "llvm.intr.pow"(%arg0, %arg0) : (f64, f64) -> f64
- %0 = math.powf %arg0, %arg0 : f64
- std.return
-}
-
-// -----
-
// CHECK-LABEL: func @fmaf(
// CHECK-SAME: %[[ARG0:.*]]: f32
// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
More information about the Mlir-commits
mailing list