[flang-commits] [flang] [mlir] [mlir][GPU] Add NVVM-specific `cf.assert` lowering (PR #120431)
Matthias Springer via flang-commits
flang-commits at lists.llvm.org
Fri Dec 20 05:32:25 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/120431
>From d75a6ed0265555614219e09bed8f0fae6f0f52c9 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 18 Dec 2024 15:42:59 +0100
Subject: [PATCH] [mlir][GPU] Add `gpu.assert` op
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 1 +
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp | 1 +
.../ControlFlowToLLVM/ControlFlowToLLVM.h | 4 +
.../ControlFlowToLLVM/ControlFlowToLLVM.cpp | 3 +-
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 122 +++++++++---------
.../lib/Conversion/GPUCommon/GPUOpsLowering.h | 21 +++
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 99 +++++++++++++-
.../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 1 +
.../Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp | 1 +
.../Conversion/GPUToNVVM/gpu-to-nvvm.mlir | 29 +++++
mlir/test/Integration/GPU/CUDA/assert.mlir | 37 ++++++
11 files changed, 255 insertions(+), 64 deletions(-)
create mode 100644 mlir/test/Integration/GPU/CUDA/assert.mlir
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 2f4cd84dda7dec..036432aea675e3 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3917,6 +3917,7 @@ class FIRToLLVMLowering
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, pattern);
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
pattern);
+ mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, pattern);
// Math operations that have not been converted yet must be converted
// to Libm.
if (!isAMDGCN)
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 3ad70e7279692b..123d114ae16359 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -220,6 +220,7 @@ void ToyToLLVMLoweringPass::runOnOperation() {
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
+ cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
populateFuncToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the
diff --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
index b88c1e8b20f32b..88f18022da9bb1 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
@@ -29,6 +29,10 @@ namespace cf {
/// Collect the patterns to convert from the ControlFlow dialect to LLVM. The
/// conversion patterns capture the LLVMTypeConverter by reference meaning the
/// references have to remain alive during the entire pattern lifetime.
+///
+/// Note: This function does not populate the default cf.assert lowering. That
+/// is because some platforms have a custom cf.assert lowering. The default
+/// lowering can be populated with `populateAssertToLLVMConversionPattern`.
void populateControlFlowToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index 8672e7b849d9de..d0ffb94f3f96a9 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -215,7 +215,6 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
- AssertOpLowering,
BranchOpLowering,
CondBranchOpLowering,
SwitchOpLowering>(converter);
@@ -258,6 +257,7 @@ struct ConvertControlFlowToLLVM
LLVMTypeConverter converter(ctx, options);
RewritePatternSet patterns(ctx);
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
+ mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -286,6 +286,7 @@ struct ControlFlowToLLVMDialectInterface
RewritePatternSet &patterns) const final {
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
patterns);
+ mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
}
};
} // namespace
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index b3c3fd4956d0bb..544fc57949e24d 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -19,6 +19,59 @@
using namespace mlir;
+LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
+ Location loc, OpBuilder &b,
+ StringRef name,
+ LLVM::LLVMFunctionType type) {
+ LLVM::LLVMFuncOp ret;
+ if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleOp.getBody());
+ ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
+ }
+ return ret;
+}
+
+static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
+ StringRef prefix) {
+ // Get a unique global name.
+ unsigned stringNumber = 0;
+ SmallString<16> stringConstName;
+ do {
+ stringConstName.clear();
+ (prefix + Twine(stringNumber++)).toStringRef(stringConstName);
+ } while (moduleOp.lookupSymbol(stringConstName));
+ return stringConstName;
+}
+
+LLVM::GlobalOp
+mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
+ gpu::GPUModuleOp moduleOp, Type llvmI8,
+ StringRef namePrefix, StringRef str,
+ uint64_t alignment, unsigned addrSpace) {
+ llvm::SmallString<20> nullTermStr(str);
+ nullTermStr.push_back('\0'); // Null terminate for C
+ auto globalType =
+ LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
+ StringAttr attr = b.getStringAttr(nullTermStr);
+
+ // Try to find existing global.
+ for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
+ if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
+ globalOp.getValueAttr() == attr &&
+ globalOp.getAlignment().value_or(0) == alignment &&
+ globalOp.getAddrSpace() == addrSpace)
+ return globalOp;
+
+ // Not found: create new global.
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleOp.getBody());
+ SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
+ return b.create<LLVM::GlobalOp>(loc, globalType,
+ /*isConstant=*/true, LLVM::Linkage::Internal,
+ name, attr, alignment, addrSpace);
+}
+
LogicalResult
GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
return success();
}
-static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
- const char formatStringPrefix[] = "printfFormat_";
- // Get a unique global name.
- unsigned stringNumber = 0;
- SmallString<16> stringConstName;
- do {
- stringConstName.clear();
- (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
- } while (moduleOp.lookupSymbol(stringConstName));
- return stringConstName;
-}
-
-/// Create an global that contains the given format string. If a global with
-/// the same format string exists already in the module, return that global.
-static LLVM::GlobalOp getOrCreateFormatStringConstant(
- OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
- StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
- llvm::SmallString<20> formatString(str);
- formatString.push_back('\0'); // Null terminate for C
- auto globalType =
- LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
- StringAttr attr = b.getStringAttr(formatString);
-
- // Try to find existing global.
- for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
- if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
- globalOp.getValueAttr() == attr &&
- globalOp.getAlignment().value_or(0) == alignment &&
- globalOp.getAddrSpace() == addrSpace)
- return globalOp;
-
- // Not found: create new global.
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointToStart(moduleOp.getBody());
- SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
- return b.create<LLVM::GlobalOp>(loc, globalType,
- /*isConstant=*/true, LLVM::Linkage::Internal,
- name, attr, alignment, addrSpace);
-}
-
-template <typename T>
-static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
- ConversionPatternRewriter &rewriter,
- StringRef name,
- LLVM::LLVMFunctionType type) {
- LLVM::LLVMFuncOp ret;
- if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
- ConversionPatternRewriter::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(moduleOp.getBody());
- ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
- LLVM::Linkage::External);
- }
- return ret;
-}
-
LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
Value printfDesc = printfBeginCall.getResult();
// Create the global op or find an existing one.
- LLVM::GlobalOp global = getOrCreateFormatStringConstant(
- rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
+ LLVM::GlobalOp global = getOrCreateStringConstant(
+ rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element and pass it to printf()
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
// Create the global op or find an existing one.
- LLVM::GlobalOp global = getOrCreateFormatStringConstant(
- rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
- addressSpace);
+ LLVM::GlobalOp global = getOrCreateStringConstant(
+ rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
+ /*alignment=*/0, addressSpace);
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
// Create the global op or find an existing one.
- LLVM::GlobalOp global = getOrCreateFormatStringConstant(
- rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
+ LLVM::GlobalOp global = getOrCreateStringConstant(
+ rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 444a07a93ca36e..e73a74845d2b66 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -14,6 +14,27 @@
namespace mlir {
+//===----------------------------------------------------------------------===//
+// Helper Functions
+//===----------------------------------------------------------------------===//
+
+/// Find or create an external function declaration in the given module.
+LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc,
+ OpBuilder &b, StringRef name,
+ LLVM::LLVMFunctionType type);
+
+/// Create a global that contains the given string. If a global with the same
+/// string already exists in the module, return that global.
+LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc,
+ gpu::GPUModuleOp moduleOp, Type llvmI8,
+ StringRef namePrefix, StringRef str,
+ uint64_t alignment = 0,
+ unsigned addrSpace = 0);
+
+//===----------------------------------------------------------------------===//
+// Lowering Patterns
+//===----------------------------------------------------------------------===//
+
/// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
/// create a 0-sized global array symbol similar as LLVM expects. It constructs
/// a memref descriptor with these values and return it.
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index b343cf71e3a2e7..44a36a0502bee4 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -25,6 +25,7 @@
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -236,6 +237,101 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
}
};
+/// Lowering of cf.assert into a conditional __assertfail.
+struct AssertOpToAssertfailLowering
+ : public ConvertOpToLLVMPattern<cf::AssertOp> {
+ using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MLIRContext *ctx = rewriter.getContext();
+ Location loc = assertOp.getLoc();
+ Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
+ Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
+ Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
+ Type ptrType = LLVM::LLVMPointerType::get(ctx);
+ Type voidType = LLVM::LLVMVoidType::get(ctx);
+
+ // Find or create __assertfail function declaration.
+ auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
+ auto assertfailType = LLVM::LLVMFunctionType::get(
+ voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
+ LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
+ moduleOp, loc, rewriter, "__assertfail", assertfailType);
+ assertfailDecl.setPassthroughAttr(
+ ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
+
+ // Split blocks and insert conditional branch.
+ // ^before:
+ // ...
+ // cf.cond_br %condition, ^after, ^assert
+ // ^assert:
+ // cf.assert
+ // cf.br ^after
+ // ^after:
+ // ...
+ Block *beforeBlock = assertOp->getBlock();
+ Block *assertBlock =
+ rewriter.splitBlock(beforeBlock, assertOp->getIterator());
+ Block *afterBlock =
+ rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
+ rewriter.setInsertionPointToEnd(beforeBlock);
+ rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
+ assertBlock);
+ rewriter.setInsertionPointToEnd(assertBlock);
+ rewriter.create<cf::BranchOp>(loc, afterBlock);
+
+ // Continue cf.assert lowering.
+ rewriter.setInsertionPoint(assertOp);
+
+ // Populate file name, file number and function name from the location of
+ // the AssertOp.
+ StringRef fileName = "(unknown)";
+ StringRef funcName = "(unknown)";
+ int32_t fileLine = 0;
+ if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
+ fileName = fileLineColLoc.getFilename().strref();
+ fileLine = fileLineColLoc.getStartLine();
+ } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
+ funcName = nameLoc.getName().strref();
+ if (auto fileLineColLoc =
+ dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
+ fileName = fileLineColLoc.getFilename().strref();
+ fileLine = fileLineColLoc.getStartLine();
+ }
+ }
+
+ // Create constants.
+ auto getGlobal = [&](LLVM::GlobalOp global) {
+ // Get a pointer to the format string's first element.
+ Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
+ loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
+ global.getSymNameAttr());
+ Value start =
+ rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ return start;
+ };
+ Value assertMessage = getGlobal(getOrCreateStringConstant(
+ rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
+ Value assertFile = getGlobal(getOrCreateStringConstant(
+ rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
+ Value assertFunc = getGlobal(getOrCreateStringConstant(
+ rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
+ Value assertLine =
+ rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
+ Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
+
+ // Insert function call to __assertfail.
+ SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
+ assertFunc, c1};
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
+ arguments);
+ return success();
+ }
+};
+
/// Import the GPU Ops to NVVM Patterns.
#include "GPUToNVVM.cpp.inc"
@@ -358,7 +454,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
using gpu::index_lowering::IndexKind;
using gpu::index_lowering::IntrType;
populateWithGenerated(patterns);
- patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
+ patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
+ converter);
patterns.add<
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index aa4d3b70329fba..aaf00e51f49416 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -297,6 +297,7 @@ struct LowerGpuOpsToROCDLOpsPass
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
+ cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 58fd3d565fce50..5d0003911bca87 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -304,6 +304,7 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
LLVMTypeConverter converter(&getContext());
arith::populateArithToLLVMConversionPatterns(converter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
+ cf::populateAssertToLLVMConversionPattern(converter, patterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
populateFuncToLLVMConversionPatterns(converter, patterns);
populateOpenMPToLLVMConversionPatterns(converter, patterns);
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 748dfe8c68fc7e..a10f9fb51a3b8c 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -969,6 +969,35 @@ gpu.module @test_module_50 {
}
}
+// CHECK-LABEL: gpu.module @test_module_51
+// CHECK: llvm.mlir.global internal constant @[[func_name:.*]]("(unknown)\00") {addr_space = 0 : i32}
+// CHECK: llvm.mlir.global internal constant @[[file_name:.*]]("within split at {{.*}}gpu-to-nvvm.mlir:1 offset \00") {addr_space = 0 : i32}
+// CHECK: llvm.mlir.global internal constant @[[message:.*]]("assert message\00") {addr_space = 0 : i32}
+// CHECK: llvm.func @__assertfail(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) attributes {passthrough = ["noreturn"]}
+// CHECK: llvm.func @test_assert(%[[cond:.*]]: i1) attributes {gpu.kernel, nvvm.kernel} {
+// CHECK: llvm.cond_br %[[cond]], ^[[after_block:.*]], ^[[assert_block:.*]]
+// CHECK: ^[[assert_block]]:
+// CHECK: %[[message_ptr:.*]] = llvm.mlir.addressof @[[message]] : !llvm.ptr
+// CHECK: %[[message_start:.*]] = llvm.getelementptr %[[message_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<15 x i8>
+// CHECK: %[[file_ptr:.*]] = llvm.mlir.addressof @[[file_name]] : !llvm.ptr
+// CHECK: %[[file_start:.*]] = llvm.getelementptr %[[file_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{.*}} x i8>
+// CHECK: %[[func_ptr:.*]] = llvm.mlir.addressof @[[func_name]] : !llvm.ptr
+// CHECK: %[[func_start:.*]] = llvm.getelementptr %[[func_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{.*}} x i8>
+// CHECK: %[[line_num:.*]] = llvm.mlir.constant({{.*}} : i32) : i32
+// CHECK: %[[ptr:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: llvm.call @__assertfail(%[[message_start]], %[[file_start]], %[[line_num]], %[[func_start]], %[[ptr]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) -> ()
+// CHECK: llvm.br ^[[after_block]]
+// CHECK: ^[[after_block]]:
+// CHECK: llvm.return
+// CHECK: }
+
+gpu.module @test_module_51 {
+ gpu.func @test_assert(%arg0: i1) kernel {
+ cf.assert %arg0, "assert message"
+ gpu.return
+ }
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
%gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module
diff --git a/mlir/test/Integration/GPU/CUDA/assert.mlir b/mlir/test/Integration/GPU/CUDA/assert.mlir
new file mode 100644
index 00000000000000..c8e8e2cbb60576
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/assert.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-cpu-runner \
+// RUN: --shared-libs=%mlir_cuda_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --entry-point-result=void 2>&1 \
+// RUN: | FileCheck %s
+
+// CHECK-DAG: thread 0: print after passing assertion
+// CHECK-DAG: thread 1: print after passing assertion
+// CHECK-DAG: mlir/test/Integration/GPU/CUDA/assert.mlir:{{.*}}: (unknown): block: [0,0,0], thread: [0,0,0] Assertion `failing assertion` failed.
+// CHECK-DAG: mlir/test/Integration/GPU/CUDA/assert.mlir:{{.*}}: (unknown): block: [0,0,0], thread: [1,0,0] Assertion `failing assertion` failed.
+// CHECK-NOT: print after failing assertion
+
+module attributes {gpu.container_module} {
+gpu.module @kernels {
+gpu.func @test_assert(%c0: i1, %c1: i1) kernel {
+ %0 = gpu.thread_id x
+ cf.assert %c1, "passing assertion"
+ gpu.printf "thread %lld: print after passing assertion\n" %0 : index
+ cf.assert %c0, "failing assertion"
+ gpu.printf "thread %lld: print after failing assertion\n" %0 : index
+ gpu.return
+}
+}
+
+func.func @main() {
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+ %c0_i1 = arith.constant 0 : i1
+ %c1_i1 = arith.constant 1 : i1
+ gpu.launch_func @kernels::@test_assert
+ blocks in (%c1, %c1, %c1)
+ threads in (%c2, %c1, %c1)
+ args(%c0_i1 : i1, %c1_i1 : i1)
+ return
+}
+}
More information about the flang-commits
mailing list