[Mlir-commits] [mlir] 223c54c - [mlir][math] Added math::IPowI conversion to calls of outlined implementations.
Slava Zakharin
llvmlistbot at llvm.org
Thu Aug 25 20:18:20 PDT 2022
Author: Slava Zakharin
Date: 2022-08-25T20:11:41-07:00
New Revision: 223c54c4bef67de92afe4c44b3a4796330eb8e5c
URL: https://github.com/llvm/llvm-project/commit/223c54c4bef67de92afe4c44b3a4796330eb8e5c
DIFF: https://github.com/llvm/llvm-project/commit/223c54c4bef67de92afe4c44b3a4796330eb8e5c.diff
LOG: [mlir][math] Added math::IPowI conversion to calls of outlined implementations.
Power functions are implemented as linkonce_odr scalar functions
for integer types used by IPowI operations met in a module.
Vector form of IPowI is linearized into a sequence of calls
of the scalar functions.
Differential Revision: https://reviews.llvm.org/D129810
Added:
mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h
mlir/lib/Conversion/MathToFuncs/CMakeLists.txt
mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir
Modified:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h b/mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h
new file mode 100644
index 0000000000000..f7595002dd0a8
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToFuncs/MathToFuncs.h
@@ -0,0 +1,22 @@
+//===- MathToFuncs.h - Math to outlined impl conversion ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MATHTOFUNCS_MATHTOFUNCS_H
+#define MLIR_CONVERSION_MATHTOFUNCS_MATHTOFUNCS_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+// Pass to convert some Math operations into calls of functions
+// containing software implementation of these operations.
+std::unique_ptr<Pass> createConvertMathToFuncsPass();
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOFUNCS_MATHTOFUNCS_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 9f10e9459f450..5163214ffb3af 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -32,6 +32,7 @@
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bacbaeb511c82..87946c93bfc3a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -511,6 +511,27 @@ def ConvertMathToSPIRV : Pass<"convert-math-to-spirv"> {
let dependentDialects = ["spirv::SPIRVDialect"];
}
+//===----------------------------------------------------------------------===//
+// MathToFuncs
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
+ let summary = "Convert Math operations to calls of outlined implementations.";
+ let description = [{
+ This pass converts supported Math ops to calls of compiler generated
+ functions implementing these operations in software.
+ The LLVM dialect is used for LinkonceODR linkage of the generated functions.
+ }];
+ let constructor = "mlir::createConvertMathToFuncsPass()";
+ let dependentDialects = [
+ "arith::ArithmeticDialect",
+ "cf::ControlFlowDialect",
+ "func::FuncDialect",
+ "vector::VectorDialect",
+ "LLVM::LLVMDialect",
+ ];
+}
+
//===----------------------------------------------------------------------===//
// MemRefToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 34488e7af9af6..d87d0ec251ff5 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -21,6 +21,7 @@ add_subdirectory(LinalgToLLVM)
add_subdirectory(LinalgToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(LLVMCommon)
+add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToSPIRV)
diff --git a/mlir/lib/Conversion/MathToFuncs/CMakeLists.txt b/mlir/lib/Conversion/MathToFuncs/CMakeLists.txt
new file mode 100644
index 0000000000000..72d828ade0681
--- /dev/null
+++ b/mlir/lib/Conversion/MathToFuncs/CMakeLists.txt
@@ -0,0 +1,23 @@
+add_mlir_conversion_library(MLIRMathToFuncs
+ MathToFuncs.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToFuncs
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArithmeticDialect
+ MLIRControlFlowDialect
+ MLIRFuncDialect
+ MLIRLLVMDialect
+ MLIRMathDialect
+ MLIRPass
+ MLIRTransforms
+ MLIRVectorDialect
+ MLIRVectorUtils
+ )
diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
new file mode 100644
index 0000000000000..25ee8cec62d85
--- /dev/null
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -0,0 +1,383 @@
+//===- MathToFuncs.cpp - Math to outlined implementation conversion -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
+#include "../PassDetail.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+namespace {
+// Pattern to convert vector operations to scalar operations.
+template <typename Op>
+struct VecOpToScalarOp : public OpRewritePattern<Op> {
+public:
+ using OpRewritePattern<Op>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
+};
+
+// Callback type for getting pre-generated FuncOp implementing
+// a power operation of the given type.
+using GetPowerFuncCallbackTy = function_ref<func::FuncOp(Type)>;
+
+// Pattern to convert scalar IPowIOp into a call of outlined
+// software implementation.
+struct IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
+
+private:
+ GetPowerFuncCallbackTy getFuncOpCallback;
+
+public:
+ IPowIOpLowering(MLIRContext *context, GetPowerFuncCallbackTy cb)
+ : OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
+
+ /// Convert IPowI into a call to a local function implementing
+ /// the power operation. The local function computes a scalar result,
+ /// so vector forms of IPowI are linearized.
+ LogicalResult matchAndRewrite(math::IPowIOp op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+template <typename Op>
+LogicalResult
+VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
+ Type opType = op.getType();
+ Location loc = op.getLoc();
+ auto vecType = opType.template dyn_cast<VectorType>();
+
+ if (!vecType)
+ return rewriter.notifyMatchFailure(op, "not a vector operation");
+ if (!vecType.hasRank())
+ return rewriter.notifyMatchFailure(op, "unknown vector rank");
+ ArrayRef<int64_t> shape = vecType.getShape();
+ int64_t numElements = vecType.getNumElements();
+
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(
+ vecType, IntegerAttr::get(vecType.getElementType(), 0)));
+ SmallVector<int64_t> ones(shape.size(), 1);
+ SmallVector<int64_t> strides = computeStrides(shape, ones);
+ for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
+ SmallVector<int64_t> positions = delinearize(strides, linearIndex);
+ SmallVector<Value> operands;
+ for (Value input : op->getOperands())
+ operands.push_back(
+ rewriter.create<vector::ExtractOp>(loc, input, positions));
+ Value scalarOp =
+ rewriter.create<Op>(loc, vecType.getElementType(), operands);
+ result =
+ rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+}
+
+/// Create linkonce_odr function to implement the power function with
+/// the given \p funcType type inside \p module. \p funcType must be
+/// 'IntegerType (*)(IntegerType, IntegerType)' function type.
+///
+/// template <typename T>
+/// T __mlir_math_ipowi_*(T b, T p) {
+/// if (p == T(0))
+/// return T(1);
+/// if (p < T(0)) {
+/// if (b == T(0))
+/// return T(1) / T(0); // trigger div-by-zero
+/// if (b == T(1))
+/// return T(1);
+/// if (b == T(-1)) {
+/// if (p & T(1))
+/// return T(-1);
+/// return T(1);
+/// }
+/// return T(0);
+/// }
+/// T result = T(1);
+/// while (true) {
+/// if (p & T(1))
+/// result *= b;
+/// p >>= T(1);
+/// if (p == T(0))
+/// return result;
+/// b *= b;
+/// }
+/// }
+static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
+ assert(elementType.isa<IntegerType>() &&
+ "non-integer element type for IPowIOp");
+
+ // IntegerType elementType = funcType.getInput(0).cast<IntegerType>();
+ ImplicitLocOpBuilder builder =
+ ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
+
+ std::string funcName("__mlir_math_ipowi");
+ llvm::raw_string_ostream nameOS(funcName);
+ nameOS << '_' << elementType;
+
+ FunctionType funcType = FunctionType::get(
+ builder.getContext(), {elementType, elementType}, elementType);
+ auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
+ LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
+ Attribute linkage =
+ LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
+ funcOp->setAttr("llvm.linkage", linkage);
+ funcOp.setPrivate();
+
+ Block *entryBlock = funcOp.addEntryBlock();
+ Region *funcBody = entryBlock->getParent();
+
+ Value bArg = funcOp.getArgument(0);
+ Value pArg = funcOp.getArgument(1);
+ builder.setInsertionPointToEnd(entryBlock);
+ Value zeroValue = builder.create<arith::ConstantOp>(
+ elementType, builder.getIntegerAttr(elementType, 0));
+ Value oneValue = builder.create<arith::ConstantOp>(
+ elementType, builder.getIntegerAttr(elementType, 1));
+ Value minusOneValue = builder.create<arith::ConstantOp>(
+ elementType,
+ builder.getIntegerAttr(elementType,
+ APInt(elementType.getIntOrFloatBitWidth(), -1ULL,
+ /*isSigned=*/true)));
+
+ // if (p == T(0))
+ // return T(1);
+ auto pIsZero =
+ builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue);
+ Block *thenBlock = builder.createBlock(funcBody);
+ builder.create<func::ReturnOp>(oneValue);
+ Block *fallthroughBlock = builder.createBlock(funcBody);
+ // Set up conditional branch for (p == T(0)).
+ builder.setInsertionPointToEnd(pIsZero->getBlock());
+ builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
+
+ // if (p < T(0)) {
+ builder.setInsertionPointToEnd(fallthroughBlock);
+ auto pIsNeg =
+ builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue);
+ // if (b == T(0))
+ builder.createBlock(funcBody);
+ auto bIsZero =
+ builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue);
+ // return T(1) / T(0);
+ thenBlock = builder.createBlock(funcBody);
+ builder.create<func::ReturnOp>(
+ builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult());
+ fallthroughBlock = builder.createBlock(funcBody);
+ // Set up conditional branch for (b == T(0)).
+ builder.setInsertionPointToEnd(bIsZero->getBlock());
+ builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock);
+
+ // if (b == T(1))
+ builder.setInsertionPointToEnd(fallthroughBlock);
+ auto bIsOne =
+ builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue);
+ // return T(1);
+ thenBlock = builder.createBlock(funcBody);
+ builder.create<func::ReturnOp>(oneValue);
+ fallthroughBlock = builder.createBlock(funcBody);
+ // Set up conditional branch for (b == T(1)).
+ builder.setInsertionPointToEnd(bIsOne->getBlock());
+ builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock);
+
+ // if (b == T(-1)) {
+ builder.setInsertionPointToEnd(fallthroughBlock);
+ auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+ bArg, minusOneValue);
+ // if (p & T(1))
+ builder.createBlock(funcBody);
+ auto pIsOdd = builder.create<arith::CmpIOp>(
+ arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue),
+ zeroValue);
+ // return T(-1);
+ thenBlock = builder.createBlock(funcBody);
+ builder.create<func::ReturnOp>(minusOneValue);
+ fallthroughBlock = builder.createBlock(funcBody);
+ // Set up conditional branch for (p & T(1)).
+ builder.setInsertionPointToEnd(pIsOdd->getBlock());
+ builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock);
+
+ // return T(1);
+ // } // b == T(-1)
+ builder.setInsertionPointToEnd(fallthroughBlock);
+ builder.create<func::ReturnOp>(oneValue);
+ fallthroughBlock = builder.createBlock(funcBody);
+ // Set up conditional branch for (b == T(-1)).
+ builder.setInsertionPointToEnd(bIsMinusOne->getBlock());
+ builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(),
+ fallthroughBlock);
+
+ // return T(0);
+ // } // (p < T(0))
+ builder.setInsertionPointToEnd(fallthroughBlock);
+ builder.create<func::ReturnOp>(zeroValue);
+ Block *loopHeader = builder.createBlock(
+ funcBody, funcBody->end(), {elementType, elementType, elementType},
+ {builder.getLoc(), builder.getLoc(), builder.getLoc()});
+ // Set up conditional branch for (p < T(0)).
+ builder.setInsertionPointToEnd(pIsNeg->getBlock());
+ // Set initial values of 'result', 'b' and 'p' for the loop.
+ builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader,
+ ValueRange{oneValue, bArg, pArg});
+
+ // T result = T(1);
+ // while (true) {
+ // if (p & T(1))
+ // result *= b;
+ // p >>= T(1);
+ // if (p == T(0))
+ // return result;
+ // b *= b;
+ // }
+ Value resultTmp = loopHeader->getArgument(0);
+ Value baseTmp = loopHeader->getArgument(1);
+ Value powerTmp = loopHeader->getArgument(2);
+ builder.setInsertionPointToEnd(loopHeader);
+
+ // if (p & T(1))
+ auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
+ arith::CmpIPredicate::ne,
+ builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue);
+ thenBlock = builder.createBlock(funcBody);
+ // result *= b;
+ Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp);
+ fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType,
+ builder.getLoc());
+ builder.setInsertionPointToEnd(thenBlock);
+ builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
+ // Set up conditional branch for (p & T(1)).
+ builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
+ builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
+ resultTmp);
+ // Merged 'result'.
+ newResultTmp = fallthroughBlock->getArgument(0);
+
+ // p >>= T(1);
+ builder.setInsertionPointToEnd(fallthroughBlock);
+ Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue);
+
+ // if (p == T(0))
+ auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+ newPowerTmp, zeroValue);
+ // return result;
+ thenBlock = builder.createBlock(funcBody);
+ builder.create<func::ReturnOp>(newResultTmp);
+ fallthroughBlock = builder.createBlock(funcBody);
+ // Set up conditional branch for (p == T(0)).
+ builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
+ builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock);
+
+ // b *= b;
+ // }
+ builder.setInsertionPointToEnd(fallthroughBlock);
+ Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp);
+ // Pass new values for 'result', 'b' and 'p' to the loop header.
+ builder.create<cf::BranchOp>(
+ ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
+ return funcOp;
+}
+
+/// Convert IPowI into a call to a local function implementing
+/// the power operation. The local function computes a scalar result,
+/// so vector forms of IPowI are linearized.
+LogicalResult
+IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
+ PatternRewriter &rewriter) const {
+ auto baseType = op.getOperands()[0].getType().dyn_cast<IntegerType>();
+
+ if (!baseType)
+ return rewriter.notifyMatchFailure(op, "non-integer base operand");
+
+ // The outlined software implementation must have been already
+ // generated.
+ func::FuncOp elementFunc = getFuncOpCallback(baseType);
+ if (!elementFunc)
+ return rewriter.notifyMatchFailure(op, "missing software implementation");
+
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
+ return success();
+}
+
+namespace {
+struct ConvertMathToFuncsPass
+ : public ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
+ ConvertMathToFuncsPass() = default;
+
+ void runOnOperation() override;
+
+private:
+ // Generate outlined implementations for power operations
+ // and store them in powerFuncs map.
+ void preprocessPowOperations();
+
+ // A map between function types deduced from power operations
+ // and the corresponding outlined software implementations
+ // of these operations.
+ DenseMap<Type, func::FuncOp> powerFuncs;
+};
+} // namespace
+
+void ConvertMathToFuncsPass::preprocessPowOperations() {
+ ModuleOp module = getOperation();
+
+ module.walk([&](Operation *op) {
+ TypeSwitch<Operation *>(op).Case<math::IPowIOp>([&](math::IPowIOp op) {
+ Type resultType = getElementTypeOrSelf(op.getResult().getType());
+
+ // Generate the software implementation of this operation,
+ // if it has not been generated yet.
+ auto entry = powerFuncs.try_emplace(resultType, func::FuncOp{});
+ if (entry.second)
+ entry.first->second = createElementIPowIFunc(&module, resultType);
+ });
+ });
+}
+
+void ConvertMathToFuncsPass::runOnOperation() {
+ ModuleOp module = getOperation();
+
+ // Create outlined implementations for power operations.
+ preprocessPowOperations();
+
+ RewritePatternSet patterns(&getContext());
+ patterns.add<VecOpToScalarOp<math::IPowIOp>>(patterns.getContext());
+
+ // For the given Type Returns FuncOp stored in powerFuncs map.
+ auto getPowerFuncOpByType = [&](Type type) -> func::FuncOp {
+ auto it = powerFuncs.find(type);
+ if (it == powerFuncs.end())
+ return {};
+
+ return it->second;
+ };
+ patterns.add<IPowIOpLowering>(patterns.getContext(), getPowerFuncOpByType);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<arith::ArithmeticDialect, cf::ControlFlowDialect,
+ func::FuncDialect, vector::VectorDialect>();
+ target.addIllegalOp<math::IPowIOp>();
+ if (failed(applyPartialConversion(module, target, std::move(patterns))))
+ signalPassFailure();
+}
+
+std::unique_ptr<Pass> mlir::createConvertMathToFuncsPass() {
+ return std::make_unique<ConvertMathToFuncsPass>();
+}
diff --git a/mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir b/mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir
new file mode 100644
index 0000000000000..af0f49254a055
--- /dev/null
+++ b/mlir/test/Conversion/MathToFuncs/math-to-funcs.mlir
@@ -0,0 +1,172 @@
+// RUN: mlir-opt %s -split-input-file -pass-pipeline="convert-math-to-funcs" | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: func @ipowi(
+// CHECK-SAME: %[[ARG0:.+]]: i64,
+// CHECK-SAME: %[[ARG1:.+]]: i64)
+func.func @ipowi(%arg0: i64, %arg1: i64) {
+ // CHECK: call @__mlir_math_ipowi_i64(%[[ARG0]], %[[ARG1]]) : (i64, i64) -> i64
+ %0 = math.ipowi %arg0, %arg1 : i64
+ func.return
+}
+
+// CHECK-LABEL: func.func private @__mlir_math_ipowi_i64(
+// CHECK-SAME: %[[VAL_0:.*]]: i64,
+// CHECK-SAME: %[[VAL_1:.*]]: i64) -> i64
+// CHECK-SAME: attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i64
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64
+// CHECK: %[[VAL_4:.*]] = arith.constant -1 : i64
+// CHECK: %[[VAL_5:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_2]] : i64
+// CHECK: cf.cond_br %[[VAL_5]], ^bb1, ^bb2
+// CHECK: ^bb1:
+// CHECK: return %[[VAL_3]] : i64
+// CHECK: ^bb2:
+// CHECK: %[[VAL_6:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_2]] : i64
+// CHECK: cf.cond_br %[[VAL_6]], ^bb3, ^bb12(%[[VAL_3]], %[[VAL_0]], %[[VAL_1]] : i64, i64, i64)
+// CHECK: ^bb3:
+// CHECK: %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_2]] : i64
+// CHECK: cf.cond_br %[[VAL_7]], ^bb4, ^bb5
+// CHECK: ^bb4:
+// CHECK: %[[VAL_8:.*]] = arith.divsi %[[VAL_3]], %[[VAL_2]] : i64
+// CHECK: return %[[VAL_8]] : i64
+// CHECK: ^bb5:
+// CHECK: %[[VAL_9:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_3]] : i64
+// CHECK: cf.cond_br %[[VAL_9]], ^bb6, ^bb7
+// CHECK: ^bb6:
+// CHECK: return %[[VAL_3]] : i64
+// CHECK: ^bb7:
+// CHECK: %[[VAL_10:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_4]] : i64
+// CHECK: cf.cond_br %[[VAL_10]], ^bb8, ^bb11
+// CHECK: ^bb8:
+// CHECK: %[[VAL_11:.*]] = arith.andi %[[VAL_1]], %[[VAL_3]] : i64
+// CHECK: %[[VAL_12:.*]] = arith.cmpi ne, %[[VAL_11]], %[[VAL_2]] : i64
+// CHECK: cf.cond_br %[[VAL_12]], ^bb9, ^bb10
+// CHECK: ^bb9:
+// CHECK: return %[[VAL_4]] : i64
+// CHECK: ^bb10:
+// CHECK: return %[[VAL_3]] : i64
+// CHECK: ^bb11:
+// CHECK: return %[[VAL_2]] : i64
+// CHECK: ^bb12(%[[VAL_13:.*]]: i64, %[[VAL_14:.*]]: i64, %[[VAL_15:.*]]: i64):
+// CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_15]], %[[VAL_3]] : i64
+// CHECK: %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_16]], %[[VAL_2]] : i64
+// CHECK: cf.cond_br %[[VAL_17]], ^bb13, ^bb14(%[[VAL_13]] : i64)
+// CHECK: ^bb13:
+// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_14]] : i64
+// CHECK: cf.br ^bb14(%[[VAL_18]] : i64)
+// CHECK: ^bb14(%[[VAL_19:.*]]: i64):
+// CHECK: %[[VAL_20:.*]] = arith.shrui %[[VAL_15]], %[[VAL_3]] : i64
+// CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_2]] : i64
+// CHECK: cf.cond_br %[[VAL_21]], ^bb15, ^bb16
+// CHECK: ^bb15:
+// CHECK: return %[[VAL_19]] : i64
+// CHECK: ^bb16:
+// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_14]], %[[VAL_14]] : i64
+// CHECK: cf.br ^bb12(%[[VAL_19]], %[[VAL_22]], %[[VAL_20]] : i64, i64, i64)
+// CHECK: }
+
+// -----
+
+// CHECK-LABEL: func @ipowi(
+// CHECK-SAME: %[[ARG0:.+]]: i8,
+// CHECK-SAME: %[[ARG1:.+]]: i8)
+ // CHECK: call @__mlir_math_ipowi_i8(%[[ARG0]], %[[ARG1]]) : (i8, i8) -> i8
+func.func @ipowi(%arg0: i8, %arg1: i8) {
+ %0 = math.ipowi %arg0, %arg1 : i8
+ func.return
+}
+
+// CHECK-LABEL: func.func private @__mlir_math_ipowi_i8(
+// CHECK-SAME: %[[VAL_0:.*]]: i8,
+// CHECK-SAME: %[[VAL_1:.*]]: i8) -> i8
+// CHECK-SAME: attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i8
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i8
+// CHECK: %[[VAL_4:.*]] = arith.constant -1 : i8
+// CHECK: %[[VAL_5:.*]] = arith.cmpi eq, %[[VAL_1]], %[[VAL_2]] : i8
+// CHECK: cf.cond_br %[[VAL_5]], ^bb1, ^bb2
+// CHECK: ^bb1:
+// CHECK: return %[[VAL_3]] : i8
+// CHECK: ^bb2:
+// CHECK: %[[VAL_6:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_2]] : i8
+// CHECK: cf.cond_br %[[VAL_6]], ^bb3, ^bb12(%[[VAL_3]], %[[VAL_0]], %[[VAL_1]] : i8, i8, i8)
+// CHECK: ^bb3:
+// CHECK: %[[VAL_7:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_2]] : i8
+// CHECK: cf.cond_br %[[VAL_7]], ^bb4, ^bb5
+// CHECK: ^bb4:
+// CHECK: %[[VAL_8:.*]] = arith.divsi %[[VAL_3]], %[[VAL_2]] : i8
+// CHECK: return %[[VAL_8]] : i8
+// CHECK: ^bb5:
+// CHECK: %[[VAL_9:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_3]] : i8
+// CHECK: cf.cond_br %[[VAL_9]], ^bb6, ^bb7
+// CHECK: ^bb6:
+// CHECK: return %[[VAL_3]] : i8
+// CHECK: ^bb7:
+// CHECK: %[[VAL_10:.*]] = arith.cmpi eq, %[[VAL_0]], %[[VAL_4]] : i8
+// CHECK: cf.cond_br %[[VAL_10]], ^bb8, ^bb11
+// CHECK: ^bb8:
+// CHECK: %[[VAL_11:.*]] = arith.andi %[[VAL_1]], %[[VAL_3]] : i8
+// CHECK: %[[VAL_12:.*]] = arith.cmpi ne, %[[VAL_11]], %[[VAL_2]] : i8
+// CHECK: cf.cond_br %[[VAL_12]], ^bb9, ^bb10
+// CHECK: ^bb9:
+// CHECK: return %[[VAL_4]] : i8
+// CHECK: ^bb10:
+// CHECK: return %[[VAL_3]] : i8
+// CHECK: ^bb11:
+// CHECK: return %[[VAL_2]] : i8
+// CHECK: ^bb12(%[[VAL_13:.*]]: i8, %[[VAL_14:.*]]: i8, %[[VAL_15:.*]]: i8):
+// CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_15]], %[[VAL_3]] : i8
+// CHECK: %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_16]], %[[VAL_2]] : i8
+// CHECK: cf.cond_br %[[VAL_17]], ^bb13, ^bb14(%[[VAL_13]] : i8)
+// CHECK: ^bb13:
+// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_14]] : i8
+// CHECK: cf.br ^bb14(%[[VAL_18]] : i8)
+// CHECK: ^bb14(%[[VAL_19:.*]]: i8):
+// CHECK: %[[VAL_20:.*]] = arith.shrui %[[VAL_15]], %[[VAL_3]] : i8
+// CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_2]] : i8
+// CHECK: cf.cond_br %[[VAL_21]], ^bb15, ^bb16
+// CHECK: ^bb15:
+// CHECK: return %[[VAL_19]] : i8
+// CHECK: ^bb16:
+// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_14]], %[[VAL_14]] : i8
+// CHECK: cf.br ^bb12(%[[VAL_19]], %[[VAL_22]], %[[VAL_20]] : i8, i8, i8)
+// CHECK: }
+
+// -----
+
+// CHECK-LABEL: func.func @ipowi_vec(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xi64>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<2x3xi64>) {
+func.func @ipowi_vec(%arg0: vector<2x3xi64>, %arg1: vector<2x3xi64>) {
+// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x3xi64>
+// CHECK: %[[B00:.*]] = vector.extract %[[VAL_0]][0, 0] : vector<2x3xi64>
+// CHECK: %[[E00:.*]] = vector.extract %[[VAL_1]][0, 0] : vector<2x3xi64>
+// CHECK: %[[R00:.*]] = call @__mlir_math_ipowi_i64(%[[B00]], %[[E00]]) : (i64, i64) -> i64
+// CHECK: %[[TMP00:.*]] = vector.insert %[[R00]], %[[CST]] [0, 0] : i64 into vector<2x3xi64>
+// CHECK: %[[B01:.*]] = vector.extract %[[VAL_0]][0, 1] : vector<2x3xi64>
+// CHECK: %[[E01:.*]] = vector.extract %[[VAL_1]][0, 1] : vector<2x3xi64>
+// CHECK: %[[R01:.*]] = call @__mlir_math_ipowi_i64(%[[B01]], %[[E01]]) : (i64, i64) -> i64
+// CHECK: %[[TMP01:.*]] = vector.insert %[[R01]], %[[TMP00]] [0, 1] : i64 into vector<2x3xi64>
+// CHECK: %[[B02:.*]] = vector.extract %[[VAL_0]][0, 2] : vector<2x3xi64>
+// CHECK: %[[E02:.*]] = vector.extract %[[VAL_1]][0, 2] : vector<2x3xi64>
+// CHECK: %[[R02:.*]] = call @__mlir_math_ipowi_i64(%[[B02]], %[[E02]]) : (i64, i64) -> i64
+// CHECK: %[[TMP02:.*]] = vector.insert %[[R02]], %[[TMP01]] [0, 2] : i64 into vector<2x3xi64>
+// CHECK: %[[B10:.*]] = vector.extract %[[VAL_0]][1, 0] : vector<2x3xi64>
+// CHECK: %[[E10:.*]] = vector.extract %[[VAL_1]][1, 0] : vector<2x3xi64>
+// CHECK: %[[R10:.*]] = call @__mlir_math_ipowi_i64(%[[B10]], %[[E10]]) : (i64, i64) -> i64
+// CHECK: %[[TMP10:.*]] = vector.insert %[[R10]], %[[TMP02]] [1, 0] : i64 into vector<2x3xi64>
+// CHECK: %[[B11:.*]] = vector.extract %[[VAL_0]][1, 1] : vector<2x3xi64>
+// CHECK: %[[E11:.*]] = vector.extract %[[VAL_1]][1, 1] : vector<2x3xi64>
+// CHECK: %[[R11:.*]] = call @__mlir_math_ipowi_i64(%[[B11]], %[[E11]]) : (i64, i64) -> i64
+// CHECK: %[[TMP11:.*]] = vector.insert %[[R11]], %[[TMP10]] [1, 1] : i64 into vector<2x3xi64>
+// CHECK: %[[B12:.*]] = vector.extract %[[VAL_0]][1, 2] : vector<2x3xi64>
+// CHECK: %[[E12:.*]] = vector.extract %[[VAL_1]][1, 2] : vector<2x3xi64>
+// CHECK: %[[R12:.*]] = call @__mlir_math_ipowi_i64(%[[B12]], %[[E12]]) : (i64, i64) -> i64
+// CHECK: %[[TMP12:.*]] = vector.insert %[[R12]], %[[TMP11]] [1, 2] : i64 into vector<2x3xi64>
+// CHECK: return
+// CHECK: }
+ %0 = math.ipowi %arg0, %arg1 : vector<2x3xi64>
+ func.return
+}
More information about the Mlir-commits
mailing list