[Mlir-commits] [mlir] bfbccfa - Add support for math.ctlz in convert-math-to-funcs
Slava Zakharin
llvmlistbot at llvm.org
Mon Apr 10 10:02:15 PDT 2023
Author: Jeremy Kun
Date: 2023-04-10T10:02:00-07:00
New Revision: bfbccfa17c97f29993555fcc4d9f191ba49d9606
URL: https://github.com/llvm/llvm-project/commit/bfbccfa17c97f29993555fcc4d9f191ba49d9606
DIFF: https://github.com/llvm/llvm-project/commit/bfbccfa17c97f29993555fcc4d9f191ba49d9606.diff
LOG: Add support for math.ctlz in convert-math-to-funcs
This change adds a software implementation of the `math.ctlz` operation
and includes it in `--convert-math-to-funcs`.
This is my first change to MLIR, so please bear with me as I'm still learning
the idioms of the codebase.
The context for this change is that I have some larger scale project in which
I'd like to lower from a mix of MLIR dialects to CIRCT, but many of the CIRCT
passes don't support the `math` dialect.
I noticed the content of `convert-math-to-funcs` was limited entirely to
the `pow` functions, but otherwise provided the needed structure to implement
this feature with minimal changes.
Highlight of the changes:
- Add a dependence on the SCF dialect for this lower. I could have lowered
directly to cf, following the pow lowerings in the same pass, but I felt it
was not necessary given the existing support for lowering scf to cf.
- Generalize the DenseMap storing op implementations: modify the callback
function hashmap to be keyed by both OperationType (for me this effectively
means the name of the op being implemented in software) and the type
signature of the resulting function.
- Implement the ctlz function as a loop. I had researched a variety of
implementations that claimed to be more efficient (such as those based on a
de Bruijn sequence), but it seems to me that the simplest approach would make
it easier for later compiler optimizations to do a better (platform-aware)
job optimizing this than I could do by hand.
Questions I had for the reviewer:
- [edit: found mlir-cpu-runner and added two tests] What would I add to the filecheck invocation to actually run the resulting MLIR on a value and assert the output is correct? I have done this manually with the C implementation but I'm not confident my port to MLIR is correct.
- Should I add a test for a vectorized version of this lowering? I followed suit with the ` VecOpToScalarOp` but I admit I don't fully understand what it's doing.
Reviewed By: vzakhari
Differential Revision: https://reviews.llvm.org/D146261
Added:
mlir/test/Conversion/MathToFuncs/ctlz.mlir
mlir/test/Integration/Dialect/Math/CPU/mathtofuncs_ctlz.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/MathToFuncs/CMakeLists.txt
mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cf8bbf13d4b67..31489a2245eea 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -627,6 +627,7 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
"arith::ArithDialect",
"cf::ControlFlowDialect",
"func::FuncDialect",
+ "scf::SCFDialect",
"vector::VectorDialect",
"LLVM::LLVMDialect",
];
@@ -634,7 +635,12 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
Option<"minWidthOfFPowIExponent", "min-width-of-fpowi-exponent", "unsigned",
/*default=*/"1",
"Convert FPowI only if the width of its exponent's integer type "
- "is greater than or equal to this value">
+ "is greater than or equal to this value">,
+ // Most backend targets support a native ctlz operation, so by default
+ // ctrlz conversion is disabled.
+ Option<"convertCtlz", "convert-ctlz", "bool", /*default=*/"false",
+ "Convert math.ctlz to a software implementation. Enable "
+ "for targets that do not natively support ctlz.">,
];
}
@@ -646,10 +652,10 @@ def FinalizeMemRefToLLVMConversionPass :
Pass<"finalize-memref-to-llvm", "ModuleOp"> {
let summary = "Finalize MemRef dialect to LLVM dialect conversion";
let description = [{
- Finalize the conversion of the operations from the MemRef
+ Finalize the conversion of the operations from the MemRef
dialect to the LLVM dialect.
- This conversion will not convert some complex MemRef
- operations. Make sure to run `expand-strided-metadata`
+ This conversion will not convert some complex MemRef
+ operations. Make sure to run `expand-strided-metadata`
beforehand for these.
}];
let dependentDialects = ["LLVM::LLVMDialect"];
diff --git a/mlir/lib/Conversion/MathToFuncs/CMakeLists.txt b/mlir/lib/Conversion/MathToFuncs/CMakeLists.txt
index d33df7b03ad73..2ac191cac22a0 100644
--- a/mlir/lib/Conversion/MathToFuncs/CMakeLists.txt
+++ b/mlir/lib/Conversion/MathToFuncs/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRMathToFuncs
MLIRLLVMDialect
MLIRMathDialect
MLIRPass
+ MLIRSCFDialect
MLIRTransforms
MLIRVectorDialect
MLIRVectorUtils
diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index c177e50e2c3c9..6eac3b85e25f0 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
@@ -22,6 +23,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOFUNCS
@@ -30,6 +32,9 @@ namespace mlir {
using namespace mlir;
+#define DEBUG_TYPE "math-to-funcs"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
namespace {
// Pattern to convert vector operations to scalar operations.
template <typename Op>
@@ -41,14 +46,14 @@ struct VecOpToScalarOp : public OpRewritePattern<Op> {
};
// Callback type for getting pre-generated FuncOp implementing
-// a power operation of the given type.
-using GetPowerFuncCallbackTy = function_ref<func::FuncOp(Type)>;
+// an operation of the given type.
+using GetFuncCallbackTy = function_ref<func::FuncOp(Operation *, Type)>;
// Pattern to convert scalar IPowIOp into a call of outlined
// software implementation.
class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
public:
- IPowIOpLowering(MLIRContext *context, GetPowerFuncCallbackTy cb)
+ IPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
: OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
/// Convert IPowI into a call to a local function implementing
@@ -58,14 +63,14 @@ class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
PatternRewriter &rewriter) const final;
private:
- GetPowerFuncCallbackTy getFuncOpCallback;
+ GetFuncCallbackTy getFuncOpCallback;
};
// Pattern to convert scalar FPowIOp into a call of outlined
// software implementation.
class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> {
public:
- FPowIOpLowering(MLIRContext *context, GetPowerFuncCallbackTy cb)
+ FPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
: OpRewritePattern<math::FPowIOp>(context), getFuncOpCallback(cb) {}
/// Convert FPowI into a call to a local function implementing
@@ -75,7 +80,24 @@ class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> {
PatternRewriter &rewriter) const final;
private:
- GetPowerFuncCallbackTy getFuncOpCallback;
+ GetFuncCallbackTy getFuncOpCallback;
+};
+
+// Pattern to convert scalar ctlz into a call of outlined software
+// implementation.
+class CtlzOpLowering : public OpRewritePattern<math::CountLeadingZerosOp> {
+public:
+ CtlzOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
+ : OpRewritePattern<math::CountLeadingZerosOp>(context),
+ getFuncOpCallback(cb) {}
+
+ /// Convert ctlz into a call to a local function implementing
+ /// the count leading zeros operation.
+ LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,
+ PatternRewriter &rewriter) const final;
+
+private:
+ GetFuncCallbackTy getFuncOpCallback;
};
} // namespace
@@ -346,7 +368,7 @@ IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
// The outlined software implementation must have been already
// generated.
- func::FuncOp elementFunc = getFuncOpCallback(baseType);
+ func::FuncOp elementFunc = getFuncOpCallback(op, baseType);
if (!elementFunc)
return rewriter.notifyMatchFailure(op, "missing software implementation");
@@ -571,7 +593,7 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
// The outlined software implementation must have been already
// generated.
- func::FuncOp elementFunc = getFuncOpCallback(funcType);
+ func::FuncOp elementFunc = getFuncOpCallback(op, funcType);
if (!elementFunc)
return rewriter.notifyMatchFailure(op, "missing software implementation");
@@ -579,6 +601,171 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
return success();
}
+/// Create function to implement the ctlz function the given \p elementType type
+/// inside \p module. The \p elementType must be IntegerType, an the created
+/// function has 'IntegerType (*)(IntegerType)' function type.
+///
+/// template <typename T>
+/// T __mlir_math_ctlz_*(T x) {
+/// bits = sizeof(x) * 8;
+/// if (x == 0)
+/// return bits;
+///
+/// uint32_t n = 0;
+/// for (int i = 1; i < bits; ++i) {
+/// if (x < 0) continue;
+/// n++;
+/// x <<= 1;
+/// }
+/// return n;
+/// }
+///
+/// Converts to (for i32):
+///
+/// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 {
+/// %c_32 = arith.constant 32 : index
+/// %c_0 = arith.constant 0 : i32
+/// %arg_eq_zero = arith.cmpi eq, %arg, %c_0 : i1
+/// %out = scf.if %arg_eq_zero {
+/// scf.yield %c_32 : i32
+/// } else {
+/// %c_1index = arith.constant 1 : index
+/// %c_1i32 = arith.constant 1 : i32
+/// %n = arith.constant 0 : i32
+/// %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index
+/// iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) {
+/// %cond = arith.cmpi slt, %arg_iter, %c_0 : i32
+/// %yield_val = scf.if %cond {
+/// scf.yield %arg_iter, %n_iter : i32, i32
+/// } else {
+/// %arg_next = arith.shli %arg_iter, %c_1i32 : i32
+/// %n_next = arith.addi %n_iter, %c_1i32 : i32
+/// scf.yield %arg_next, %n_next : i32, i32
+/// }
+/// scf.yield %yield_val: i32, i32
+/// }
+/// scf.yield %n_out : i32
+/// }
+/// return %out: i32
+/// }
+static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
+ if (!elementType.isa<IntegerType>()) {
+ LLVM_DEBUG({
+ DBGS() << "non-integer element type for CtlzFunc; type was: ";
+ elementType.print(llvm::dbgs());
+ });
+ llvm_unreachable("non-integer element type");
+ }
+ int64_t bitWidth = elementType.getIntOrFloatBitWidth();
+
+ Location loc = module->getLoc();
+ ImplicitLocOpBuilder builder =
+ ImplicitLocOpBuilder::atBlockEnd(loc, module->getBody());
+
+ std::string funcName("__mlir_math_ctlz");
+ llvm::raw_string_ostream nameOS(funcName);
+ nameOS << '_' << elementType;
+ FunctionType funcType =
+ FunctionType::get(builder.getContext(), {elementType}, elementType);
+ auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
+
+ // LinkonceODR ensures that there is only one implementation of this function
+ // across all math.ctlz functions that are lowered in this way.
+ LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
+ Attribute linkage =
+ LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
+ funcOp->setAttr("llvm.linkage", linkage);
+ funcOp.setPrivate();
+
+ // set the insertion point to the start of the function
+ Block *funcBody = funcOp.addEntryBlock();
+ builder.setInsertionPointToStart(funcBody);
+
+ Value arg = funcOp.getArgument(0);
+ Type indexType = builder.getIndexType();
+ Value bitWidthValue = builder.create<arith::ConstantOp>(
+ elementType, builder.getIntegerAttr(elementType, bitWidth));
+ Value zeroValue = builder.create<arith::ConstantOp>(
+ elementType, builder.getIntegerAttr(elementType, 0));
+
+ Value inputEqZero =
+ builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, arg, zeroValue);
+
+ // if input == 0, return bit width, else enter loop.
+ scf::IfOp ifOp = builder.create<scf::IfOp>(
+ elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true);
+ ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue);
+
+ auto elseBuilder =
+ ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front());
+
+ Value oneIndex = elseBuilder.create<arith::ConstantOp>(
+ indexType, elseBuilder.getIndexAttr(1));
+ Value oneValue = elseBuilder.create<arith::ConstantOp>(
+ elementType, elseBuilder.getIntegerAttr(elementType, 1));
+ Value bitWidthIndex = elseBuilder.create<arith::ConstantOp>(
+ indexType, elseBuilder.getIndexAttr(bitWidth));
+ Value nValue = elseBuilder.create<arith::ConstantOp>(
+ elementType, elseBuilder.getIntegerAttr(elementType, 0));
+
+ auto loop = elseBuilder.create<scf::ForOp>(
+ oneIndex, bitWidthIndex, oneIndex,
+ // Initial values for two loop induction variables, the arg which is being
+ // shifted left in each iteration, and the n value which tracks the count
+ // of leading zeros.
+ ValueRange{arg, nValue},
+ // Callback to build the body of the for loop
+ // if (arg < 0) {
+ // continue;
+ // } else {
+ // n++;
+ // arg <<= 1;
+ // }
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+ Value argIter = args[0];
+ Value nIter = args[1];
+
+ Value argIsNonNegative = b.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::slt, argIter, zeroValue);
+ scf::IfOp ifOp = b.create<scf::IfOp>(
+ loc, argIsNonNegative,
+ [&](OpBuilder &b, Location loc) {
+ // If arg is negative, continue (effectively, break)
+ b.create<scf::YieldOp>(loc, ValueRange{argIter, nIter});
+ },
+ [&](OpBuilder &b, Location loc) {
+ // Otherwise, increment n and shift arg left.
+ Value nNext = b.create<arith::AddIOp>(loc, nIter, oneValue);
+ Value argNext = b.create<arith::ShLIOp>(loc, argIter, oneValue);
+ b.create<scf::YieldOp>(loc, ValueRange{argNext, nNext});
+ });
+ b.create<scf::YieldOp>(loc, ifOp.getResults());
+ });
+ elseBuilder.create<scf::YieldOp>(loop.getResult(1));
+
+ builder.create<func::ReturnOp>(ifOp.getResult(0));
+ return funcOp;
+}
+
+/// Convert ctlz into a call to a local function implementing the ctlz
+/// operation.
+LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
+ PatternRewriter &rewriter) const {
+ if (op.getType().template dyn_cast<VectorType>())
+ return rewriter.notifyMatchFailure(op, "non-scalar operation");
+
+ Type type = getElementTypeOrSelf(op.getResult().getType());
+ func::FuncOp elementFunc = getFuncOpCallback(op, type);
+ if (!elementFunc)
+ return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
+ diag << "Missing software implementation for op " << op->getName()
+ << " and type " << type;
+ });
+
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperand());
+ return success();
+}
+
namespace {
struct ConvertMathToFuncsPass
: public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
@@ -595,13 +782,13 @@ struct ConvertMathToFuncsPass
bool isFPowIConvertible(math::FPowIOp op);
// Generate outlined implementations for power operations
- // and store them in powerFuncs map.
- void preprocessPowOperations();
+ // and store them in funcImpls map.
+ void generateOpImplementations();
- // A map between function types deduced from power operations
- // and the corresponding outlined software implementations
- // of these operations.
- DenseMap<Type, func::FuncOp> powerFuncs;
+ // A map between pairs of (operation, type) deduced from operations that this
+ // pass will convert, and the corresponding outlined software implementations
+ // of these operations for the given type.
+ DenseMap<std::pair<OperationName, Type>, func::FuncOp> funcImpls;
};
} // namespace
@@ -611,17 +798,28 @@ bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
}
-void ConvertMathToFuncsPass::preprocessPowOperations() {
+void ConvertMathToFuncsPass::generateOpImplementations() {
ModuleOp module = getOperation();
module.walk([&](Operation *op) {
TypeSwitch<Operation *>(op)
+ .Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) {
+ Type resultType = getElementTypeOrSelf(op.getResult().getType());
+
+ // Generate the software implementation of this operation,
+ // if it has not been generated yet.
+ auto key = std::pair(op->getName(), resultType);
+ auto entry = funcImpls.try_emplace(key, func::FuncOp{});
+ if (entry.second)
+ entry.first->second = createCtlzFunc(&module, resultType);
+ })
.Case<math::IPowIOp>([&](math::IPowIOp op) {
Type resultType = getElementTypeOrSelf(op.getResult().getType());
// Generate the software implementation of this operation,
// if it has not been generated yet.
- auto entry = powerFuncs.try_emplace(resultType, func::FuncOp{});
+ auto key = std::pair(op->getName(), resultType);
+ auto entry = funcImpls.try_emplace(key, func::FuncOp{});
if (entry.second)
entry.first->second = createElementIPowIFunc(&module, resultType);
})
@@ -635,7 +833,8 @@ void ConvertMathToFuncsPass::preprocessPowOperations() {
// if it has not been generated yet.
// FPowI implementations are mapped via the FunctionType
// created from the operation's result and operands.
- auto entry = powerFuncs.try_emplace(funcType, func::FuncOp{});
+ auto key = std::pair(op->getName(), funcType);
+ auto entry = funcImpls.try_emplace(key, func::FuncOp{});
if (entry.second)
entry.first->second = createElementFPowIFunc(&module, funcType);
});
@@ -646,27 +845,34 @@ void ConvertMathToFuncsPass::runOnOperation() {
ModuleOp module = getOperation();
// Create outlined implementations for power operations.
- preprocessPowOperations();
+ generateOpImplementations();
RewritePatternSet patterns(&getContext());
- patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>>(
+ patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>,
+ VecOpToScalarOp<math::CountLeadingZerosOp>>(
patterns.getContext());
- // For the given Type Returns FuncOp stored in powerFuncs map.
- auto getPowerFuncOpByType = [&](Type type) -> func::FuncOp {
- auto it = powerFuncs.find(type);
- if (it == powerFuncs.end())
+ // For the given Type Returns FuncOp stored in funcImpls map.
+ auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp {
+ auto it = funcImpls.find(std::pair(op->getName(), type));
+ if (it == funcImpls.end())
return {};
return it->second;
};
patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(),
- getPowerFuncOpByType);
+ getFuncOpByType);
+
+ if (convertCtlz)
+ patterns.add<CtlzOpLowering>(patterns.getContext(), getFuncOpByType);
ConversionTarget target(getContext());
target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
- func::FuncDialect, vector::VectorDialect>();
+ func::FuncDialect, scf::SCFDialect,
+ vector::VectorDialect>();
+
target.addIllegalOp<math::IPowIOp>();
+ target.addIllegalOp<math::CountLeadingZerosOp>();
target.addDynamicallyLegalOp<math::FPowIOp>(
[this](math::FPowIOp op) { return !isFPowIConvertible(op); });
if (failed(applyPartialConversion(module, target, std::move(patterns))))
diff --git a/mlir/test/Conversion/MathToFuncs/ctlz.mlir b/mlir/test/Conversion/MathToFuncs/ctlz.mlir
new file mode 100644
index 0000000000000..8678c22a2b6f3
--- /dev/null
+++ b/mlir/test/Conversion/MathToFuncs/ctlz.mlir
@@ -0,0 +1,90 @@
+// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(convert-math-to-funcs{convert-ctlz})" | FileCheck %s
+
+// Check a golden-path i32 conversion
+
+// CHECK-LABEL: func.func @main(
+// CHECK-SAME: %[[VAL_0:.*]]: i32
+// CHECK-SAME: ) {
+// CHECK: %[[VAL_1:.*]] = call @__mlir_math_ctlz_i32(%[[VAL_0]]) : (i32) -> i32
+// CHECK: return
+// CHECK: }
+
+// CHECK-LABEL: func.func private @__mlir_math_ctlz_i32(
+// CHECK-SAME: %[[ARG:.*]]: i32
+// CHECK-SAME: ) -> i32 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK: %[[C_32:.*]] = arith.constant 32 : i32
+// CHECK: %[[C_0:.*]] = arith.constant 0 : i32
+// CHECK: %[[ARGCMP:.*]] = arith.cmpi eq, %[[ARG]], %[[C_0]] : i32
+// CHECK: %[[OUT:.*]] = scf.if %[[ARGCMP]] -> (i32) {
+// CHECK: scf.yield %[[C_32]] : i32
+// CHECK: } else {
+// CHECK: %[[C_1INDEX:.*]] = arith.constant 1 : index
+// CHECK: %[[C_1I32:.*]] = arith.constant 1 : i32
+// CHECK: %[[C_32INDEX:.*]] = arith.constant 32 : index
+// CHECK: %[[N:.*]] = arith.constant 0 : i32
+// CHECK: %[[FOR_RET:.*]]:2 = scf.for %[[I:.*]] = %[[C_1INDEX]] to %[[C_32INDEX]] step %[[C_1INDEX]]
+// CHECK: iter_args(%[[ARG_ITER:.*]] = %[[ARG]], %[[N_ITER:.*]] = %[[N]]) -> (i32, i32) {
+// CHECK: %[[COND:.*]] = arith.cmpi slt, %[[ARG_ITER]], %[[C_0]] : i32
+// CHECK: %[[IF_RET:.*]]:2 = scf.if %[[COND]] -> (i32, i32) {
+// CHECK: scf.yield %[[ARG_ITER]], %[[N_ITER]] : i32, i32
+// CHECK: } else {
+// CHECK: %[[N_NEXT:.*]] = arith.addi %[[N_ITER]], %[[C_1I32]] : i32
+// CHECK: %[[ARG_NEXT:.*]] = arith.shli %[[ARG_ITER]], %[[C_1I32]] : i32
+// CHECK: scf.yield %[[ARG_NEXT]], %[[N_NEXT]] : i32, i32
+// CHECK: }
+// CHECK: scf.yield %[[IF_RET]]#0, %[[IF_RET]]#1 : i32, i32
+// CHECK: }
+// CHECK: scf.yield %[[FOR_RET]]#1 : i32
+// CHECK: }
+// CHECK: return %[[OUT]] : i32
+// CHECK: }
+func.func @main(%arg0: i32) {
+ %0 = math.ctlz %arg0 : i32
+ func.return
+}
+
+// -----
+
+// Check that i8 input is preserved
+
+// CHECK-LABEL: func.func @main(
+// CHECK-SAME: %[[VAL_0:.*]]: i8
+// CHECK-SAME: ) {
+// CHECK: %[[VAL_1:.*]] = call @__mlir_math_ctlz_i8(%[[VAL_0]]) : (i8) -> i8
+// CHECK: return
+// CHECK: }
+
+// CHECK-LABEL: func.func private @__mlir_math_ctlz_i8(
+// CHECK-SAME: %[[ARG:.*]]: i8
+// CHECK-SAME: ) -> i8 attributes {llvm.linkage = #llvm.linkage<linkonce_odr>} {
+// CHECK: %[[C_8:.*]] = arith.constant 8 : i8
+// CHECK: %[[C_0:.*]] = arith.constant 0 : i8
+// CHECK: %[[ARGCMP:.*]] = arith.cmpi eq, %[[ARG]], %[[C_0]] : i8
+// CHECK: %[[OUT:.*]] = scf.if %[[ARGCMP]] -> (i8) {
+// CHECK: scf.yield %[[C_8]] : i8
+// CHECK: } else {
+// CHECK: %[[C_1INDEX:.*]] = arith.constant 1 : index
+// CHECK: %[[C_1I32:.*]] = arith.constant 1 : i8
+// CHECK: %[[C_8INDEX:.*]] = arith.constant 8 : index
+// CHECK: %[[N:.*]] = arith.constant 0 : i8
+// CHECK: %[[FOR_RET:.*]]:2 = scf.for %[[I:.*]] = %[[C_1INDEX]] to %[[C_8INDEX]] step %[[C_1INDEX]]
+// CHECK: iter_args(%[[ARG_ITER:.*]] = %[[ARG]], %[[N_ITER:.*]] = %[[N]]) -> (i8, i8) {
+// CHECK: %[[COND:.*]] = arith.cmpi slt, %[[ARG_ITER]], %[[C_0]] : i8
+// CHECK: %[[IF_RET:.*]]:2 = scf.if %[[COND]] -> (i8, i8) {
+// CHECK: scf.yield %[[ARG_ITER]], %[[N_ITER]] : i8, i8
+// CHECK: } else {
+// CHECK: %[[N_NEXT:.*]] = arith.addi %[[N_ITER]], %[[C_1I32]] : i8
+// CHECK: %[[ARG_NEXT:.*]] = arith.shli %[[ARG_ITER]], %[[C_1I32]] : i8
+// CHECK: scf.yield %[[ARG_NEXT]], %[[N_NEXT]] : i8, i8
+// CHECK: }
+// CHECK: scf.yield %[[IF_RET]]#0, %[[IF_RET]]#1 : i8, i8
+// CHECK: }
+// CHECK: scf.yield %[[FOR_RET]]#1 : i8
+// CHECK: }
+// CHECK: return %[[OUT]] : i8
+// CHECK: }
+func.func @main(%arg0: i8) {
+ %0 = math.ctlz %arg0 : i8
+ func.return
+}
+
diff --git a/mlir/test/Integration/Dialect/Math/CPU/mathtofuncs_ctlz.mlir b/mlir/test/Integration/Dialect/Math/CPU/mathtofuncs_ctlz.mlir
new file mode 100644
index 0000000000000..37f94c4e1fcff
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Math/CPU/mathtofuncs_ctlz.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s \
+// RUN: -pass-pipeline="builtin.module( \
+// RUN: convert-math-to-funcs{convert-ctlz}, \
+// RUN: func.func(convert-scf-to-cf,convert-arith-to-llvm), \
+// RUN: convert-func-to-llvm, \
+// RUN: convert-cf-to-llvm, \
+// RUN: reconcile-unrealized-casts)" \
+// RUN: | mlir-cpu-runner -e test_7i32_to_29 -entry-point-result=i32 | FileCheck %s --check-prefix=CHECK_TEST_7i32_TO_29
+
+func.func @test_7i32_to_29() -> i32 {
+ %arg = arith.constant 7 : i32
+ %0 = math.ctlz %arg : i32
+ func.return %0 : i32
+}
+// CHECK_TEST_7i32_TO_29: 29
+
+// RUN: mlir-opt %s \
+// RUN: -pass-pipeline="builtin.module( \
+// RUN: convert-math-to-funcs{convert-ctlz}, \
+// RUN: func.func(convert-scf-to-cf,convert-arith-to-llvm), \
+// RUN: convert-func-to-llvm, \
+// RUN: convert-cf-to-llvm, \
+// RUN: reconcile-unrealized-casts)" \
+// RUN: | mlir-cpu-runner -e test_zero -entry-point-result=i32 | FileCheck %s --check-prefix=CHECK_TEST_ZERO
+
+func.func @test_zero() -> i32 {
+ %arg = arith.constant 0 : i32
+ %0 = math.ctlz %arg : i32
+ func.return %0 : i32
+}
+// CHECK_TEST_ZERO: 32
+
+// Apparently mlir-cpu-runner doesn't support i8 return values, so testing i64 instead
+// RUN: mlir-opt %s \
+// RUN: -pass-pipeline="builtin.module( \
+// RUN: convert-math-to-funcs, \
+// RUN: func.func(convert-scf-to-cf,convert-arith-to-llvm), \
+// RUN: convert-func-to-llvm, \
+// RUN: convert-cf-to-llvm, \
+// RUN: reconcile-unrealized-casts)" \
+// RUN: | mlir-cpu-runner -e test_7i64_to_61 -entry-point-result=i64 | FileCheck %s --check-prefix=CHECK_TEST_7i64_TO_61
+
+func.func @test_7i64_to_61() -> i64 {
+ %arg = arith.constant 7 : i64
+ %0 = math.ctlz %arg : i64
+ func.return %0 : i64
+}
+// CHECK_TEST_7i64_TO_61: 61
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 001b371d5b4f5..97a9b9a8ffbb2 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6529,6 +6529,7 @@ cc_library(
":LLVMDialect",
":MathDialect",
":Pass",
+ ":SCFDialect",
":Transforms",
":VectorDialect",
":VectorUtils",
More information about the Mlir-commits
mailing list