[Mlir-commits] [mlir] 325b58d - [mlir][cf] Print message in cf.assert to LLVM lowering
Matthias Springer
llvmlistbot at llvm.org
Thu Dec 15 08:45:43 PST 2022
Author: Matthias Springer
Date: 2022-12-15T17:45:34+01:00
New Revision: 325b58d59f0073c364b29d6e125b809c76484a16
URL: https://github.com/llvm/llvm-project/commit/325b58d59f0073c364b29d6e125b809c76484a16
DIFF: https://github.com/llvm/llvm-project/commit/325b58d59f0073c364b29d6e125b809c76484a16.diff
LOG: [mlir][cf] Print message in cf.assert to LLVM lowering
The assert message was previously ignored. The lowered IR now calls `puts` it in case of a failed assertion.
Differential Revision: https://reviews.llvm.org/D138647
Added:
mlir/test/Integration/Dialect/ControlFlow/assert.mlir
mlir/test/lib/Dialect/ControlFlow/CMakeLists.txt
mlir/test/lib/Dialect/ControlFlow/TestAssert.cpp
Modified:
mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
mlir/test/lib/Dialect/CMakeLists.txt
mlir/tools/mlir-opt/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
index de00b939a225a..b9dfaafcf5a58 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
@@ -30,6 +30,13 @@ namespace cf {
void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
+/// Populate the cf.assert to LLVM conversion pattern. If `abortOnFailure` is
+/// set to false, the program execution continues when a condition is
+/// unsatisfied.
+void populateAssertToLLVMConversionPattern(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ bool abortOnFailure = true);
+
/// Creates a pass to convert the ControlFlow dialect into the LLVMIR dialect.
std::unique_ptr<Pass> createConvertControlFlowToLLVMPass();
} // namespace cf
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index ac86e8461d277..3421e9bb67b6b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -36,6 +36,7 @@ LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintStrFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index 2c84e0b25c550..d0c9105f4b26a 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -35,39 +35,88 @@ using namespace mlir;
#define PASS_NAME "convert-cf-to-llvm"
+static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
+ std::string prefix = "assert_msg_";
+ int counter = 0;
+ while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
+ ++counter;
+ return prefix + std::to_string(counter);
+}
+
+/// Generate IR that prints the given string to stderr.
+static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
+ StringRef msg) {
+ auto ip = builder.saveInsertionPoint();
+ builder.setInsertionPointToStart(moduleOp.getBody());
+ MLIRContext *ctx = builder.getContext();
+
+ // Create a zero-terminated byte representation and allocate global symbol.
+ SmallVector<uint8_t> elementVals;
+ elementVals.append(msg.begin(), msg.end());
+ elementVals.push_back(0);
+ auto dataAttrType = RankedTensorType::get(
+ {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
+ auto dataAttr =
+ DenseElementsAttr::get(dataAttrType, llvm::makeArrayRef(elementVals));
+ auto arrayTy =
+ LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
+ std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
+ auto globalOp = builder.create<LLVM::GlobalOp>(
+ loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
+ dataAttr);
+
+ // Emit call to `printStr` in runtime library.
+ builder.restoreInsertionPoint(ip);
+ auto msgAddr = builder.create<LLVM::AddressOfOp>(
+ loc, LLVM::LLVMPointerType::get(arrayTy), globalOp.getName());
+ SmallVector<LLVM::GEPArg> indices(1, 0);
+ Value gep = builder.create<LLVM::GEPOp>(
+ loc, LLVM::LLVMPointerType::get(builder.getI8Type()), msgAddr, indices);
+ Operation *printer = LLVM::lookupOrCreatePrintStrFn(moduleOp);
+ builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
+ gep);
+}
+
namespace {
/// Lower `cf.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
/// ignored by the default lowering but should be propagated by any custom
/// lowering.
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
- using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
+ explicit AssertOpLowering(LLVMTypeConverter &typeConverter,
+ bool abortOnFailedAssert = true)
+ : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
+ abortOnFailedAssert(abortOnFailedAssert) {}
LogicalResult
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
-
- // Insert the `abort` declaration if necessary.
auto module = op->getParentOfType<ModuleOp>();
- auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
- if (!abortFunc) {
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(module.getBody());
- auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
- abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
- "abort", abortFuncTy);
- }
// Split block at `assert` operation.
Block *opBlock = rewriter.getInsertionBlock();
auto opPosition = rewriter.getInsertionPoint();
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
- // Generate IR to call `abort`.
+ // Failed block: Generate IR to print the message and call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
- rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
- rewriter.create<LLVM::UnreachableOp>(loc);
+ createPrintMsg(rewriter, loc, module, op.getMsg());
+ if (abortOnFailedAssert) {
+ // Insert the `abort` declaration if necessary.
+ auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
+ if (!abortFunc) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(module.getBody());
+ auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
+ abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
+ "abort", abortFuncTy);
+ }
+ rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
+ rewriter.create<LLVM::UnreachableOp>(loc);
+ } else {
+ rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
+ }
// Generate assertion test.
rewriter.setInsertionPointToEnd(opBlock);
@@ -76,6 +125,11 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
return success();
}
+
+private:
+ /// If set to `false`, messages are printed but program execution continues.
+ /// This is useful for testing asserts.
+ bool abortOnFailedAssert = true;
};
/// The cf->LLVM lowerings for branching ops require that the blocks they jump
@@ -195,6 +249,12 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
// clang-format on
}
+void mlir::cf::populateAssertToLLVMConversionPattern(
+ LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ bool abortOnFailure) {
+ patterns.add<AssertOpLowering>(converter, abortOnFailure);
+}
+
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 409c513f69ddb..ff590d4e7070e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -28,6 +28,7 @@ static constexpr llvm::StringRef kPrintI64 = "printI64";
static constexpr llvm::StringRef kPrintU64 = "printU64";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
+static constexpr llvm::StringRef kPrintStr = "puts";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
static constexpr llvm::StringRef kPrintComma = "printComma";
@@ -78,6 +79,13 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStrFn(ModuleOp moduleOp) {
+ return lookupOrCreateFn(
+ moduleOp, kPrintStr,
+ LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintOpen, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
diff --git a/mlir/test/Integration/Dialect/ControlFlow/assert.mlir b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir
new file mode 100644
index 0000000000000..42130250daf1b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -test-cf-assert \
+// RUN: -convert-func-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void | \
+// RUN: FileCheck %s
+
+func.func @main() {
+ %a = arith.constant 0 : i1
+ %b = arith.constant 1 : i1
+ // CHECK: assertion foo
+ cf.assert %a, "assertion foo"
+ // CHECK-NOT: assertion bar
+ cf.assert %b, "assertion bar"
+ return
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 11c223620b58c..48bde69e01700 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -1,6 +1,7 @@
add_subdirectory(Affine)
add_subdirectory(Arith)
add_subdirectory(Bufferization)
+add_subdirectory(ControlFlow)
add_subdirectory(DLTI)
add_subdirectory(Func)
add_subdirectory(GPU)
diff --git a/mlir/test/lib/Dialect/ControlFlow/CMakeLists.txt b/mlir/test/lib/Dialect/ControlFlow/CMakeLists.txt
new file mode 100644
index 0000000000000..39d9555c7405e
--- /dev/null
+++ b/mlir/test/lib/Dialect/ControlFlow/CMakeLists.txt
@@ -0,0 +1,14 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRControlFlowTestPasses
+ TestAssert.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ LINK_LIBS PUBLIC
+ MLIRControlFlowToLLVM
+ MLIRFuncDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRPass
+ MLIRTransforms
+)
diff --git a/mlir/test/lib/Dialect/ControlFlow/TestAssert.cpp b/mlir/test/lib/Dialect/ControlFlow/TestAssert.cpp
new file mode 100644
index 0000000000000..db769741117aa
--- /dev/null
+++ b/mlir/test/lib/Dialect/ControlFlow/TestAssert.cpp
@@ -0,0 +1,55 @@
+//===- TestAssert.cpp - Test cf.assert Lowering ----------------*- c++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for integration testing of wide integer
+// emulation patterns. Applies conversion patterns only to functions whose
+// names start with a specified prefix.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+struct TestAssertPass
+ : public PassWrapper<TestAssertPass, OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAssertPass)
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<cf::ControlFlowDialect, LLVM::LLVMDialect>();
+ }
+ StringRef getArgument() const final { return "test-cf-assert"; }
+ StringRef getDescription() const final {
+ return "Function pass to test cf.assert lowering to LLVM without abort";
+ }
+
+ void runOnOperation() override {
+ LLVMConversionTarget target(getContext());
+ RewritePatternSet patterns(&getContext());
+
+ LLVMTypeConverter converter(&getContext());
+ mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns,
+ /*abortOnFailure=*/false);
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+namespace mlir::test {
+void registerTestCfAssertPass() { PassRegistration<TestAssertPass>(); }
+} // namespace mlir::test
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index ca33fae2d4e48..0c81891c80113 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -16,6 +16,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRAffineTransformsTestPasses
MLIRArithTestPasses
MLIRBufferizationTestPasses
+ MLIRControlFlowTestPasses
MLIRDLTITestPasses
MLIRFuncTestPasses
MLIRGPUTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 622b9c945b6d5..6efbeb33d2e24 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -68,6 +68,7 @@ void registerTestArithEmulateWideIntPass();
void registerTestAliasAnalysisPass();
void registerTestBuiltinAttributeInterfaces();
void registerTestCallGraphPass();
+void registerTestCfAssertPass();
void registerTestConstantFold();
void registerTestControlFlowSink();
void registerTestGpuSerializeToCubinPass();
@@ -172,6 +173,7 @@ void registerTestPasses() {
mlir::test::registerTestArithEmulateWideIntPass();
mlir::test::registerTestBuiltinAttributeInterfaces();
mlir::test::registerTestCallGraphPass();
+ mlir::test::registerTestCfAssertPass();
mlir::test::registerTestConstantFold();
mlir::test::registerTestControlFlowSink();
mlir::test::registerTestDiagnosticsPass();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 304424a88a473..329390fc1379e 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6898,6 +6898,7 @@ cc_binary(
"//mlir/test:TestAnalysis",
"//mlir/test:TestArith",
"//mlir/test:TestBufferization",
+ "//mlir/test:TestControlFlow",
"//mlir/test:TestDLTI",
"//mlir/test:TestDialect",
"//mlir/test:TestFunc",
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index aa262f2a1aa79..8be772dc0bf0e 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -715,6 +715,21 @@ cc_library(
],
)
+cc_library(
+ name = "TestControlFlow",
+ srcs = glob(["lib/Dialect/ControlFlow/*.cpp"]),
+ includes = ["lib/Dialect/Test"],
+ deps = [
+ "//mlir:ControlFlowDialect",
+ "//mlir:ControlFlowToLLVM",
+ "//mlir:FuncDialect",
+ "//mlir:LLVMCommonConversion",
+ "//mlir:LLVMDialect",
+ "//mlir:Pass",
+ "//mlir:Transforms",
+ ],
+)
+
cc_library(
name = "TestShapeDialect",
srcs = [
More information about the Mlir-commits
mailing list