[Mlir-commits] [mlir] fb1e571 - [MLIR][Standard] Add default lowering for `assert`

Frederik Gossen llvmlistbot at llvm.org
Fri Jul 24 01:31:40 PDT 2020


Author: Frederik Gossen
Date: 2020-07-24T08:31:12Z
New Revision: fb1e5716877c751d561dad3483af8a4a4559a0fe

URL: https://github.com/llvm/llvm-project/commit/fb1e5716877c751d561dad3483af8a4a4559a0fe
DIFF: https://github.com/llvm/llvm-project/commit/fb1e5716877c751d561dad3483af8a4a4559a0fe.diff

LOG: [MLIR][Standard] Add default lowering for `assert`

The default lowering of `assert` calls `abort` in case the assertion is
violated. The failure message is ignored but should be used by custom lowerings
that can assume more about their environment.

Differential Revision: https://reviews.llvm.org/D83886

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 080264e666cf..1e6fa6a8754b 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1422,6 +1422,50 @@ using UnsignedShiftRightOpLowering =
     OneToOneConvertToLLVMPattern<UnsignedShiftRightOp, LLVM::LShrOp>;
 using XOrOpLowering = VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp>;
 
+/// Lower `std.assert`. The default lowering calls the `abort` function if the
+/// assertion is violated and has no effect otherwise. The failure message is
+/// ignored by the default lowering but should be propagated by any custom
+/// lowering.
+struct AssertOpLowering : public ConvertOpToLLVMPattern<AssertOp> {
+  using ConvertOpToLLVMPattern<AssertOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+    AssertOp::Adaptor transformed(operands);
+
+    // Insert the `abort` declaration if necessary.
+    auto module = op->getParentOfType<ModuleOp>();
+    auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
+    if (!abortFunc) {
+      OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(module.getBody());
+      auto abortFuncTy =
+          LLVM::LLVMType::getFunctionTy(getVoidType(), {}, /*isVarArg=*/false);
+      abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
+                                                    "abort", abortFuncTy);
+    }
+
+    // Split block at `assert` operation.
+    Block *opBlock = rewriter.getInsertionBlock();
+    auto opPosition = rewriter.getInsertionPoint();
+    Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
+
+    // Generate IR to call `abort`.
+    Block *failureBlock = rewriter.createBlock(opBlock->getParent());
+    rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
+    rewriter.create<LLVM::UnreachableOp>(loc);
+
+    // Generate assertion test.
+    rewriter.setInsertionPointToEnd(opBlock);
+    rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
+        op, transformed.arg(), continuationBlock, failureBlock);
+
+    return success();
+  }
+};
+
 // Lowerings for operations on complex numbers.
 
 struct CreateComplexOpLowering
@@ -3169,6 +3213,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       AddIOpLowering,
       AllocaOpLowering,
       AndOpLowering,
+      AssertOpLowering,
       AtomicRMWOpLowering,
       BranchOpLowering,
       CallIndirectOpLowering,

diff  --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index 3b0a17be640b..a6006023a509 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -75,3 +75,21 @@ func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
   %0 = rsqrt %arg0 : vector<4x3xf32>
   std.return
 }
+
+// -----
+
+// Lowers `assert` to a function call to `abort` if the assertion is violated.
+// CHECK: llvm.func @abort()
+// CHECK-LABEL: @assert_test_function
+// CHECK-SAME:  (%[[ARG:.*]]: !llvm.i1)
+func @assert_test_function(%arg : i1) {
+  // CHECK: llvm.cond_br %[[ARG]], ^[[CONTINUATION_BLOCK:.*]], ^[[FAILURE_BLOCK:.*]]
+  // CHECK: ^[[CONTINUATION_BLOCK]]:
+  // CHECK: llvm.return
+  // CHECK: ^[[FAILURE_BLOCK]]:
+  // CHECK: llvm.call @abort() : () -> ()
+  // CHECK: llvm.unreachable
+  assert %arg, "Computer says no"
+  return
+}
+


        


More information about the Mlir-commits mailing list