[flang-commits] [flang] 22702cc - [mlir][math] Added math::FPowI conversion to calls of outlined implementations.

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Tue Dec 13 12:15:44 PST 2022


Author: Slava Zakharin
Date: 2022-12-13T12:15:35-08:00
New Revision: 22702cc76c4d6fcd3ee0e37ca826c539af146494

URL: https://github.com/llvm/llvm-project/commit/22702cc76c4d6fcd3ee0e37ca826c539af146494
DIFF: https://github.com/llvm/llvm-project/commit/22702cc76c4d6fcd3ee0e37ca826c539af146494.diff

LOG: [mlir][math] Added math::FPowI conversion to calls of outlined implementations.

Power functions are implemented as linkonce_odr scalar functions
for FPowI operations met in a module.
Vector form of FPowI is linearized into a sequence of calls
of the scalar functions.

Option {min-width-of-fpowi-exponent} controls which FPowI operations
are converted by MathToFuncs: if the width of the exponent's integer
type is less than the specified value, then the operation is not converted.

Flang will specify {min-width-of-fpowi-exponent=33} to make sure that
math::FPowI operations with exponent wider than 32 bits will be converted
by MathToFuncs, and operations with more narrow exponent will be left
for MathToLLVM to convert them to LLVM::PowIOp.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D139804

Added: 
    mlir/test/Conversion/MathToFuncs/fpowi.mlir
    mlir/test/Conversion/MathToFuncs/ipowi.mlir

Modified: 
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp

Removed: 
    mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir


################################################################################
diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index ad0da44d48b6b..85f2f2173a5c3 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3628,7 +3628,7 @@ class FIRToLLVMLowering
     // function operations in it. We have to run such conversions
     // as passes here.
     mlir::OpPassManager mathConvertionPM("builtin.module");
-    mathConvertionPM.addPass(mlir::createConvertMathToFuncsPass());
+    mathConvertionPM.addPass(mlir::createConvertMathToFuncs());
     mathConvertionPM.addPass(mlir::createConvertComplexToStandardPass());
     if (mlir::failed(runPipeline(mathConvertionPM, mod)))
       return signalPassFailure();

diff  --git a/mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h b/mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h
index 8b0247b4dfa83..d2fa66dd655c4 100644
--- a/mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h
+++ b/mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h
@@ -17,9 +17,6 @@ class Pass;
 #define GEN_PASS_DECL_CONVERTMATHTOFUNCS
 #include "mlir/Conversion/Passes.h.inc"
 
-// Pass to convert some Math operations into calls of functions
-// containing software implementation of these operations.
-std::unique_ptr<Pass> createConvertMathToFuncsPass();
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_MATHTOFUNCS_MATHTOFUNCS_H

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 024f527f1c083..40ade95a09f95 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -536,7 +536,6 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
     functions implementing these operations in software.
     The LLVM dialect is used for LinkonceODR linkage of the generated functions.
   }];
-  let constructor = "mlir::createConvertMathToFuncsPass()";
   let dependentDialects = [
     "arith::ArithDialect",
     "cf::ControlFlowDialect",
@@ -544,6 +543,12 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
     "vector::VectorDialect",
     "LLVM::LLVMDialect",
   ];
+  let options = [
+    Option<"minWidthOfFPowIExponent", "min-width-of-fpowi-exponent", "unsigned",
+           /*default=*/"1",
+           "Convert FPowI only if the width of its exponent's integer type "
+           "is greater than or equal to this value">
+  ];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index d88e7795f0de4..d24a54c7cbc88 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -46,11 +46,7 @@ using GetPowerFuncCallbackTy = function_ref<func::FuncOp(Type)>;
 
 // Pattern to convert scalar IPowIOp into a call of outlined
 // software implementation.
-struct IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
-
-private:
-  GetPowerFuncCallbackTy getFuncOpCallback;
-
+class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
 public:
   IPowIOpLowering(MLIRContext *context, GetPowerFuncCallbackTy cb)
       : OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
@@ -60,6 +56,26 @@ struct IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
   /// so vector forms of IPowI are linearized.
   LogicalResult matchAndRewrite(math::IPowIOp op,
                                 PatternRewriter &rewriter) const final;
+
+private:
+  GetPowerFuncCallbackTy getFuncOpCallback;
+};
+
+// Pattern to convert scalar FPowIOp into a call of outlined
+// software implementation.
+class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> {
+public:
+  FPowIOpLowering(MLIRContext *context, GetPowerFuncCallbackTy cb)
+      : OpRewritePattern<math::FPowIOp>(context), getFuncOpCallback(cb) {}
+
+  /// Convert FPowI into a call to a local function implementing
+  /// the power operation. The local function computes a scalar result,
+  /// so vector forms of FPowI are linearized.
+  LogicalResult matchAndRewrite(math::FPowIOp op,
+                                PatternRewriter &rewriter) const final;
+
+private:
+  GetPowerFuncCallbackTy getFuncOpCallback;
 };
 } // namespace
 
@@ -77,9 +93,14 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
   ArrayRef<int64_t> shape = vecType.getShape();
   int64_t numElements = vecType.getNumElements();
 
+  Type resultElementType = vecType.getElementType();
+  Attribute initValueAttr;
+  if (resultElementType.isa<FloatType>())
+    initValueAttr = FloatAttr::get(resultElementType, 0.0);
+  else
+    initValueAttr = IntegerAttr::get(resultElementType, 0);
   Value result = rewriter.create<arith::ConstantOp>(
-      loc, DenseElementsAttr::get(
-               vecType, IntegerAttr::get(vecType.getElementType(), 0)));
+      loc, DenseElementsAttr::get(vecType, initValueAttr));
   SmallVector<int64_t> strides = computeStrides(shape);
   for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
     SmallVector<int64_t> positions = delinearize(strides, linearIndex);
@@ -96,8 +117,21 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
   return success();
 }
 
+static FunctionType getElementalFuncTypeForOp(Operation *op) {
+  SmallVector<Type, 1> resultTys(op->getNumResults());
+  SmallVector<Type, 2> inputTys(op->getNumOperands());
+  std::transform(op->result_type_begin(), op->result_type_end(),
+                 resultTys.begin(),
+                 [](Type ty) { return getElementTypeOrSelf(ty); });
+  std::transform(op->operand_type_begin(), op->operand_type_end(),
+                 inputTys.begin(),
+                 [](Type ty) { return getElementTypeOrSelf(ty); });
+  return FunctionType::get(op->getContext(), inputTys, resultTys);
+}
+
 /// Create linkonce_odr function to implement the power function with
-/// the given \p funcType type inside \p module. \p funcType must be
+/// the given \p elementType type inside \p module. The \p elementType
+/// must be IntegerType, an the created function has
 /// 'IntegerType (*)(IntegerType, IntegerType)' function type.
 ///
 /// template <typename T>
@@ -130,7 +164,6 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
   assert(elementType.isa<IntegerType>() &&
          "non-integer element type for IPowIOp");
 
-  //  IntegerType elementType = funcType.getInput(0).cast<IntegerType>();
   ImplicitLocOpBuilder builder =
       ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
 
@@ -321,14 +354,246 @@ IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
   return success();
 }
 
+/// Create linkonce_odr function to implement the power function with
+/// the given \p funcType type inside \p module. The \p funcType must be
+/// 'FloatType (*)(FloatType, IntegerType)' function type.
+///
+/// template <typename T>
+/// Tb __mlir_math_fpowi_*(Tb b, Tp p) {
+///   if (p == Tp{0})
+///     return Tb{1};
+///   bool isNegativePower{p < Tp{0}}
+///   bool isMin{p == std::numeric_limits<Tp>::min()};
+///   if (isMin) {
+///     p = std::numeric_limits<Tp>::max();
+///   } else if (isNegativePower) {
+///     p = -p;
+///   }
+///   Tb result = Tb{1};
+///   Tb origBase = Tb{b};
+///   while (true) {
+///     if (p & Tp{1})
+///       result *= b;
+///     p >>= Tp{1};
+///     if (p == Tp{0})
+///       break;
+///     b *= b;
+///   }
+///   if (isMin) {
+///     result *= origBase;
+///   }
+///   if (isNegativePower) {
+///     result = Tb{1} / result;
+///   }
+///   return result;
+/// }
+static func::FuncOp createElementFPowIFunc(ModuleOp *module,
+                                           FunctionType funcType) {
+  auto baseType = funcType.getInput(0).cast<FloatType>();
+  auto powType = funcType.getInput(1).cast<IntegerType>();
+  ImplicitLocOpBuilder builder =
+      ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
+
+  std::string funcName("__mlir_math_fpowi");
+  llvm::raw_string_ostream nameOS(funcName);
+  nameOS << '_' << baseType;
+  nameOS << '_' << powType;
+  auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
+  LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
+  Attribute linkage =
+      LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
+  funcOp->setAttr("llvm.linkage", linkage);
+  funcOp.setPrivate();
+
+  Block *entryBlock = funcOp.addEntryBlock();
+  Region *funcBody = entryBlock->getParent();
+
+  Value bArg = funcOp.getArgument(0);
+  Value pArg = funcOp.getArgument(1);
+  builder.setInsertionPointToEnd(entryBlock);
+  Value oneBValue = builder.create<arith::ConstantOp>(
+      baseType, builder.getFloatAttr(baseType, 1.0));
+  Value zeroPValue = builder.create<arith::ConstantOp>(
+      powType, builder.getIntegerAttr(powType, 0));
+  Value onePValue = builder.create<arith::ConstantOp>(
+      powType, builder.getIntegerAttr(powType, 1));
+  Value minPValue = builder.create<arith::ConstantOp>(
+      powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue(
+                                                   powType.getWidth())));
+  Value maxPValue = builder.create<arith::ConstantOp>(
+      powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue(
+                                                   powType.getWidth())));
+
+  // if (p == Tp{0})
+  //   return Tb{1};
+  auto pIsZero =
+      builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroPValue);
+  Block *thenBlock = builder.createBlock(funcBody);
+  builder.create<func::ReturnOp>(oneBValue);
+  Block *fallthroughBlock = builder.createBlock(funcBody);
+  // Set up conditional branch for (p == Tp{0}).
+  builder.setInsertionPointToEnd(pIsZero->getBlock());
+  builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
+
+  builder.setInsertionPointToEnd(fallthroughBlock);
+  // bool isNegativePower{p < Tp{0}}
+  auto pIsNeg = builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg,
+                                              zeroPValue);
+  // bool isMin{p == std::numeric_limits<Tp>::min()};
+  auto pIsMin =
+      builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, minPValue);
+
+  // if (isMin) {
+  //   p = std::numeric_limits<Tp>::max();
+  // } else if (isNegativePower) {
+  //   p = -p;
+  // }
+  Value negP = builder.create<arith::SubIOp>(zeroPValue, pArg);
+  auto pInit = builder.create<arith::SelectOp>(pIsNeg, negP, pArg);
+  pInit = builder.create<arith::SelectOp>(pIsMin, maxPValue, pInit);
+
+  // Tb result = Tb{1};
+  // Tb origBase = Tb{b};
+  // while (true) {
+  //   if (p & Tp{1})
+  //     result *= b;
+  //   p >>= Tp{1};
+  //   if (p == Tp{0})
+  //     break;
+  //   b *= b;
+  // }
+  Block *loopHeader = builder.createBlock(
+      funcBody, funcBody->end(), {baseType, baseType, powType},
+      {builder.getLoc(), builder.getLoc(), builder.getLoc()});
+  // Set initial values of 'result', 'b' and 'p' for the loop.
+  builder.setInsertionPointToEnd(pInit->getBlock());
+  builder.create<cf::BranchOp>(loopHeader, ValueRange{oneBValue, bArg, pInit});
+
+  // Create loop body.
+  Value resultTmp = loopHeader->getArgument(0);
+  Value baseTmp = loopHeader->getArgument(1);
+  Value powerTmp = loopHeader->getArgument(2);
+  builder.setInsertionPointToEnd(loopHeader);
+
+  //   if (p & Tp{1})
+  auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
+      arith::CmpIPredicate::ne,
+      builder.create<arith::AndIOp>(powerTmp, onePValue), zeroPValue);
+  thenBlock = builder.createBlock(funcBody);
+  //     result *= b;
+  Value newResultTmp = builder.create<arith::MulFOp>(resultTmp, baseTmp);
+  fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
+                                         builder.getLoc());
+  builder.setInsertionPointToEnd(thenBlock);
+  builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
+  // Set up conditional branch for (p & Tp{1}).
+  builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
+  builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
+                                   resultTmp);
+  // Merged 'result'.
+  newResultTmp = fallthroughBlock->getArgument(0);
+
+  //   p >>= Tp{1};
+  builder.setInsertionPointToEnd(fallthroughBlock);
+  Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, onePValue);
+
+  //   if (p == Tp{0})
+  auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+                                                      newPowerTmp, zeroPValue);
+  //     break;
+  //
+  // The conditional branch is finalized below with a jump to
+  // the loop exit block.
+  fallthroughBlock = builder.createBlock(funcBody);
+
+  //   b *= b;
+  // }
+  builder.setInsertionPointToEnd(fallthroughBlock);
+  Value newBaseTmp = builder.create<arith::MulFOp>(baseTmp, baseTmp);
+  // Pass new values for 'result', 'b' and 'p' to the loop header.
+  builder.create<cf::BranchOp>(
+      ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
+
+  // Set up conditional branch for early loop exit:
+  //   if (p == Tp{0})
+  //     break;
+  Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType,
+                                        builder.getLoc());
+  builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
+  builder.create<cf::CondBranchOp>(newPowerIsZero, loopExit, newResultTmp,
+                                   fallthroughBlock, ValueRange{});
+
+  // if (isMin) {
+  //   result *= origBase;
+  // }
+  newResultTmp = loopExit->getArgument(0);
+  thenBlock = builder.createBlock(funcBody);
+  fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
+                                         builder.getLoc());
+  builder.setInsertionPointToEnd(loopExit);
+  builder.create<cf::CondBranchOp>(pIsMin, thenBlock, fallthroughBlock,
+                                   newResultTmp);
+  builder.setInsertionPointToEnd(thenBlock);
+  newResultTmp = builder.create<arith::MulFOp>(newResultTmp, bArg);
+  builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
+
+  /// if (isNegativePower) {
+  ///   result = Tb{1} / result;
+  /// }
+  newResultTmp = fallthroughBlock->getArgument(0);
+  thenBlock = builder.createBlock(funcBody);
+  Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
+                                           builder.getLoc());
+  builder.setInsertionPointToEnd(fallthroughBlock);
+  builder.create<cf::CondBranchOp>(pIsNeg, thenBlock, returnBlock,
+                                   newResultTmp);
+  builder.setInsertionPointToEnd(thenBlock);
+  newResultTmp = builder.create<arith::DivFOp>(oneBValue, newResultTmp);
+  builder.create<cf::BranchOp>(newResultTmp, returnBlock);
+
+  // return result;
+  builder.setInsertionPointToEnd(returnBlock);
+  builder.create<func::ReturnOp>(returnBlock->getArgument(0));
+
+  return funcOp;
+}
+
+/// Convert FPowI into a call to a local function implementing
+/// the power operation. The local function computes a scalar result,
+/// so vector forms of FPowI are linearized.
+LogicalResult
+FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
+                                 PatternRewriter &rewriter) const {
+  if (op.getType().template dyn_cast<VectorType>())
+    return rewriter.notifyMatchFailure(op, "non-scalar operation");
+
+  FunctionType funcType = getElementalFuncTypeForOp(op);
+
+  // The outlined software implementation must have been already
+  // generated.
+  func::FuncOp elementFunc = getFuncOpCallback(funcType);
+  if (!elementFunc)
+    return rewriter.notifyMatchFailure(op, "missing software implementation");
+
+  rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
+  return success();
+}
+
 namespace {
 struct ConvertMathToFuncsPass
     : public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
   ConvertMathToFuncsPass() = default;
+  ConvertMathToFuncsPass(const ConvertMathToFuncsOptions &options)
+      : impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(options) {}
 
   void runOnOperation() override;
 
 private:
+  // Return true, if this FPowI operation must be converted
+  // because the width of its exponent's type is greater than
+  // or equal to minWidthOfFPowIExponent option value.
+  bool isFPowIConvertible(math::FPowIOp op);
+
   // Generate outlined implementations for power operations
   // and store them in powerFuncs map.
   void preprocessPowOperations();
@@ -340,19 +605,40 @@ struct ConvertMathToFuncsPass
 };
 } // namespace
 
+bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
+  auto expTy =
+      getElementTypeOrSelf(op.getRhs().getType()).dyn_cast<IntegerType>();
+  return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
+}
+
 void ConvertMathToFuncsPass::preprocessPowOperations() {
   ModuleOp module = getOperation();
 
   module.walk([&](Operation *op) {
-    TypeSwitch<Operation *>(op).Case<math::IPowIOp>([&](math::IPowIOp op) {
-      Type resultType = getElementTypeOrSelf(op.getResult().getType());
-
-      // Generate the software implementation of this operation,
-      // if it has not been generated yet.
-      auto entry = powerFuncs.try_emplace(resultType, func::FuncOp{});
-      if (entry.second)
-        entry.first->second = createElementIPowIFunc(&module, resultType);
-    });
+    TypeSwitch<Operation *>(op)
+        .Case<math::IPowIOp>([&](math::IPowIOp op) {
+          Type resultType = getElementTypeOrSelf(op.getResult().getType());
+
+          // Generate the software implementation of this operation,
+          // if it has not been generated yet.
+          auto entry = powerFuncs.try_emplace(resultType, func::FuncOp{});
+          if (entry.second)
+            entry.first->second = createElementIPowIFunc(&module, resultType);
+        })
+        .Case<math::FPowIOp>([&](math::FPowIOp op) {
+          if (!isFPowIConvertible(op))
+            return;
+
+          FunctionType funcType = getElementalFuncTypeForOp(op);
+
+          // Generate the software implementation of this operation,
+          // if it has not been generated yet.
+          // FPowI implementations are mapped via the FunctionType
+          // created from the operation's result and operands.
+          auto entry = powerFuncs.try_emplace(funcType, func::FuncOp{});
+          if (entry.second)
+            entry.first->second = createElementFPowIFunc(&module, funcType);
+        });
   });
 }
 
@@ -363,7 +649,8 @@ void ConvertMathToFuncsPass::runOnOperation() {
   preprocessPowOperations();
 
   RewritePatternSet patterns(&getContext());
-  patterns.add<VecOpToScalarOp<math::IPowIOp>>(patterns.getContext());
+  patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>>(
+      patterns.getContext());
 
   // For the given Type Returns FuncOp stored in powerFuncs map.
   auto getPowerFuncOpByType = [&](Type type) -> func::FuncOp {
@@ -373,16 +660,15 @@ void ConvertMathToFuncsPass::runOnOperation() {
 
     return it->second;
   };
-  patterns.add<IPowIOpLowering>(patterns.getContext(), getPowerFuncOpByType);
+  patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(),
+                                                 getPowerFuncOpByType);
 
   ConversionTarget target(getContext());
   target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
                          func::FuncDialect, vector::VectorDialect>();
   target.addIllegalOp<math::IPowIOp>();
+  target.addDynamicallyLegalOp<math::FPowIOp>(
+      [this](math::FPowIOp op) { return !isFPowIConvertible(op); });
   if (failed(applyPartialConversion(module, target, std::move(patterns))))
     signalPassFailure();
 }
-
-std::unique_ptr<Pass> mlir::createConvertMathToFuncsPass() {
-  return std::make_unique<ConvertMathToFuncsPass>();
-}

diff  --git a/mlir/test/Conversion/MathToFuncs/fpowi.mlir b/mlir/test/Conversion/MathToFuncs/fpowi.mlir
new file mode 100644
index 0000000000000..bae791707f2bb
--- /dev/null
+++ b/mlir/test/Conversion/MathToFuncs/fpowi.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(convert-math-to-funcs{min-width-of-fpowi-exponent=33})" | FileCheck %s
+
+// -----
+
+// Check that i32 exponent case is not converted
+// due to {min-width-of-fpowi-exponent=33}:
+
+// CHECK-LABEL:   func.func @fpowi32(
+// CHECK-SAME:                       %[[VAL_0:.*]]: f64,
+// CHECK-SAME:                       %[[VAL_1:.*]]: i32) {
+// CHECK:           %[[VAL_2:.*]] = math.fpowi %[[VAL_0]], %[[VAL_1]] : f64, i32
+// CHECK:           return
+// CHECK:         }
+func.func @fpowi32(%arg0: f64, %arg1: i32) {
+  %0 = math.fpowi %arg0, %arg1 : f64, i32
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fpowi64(
+// CHECK-SAME:                       %[[VAL_0:.*]]: f64,
+// CHECK-SAME:                       %[[VAL_1:.*]]: i64) {
+// CHECK:           %[[VAL_2:.*]] = call @__mlir_math_fpowi_f64_i64(%[[VAL_0]], %[[VAL_1]]) : (f64, i64) -> f64
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func private @__mlir_math_fpowi_f64_i64(
+// CHECK-SAME:                                                 %[[VAL_0:.*]]: f64,
+// CHECK-SAME:                                                 %[[VAL_1:.*]]: i64) -> f64 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f64
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : i64
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : i64
+// CHECK:           %[[VAL_5:.*]] = arith.constant -9223372036854775808 : i64
+// CHECK:           %[[VAL_6:.*]] = arith.constant 9223372036854775807 : i64
+// CHECK:           %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_3]] : i64
+// CHECK:           cf.cond_br %[[VAL_7]], ^bb1, ^bb2
+// CHECK:         ^bb1:
+// CHECK:           return %[[VAL_2]] : f64
+// CHECK:         ^bb2:
+// CHECK:           %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_3]] : i64
+// CHECK:           %[[VAL_9:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_5]] : i64
+// CHECK:           %[[VAL_10:.*]] = arith.subi %[[VAL_3]], %[[VAL_1]] : i64
+// CHECK:           %[[VAL_11:.*]] = arith.select %[[VAL_8]], %[[VAL_10]], %[[VAL_1]] : i64
+// CHECK:           %[[VAL_12:.*]] = arith.select %[[VAL_9]], %[[VAL_6]], %[[VAL_11]] : i64
+// CHECK:           cf.br ^bb3(%[[VAL_2]], %[[VAL_0]], %[[VAL_12]] : f64, f64, i64)
+// CHECK:         ^bb3(%[[VAL_13:.*]]: f64, %[[VAL_14:.*]]: f64, %[[VAL_15:.*]]: i64):
+// CHECK:           %[[VAL_16:.*]] = arith.andi %[[VAL_15]], %[[VAL_4]] : i64
+// CHECK:           %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_16]], %[[VAL_3]] : i64
+// CHECK:           cf.cond_br %[[VAL_17]], ^bb4, ^bb5(%[[VAL_13]] : f64)
+// CHECK:         ^bb4:
+// CHECK:           %[[VAL_18:.*]] = arith.mulf %[[VAL_13]], %[[VAL_14]] : f64
+// CHECK:           cf.br ^bb5(%[[VAL_18]] : f64)
+// CHECK:         ^bb5(%[[VAL_19:.*]]: f64):
+// CHECK:           %[[VAL_20:.*]] = arith.shrui %[[VAL_15]], %[[VAL_4]] : i64
+// CHECK:           %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_3]] : i64
+// CHECK:           cf.cond_br %[[VAL_21]], ^bb7(%[[VAL_19]] : f64), ^bb6
+// CHECK:         ^bb6:
+// CHECK:           %[[VAL_22:.*]] = arith.mulf %[[VAL_14]], %[[VAL_14]] : f64
+// CHECK:           cf.br ^bb3(%[[VAL_19]], %[[VAL_22]], %[[VAL_20]] : f64, f64, i64)
+// CHECK:         ^bb7(%[[VAL_23:.*]]: f64):
+// CHECK:           cf.cond_br %[[VAL_9]], ^bb8, ^bb9(%[[VAL_23]] : f64)
+// CHECK:         ^bb8:
+// CHECK:           %[[VAL_24:.*]] = arith.mulf %[[VAL_23]], %[[VAL_0]] : f64
+// CHECK:           cf.br ^bb9(%[[VAL_24]] : f64)
+// CHECK:         ^bb9(%[[VAL_25:.*]]: f64):
+// CHECK:           cf.cond_br %[[VAL_8]], ^bb10, ^bb11(%[[VAL_25]] : f64)
+// CHECK:         ^bb10:
+// CHECK:           %[[VAL_26:.*]] = arith.divf %[[VAL_2]], %[[VAL_25]] : f64
+// CHECK:           cf.br ^bb11(%[[VAL_26]] : f64)
+// CHECK:         ^bb11(%[[VAL_27:.*]]: f64):
+// CHECK:           return %[[VAL_27]] : f64
+// CHECK:         }
+func.func @fpowi64(%arg0: f64, %arg1: i64) {
+  %0 = math.fpowi %arg0, %arg1 : f64, i64
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fpowi64_vec(
+// CHECK-SAME:                           %[[VAL_0:.*]]: vector<2x3xf32>,
+// CHECK-SAME:                           %[[VAL_1:.*]]: vector<2x3xi64>) {
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0, 0] : vector<2x3xf32>
+// CHECK:           %[[VAL_4:.*]] = vector.extract %[[VAL_1]][0, 0] : vector<2x3xi64>
+// CHECK:           %[[VAL_5:.*]] = call @__mlir_math_fpowi_f32_i64(%[[VAL_3]], %[[VAL_4]]) : (f32, i64) -> f32
+// CHECK:           %[[VAL_6:.*]] = vector.insert %[[VAL_5]], %[[VAL_2]] [0, 0] : f32 into vector<2x3xf32>
+// CHECK:           %[[VAL_7:.*]] = vector.extract %[[VAL_0]][0, 1] : vector<2x3xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.extract %[[VAL_1]][0, 1] : vector<2x3xi64>
+// CHECK:           %[[VAL_9:.*]] = call @__mlir_math_fpowi_f32_i64(%[[VAL_7]], %[[VAL_8]]) : (f32, i64) -> f32
+// CHECK:           %[[VAL_10:.*]] = vector.insert %[[VAL_9]], %[[VAL_6]] [0, 1] : f32 into vector<2x3xf32>
+// CHECK:           %[[VAL_11:.*]] = vector.extract %[[VAL_0]][0, 2] : vector<2x3xf32>
+// CHECK:           %[[VAL_12:.*]] = vector.extract %[[VAL_1]][0, 2] : vector<2x3xi64>
+// CHECK:           %[[VAL_13:.*]] = call @__mlir_math_fpowi_f32_i64(%[[VAL_11]], %[[VAL_12]]) : (f32, i64) -> f32
+// CHECK:           %[[VAL_14:.*]] = vector.insert %[[VAL_13]], %[[VAL_10]] [0, 2] : f32 into vector<2x3xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.extract %[[VAL_0]][1, 0] : vector<2x3xf32>
+// CHECK:           %[[VAL_16:.*]] = vector.extract %[[VAL_1]][1, 0] : vector<2x3xi64>
+// CHECK:           %[[VAL_17:.*]] = call @__mlir_math_fpowi_f32_i64(%[[VAL_15]], %[[VAL_16]]) : (f32, i64) -> f32
+// CHECK:           %[[VAL_18:.*]] = vector.insert %[[VAL_17]], %[[VAL_14]] [1, 0] : f32 into vector<2x3xf32>
+// CHECK:           %[[VAL_19:.*]] = vector.extract %[[VAL_0]][1, 1] : vector<2x3xf32>
+// CHECK:           %[[VAL_20:.*]] = vector.extract %[[VAL_1]][1, 1] : vector<2x3xi64>
+// CHECK:           %[[VAL_21:.*]] = call @__mlir_math_fpowi_f32_i64(%[[VAL_19]], %[[VAL_20]]) : (f32, i64) -> f32
+// CHECK:           %[[VAL_22:.*]] = vector.insert %[[VAL_21]], %[[VAL_18]] [1, 1] : f32 into vector<2x3xf32>
+// CHECK:           %[[VAL_23:.*]] = vector.extract %[[VAL_0]][1, 2] : vector<2x3xf32>
+// CHECK:           %[[VAL_24:.*]] = vector.extract %[[VAL_1]][1, 2] : vector<2x3xi64>
+// CHECK:           %[[VAL_25:.*]] = call @__mlir_math_fpowi_f32_i64(%[[VAL_23]], %[[VAL_24]]) : (f32, i64) -> f32
+// CHECK:           %[[VAL_26:.*]] = vector.insert %[[VAL_25]], %[[VAL_22]] [1, 2] : f32 into vector<2x3xf32>
+// CHECK:           return
+// CHECK:         }
+
+// CHECK-LABEL:   func.func private @__mlir_math_fpowi_f32_i64(
+// CHECK-SAME:                                                 %[[VAL_0:.*]]: f32,
+// CHECK-SAME:                                                 %[[VAL_1:.*]]: i64) -> f32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : i64
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : i64
+// CHECK:           %[[VAL_5:.*]] = arith.constant -9223372036854775808 : i64
+// CHECK:           %[[VAL_6:.*]] = arith.constant 9223372036854775807 : i64
+// CHECK:           %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_3]] : i64
+// CHECK:           cf.cond_br %[[VAL_7]], ^bb1, ^bb2
+// CHECK:         ^bb1:
+// CHECK:           return %[[VAL_2]] : f32
+// CHECK:         ^bb2:
+// CHECK:           %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_3]] : i64
+// CHECK:           %[[VAL_9:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_5]] : i64
+// CHECK:           %[[VAL_10:.*]] = arith.subi %[[VAL_3]], %[[VAL_1]] : i64
+// CHECK:           %[[VAL_11:.*]] = arith.select %[[VAL_8]], %[[VAL_10]], %[[VAL_1]] : i64
+// CHECK:           %[[VAL_12:.*]] = arith.select %[[VAL_9]], %[[VAL_6]], %[[VAL_11]] : i64
+// CHECK:           cf.br ^bb3(%[[VAL_2]], %[[VAL_0]], %[[VAL_12]] : f32, f32, i64)
+// CHECK:         ^bb3(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: i64):
+// CHECK:           %[[VAL_16:.*]] = arith.andi %[[VAL_15]], %[[VAL_4]] : i64
+// CHECK:           %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_16]], %[[VAL_3]] : i64
+// CHECK:           cf.cond_br %[[VAL_17]], ^bb4, ^bb5(%[[VAL_13]] : f32)
+// CHECK:         ^bb4:
+// CHECK:           %[[VAL_18:.*]] = arith.mulf %[[VAL_13]], %[[VAL_14]] : f32
+// CHECK:           cf.br ^bb5(%[[VAL_18]] : f32)
+// CHECK:         ^bb5(%[[VAL_19:.*]]: f32):
+// CHECK:           %[[VAL_20:.*]] = arith.shrui %[[VAL_15]], %[[VAL_4]] : i64
+// CHECK:           %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_3]] : i64
+// CHECK:           cf.cond_br %[[VAL_21]], ^bb7(%[[VAL_19]] : f32), ^bb6
+// CHECK:         ^bb6:
+// CHECK:           %[[VAL_22:.*]] = arith.mulf %[[VAL_14]], %[[VAL_14]] : f32
+// CHECK:           cf.br ^bb3(%[[VAL_19]], %[[VAL_22]], %[[VAL_20]] : f32, f32, i64)
+// CHECK:         ^bb7(%[[VAL_23:.*]]: f32):
+// CHECK:           cf.cond_br %[[VAL_9]], ^bb8, ^bb9(%[[VAL_23]] : f32)
+// CHECK:         ^bb8:
+// CHECK:           %[[VAL_24:.*]] = arith.mulf %[[VAL_23]], %[[VAL_0]] : f32
+// CHECK:           cf.br ^bb9(%[[VAL_24]] : f32)
+// CHECK:         ^bb9(%[[VAL_25:.*]]: f32):
+// CHECK:           cf.cond_br %[[VAL_8]], ^bb10, ^bb11(%[[VAL_25]] : f32)
+// CHECK:         ^bb10:
+// CHECK:           %[[VAL_26:.*]] = arith.divf %[[VAL_2]], %[[VAL_25]] : f32
+// CHECK:           cf.br ^bb11(%[[VAL_26]] : f32)
+// CHECK:         ^bb11(%[[VAL_27:.*]]: f32):
+// CHECK:           return %[[VAL_27]] : f32
+// CHECK:         }
+func.func @fpowi64_vec(%arg0: vector<2x3xf32>, %arg1: vector<2x3xi64>) {
+  %0 = math.fpowi %arg0, %arg1 : vector<2x3xf32>, vector<2x3xi64>
+  func.return
+}

diff  --git a/mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir b/mlir/test/Conversion/MathToFuncs/ipowi.mlir
similarity index 100%
rename from mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir
rename to mlir/test/Conversion/MathToFuncs/ipowi.mlir


        


More information about the flang-commits mailing list