[Mlir-commits] [mlir] 223c54c - [mlir][math] Added math::IPowI conversion to calls of outlined implementations.

Slava Zakharin llvmlistbot at llvm.org
Thu Aug 25 20:18:20 PDT 2022


Author: Slava Zakharin
Date: 2022-08-25T20:11:41-07:00
New Revision: 223c54c4bef67de92afe4c44b3a4796330eb8e5c

URL: https://github.com/llvm/llvm-project/commit/223c54c4bef67de92afe4c44b3a4796330eb8e5c
DIFF: https://github.com/llvm/llvm-project/commit/223c54c4bef67de92afe4c44b3a4796330eb8e5c.diff

LOG: [mlir][math] Added math::IPowI conversion to calls of outlined implementations.

Power functions are implemented as linkonce_odr scalar functions
for integer types used by IPowI operations met in a module.
Vector form of IPowI is linearized into a sequence of calls
of the scalar functions.

Differential Revision: https://reviews.llvm.org/D129810

Added: 
    mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h
    mlir/lib/Conversion/MathToFuncs/CMakeLists.txt
    mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
    mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h b/mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h
new file mode 100644
index 0000000000000..f7595002dd0a8
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h
@@ -0,0 +1,22 @@
+//===- MathToFuncs.h - Math to outlined impl conversion ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MATHTOFUNCS_MATHTOFUNCS_H
+#define MLIR_CONVERSION_MATHTOFUNCS_MATHTOFUNCS_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+// Pass to convert some Math operations into calls of functions
+// containing software implementation of these operations.
+std::unique_ptr<Pass> createConvertMathToFuncsPass();
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOFUNCS_MATHTOFUNCS_H

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 9f10e9459f450..5163214ffb3af 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -32,6 +32,7 @@
 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
 #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
 #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bacbaeb511c82..87946c93bfc3a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -511,6 +511,27 @@ def ConvertMathToSPIRV : Pass<"convert-math-to-spirv"> {
   let dependentDialects = ["spirv::SPIRVDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// MathToFuncs
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
+  let summary = "Convert Math operations to calls of outlined implementations.";
+  let description = [{
+    This pass converts supported Math ops to calls of compiler generated
+    functions implementing these operations in software.
+    The LLVM dialect is used for LinkonceODR linkage of the generated functions.
+  }];
+  let constructor = "mlir::createConvertMathToFuncsPass()";
+  let dependentDialects = [
+    "arith::ArithmeticDialect",
+    "cf::ControlFlowDialect",
+    "func::FuncDialect",
+    "vector::VectorDialect",
+    "LLVM::LLVMDialect",
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefToLLVM
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 34488e7af9af6..d87d0ec251ff5 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -21,6 +21,7 @@ add_subdirectory(LinalgToLLVM)
 add_subdirectory(LinalgToSPIRV)
 add_subdirectory(LinalgToStandard)
 add_subdirectory(LLVMCommon)
+add_subdirectory(MathToFuncs)
 add_subdirectory(MathToLibm)
 add_subdirectory(MathToLLVM)
 add_subdirectory(MathToSPIRV)

diff  --git a/mlir/lib/Conversion/MathToFuncs/CMakeLists.txt b/mlir/lib/Conversion/MathToFuncs/CMakeLists.txt
new file mode 100644
index 0000000000000..72d828ade0681
--- /dev/null
+++ b/mlir/lib/Conversion/MathToFuncs/CMakeLists.txt
@@ -0,0 +1,23 @@
+add_mlir_conversion_library(MLIRMathToFuncs
+  MathToFuncs.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToFuncs
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRArithmeticDialect
+  MLIRControlFlowDialect
+  MLIRFuncDialect
+  MLIRLLVMDialect
+  MLIRMathDialect
+  MLIRPass
+  MLIRTransforms
+  MLIRVectorDialect
+  MLIRVectorUtils
+  )

diff  --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
new file mode 100644
index 0000000000000..25ee8cec62d85
--- /dev/null
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -0,0 +1,383 @@
+//===- MathToFuncs.cpp - Math to outlined implementation conversion -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
+#include "../PassDetail.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+namespace {
+// Pattern to convert vector operations to scalar operations.
+template <typename Op>
+struct VecOpToScalarOp : public OpRewritePattern<Op> {
+public:
+  using OpRewritePattern<Op>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
+};
+
+// Callback type for getting pre-generated FuncOp implementing
+// a power operation of the given type.
+using GetPowerFuncCallbackTy = function_ref<func::FuncOp(Type)>;
+
+// Pattern to convert scalar IPowIOp into a call of outlined
+// software implementation.
+struct IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
+
+private:
+  GetPowerFuncCallbackTy getFuncOpCallback;
+
+public:
+  IPowIOpLowering(MLIRContext *context, GetPowerFuncCallbackTy cb)
+      : OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
+
+  /// Convert IPowI into a call to a local function implementing
+  /// the power operation. The local function computes a scalar result,
+  /// so vector forms of IPowI are linearized.
+  LogicalResult matchAndRewrite(math::IPowIOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+template <typename Op>
+LogicalResult
+VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
+  Type opType = op.getType();
+  Location loc = op.getLoc();
+  auto vecType = opType.template dyn_cast<VectorType>();
+
+  if (!vecType)
+    return rewriter.notifyMatchFailure(op, "not a vector operation");
+  if (!vecType.hasRank())
+    return rewriter.notifyMatchFailure(op, "unknown vector rank");
+  ArrayRef<int64_t> shape = vecType.getShape();
+  int64_t numElements = vecType.getNumElements();
+
+  Value result = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(
+               vecType, IntegerAttr::get(vecType.getElementType(), 0)));
+  SmallVector<int64_t> ones(shape.size(), 1);
+  SmallVector<int64_t> strides = computeStrides(shape, ones);
+  for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
+    SmallVector<int64_t> positions = delinearize(strides, linearIndex);
+    SmallVector<Value> operands;
+    for (Value input : op->getOperands())
+      operands.push_back(
+          rewriter.create<vector::ExtractOp>(loc, input, positions));
+    Value scalarOp =
+        rewriter.create<Op>(loc, vecType.getElementType(), operands);
+    result =
+        rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
+  }
+  rewriter.replaceOp(op, result);
+  return success();
+}
+
+/// Create linkonce_odr function to implement the power function with
+/// the given \p funcType type inside \p module. \p funcType must be
+/// 'IntegerType (*)(IntegerType, IntegerType)' function type.
+///
+/// template <typename T>
+/// T __mlir_math_ipowi_*(T b, T p) {
+///   if (p == T(0))
+///     return T(1);
+///   if (p < T(0)) {
+///     if (b == T(0))
+///       return T(1) / T(0); // trigger div-by-zero
+///     if (b == T(1))
+///       return T(1);
+///     if (b == T(-1)) {
+///       if (p & T(1))
+///         return T(-1);
+///       return T(1);
+///     }
+///     return T(0);
+///   }
+///   T result = T(1);
+///   while (true) {
+///     if (p & T(1))
+///       result *= b;
+///     p >>= T(1);
+///     if (p == T(0))
+///       return result;
+///     b *= b;
+///   }
+/// }
+static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
+  assert(elementType.isa<IntegerType>() &&
+         "non-integer element type for IPowIOp");
+
+  //  IntegerType elementType = funcType.getInput(0).cast<IntegerType>();
+  ImplicitLocOpBuilder builder =
+      ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
+
+  std::string funcName("__mlir_math_ipowi");
+  llvm::raw_string_ostream nameOS(funcName);
+  nameOS << '_' << elementType;
+
+  FunctionType funcType = FunctionType::get(
+      builder.getContext(), {elementType, elementType}, elementType);
+  auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
+  LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
+  Attribute linkage =
+      LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
+  funcOp->setAttr("llvm.linkage", linkage);
+  funcOp.setPrivate();
+
+  Block *entryBlock = funcOp.addEntryBlock();
+  Region *funcBody = entryBlock->getParent();
+
+  Value bArg = funcOp.getArgument(0);
+  Value pArg = funcOp.getArgument(1);
+  builder.setInsertionPointToEnd(entryBlock);
+  Value zeroValue = builder.create<arith::ConstantOp>(
+      elementType, builder.getIntegerAttr(elementType, 0));
+  Value oneValue = builder.create<arith::ConstantOp>(
+      elementType, builder.getIntegerAttr(elementType, 1));
+  Value minusOneValue = builder.create<arith::ConstantOp>(
+      elementType,
+      builder.getIntegerAttr(elementType,
+                             APInt(elementType.getIntOrFloatBitWidth(), -1ULL,
+                                   /*isSigned=*/true)));
+
+  // if (p == T(0))
+  //   return T(1);
+  auto pIsZero =
+      builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue);
+  Block *thenBlock = builder.createBlock(funcBody);
+  builder.create<func::ReturnOp>(oneValue);
+  Block *fallthroughBlock = builder.createBlock(funcBody);
+  // Set up conditional branch for (p == T(0)).
+  builder.setInsertionPointToEnd(pIsZero->getBlock());
+  builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
+
+  // if (p < T(0)) {
+  builder.setInsertionPointToEnd(fallthroughBlock);
+  auto pIsNeg =
+      builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue);
+  //   if (b == T(0))
+  builder.createBlock(funcBody);
+  auto bIsZero =
+      builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue);
+  //     return T(1) / T(0);
+  thenBlock = builder.createBlock(funcBody);
+  builder.create<func::ReturnOp>(
+      builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult());
+  fallthroughBlock = builder.createBlock(funcBody);
+  // Set up conditional branch for (b == T(0)).
+  builder.setInsertionPointToEnd(bIsZero->getBlock());
+  builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock);
+
+  //   if (b == T(1))
+  builder.setInsertionPointToEnd(fallthroughBlock);
+  auto bIsOne =
+      builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue);
+  //    return T(1);
+  thenBlock = builder.createBlock(funcBody);
+  builder.create<func::ReturnOp>(oneValue);
+  fallthroughBlock = builder.createBlock(funcBody);
+  // Set up conditional branch for (b == T(1)).
+  builder.setInsertionPointToEnd(bIsOne->getBlock());
+  builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock);
+
+  //   if (b == T(-1)) {
+  builder.setInsertionPointToEnd(fallthroughBlock);
+  auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+                                                   bArg, minusOneValue);
+  //     if (p & T(1))
+  builder.createBlock(funcBody);
+  auto pIsOdd = builder.create<arith::CmpIOp>(
+      arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue),
+      zeroValue);
+  //       return T(-1);
+  thenBlock = builder.createBlock(funcBody);
+  builder.create<func::ReturnOp>(minusOneValue);
+  fallthroughBlock = builder.createBlock(funcBody);
+  // Set up conditional branch for (p & T(1)).
+  builder.setInsertionPointToEnd(pIsOdd->getBlock());
+  builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock);
+
+  //     return T(1);
+  //   } // b == T(-1)
+  builder.setInsertionPointToEnd(fallthroughBlock);
+  builder.create<func::ReturnOp>(oneValue);
+  fallthroughBlock = builder.createBlock(funcBody);
+  // Set up conditional branch for (b == T(-1)).
+  builder.setInsertionPointToEnd(bIsMinusOne->getBlock());
+  builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(),
+                                   fallthroughBlock);
+
+  //   return T(0);
+  // } // (p < T(0))
+  builder.setInsertionPointToEnd(fallthroughBlock);
+  builder.create<func::ReturnOp>(zeroValue);
+  Block *loopHeader = builder.createBlock(
+      funcBody, funcBody->end(), {elementType, elementType, elementType},
+      {builder.getLoc(), builder.getLoc(), builder.getLoc()});
+  // Set up conditional branch for (p < T(0)).
+  builder.setInsertionPointToEnd(pIsNeg->getBlock());
+  // Set initial values of 'result', 'b' and 'p' for the loop.
+  builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader,
+                                   ValueRange{oneValue, bArg, pArg});
+
+  // T result = T(1);
+  // while (true) {
+  //   if (p & T(1))
+  //     result *= b;
+  //   p >>= T(1);
+  //   if (p == T(0))
+  //     return result;
+  //   b *= b;
+  // }
+  Value resultTmp = loopHeader->getArgument(0);
+  Value baseTmp = loopHeader->getArgument(1);
+  Value powerTmp = loopHeader->getArgument(2);
+  builder.setInsertionPointToEnd(loopHeader);
+
+  //   if (p & T(1))
+  auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
+      arith::CmpIPredicate::ne,
+      builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue);
+  thenBlock = builder.createBlock(funcBody);
+  //     result *= b;
+  Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp);
+  fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType,
+                                         builder.getLoc());
+  builder.setInsertionPointToEnd(thenBlock);
+  builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
+  // Set up conditional branch for (p & T(1)).
+  builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
+  builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
+                                   resultTmp);
+  // Merged 'result'.
+  newResultTmp = fallthroughBlock->getArgument(0);
+
+  //   p >>= T(1);
+  builder.setInsertionPointToEnd(fallthroughBlock);
+  Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue);
+
+  //   if (p == T(0))
+  auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+                                                      newPowerTmp, zeroValue);
+  //     return result;
+  thenBlock = builder.createBlock(funcBody);
+  builder.create<func::ReturnOp>(newResultTmp);
+  fallthroughBlock = builder.createBlock(funcBody);
+  // Set up conditional branch for (p == T(0)).
+  builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
+  builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock);
+
+  //   b *= b;
+  // }
+  builder.setInsertionPointToEnd(fallthroughBlock);
+  Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp);
+  // Pass new values for 'result', 'b' and 'p' to the loop header.
+  builder.create<cf::BranchOp>(
+      ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
+  return funcOp;
+}
+
+/// Convert IPowI into a call to a local function implementing
+/// the power operation. The local function computes a scalar result,
+/// so vector forms of IPowI are linearized.
+LogicalResult
+IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
+                                 PatternRewriter &rewriter) const {
+  auto baseType = op.getOperands()[0].getType().dyn_cast<IntegerType>();
+
+  if (!baseType)
+    return rewriter.notifyMatchFailure(op, "non-integer base operand");
+
+  // The outlined software implementation must have been already
+  // generated.
+  func::FuncOp elementFunc = getFuncOpCallback(baseType);
+  if (!elementFunc)
+    return rewriter.notifyMatchFailure(op, "missing software implementation");
+
+  rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
+  return success();
+}
+
+namespace {
+struct ConvertMathToFuncsPass
+    : public ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
+  ConvertMathToFuncsPass() = default;
+
+  void runOnOperation() override;
+
+private:
+  // Generate outlined implementations for power operations
+  // and store them in powerFuncs map.
+  void preprocessPowOperations();
+
+  // A map between function types deduced from power operations
+  // and the corresponding outlined software implementations
+  // of these operations.
+  DenseMap<Type, func::FuncOp> powerFuncs;
+};
+} // namespace
+
+void ConvertMathToFuncsPass::preprocessPowOperations() {
+  ModuleOp module = getOperation();
+
+  module.walk([&](Operation *op) {
+    TypeSwitch<Operation *>(op).Case<math::IPowIOp>([&](math::IPowIOp op) {
+      Type resultType = getElementTypeOrSelf(op.getResult().getType());
+
+      // Generate the software implementation of this operation,
+      // if it has not been generated yet.
+      auto entry = powerFuncs.try_emplace(resultType, func::FuncOp{});
+      if (entry.second)
+        entry.first->second = createElementIPowIFunc(&module, resultType);
+    });
+  });
+}
+
+void ConvertMathToFuncsPass::runOnOperation() {
+  ModuleOp module = getOperation();
+
+  // Create outlined implementations for power operations.
+  preprocessPowOperations();
+
+  RewritePatternSet patterns(&getContext());
+  patterns.add<VecOpToScalarOp<math::IPowIOp>>(patterns.getContext());
+
+  // For the given Type Returns FuncOp stored in powerFuncs map.
+  auto getPowerFuncOpByType = [&](Type type) -> func::FuncOp {
+    auto it = powerFuncs.find(type);
+    if (it == powerFuncs.end())
+      return {};
+
+    return it->second;
+  };
+  patterns.add<IPowIOpLowering>(patterns.getContext(), getPowerFuncOpByType);
+
+  ConversionTarget target(getContext());
+  target.addLegalDialect<arith::ArithmeticDialect, cf::ControlFlowDialect,
+                         func::FuncDialect, vector::VectorDialect>();
+  target.addIllegalOp<math::IPowIOp>();
+  if (failed(applyPartialConversion(module, target, std::move(patterns))))
+    signalPassFailure();
+}
+
+std::unique_ptr<Pass> mlir::createConvertMathToFuncsPass() {
+  return std::make_unique<ConvertMathToFuncsPass>();
+}

diff  --git a/mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir b/mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir
new file mode 100644
index 0000000000000..af0f49254a055
--- /dev/null
+++ b/mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir
@@ -0,0 +1,172 @@
+// RUN: mlir-opt %s -split-input-file -pass-pipeline="convert-math-to-funcs" | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: func @ipowi(
+// CHECK-SAME: %[[ARG0:.+]]: i64,
+// CHECK-SAME: %[[ARG1:.+]]: i64)
+func.func @ipowi(%arg0: i64, %arg1: i64) {
+  // CHECK: call @__mlir_math_ipowi_i64(%[[ARG0]], %[[ARG1]]) : (i64, i64) -> i64
+  %0 = math.ipowi %arg0, %arg1 : i64
+  func.return
+}
+
+// CHECK-LABEL:   func.func private @__mlir_math_ipowi_i64(
+// CHECK-SAME:      %[[VAL_0:.*]]: i64,
+// CHECK-SAME:      %[[VAL_1:.*]]: i64) -> i64
+// CHECK-SAME:        attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : i64
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i64
+// CHECK:           %[[VAL_4:.*]] = arith.constant -1 : i64
+// CHECK:           %[[VAL_5:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_2]] : i64
+// CHECK:           cf.cond_br %[[VAL_5]], ^bb1, ^bb2
+// CHECK:         ^bb1:
+// CHECK:           return %[[VAL_3]] : i64
+// CHECK:         ^bb2:
+// CHECK:           %[[VAL_6:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_2]] : i64
+// CHECK:           cf.cond_br %[[VAL_6]], ^bb3, ^bb12(%[[VAL_3]], %[[VAL_0]], %[[VAL_1]] : i64, i64, i64)
+// CHECK:         ^bb3:
+// CHECK:           %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_2]] : i64
+// CHECK:           cf.cond_br %[[VAL_7]], ^bb4, ^bb5
+// CHECK:         ^bb4:
+// CHECK:           %[[VAL_8:.*]] = arith.divsi %[[VAL_3]], %[[VAL_2]]  : i64
+// CHECK:           return %[[VAL_8]] : i64
+// CHECK:         ^bb5:
+// CHECK:           %[[VAL_9:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_3]] : i64
+// CHECK:           cf.cond_br %[[VAL_9]], ^bb6, ^bb7
+// CHECK:         ^bb6:
+// CHECK:           return %[[VAL_3]] : i64
+// CHECK:         ^bb7:
+// CHECK:           %[[VAL_10:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_4]] : i64
+// CHECK:           cf.cond_br %[[VAL_10]], ^bb8, ^bb11
+// CHECK:         ^bb8:
+// CHECK:           %[[VAL_11:.*]] = arith.andi %[[VAL_1]], %[[VAL_3]]  : i64
+// CHECK:           %[[VAL_12:.*]] = arith.cmpi ne, %[[VAL_11]], %[[VAL_2]] : i64
+// CHECK:           cf.cond_br %[[VAL_12]], ^bb9, ^bb10
+// CHECK:         ^bb9:
+// CHECK:           return %[[VAL_4]] : i64
+// CHECK:         ^bb10:
+// CHECK:           return %[[VAL_3]] : i64
+// CHECK:         ^bb11:
+// CHECK:           return %[[VAL_2]] : i64
+// CHECK:         ^bb12(%[[VAL_13:.*]]: i64, %[[VAL_14:.*]]: i64, %[[VAL_15:.*]]: i64):
+// CHECK:           %[[VAL_16:.*]] = arith.andi %[[VAL_15]], %[[VAL_3]]  : i64
+// CHECK:           %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_16]], %[[VAL_2]] : i64
+// CHECK:           cf.cond_br %[[VAL_17]], ^bb13, ^bb14(%[[VAL_13]] : i64)
+// CHECK:         ^bb13:
+// CHECK:           %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_14]]  : i64
+// CHECK:           cf.br ^bb14(%[[VAL_18]] : i64)
+// CHECK:         ^bb14(%[[VAL_19:.*]]: i64):
+// CHECK:           %[[VAL_20:.*]] = arith.shrui %[[VAL_15]], %[[VAL_3]]  : i64
+// CHECK:           %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_2]] : i64
+// CHECK:           cf.cond_br %[[VAL_21]], ^bb15, ^bb16
+// CHECK:         ^bb15:
+// CHECK:           return %[[VAL_19]] : i64
+// CHECK:         ^bb16:
+// CHECK:           %[[VAL_22:.*]] = arith.muli %[[VAL_14]], %[[VAL_14]]  : i64
+// CHECK:           cf.br ^bb12(%[[VAL_19]], %[[VAL_22]], %[[VAL_20]] : i64, i64, i64)
+// CHECK:         }
+
+// -----
+
+// CHECK-LABEL: func @ipowi(
+// CHECK-SAME: %[[ARG0:.+]]: i8,
+// CHECK-SAME: %[[ARG1:.+]]: i8)
+  // CHECK: call @__mlir_math_ipowi_i8(%[[ARG0]], %[[ARG1]]) : (i8, i8) -> i8
+func.func @ipowi(%arg0: i8, %arg1: i8) {
+  %0 = math.ipowi %arg0, %arg1 : i8
+  func.return
+}
+
+// CHECK-LABEL:   func.func private @__mlir_math_ipowi_i8(
+// CHECK-SAME:      %[[VAL_0:.*]]: i8,
+// CHECK-SAME:      %[[VAL_1:.*]]: i8) -> i8
+// CHECK-SAME:        attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : i8
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i8
+// CHECK:           %[[VAL_4:.*]] = arith.constant -1 : i8
+// CHECK:           %[[VAL_5:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_2]] : i8
+// CHECK:           cf.cond_br %[[VAL_5]], ^bb1, ^bb2
+// CHECK:         ^bb1:
+// CHECK:           return %[[VAL_3]] : i8
+// CHECK:         ^bb2:
+// CHECK:           %[[VAL_6:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_2]] : i8
+// CHECK:           cf.cond_br %[[VAL_6]], ^bb3, ^bb12(%[[VAL_3]], %[[VAL_0]], %[[VAL_1]] : i8, i8, i8)
+// CHECK:         ^bb3:
+// CHECK:           %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_2]] : i8
+// CHECK:           cf.cond_br %[[VAL_7]], ^bb4, ^bb5
+// CHECK:         ^bb4:
+// CHECK:           %[[VAL_8:.*]] = arith.divsi %[[VAL_3]], %[[VAL_2]]  : i8
+// CHECK:           return %[[VAL_8]] : i8
+// CHECK:         ^bb5:
+// CHECK:           %[[VAL_9:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_3]] : i8
+// CHECK:           cf.cond_br %[[VAL_9]], ^bb6, ^bb7
+// CHECK:         ^bb6:
+// CHECK:           return %[[VAL_3]] : i8
+// CHECK:         ^bb7:
+// CHECK:           %[[VAL_10:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_4]] : i8
+// CHECK:           cf.cond_br %[[VAL_10]], ^bb8, ^bb11
+// CHECK:         ^bb8:
+// CHECK:           %[[VAL_11:.*]] = arith.andi %[[VAL_1]], %[[VAL_3]]  : i8
+// CHECK:           %[[VAL_12:.*]] = arith.cmpi ne, %[[VAL_11]], %[[VAL_2]] : i8
+// CHECK:           cf.cond_br %[[VAL_12]], ^bb9, ^bb10
+// CHECK:         ^bb9:
+// CHECK:           return %[[VAL_4]] : i8
+// CHECK:         ^bb10:
+// CHECK:           return %[[VAL_3]] : i8
+// CHECK:         ^bb11:
+// CHECK:           return %[[VAL_2]] : i8
+// CHECK:         ^bb12(%[[VAL_13:.*]]: i8, %[[VAL_14:.*]]: i8, %[[VAL_15:.*]]: i8):
+// CHECK:           %[[VAL_16:.*]] = arith.andi %[[VAL_15]], %[[VAL_3]]  : i8
+// CHECK:           %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_16]], %[[VAL_2]] : i8
+// CHECK:           cf.cond_br %[[VAL_17]], ^bb13, ^bb14(%[[VAL_13]] : i8)
+// CHECK:         ^bb13:
+// CHECK:           %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_14]]  : i8
+// CHECK:           cf.br ^bb14(%[[VAL_18]] : i8)
+// CHECK:         ^bb14(%[[VAL_19:.*]]: i8):
+// CHECK:           %[[VAL_20:.*]] = arith.shrui %[[VAL_15]], %[[VAL_3]]  : i8
+// CHECK:           %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_2]] : i8
+// CHECK:           cf.cond_br %[[VAL_21]], ^bb15, ^bb16
+// CHECK:         ^bb15:
+// CHECK:           return %[[VAL_19]] : i8
+// CHECK:         ^bb16:
+// CHECK:           %[[VAL_22:.*]] = arith.muli %[[VAL_14]], %[[VAL_14]]  : i8
+// CHECK:           cf.br ^bb12(%[[VAL_19]], %[[VAL_22]], %[[VAL_20]] : i8, i8, i8)
+// CHECK:         }
+
+// -----
+
+// CHECK-LABEL:   func.func @ipowi_vec(
+// CHECK-SAME:                          %[[VAL_0:.*]]: vector<2x3xi64>,
+// CHECK-SAME:                          %[[VAL_1:.*]]: vector<2x3xi64>) {
+func.func @ipowi_vec(%arg0: vector<2x3xi64>, %arg1: vector<2x3xi64>) {
+// CHECK:   %[[CST:.*]] = arith.constant dense<0> : vector<2x3xi64>
+// CHECK:   %[[B00:.*]] = vector.extract %[[VAL_0]][0, 0] : vector<2x3xi64>
+// CHECK:   %[[E00:.*]] = vector.extract %[[VAL_1]][0, 0] : vector<2x3xi64>
+// CHECK:   %[[R00:.*]] = call @__mlir_math_ipowi_i64(%[[B00]], %[[E00]]) : (i64, i64) -> i64
+// CHECK:   %[[TMP00:.*]] = vector.insert %[[R00]], %[[CST]] [0, 0] : i64 into vector<2x3xi64>
+// CHECK:   %[[B01:.*]] = vector.extract %[[VAL_0]][0, 1] : vector<2x3xi64>
+// CHECK:   %[[E01:.*]] = vector.extract %[[VAL_1]][0, 1] : vector<2x3xi64>
+// CHECK:   %[[R01:.*]] = call @__mlir_math_ipowi_i64(%[[B01]], %[[E01]]) : (i64, i64) -> i64
+// CHECK:   %[[TMP01:.*]] = vector.insert %[[R01]], %[[TMP00]] [0, 1] : i64 into vector<2x3xi64>
+// CHECK:   %[[B02:.*]] = vector.extract %[[VAL_0]][0, 2] : vector<2x3xi64>
+// CHECK:   %[[E02:.*]] = vector.extract %[[VAL_1]][0, 2] : vector<2x3xi64>
+// CHECK:   %[[R02:.*]] = call @__mlir_math_ipowi_i64(%[[B02]], %[[E02]]) : (i64, i64) -> i64
+// CHECK:   %[[TMP02:.*]] = vector.insert %[[R02]], %[[TMP01]] [0, 2] : i64 into vector<2x3xi64>
+// CHECK:   %[[B10:.*]] = vector.extract %[[VAL_0]][1, 0] : vector<2x3xi64>
+// CHECK:   %[[E10:.*]] = vector.extract %[[VAL_1]][1, 0] : vector<2x3xi64>
+// CHECK:   %[[R10:.*]] = call @__mlir_math_ipowi_i64(%[[B10]], %[[E10]]) : (i64, i64) -> i64
+// CHECK:   %[[TMP10:.*]] = vector.insert %[[R10]], %[[TMP02]] [1, 0] : i64 into vector<2x3xi64>
+// CHECK:   %[[B11:.*]] = vector.extract %[[VAL_0]][1, 1] : vector<2x3xi64>
+// CHECK:   %[[E11:.*]] = vector.extract %[[VAL_1]][1, 1] : vector<2x3xi64>
+// CHECK:   %[[R11:.*]] = call @__mlir_math_ipowi_i64(%[[B11]], %[[E11]]) : (i64, i64) -> i64
+// CHECK:   %[[TMP11:.*]] = vector.insert %[[R11]], %[[TMP10]] [1, 1] : i64 into vector<2x3xi64>
+// CHECK:   %[[B12:.*]] = vector.extract %[[VAL_0]][1, 2] : vector<2x3xi64>
+// CHECK:   %[[E12:.*]] = vector.extract %[[VAL_1]][1, 2] : vector<2x3xi64>
+// CHECK:   %[[R12:.*]] = call @__mlir_math_ipowi_i64(%[[B12]], %[[E12]]) : (i64, i64) -> i64
+// CHECK:   %[[TMP12:.*]] = vector.insert %[[R12]], %[[TMP11]] [1, 2] : i64 into vector<2x3xi64>
+// CHECK:   return
+// CHECK: }
+  %0 = math.ipowi %arg0, %arg1 : vector<2x3xi64>
+  func.return
+}


        


More information about the Mlir-commits mailing list