[flang-commits] [flang] [mlir] [mlir][GPU] Add NVVM-specific `cf.assert` lowering (PR #120431)
Thomas Raoux via flang-commits
flang-commits at lists.llvm.org
Fri Dec 20 07:54:14 PST 2024
================
@@ -236,6 +237,101 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
}
};
+/// Lowering of cf.assert into a conditional __assertfail.
+struct AssertOpToAssertfailLowering
+ : public ConvertOpToLLVMPattern<cf::AssertOp> {
+ using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MLIRContext *ctx = rewriter.getContext();
+ Location loc = assertOp.getLoc();
+ Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
+ Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
+ Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
+ Type ptrType = LLVM::LLVMPointerType::get(ctx);
+ Type voidType = LLVM::LLVMVoidType::get(ctx);
+
+ // Find or create __assertfail function declaration.
+ auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
+ auto assertfailType = LLVM::LLVMFunctionType::get(
+ voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
+ LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
+ moduleOp, loc, rewriter, "__assertfail", assertfailType);
+ assertfailDecl.setPassthroughAttr(
+ ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
+
+ // Split blocks and insert conditional branch.
+ // ^before:
+ // ...
+ // cf.cond_br %condition, ^after, ^assert
+ // ^assert:
+ // cf.assert
+ // cf.br ^after
+ // ^after:
+ // ...
+ Block *beforeBlock = assertOp->getBlock();
+ Block *assertBlock =
+ rewriter.splitBlock(beforeBlock, assertOp->getIterator());
+ Block *afterBlock =
+ rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
+ rewriter.setInsertionPointToEnd(beforeBlock);
+ rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
+ assertBlock);
+ rewriter.setInsertionPointToEnd(assertBlock);
+ rewriter.create<cf::BranchOp>(loc, afterBlock);
+
+ // Continue cf.assert lowering.
+ rewriter.setInsertionPoint(assertOp);
+
+ // Populate file name, file number and function name from the location of
+ // the AssertOp.
+ StringRef fileName = "(unknown)";
+ StringRef funcName = "(unknown)";
+ int32_t fileLine = 0;
+ if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
----------------
ThomasRaoux wrote:
you may want to also support `CallSiteLoc` for cases where the assert comes from an inlined function.
https://github.com/llvm/llvm-project/pull/120431
More information about the flang-commits
mailing list