[Mlir-commits] [mlir] 325b58d - [mlir][cf] Print message in cf.assert to LLVM lowering

Matthias Springer llvmlistbot at llvm.org
Thu Dec 15 08:45:43 PST 2022


Author: Matthias Springer
Date: 2022-12-15T17:45:34+01:00
New Revision: 325b58d59f0073c364b29d6e125b809c76484a16

URL: https://github.com/llvm/llvm-project/commit/325b58d59f0073c364b29d6e125b809c76484a16
DIFF: https://github.com/llvm/llvm-project/commit/325b58d59f0073c364b29d6e125b809c76484a16.diff

LOG: [mlir][cf] Print message in cf.assert to LLVM lowering

The assert message was previously ignored. The lowered IR now calls `puts` it in case of a failed assertion.

Differential Revision: https://reviews.llvm.org/D138647

Added: 
    mlir/test/Integration/Dialect/ControlFlow/assert.mlir
    mlir/test/lib/Dialect/ControlFlow/CMakeLists.txt
    mlir/test/lib/Dialect/ControlFlow/TestAssert.cpp

Modified: 
    mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
    mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
    mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
    mlir/test/lib/Dialect/CMakeLists.txt
    mlir/tools/mlir-opt/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
index de00b939a225a..b9dfaafcf5a58 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
@@ -30,6 +30,13 @@ namespace cf {
 void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns);
 
+/// Populate the cf.assert to LLVM conversion pattern. If `abortOnFailure` is
+/// set to false, the program execution continues when a condition is
+/// unsatisfied.
+void populateAssertToLLVMConversionPattern(LLVMTypeConverter &converter,
+                                           RewritePatternSet &patterns,
+                                           bool abortOnFailure = true);
+
 /// Creates a pass to convert the ControlFlow dialect into the LLVMIR dialect.
 std::unique_ptr<Pass> createConvertControlFlowToLLVMPass();
 } // namespace cf

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index ac86e8461d277..3421e9bb67b6b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -36,6 +36,7 @@ LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
 LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
 LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
 LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintStrFn(ModuleOp moduleOp);
 LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
 LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
 LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);

diff  --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index 2c84e0b25c550..d0c9105f4b26a 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -35,39 +35,88 @@ using namespace mlir;
 
 #define PASS_NAME "convert-cf-to-llvm"
 
+static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
+  std::string prefix = "assert_msg_";
+  int counter = 0;
+  while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
+    ++counter;
+  return prefix + std::to_string(counter);
+}
+
+/// Generate IR that prints the given string to stderr.
+static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
+                           StringRef msg) {
+  auto ip = builder.saveInsertionPoint();
+  builder.setInsertionPointToStart(moduleOp.getBody());
+  MLIRContext *ctx = builder.getContext();
+
+  // Create a zero-terminated byte representation and allocate global symbol.
+  SmallVector<uint8_t> elementVals;
+  elementVals.append(msg.begin(), msg.end());
+  elementVals.push_back(0);
+  auto dataAttrType = RankedTensorType::get(
+      {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
+  auto dataAttr =
+      DenseElementsAttr::get(dataAttrType, llvm::makeArrayRef(elementVals));
+  auto arrayTy =
+      LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
+  std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
+  auto globalOp = builder.create<LLVM::GlobalOp>(
+      loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
+      dataAttr);
+
+  // Emit call to `printStr` in runtime library.
+  builder.restoreInsertionPoint(ip);
+  auto msgAddr = builder.create<LLVM::AddressOfOp>(
+      loc, LLVM::LLVMPointerType::get(arrayTy), globalOp.getName());
+  SmallVector<LLVM::GEPArg> indices(1, 0);
+  Value gep = builder.create<LLVM::GEPOp>(
+      loc, LLVM::LLVMPointerType::get(builder.getI8Type()), msgAddr, indices);
+  Operation *printer = LLVM::lookupOrCreatePrintStrFn(moduleOp);
+  builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
+                               gep);
+}
+
 namespace {
 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
 /// assertion is violated and has no effect otherwise. The failure message is
 /// ignored by the default lowering but should be propagated by any custom
 /// lowering.
 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
-  using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
+  explicit AssertOpLowering(LLVMTypeConverter &typeConverter,
+                            bool abortOnFailedAssert = true)
+      : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
+        abortOnFailedAssert(abortOnFailedAssert) {}
 
   LogicalResult
   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
-
-    // Insert the `abort` declaration if necessary.
     auto module = op->getParentOfType<ModuleOp>();
-    auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
-    if (!abortFunc) {
-      OpBuilder::InsertionGuard guard(rewriter);
-      rewriter.setInsertionPointToStart(module.getBody());
-      auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
-      abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
-                                                    "abort", abortFuncTy);
-    }
 
     // Split block at `assert` operation.
     Block *opBlock = rewriter.getInsertionBlock();
     auto opPosition = rewriter.getInsertionPoint();
     Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
 
-    // Generate IR to call `abort`.
+    // Failed block: Generate IR to print the message and call `abort`.
     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
-    rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
-    rewriter.create<LLVM::UnreachableOp>(loc);
+    createPrintMsg(rewriter, loc, module, op.getMsg());
+    if (abortOnFailedAssert) {
+      // Insert the `abort` declaration if necessary.
+      auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
+      if (!abortFunc) {
+        OpBuilder::InsertionGuard guard(rewriter);
+        rewriter.setInsertionPointToStart(module.getBody());
+        auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
+        abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
+                                                      "abort", abortFuncTy);
+      }
+      rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
+      rewriter.create<LLVM::UnreachableOp>(loc);
+    } else {
+      rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
+    }
 
     // Generate assertion test.
     rewriter.setInsertionPointToEnd(opBlock);
@@ -76,6 +125,11 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
 
     return success();
   }
+
+private:
+  /// If set to `false`, messages are printed but program execution continues.
+  /// This is useful for testing asserts.
+  bool abortOnFailedAssert = true;
 };
 
 /// The cf->LLVM lowerings for branching ops require that the blocks they jump
@@ -195,6 +249,12 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
   // clang-format on
 }
 
+void mlir::cf::populateAssertToLLVMConversionPattern(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    bool abortOnFailure) {
+  patterns.add<AssertOpLowering>(converter, abortOnFailure);
+}
+
 //===----------------------------------------------------------------------===//
 // Pass Definition
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 409c513f69ddb..ff590d4e7070e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -28,6 +28,7 @@ static constexpr llvm::StringRef kPrintI64 = "printI64";
 static constexpr llvm::StringRef kPrintU64 = "printU64";
 static constexpr llvm::StringRef kPrintF32 = "printF32";
 static constexpr llvm::StringRef kPrintF64 = "printF64";
+static constexpr llvm::StringRef kPrintStr = "puts";
 static constexpr llvm::StringRef kPrintOpen = "printOpen";
 static constexpr llvm::StringRef kPrintClose = "printClose";
 static constexpr llvm::StringRef kPrintComma = "printComma";
@@ -78,6 +79,13 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStrFn(ModuleOp moduleOp) {
+  return lookupOrCreateFn(
+      moduleOp, kPrintStr,
+      LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintOpen, {},
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));

diff  --git a/mlir/test/Integration/Dialect/ControlFlow/assert.mlir b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir
new file mode 100644
index 0000000000000..42130250daf1b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/ControlFlow/assert.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -test-cf-assert \
+// RUN:     -convert-func-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void | \
+// RUN: FileCheck %s
+
+func.func @main() {
+  %a = arith.constant 0 : i1
+  %b = arith.constant 1 : i1
+  // CHECK: assertion foo
+  cf.assert %a, "assertion foo"
+  // CHECK-NOT: assertion bar
+  cf.assert %b, "assertion bar"
+  return
+}

diff  --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 11c223620b58c..48bde69e01700 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_subdirectory(Affine)
 add_subdirectory(Arith)
 add_subdirectory(Bufferization)
+add_subdirectory(ControlFlow)
 add_subdirectory(DLTI)
 add_subdirectory(Func)
 add_subdirectory(GPU)

diff  --git a/mlir/test/lib/Dialect/ControlFlow/CMakeLists.txt b/mlir/test/lib/Dialect/ControlFlow/CMakeLists.txt
new file mode 100644
index 0000000000000..39d9555c7405e
--- /dev/null
+++ b/mlir/test/lib/Dialect/ControlFlow/CMakeLists.txt
@@ -0,0 +1,14 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRControlFlowTestPasses
+  TestAssert.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRControlFlowToLLVM
+  MLIRFuncDialect
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRPass
+  MLIRTransforms
+)

diff  --git a/mlir/test/lib/Dialect/ControlFlow/TestAssert.cpp b/mlir/test/lib/Dialect/ControlFlow/TestAssert.cpp
new file mode 100644
index 0000000000000..db769741117aa
--- /dev/null
+++ b/mlir/test/lib/Dialect/ControlFlow/TestAssert.cpp
@@ -0,0 +1,55 @@
+//===- TestAssert.cpp - Test cf.assert Lowering  ----------------*- c++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for integration testing of wide integer
+// emulation patterns. Applies conversion patterns only to functions whose
+// names start with a specified prefix.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+struct TestAssertPass
+    : public PassWrapper<TestAssertPass, OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAssertPass)
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<cf::ControlFlowDialect, LLVM::LLVMDialect>();
+  }
+  StringRef getArgument() const final { return "test-cf-assert"; }
+  StringRef getDescription() const final {
+    return "Function pass to test cf.assert lowering to LLVM without abort";
+  }
+
+  void runOnOperation() override {
+    LLVMConversionTarget target(getContext());
+    RewritePatternSet patterns(&getContext());
+
+    LLVMTypeConverter converter(&getContext());
+    mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns,
+                                                    /*abortOnFailure=*/false);
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+namespace mlir::test {
+void registerTestCfAssertPass() { PassRegistration<TestAssertPass>(); }
+} // namespace mlir::test

diff  --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index ca33fae2d4e48..0c81891c80113 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -16,6 +16,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRAffineTransformsTestPasses
     MLIRArithTestPasses
     MLIRBufferizationTestPasses
+    MLIRControlFlowTestPasses
     MLIRDLTITestPasses
     MLIRFuncTestPasses
     MLIRGPUTestPasses

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 622b9c945b6d5..6efbeb33d2e24 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -68,6 +68,7 @@ void registerTestArithEmulateWideIntPass();
 void registerTestAliasAnalysisPass();
 void registerTestBuiltinAttributeInterfaces();
 void registerTestCallGraphPass();
+void registerTestCfAssertPass();
 void registerTestConstantFold();
 void registerTestControlFlowSink();
 void registerTestGpuSerializeToCubinPass();
@@ -172,6 +173,7 @@ void registerTestPasses() {
   mlir::test::registerTestArithEmulateWideIntPass();
   mlir::test::registerTestBuiltinAttributeInterfaces();
   mlir::test::registerTestCallGraphPass();
+  mlir::test::registerTestCfAssertPass();
   mlir::test::registerTestConstantFold();
   mlir::test::registerTestControlFlowSink();
   mlir::test::registerTestDiagnosticsPass();

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 304424a88a473..329390fc1379e 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6898,6 +6898,7 @@ cc_binary(
         "//mlir/test:TestAnalysis",
         "//mlir/test:TestArith",
         "//mlir/test:TestBufferization",
+        "//mlir/test:TestControlFlow",
         "//mlir/test:TestDLTI",
         "//mlir/test:TestDialect",
         "//mlir/test:TestFunc",

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index aa262f2a1aa79..8be772dc0bf0e 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -715,6 +715,21 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "TestControlFlow",
+    srcs = glob(["lib/Dialect/ControlFlow/*.cpp"]),
+    includes = ["lib/Dialect/Test"],
+    deps = [
+        "//mlir:ControlFlowDialect",
+        "//mlir:ControlFlowToLLVM",
+        "//mlir:FuncDialect",
+        "//mlir:LLVMCommonConversion",
+        "//mlir:LLVMDialect",
+        "//mlir:Pass",
+        "//mlir:Transforms",
+    ],
+)
+
 cc_library(
     name = "TestShapeDialect",
     srcs = [


        


More information about the Mlir-commits mailing list