[flang-commits] [flang] [mlir] [mlir][GPU] Add NVVM-specific `cf.assert` lowering (PR #120431)

Thomas Raoux via flang-commits flang-commits at lists.llvm.org
Fri Dec 20 07:54:14 PST 2024


================
@@ -236,6 +237,101 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
   }
 };
 
+/// Lowering of cf.assert into a conditional __assertfail.
+struct AssertOpToAssertfailLowering
+    : public ConvertOpToLLVMPattern<cf::AssertOp> {
+  using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    MLIRContext *ctx = rewriter.getContext();
+    Location loc = assertOp.getLoc();
+    Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
+    Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
+    Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
+    Type ptrType = LLVM::LLVMPointerType::get(ctx);
+    Type voidType = LLVM::LLVMVoidType::get(ctx);
+
+    // Find or create __assertfail function declaration.
+    auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
+    auto assertfailType = LLVM::LLVMFunctionType::get(
+        voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
+    LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
+        moduleOp, loc, rewriter, "__assertfail", assertfailType);
+    assertfailDecl.setPassthroughAttr(
+        ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
+
+    // Split blocks and insert conditional branch.
+    // ^before:
+    //   ...
+    //   cf.cond_br %condition, ^after, ^assert
+    // ^assert:
+    //   cf.assert
+    //   cf.br ^after
+    // ^after:
+    //   ...
+    Block *beforeBlock = assertOp->getBlock();
+    Block *assertBlock =
+        rewriter.splitBlock(beforeBlock, assertOp->getIterator());
+    Block *afterBlock =
+        rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
+    rewriter.setInsertionPointToEnd(beforeBlock);
+    rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
+                                      assertBlock);
+    rewriter.setInsertionPointToEnd(assertBlock);
+    rewriter.create<cf::BranchOp>(loc, afterBlock);
+
+    // Continue cf.assert lowering.
+    rewriter.setInsertionPoint(assertOp);
+
+    // Populate file name, file number and function name from the location of
+    // the AssertOp.
+    StringRef fileName = "(unknown)";
+    StringRef funcName = "(unknown)";
+    int32_t fileLine = 0;
+    if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
----------------
ThomasRaoux wrote:

you may want to also support `CallSiteLoc` for cases where the assert comes from an inlined function.

https://github.com/llvm/llvm-project/pull/120431


More information about the flang-commits mailing list