[Mlir-commits] [mlir] 22702cc - [mlir][math] Added math::FPowI conversion to calls of outlined implementations.
Slava Zakharin
llvmlistbot at llvm.org
Tue Dec 13 12:15:45 PST 2022
Author: Slava Zakharin
Date: 2022-12-13T12:15:35-08:00
New Revision: 22702cc76c4d6fcd3ee0e37ca826c539af146494
URL: https://github.com/llvm/llvm-project/commit/22702cc76c4d6fcd3ee0e37ca826c539af146494
DIFF: https://github.com/llvm/llvm-project/commit/22702cc76c4d6fcd3ee0e37ca826c539af146494.diff
LOG: [mlir][math] Added math::FPowI conversion to calls of outlined implementations.
Power functions are implemented as linkonce_odr scalar functions
for FPowI operations met in a module.
Vector form of FPowI is linearized into a sequence of calls
of the scalar functions.
Option {min-width-of-fpowi-exponent} controls which FPowI operations
are converted by MathToFuncs: if the width of the exponent's integer
type is less than the specified value, then the operation is not converted.
Flang will specify {min-width-of-fpowi-exponent=33} to make sure that
math::FPowI operations with exponent wider than 32 bits will be converted
by MathToFuncs, and operations with more narrow exponent will be left
for MathToLLVM to convert them to LLVM::PowIOp.
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D139804
Added:
mlir/test/Conversion/MathToFuncs/fpowi.mlir
mlir/test/Conversion/MathToFuncs/ipowi.mlir
Modified:
flang/lib/Optimizer/CodeGen/CodeGen.cpp
mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
Removed:
mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir
################################################################################
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index ad0da44d48b6b..85f2f2173a5c3 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3628,7 +3628,7 @@ class FIRToLLVMLowering
// function operations in it. We have to run such conversions
// as passes here.
mlir::OpPassManager mathConvertionPM("builtin.module");
- mathConvertionPM.addPass(mlir::createConvertMathToFuncsPass());
+ mathConvertionPM.addPass(mlir::createConvertMathToFuncs());
mathConvertionPM.addPass(mlir::createConvertComplexToStandardPass());
if (mlir::failed(runPipeline(mathConvertionPM, mod)))
return signalPassFailure();
diff --git a/mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h b/mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h
index 8b0247b4dfa83..d2fa66dd655c4 100644
--- a/mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h
+++ b/mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h
@@ -17,9 +17,6 @@ class Pass;
#define GEN_PASS_DECL_CONVERTMATHTOFUNCS
#include "mlir/Conversion/Passes.h.inc"
-// Pass to convert some Math operations into calls of functions
-// containing software implementation of these operations.
-std::unique_ptr<Pass> createConvertMathToFuncsPass();
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOFUNCS_MATHTOFUNCS_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 024f527f1c083..40ade95a09f95 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -536,7 +536,6 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
functions implementing these operations in software.
The LLVM dialect is used for LinkonceODR linkage of the generated functions.
}];
- let constructor = "mlir::createConvertMathToFuncsPass()";
let dependentDialects = [
"arith::ArithDialect",
"cf::ControlFlowDialect",
@@ -544,6 +543,12 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
"vector::VectorDialect",
"LLVM::LLVMDialect",
];
+ let options = [
+ Option<"minWidthOfFPowIExponent", "min-width-of-fpowi-exponent", "unsigned",
+ /*default=*/"1",
+ "Convert FPowI only if the width of its exponent's integer type "
+ "is greater than or equal to this value">
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index d88e7795f0de4..d24a54c7cbc88 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -46,11 +46,7 @@ using GetPowerFuncCallbackTy = function_ref<func::FuncOp(Type)>;
// Pattern to convert scalar IPowIOp into a call of outlined
// software implementation.
-struct IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
-
-private:
- GetPowerFuncCallbackTy getFuncOpCallback;
-
+class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
public:
IPowIOpLowering(MLIRContext *context, GetPowerFuncCallbackTy cb)
: OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
@@ -60,6 +56,26 @@ struct IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
/// so vector forms of IPowI are linearized.
LogicalResult matchAndRewrite(math::IPowIOp op,
PatternRewriter &rewriter) const final;
+
+private:
+ GetPowerFuncCallbackTy getFuncOpCallback;
+};
+
+// Pattern to convert scalar FPowIOp into a call of outlined
+// software implementation.
+class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> {
+public:
+ FPowIOpLowering(MLIRContext *context, GetPowerFuncCallbackTy cb)
+ : OpRewritePattern<math::FPowIOp>(context), getFuncOpCallback(cb) {}
+
+ /// Convert FPowI into a call to a local function implementing
+ /// the power operation. The local function computes a scalar result,
+ /// so vector forms of FPowI are linearized.
+ LogicalResult matchAndRewrite(math::FPowIOp op,
+ PatternRewriter &rewriter) const final;
+
+private:
+ GetPowerFuncCallbackTy getFuncOpCallback;
};
} // namespace
@@ -77,9 +93,14 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
ArrayRef<int64_t> shape = vecType.getShape();
int64_t numElements = vecType.getNumElements();
+ Type resultElementType = vecType.getElementType();
+ Attribute initValueAttr;
+ if (resultElementType.isa<FloatType>())
+ initValueAttr = FloatAttr::get(resultElementType, 0.0);
+ else
+ initValueAttr = IntegerAttr::get(resultElementType, 0);
Value result = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(
- vecType, IntegerAttr::get(vecType.getElementType(), 0)));
+ loc, DenseElementsAttr::get(vecType, initValueAttr));
SmallVector<int64_t> strides = computeStrides(shape);
for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
SmallVector<int64_t> positions = delinearize(strides, linearIndex);
@@ -96,8 +117,21 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
return success();
}
+static FunctionType getElementalFuncTypeForOp(Operation *op) {
+ SmallVector<Type, 1> resultTys(op->getNumResults());
+ SmallVector<Type, 2> inputTys(op->getNumOperands());
+ std::transform(op->result_type_begin(), op->result_type_end(),
+ resultTys.begin(),
+ [](Type ty) { return getElementTypeOrSelf(ty); });
+ std::transform(op->operand_type_begin(), op->operand_type_end(),
+ inputTys.begin(),
+ [](Type ty) { return getElementTypeOrSelf(ty); });
+ return FunctionType::get(op->getContext(), inputTys, resultTys);
+}
+
/// Create linkonce_odr function to implement the power function with
-/// the given \p funcType type inside \p module. \p funcType must be
+/// the given \p elementType type inside \p module. The \p elementType
+/// must be IntegerType, an the created function has
/// 'IntegerType (*)(IntegerType, IntegerType)' function type.
///
/// template <typename T>
@@ -130,7 +164,6 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
assert(elementType.isa<IntegerType>() &&
"non-integer element type for IPowIOp");
- // IntegerType elementType = funcType.getInput(0).cast<IntegerType>();
ImplicitLocOpBuilder builder =
ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
@@ -321,14 +354,246 @@ IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
return success();
}
+/// Create linkonce_odr function to implement the power function with
+/// the given \p funcType type inside \p module. The \p funcType must be
+/// 'FloatType (*)(FloatType, IntegerType)' function type.
+///
+/// template <typename T>
+/// Tb __mlir_math_fpowi_*(Tb b, Tp p) {
+/// if (p == Tp{0})
+/// return Tb{1};
+/// bool isNegativePower{p < Tp{0}}
+/// bool isMin{p == std::numeric_limits<Tp>::min()};
+/// if (isMin) {
+/// p = std::numeric_limits<Tp>::max();
+/// } else if (isNegativePower) {
+/// p = -p;
+/// }
+/// Tb result = Tb{1};
+/// Tb origBase = Tb{b};
+/// while (true) {
+/// if (p & Tp{1})
+/// result *= b;
+/// p >>= Tp{1};
+/// if (p == Tp{0})
+/// break;
+/// b *= b;
+/// }
+/// if (isMin) {
+/// result *= origBase;
+/// }
+/// if (isNegativePower) {
+/// result = Tb{1} / result;
+/// }
+/// return result;
+/// }
+static func::FuncOp createElementFPowIFunc(ModuleOp *module,
+ FunctionType funcType) {
+ auto baseType = funcType.getInput(0).cast<FloatType>();
+ auto powType = funcType.getInput(1).cast<IntegerType>();
+ ImplicitLocOpBuilder builder =
+ ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
+
+ std::string funcName("__mlir_math_fpowi");
+ llvm::raw_string_ostream nameOS(funcName);
+ nameOS << '_' << baseType;
+ nameOS << '_' << powType;
+ auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
+ LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
+ Attribute linkage =
+ LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
+ funcOp->setAttr("llvm.linkage", linkage);
+ funcOp.setPrivate();
+
+ Block *entryBlock = funcOp.addEntryBlock();
+ Region *funcBody = entryBlock->getParent();
+
+ Value bArg = funcOp.getArgument(0);
+ Value pArg = funcOp.getArgument(1);
+ builder.setInsertionPointToEnd(entryBlock);
+ Value oneBValue = builder.create<arith::ConstantOp>(
+ baseType, builder.getFloatAttr(baseType, 1.0));
+ Value zeroPValue = builder.create<arith::ConstantOp>(
+ powType, builder.getIntegerAttr(powType, 0));
+ Value onePValue = builder.create<arith::ConstantOp>(
+ powType, builder.getIntegerAttr(powType, 1));
+ Value minPValue = builder.create<arith::ConstantOp>(
+ powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue(
+ powType.getWidth())));
+ Value maxPValue = builder.create<arith::ConstantOp>(
+ powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue(
+ powType.getWidth())));
+
+ // if (p == Tp{0})
+ // return Tb{1};
+ auto pIsZero =
+ builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroPValue);
+ Block *thenBlock = builder.createBlock(funcBody);
+ builder.create<func::ReturnOp>(oneBValue);
+ Block *fallthroughBlock = builder.createBlock(funcBody);
+ // Set up conditional branch for (p == Tp{0}).
+ builder.setInsertionPointToEnd(pIsZero->getBlock());
+ builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
+
+ builder.setInsertionPointToEnd(fallthroughBlock);
+ // bool isNegativePower{p < Tp{0}}
+ auto pIsNeg = builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg,
+ zeroPValue);
+ // bool isMin{p == std::numeric_limits<Tp>::min()};
+ auto pIsMin =
+ builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, minPValue);
+
+ // if (isMin) {
+ // p = std::numeric_limits<Tp>::max();
+ // } else if (isNegativePower) {
+ // p = -p;
+ // }
+ Value negP = builder.create<arith::SubIOp>(zeroPValue, pArg);
+ auto pInit = builder.create<arith::SelectOp>(pIsNeg, negP, pArg);
+ pInit = builder.create<arith::SelectOp>(pIsMin, maxPValue, pInit);
+
+ // Tb result = Tb{1};
+ // Tb origBase = Tb{b};
+ // while (true) {
+ // if (p & Tp{1})
+ // result *= b;
+ // p >>= Tp{1};
+ // if (p == Tp{0})
+ // break;
+ // b *= b;
+ // }
+ Block *loopHeader = builder.createBlock(
+ funcBody, funcBody->end(), {baseType, baseType, powType},
+ {builder.getLoc(), builder.getLoc(), builder.getLoc()});
+ // Set initial values of 'result', 'b' and 'p' for the loop.
+ builder.setInsertionPointToEnd(pInit->getBlock());
+ builder.create<cf::BranchOp>(loopHeader, ValueRange{oneBValue, bArg, pInit});
+
+ // Create loop body.
+ Value resultTmp = loopHeader->getArgument(0);
+ Value baseTmp = loopHeader->getArgument(1);
+ Value powerTmp = loopHeader->getArgument(2);
+ builder.setInsertionPointToEnd(loopHeader);
+
+ // if (p & Tp{1})
+ auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
+ arith::CmpIPredicate::ne,
+ builder.create<arith::AndIOp>(powerTmp, onePValue), zeroPValue);
+ thenBlock = builder.createBlock(funcBody);
+ // result *= b;
+ Value newResultTmp = builder.create<arith::MulFOp>(resultTmp, baseTmp);
+ fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
+ builder.getLoc());
+ builder.setInsertionPointToEnd(thenBlock);
+ builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
+ // Set up conditional branch for (p & Tp{1}).
+ builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
+ builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
+ resultTmp);
+ // Merged 'result'.
+ newResultTmp = fallthroughBlock->getArgument(0);
+
+ // p >>= Tp{1};
+ builder.setInsertionPointToEnd(fallthroughBlock);
+ Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, onePValue);
+
+ // if (p == Tp{0})
+ auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+ newPowerTmp, zeroPValue);
+ // break;
+ //
+ // The conditional branch is finalized below with a jump to
+ // the loop exit block.
+ fallthroughBlock = builder.createBlock(funcBody);
+
+ // b *= b;
+ // }
+ builder.setInsertionPointToEnd(fallthroughBlock);
+ Value newBaseTmp = builder.create<arith::MulFOp>(baseTmp, baseTmp);
+ // Pass new values for 'result', 'b' and 'p' to the loop header.
+ builder.create<cf::BranchOp>(
+ ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
+
+ // Set up conditional branch for early loop exit:
+ // if (p == Tp{0})
+ // break;
+ Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType,
+ builder.getLoc());
+ builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
+ builder.create<cf::CondBranchOp>(newPowerIsZero, loopExit, newResultTmp,
+ fallthroughBlock, ValueRange{});
+
+ // if (isMin) {
+ // result *= origBase;
+ // }
+ newResultTmp = loopExit->getArgument(0);
+ thenBlock = builder.createBlock(funcBody);
+ fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
+ builder.getLoc());
+ builder.setInsertionPointToEnd(loopExit);
+ builder.create<cf::CondBranchOp>(pIsMin, thenBlock, fallthroughBlock,
+ newResultTmp);
+ builder.setInsertionPointToEnd(thenBlock);
+ newResultTmp = builder.create<arith::MulFOp>(newResultTmp, bArg);
+ builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
+
+ /// if (isNegativePower) {
+ /// result = Tb{1} / result;
+ /// }
+ newResultTmp = fallthroughBlock->getArgument(0);
+ thenBlock = builder.createBlock(funcBody);
+ Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
+ builder.getLoc());
+ builder.setInsertionPointToEnd(fallthroughBlock);
+ builder.create<cf::CondBranchOp>(pIsNeg, thenBlock, returnBlock,
+ newResultTmp);
+ builder.setInsertionPointToEnd(thenBlock);
+ newResultTmp = builder.create<arith::DivFOp>(oneBValue, newResultTmp);
+ builder.create<cf::BranchOp>(newResultTmp, returnBlock);
+
+ // return result;
+ builder.setInsertionPointToEnd(returnBlock);
+ builder.create<func::ReturnOp>(returnBlock->getArgument(0));
+
+ return funcOp;
+}
+
+/// Convert FPowI into a call to a local function implementing
+/// the power operation. The local function computes a scalar result,
+/// so vector forms of FPowI are linearized.
+LogicalResult
+FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
+ PatternRewriter &rewriter) const {
+ if (op.getType().template dyn_cast<VectorType>())
+ return rewriter.notifyMatchFailure(op, "non-scalar operation");
+
+ FunctionType funcType = getElementalFuncTypeForOp(op);
+
+ // The outlined software implementation must have been already
+ // generated.
+ func::FuncOp elementFunc = getFuncOpCallback(funcType);
+ if (!elementFunc)
+ return rewriter.notifyMatchFailure(op, "missing software implementation");
+
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
+ return success();
+}
+
namespace {
struct ConvertMathToFuncsPass
: public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
ConvertMathToFuncsPass() = default;
+ ConvertMathToFuncsPass(const ConvertMathToFuncsOptions &options)
+ : impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(options) {}
void runOnOperation() override;
private:
+ // Return true, if this FPowI operation must be converted
+ // because the width of its exponent's type is greater than
+ // or equal to minWidthOfFPowIExponent option value.
+ bool isFPowIConvertible(math::FPowIOp op);
+
// Generate outlined implementations for power operations
// and store them in powerFuncs map.
void preprocessPowOperations();
@@ -340,19 +605,40 @@ struct ConvertMathToFuncsPass
};
} // namespace
+bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
+ auto expTy =
+ getElementTypeOrSelf(op.getRhs().getType()).dyn_cast<IntegerType>();
+ return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
+}
+
void ConvertMathToFuncsPass::preprocessPowOperations() {
ModuleOp module = getOperation();
module.walk([&](Operation *op) {
- TypeSwitch<Operation *>(op).Case<math::IPowIOp>([&](math::IPowIOp op) {
- Type resultType = getElementTypeOrSelf(op.getResult().getType());
-
- // Generate the software implementation of this operation,
- // if it has not been generated yet.
- auto entry = powerFuncs.try_emplace(resultType, func::FuncOp{});
- if (entry.second)
- entry.first->second = createElementIPowIFunc(&module, resultType);
- });
+ TypeSwitch<Operation *>(op)
+ .Case<math::IPowIOp>([&](math::IPowIOp op) {
+ Type resultType = getElementTypeOrSelf(op.getResult().getType());
+
+ // Generate the software implementation of this operation,
+ // if it has not been generated yet.
+ auto entry = powerFuncs.try_emplace(resultType, func::FuncOp{});
+ if (entry.second)
+ entry.first->second = createElementIPowIFunc(&module, resultType);
+ })
+ .Case<math::FPowIOp>([&](math::FPowIOp op) {
+ if (!isFPowIConvertible(op))
+ return;
+
+ FunctionType funcType = getElementalFuncTypeForOp(op);
+
+ // Generate the software implementation of this operation,
+ // if it has not been generated yet.
+ // FPowI implementations are mapped via the FunctionType
+ // created from the operation's result and operands.
+ auto entry = powerFuncs.try_emplace(funcType, func::FuncOp{});
+ if (entry.second)
+ entry.first->second = createElementFPowIFunc(&module, funcType);
+ });
});
}
@@ -363,7 +649,8 @@ void ConvertMathToFuncsPass::runOnOperation() {
preprocessPowOperations();
RewritePatternSet patterns(&getContext());
- patterns.add<VecOpToScalarOp<math::IPowIOp>>(patterns.getContext());
+ patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>>(
+ patterns.getContext());
// For the given Type Returns FuncOp stored in powerFuncs map.
auto getPowerFuncOpByType = [&](Type type) -> func::FuncOp {
@@ -373,16 +660,15 @@ void ConvertMathToFuncsPass::runOnOperation() {
return it->second;
};
- patterns.add<IPowIOpLowering>(patterns.getContext(), getPowerFuncOpByType);
+ patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(),
+ getPowerFuncOpByType);
ConversionTarget target(getContext());
target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
func::FuncDialect, vector::VectorDialect>();
target.addIllegalOp<math::IPowIOp>();
+ target.addDynamicallyLegalOp<math::FPowIOp>(
+ [this](math::FPowIOp op) { return !isFPowIConvertible(op); });
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
-
-std::unique_ptr<Pass> mlir::createConvertMathToFuncsPass() {
- return std::make_unique<ConvertMathToFuncsPass>();
-}
diff --git a/mlir/test/Conversion/MathToFuncs/fpowi.mlir b/mlir/test/Conversion/MathToFuncs/fpowi.mlir
new file mode 100644
index 0000000000000..bae791707f2bb
--- /dev/null
+++ b/mlir/test/Conversion/MathToFuncs/fpowi.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(convert-math-to-funcs{min-width-of-fpowi-exponent=33})" | FileCheck %s
+
+// -----
+
+// Check that i32 exponent case is not converted
+// due to {min-width-of-fpowi-exponent=33}:
+
+// CHECK-LABEL: func.func @fpowi32(
+// CHECK-SAME: %[[VAL_0:.*]]: f64,
+// CHECK-SAME: %[[VAL_1:.*]]: i32) {
+// CHECK: %[[VAL_2:.*]] = math.fpowi %[[VAL_0]], %[[VAL_1]] : f64, i32
+// CHECK: return
+// CHECK: }
+func.func @fpowi32(%arg0: f64, %arg1: i32) {
+ %0 = math.fpowi %arg0, %arg1 : f64, i32
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fpowi64(
+// CHECK-SAME: %[[VAL_0:.*]]: f64,
+// CHECK-SAME: %[[VAL_1:.*]]: i64) {
+// CHECK: %[[VAL_2:.*]] = call @__mlir_math_fpowi_f64_i64(%[[VAL_0]], %[[VAL_1]]) : (f64, i64) -> f64
+// CHECK: return
+// CHECK: }
+
+// CHECK-LABEL: func.func private @__mlir_math_fpowi_f64_i64(
+// CHECK-SAME: %[[VAL_0:.*]]: f64,
+// CHECK-SAME: %[[VAL_1:.*]]: i64) -> f64 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK: %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f64
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i64
+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i64
+// CHECK: %[[VAL_5:.*]] = arith.constant -9223372036854775808 : i64
+// CHECK: %[[VAL_6:.*]] = arith.constant 9223372036854775807 : i64
+// CHECK: %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_3]] : i64
+// CHECK: cf.cond_br %[[VAL_7]], ^bb1, ^bb2
+// CHECK: ^bb1:
+// CHECK: return %[[VAL_2]] : f64
+// CHECK: ^bb2:
+// CHECK: %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_3]] : i64
+// CHECK: %[[VAL_9:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_5]] : i64
+// CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_3]], %[[VAL_1]] : i64
+// CHECK: %[[VAL_11:.*]] = arith.select %[[VAL_8]], %[[VAL_10]], %[[VAL_1]] : i64
+// CHECK: %[[VAL_12:.*]] = arith.select %[[VAL_9]], %[[VAL_6]], %[[VAL_11]] : i64
+// CHECK: cf.br ^bb3(%[[VAL_2]], %[[VAL_0]], %[[VAL_12]] : f64, f64, i64)
+// CHECK: ^bb3(%[[VAL_13:.*]]: f64, %[[VAL_14:.*]]: f64, %[[VAL_15:.*]]: i64):
+// CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_15]], %[[VAL_4]] : i64
+// CHECK: %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_16]], %[[VAL_3]] : i64
+// CHECK: cf.cond_br %[[VAL_17]], ^bb4, ^bb5(%[[VAL_13]] : f64)
+// CHECK: ^bb4:
+// CHECK: %[[VAL_18:.*]] = arith.mulf %[[VAL_13]], %[[VAL_14]] : f64
+// CHECK: cf.br ^bb5(%[[VAL_18]] : f64)
+// CHECK: ^bb5(%[[VAL_19:.*]]: f64):
+// CHECK: %[[VAL_20:.*]] = arith.shrui %[[VAL_15]], %[[VAL_4]] : i64
+// CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_3]] : i64
+// CHECK: cf.cond_br %[[VAL_21]], ^bb7(%[[VAL_19]] : f64), ^bb6
+// CHECK: ^bb6:
+// CHECK: %[[VAL_22:.*]] = arith.mulf %[[VAL_14]], %[[VAL_14]] : f64
+// CHECK: cf.br ^bb3(%[[VAL_19]], %[[VAL_22]], %[[VAL_20]] : f64, f64, i64)
+// CHECK: ^bb7(%[[VAL_23:.*]]: f64):
+// CHECK: cf.cond_br %[[VAL_9]], ^bb8, ^bb9(%[[VAL_23]] : f64)
+// CHECK: ^bb8:
+// CHECK: %[[VAL_24:.*]] = arith.mulf %[[VAL_23]], %[[VAL_0]] : f64
+// CHECK: cf.br ^bb9(%[[VAL_24]] : f64)
+// CHECK: ^bb9(%[[VAL_25:.*]]: f64):
+// CHECK: cf.cond_br %[[VAL_8]], ^bb10, ^bb11(%[[VAL_25]] : f64)
+// CHECK: ^bb10:
+// CHECK: %[[VAL_26:.*]] = arith.divf %[[VAL_2]], %[[VAL_25]] : f64
+// CHECK: cf.br ^bb11(%[[VAL_26]] : f64)
+// CHECK: ^bb11(%[[VAL_27:.*]]: f64):
+// CHECK: return %[[VAL_27]] : f64
+// CHECK: }
+func.func @fpowi64(%arg0: f64, %arg1: i64) {
+ %0 = math.fpowi %arg0, %arg1 : f64, i64
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fpowi64_vec(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<2x3xi64>) {
+// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
+// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0, 0] : vector<2x3xf32>
+// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_1]][0, 0] : vector<2x3xi64>
+// CHECK: %[[VAL_5:.*]] = call @__mlir_math_fpowi_f32_i64(%[[VAL_3]], %[[VAL_4]]) : (f32, i64) -> f32
+// CHECK: %[[VAL_6:.*]] = vector.insert %[[VAL_5]], %[[VAL_2]] [0, 0] : f32 into vector<2x3xf32>
+// CHECK: %[[VAL_7:.*]] = vector.extract %[[VAL_0]][0, 1] : vector<2x3xf32>
+// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_1]][0, 1] : vector<2x3xi64>
+// CHECK: %[[VAL_9:.*]] = call @__mlir_math_fpowi_f32_i64(%[[VAL_7]], %[[VAL_8]]) : (f32, i64) -> f32
+// CHECK: %[[VAL_10:.*]] = vector.insert %[[VAL_9]], %[[VAL_6]] [0, 1] : f32 into vector<2x3xf32>
+// CHECK: %[[VAL_11:.*]] = vector.extract %[[VAL_0]][0, 2] : vector<2x3xf32>
+// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_1]][0, 2] : vector<2x3xi64>
+// CHECK: %[[VAL_13:.*]] = call @__mlir_math_fpowi_f32_i64(%[[VAL_11]], %[[VAL_12]]) : (f32, i64) -> f32
+// CHECK: %[[VAL_14:.*]] = vector.insert %[[VAL_13]], %[[VAL_10]] [0, 2] : f32 into vector<2x3xf32>
+// CHECK: %[[VAL_15:.*]] = vector.extract %[[VAL_0]][1, 0] : vector<2x3xf32>
+// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_1]][1, 0] : vector<2x3xi64>
+// CHECK: %[[VAL_17:.*]] = call @__mlir_math_fpowi_f32_i64(%[[VAL_15]], %[[VAL_16]]) : (f32, i64) -> f32
+// CHECK: %[[VAL_18:.*]] = vector.insert %[[VAL_17]], %[[VAL_14]] [1, 0] : f32 into vector<2x3xf32>
+// CHECK: %[[VAL_19:.*]] = vector.extract %[[VAL_0]][1, 1] : vector<2x3xf32>
+// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_1]][1, 1] : vector<2x3xi64>
+// CHECK: %[[VAL_21:.*]] = call @__mlir_math_fpowi_f32_i64(%[[VAL_19]], %[[VAL_20]]) : (f32, i64) -> f32
+// CHECK: %[[VAL_22:.*]] = vector.insert %[[VAL_21]], %[[VAL_18]] [1, 1] : f32 into vector<2x3xf32>
+// CHECK: %[[VAL_23:.*]] = vector.extract %[[VAL_0]][1, 2] : vector<2x3xf32>
+// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_1]][1, 2] : vector<2x3xi64>
+// CHECK: %[[VAL_25:.*]] = call @__mlir_math_fpowi_f32_i64(%[[VAL_23]], %[[VAL_24]]) : (f32, i64) -> f32
+// CHECK: %[[VAL_26:.*]] = vector.insert %[[VAL_25]], %[[VAL_22]] [1, 2] : f32 into vector<2x3xf32>
+// CHECK: return
+// CHECK: }
+
+// CHECK-LABEL: func.func private @__mlir_math_fpowi_f32_i64(
+// CHECK-SAME: %[[VAL_0:.*]]: f32,
+// CHECK-SAME: %[[VAL_1:.*]]: i64) -> f32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK: %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i64
+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i64
+// CHECK: %[[VAL_5:.*]] = arith.constant -9223372036854775808 : i64
+// CHECK: %[[VAL_6:.*]] = arith.constant 9223372036854775807 : i64
+// CHECK: %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_3]] : i64
+// CHECK: cf.cond_br %[[VAL_7]], ^bb1, ^bb2
+// CHECK: ^bb1:
+// CHECK: return %[[VAL_2]] : f32
+// CHECK: ^bb2:
+// CHECK: %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_3]] : i64
+// CHECK: %[[VAL_9:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_5]] : i64
+// CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_3]], %[[VAL_1]] : i64
+// CHECK: %[[VAL_11:.*]] = arith.select %[[VAL_8]], %[[VAL_10]], %[[VAL_1]] : i64
+// CHECK: %[[VAL_12:.*]] = arith.select %[[VAL_9]], %[[VAL_6]], %[[VAL_11]] : i64
+// CHECK: cf.br ^bb3(%[[VAL_2]], %[[VAL_0]], %[[VAL_12]] : f32, f32, i64)
+// CHECK: ^bb3(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: i64):
+// CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_15]], %[[VAL_4]] : i64
+// CHECK: %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_16]], %[[VAL_3]] : i64
+// CHECK: cf.cond_br %[[VAL_17]], ^bb4, ^bb5(%[[VAL_13]] : f32)
+// CHECK: ^bb4:
+// CHECK: %[[VAL_18:.*]] = arith.mulf %[[VAL_13]], %[[VAL_14]] : f32
+// CHECK: cf.br ^bb5(%[[VAL_18]] : f32)
+// CHECK: ^bb5(%[[VAL_19:.*]]: f32):
+// CHECK: %[[VAL_20:.*]] = arith.shrui %[[VAL_15]], %[[VAL_4]] : i64
+// CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_3]] : i64
+// CHECK: cf.cond_br %[[VAL_21]], ^bb7(%[[VAL_19]] : f32), ^bb6
+// CHECK: ^bb6:
+// CHECK: %[[VAL_22:.*]] = arith.mulf %[[VAL_14]], %[[VAL_14]] : f32
+// CHECK: cf.br ^bb3(%[[VAL_19]], %[[VAL_22]], %[[VAL_20]] : f32, f32, i64)
+// CHECK: ^bb7(%[[VAL_23:.*]]: f32):
+// CHECK: cf.cond_br %[[VAL_9]], ^bb8, ^bb9(%[[VAL_23]] : f32)
+// CHECK: ^bb8:
+// CHECK: %[[VAL_24:.*]] = arith.mulf %[[VAL_23]], %[[VAL_0]] : f32
+// CHECK: cf.br ^bb9(%[[VAL_24]] : f32)
+// CHECK: ^bb9(%[[VAL_25:.*]]: f32):
+// CHECK: cf.cond_br %[[VAL_8]], ^bb10, ^bb11(%[[VAL_25]] : f32)
+// CHECK: ^bb10:
+// CHECK: %[[VAL_26:.*]] = arith.divf %[[VAL_2]], %[[VAL_25]] : f32
+// CHECK: cf.br ^bb11(%[[VAL_26]] : f32)
+// CHECK: ^bb11(%[[VAL_27:.*]]: f32):
+// CHECK: return %[[VAL_27]] : f32
+// CHECK: }
+func.func @fpowi64_vec(%arg0: vector<2x3xf32>, %arg1: vector<2x3xi64>) {
+ %0 = math.fpowi %arg0, %arg1 : vector<2x3xf32>, vector<2x3xi64>
+ func.return
+}
diff --git a/mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir b/mlir/test/Conversion/MathToFuncs/ipowi.mlir
similarity index 100%
rename from mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir
rename to mlir/test/Conversion/MathToFuncs/ipowi.mlir
More information about the Mlir-commits
mailing list