[flang-commits] [flang] [mlir] [mlir][GPU] Add NVVM-specific `cf.assert` lowering (PR #120431)

Matthias Springer via flang-commits flang-commits at lists.llvm.org
Fri Dec 20 05:32:25 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/120431

>From d75a6ed0265555614219e09bed8f0fae6f0f52c9 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 18 Dec 2024 15:42:59 +0100
Subject: [PATCH] [mlir][GPU] Add `gpu.assert` op

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       |   1 +
 mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp    |   1 +
 .../ControlFlowToLLVM/ControlFlowToLLVM.h     |   4 +
 .../ControlFlowToLLVM/ControlFlowToLLVM.cpp   |   3 +-
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   | 122 +++++++++---------
 .../lib/Conversion/GPUCommon/GPUOpsLowering.h |  21 +++
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        |  99 +++++++++++++-
 .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp      |   1 +
 .../Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp  |   1 +
 .../Conversion/GPUToNVVM/gpu-to-nvvm.mlir     |  29 +++++
 mlir/test/Integration/GPU/CUDA/assert.mlir    |  37 ++++++
 11 files changed, 255 insertions(+), 64 deletions(-)
 create mode 100644 mlir/test/Integration/GPU/CUDA/assert.mlir

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 2f4cd84dda7dec..036432aea675e3 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3917,6 +3917,7 @@ class FIRToLLVMLowering
     mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, pattern);
     mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
                                                           pattern);
+    mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, pattern);
     // Math operations that have not been converted yet must be converted
     // to Libm.
     if (!isAMDGCN)
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 3ad70e7279692b..123d114ae16359 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -220,6 +220,7 @@ void ToyToLLVMLoweringPass::runOnOperation() {
   mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
   populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
   cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
+  cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
   populateFuncToLLVMConversionPatterns(typeConverter, patterns);
 
   // The only remaining operation to lower from the `toy` dialect, is the
diff --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
index b88c1e8b20f32b..88f18022da9bb1 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
@@ -29,6 +29,10 @@ namespace cf {
 /// Collect the patterns to convert from the ControlFlow dialect to LLVM. The
 /// conversion patterns capture the LLVMTypeConverter by reference meaning the
 /// references have to remain alive during the entire pattern lifetime.
+///
+/// Note: This function does not populate the default cf.assert lowering. That
+/// is because some platforms have a custom cf.assert lowering. The default
+/// lowering can be populated with `populateAssertToLLVMConversionPattern`.
 void populateControlFlowToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns);
 
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index 8672e7b849d9de..d0ffb94f3f96a9 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -215,7 +215,6 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   // clang-format off
   patterns.add<
-      AssertOpLowering,
       BranchOpLowering,
       CondBranchOpLowering,
       SwitchOpLowering>(converter);
@@ -258,6 +257,7 @@ struct ConvertControlFlowToLLVM
     LLVMTypeConverter converter(ctx, options);
     RewritePatternSet patterns(ctx);
     mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
+    mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns);
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -286,6 +286,7 @@ struct ControlFlowToLLVMDialectInterface
       RewritePatternSet &patterns) const final {
     mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
                                                           patterns);
+    mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
   }
 };
 } // namespace
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index b3c3fd4956d0bb..544fc57949e24d 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -19,6 +19,59 @@
 
 using namespace mlir;
 
+LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
+                                           Location loc, OpBuilder &b,
+                                           StringRef name,
+                                           LLVM::LLVMFunctionType type) {
+  LLVM::LLVMFuncOp ret;
+  if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+    OpBuilder::InsertionGuard guard(b);
+    b.setInsertionPointToStart(moduleOp.getBody());
+    ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
+  }
+  return ret;
+}
+
+static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
+                                           StringRef prefix) {
+  // Get a unique global name.
+  unsigned stringNumber = 0;
+  SmallString<16> stringConstName;
+  do {
+    stringConstName.clear();
+    (prefix + Twine(stringNumber++)).toStringRef(stringConstName);
+  } while (moduleOp.lookupSymbol(stringConstName));
+  return stringConstName;
+}
+
+LLVM::GlobalOp
+mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
+                                gpu::GPUModuleOp moduleOp, Type llvmI8,
+                                StringRef namePrefix, StringRef str,
+                                uint64_t alignment, unsigned addrSpace) {
+  llvm::SmallString<20> nullTermStr(str);
+  nullTermStr.push_back('\0'); // Null terminate for C
+  auto globalType =
+      LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
+  StringAttr attr = b.getStringAttr(nullTermStr);
+
+  // Try to find existing global.
+  for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
+    if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
+        globalOp.getValueAttr() == attr &&
+        globalOp.getAlignment().value_or(0) == alignment &&
+        globalOp.getAddrSpace() == addrSpace)
+      return globalOp;
+
+  // Not found: create new global.
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPointToStart(moduleOp.getBody());
+  SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
+  return b.create<LLVM::GlobalOp>(loc, globalType,
+                                  /*isConstant=*/true, LLVM::Linkage::Internal,
+                                  name, attr, alignment, addrSpace);
+}
+
 LogicalResult
 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
                                    ConversionPatternRewriter &rewriter) const {
@@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
   return success();
 }
 
-static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
-  const char formatStringPrefix[] = "printfFormat_";
-  // Get a unique global name.
-  unsigned stringNumber = 0;
-  SmallString<16> stringConstName;
-  do {
-    stringConstName.clear();
-    (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
-  } while (moduleOp.lookupSymbol(stringConstName));
-  return stringConstName;
-}
-
-/// Create an global that contains the given format string. If a global with
-/// the same format string exists already in the module, return that global.
-static LLVM::GlobalOp getOrCreateFormatStringConstant(
-    OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
-    StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
-  llvm::SmallString<20> formatString(str);
-  formatString.push_back('\0'); // Null terminate for C
-  auto globalType =
-      LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
-  StringAttr attr = b.getStringAttr(formatString);
-
-  // Try to find existing global.
-  for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
-    if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
-        globalOp.getValueAttr() == attr &&
-        globalOp.getAlignment().value_or(0) == alignment &&
-        globalOp.getAddrSpace() == addrSpace)
-      return globalOp;
-
-  // Not found: create new global.
-  OpBuilder::InsertionGuard guard(b);
-  b.setInsertionPointToStart(moduleOp.getBody());
-  SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
-  return b.create<LLVM::GlobalOp>(loc, globalType,
-                                  /*isConstant=*/true, LLVM::Linkage::Internal,
-                                  name, attr, alignment, addrSpace);
-}
-
-template <typename T>
-static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
-                                            ConversionPatternRewriter &rewriter,
-                                            StringRef name,
-                                            LLVM::LLVMFunctionType type) {
-  LLVM::LLVMFuncOp ret;
-  if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
-    ConversionPatternRewriter::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(moduleOp.getBody());
-    ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
-                                            LLVM::Linkage::External);
-  }
-  return ret;
-}
-
 LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
     gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
   Value printfDesc = printfBeginCall.getResult();
 
   // Create the global op or find an existing one.
-  LLVM::GlobalOp global = getOrCreateFormatStringConstant(
-      rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
+  LLVM::GlobalOp global = getOrCreateStringConstant(
+      rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
 
   // Get a pointer to the format string's first element and pass it to printf()
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
       getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
 
   // Create the global op or find an existing one.
-  LLVM::GlobalOp global = getOrCreateFormatStringConstant(
-      rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
-      addressSpace);
+  LLVM::GlobalOp global = getOrCreateStringConstant(
+      rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
+      /*alignment=*/0, addressSpace);
 
   // Get a pointer to the format string's first element
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
       getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
 
   // Create the global op or find an existing one.
-  LLVM::GlobalOp global = getOrCreateFormatStringConstant(
-      rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
+  LLVM::GlobalOp global = getOrCreateStringConstant(
+      rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
 
   // Get a pointer to the format string's first element
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 444a07a93ca36e..e73a74845d2b66 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -14,6 +14,27 @@
 
 namespace mlir {
 
+//===----------------------------------------------------------------------===//
+// Helper Functions
+//===----------------------------------------------------------------------===//
+
+/// Find or create an external function declaration in the given module.
+LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc,
+                                     OpBuilder &b, StringRef name,
+                                     LLVM::LLVMFunctionType type);
+
+/// Create a global that contains the given string. If a global with the same
+/// string already exists in the module, return that global.
+LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc,
+                                         gpu::GPUModuleOp moduleOp, Type llvmI8,
+                                         StringRef namePrefix, StringRef str,
+                                         uint64_t alignment = 0,
+                                         unsigned addrSpace = 0);
+
+//===----------------------------------------------------------------------===//
+// Lowering Patterns
+//===----------------------------------------------------------------------===//
+
 /// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
 /// create a 0-sized global array symbol similar as LLVM expects. It constructs
 /// a memref descriptor with these values and return it.
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index b343cf71e3a2e7..44a36a0502bee4 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -25,6 +25,7 @@
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -236,6 +237,101 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
   }
 };
 
+/// Lowering of cf.assert into a conditional __assertfail.
+struct AssertOpToAssertfailLowering
+    : public ConvertOpToLLVMPattern<cf::AssertOp> {
+  using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    MLIRContext *ctx = rewriter.getContext();
+    Location loc = assertOp.getLoc();
+    Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
+    Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
+    Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
+    Type ptrType = LLVM::LLVMPointerType::get(ctx);
+    Type voidType = LLVM::LLVMVoidType::get(ctx);
+
+    // Find or create __assertfail function declaration.
+    auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
+    auto assertfailType = LLVM::LLVMFunctionType::get(
+        voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
+    LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
+        moduleOp, loc, rewriter, "__assertfail", assertfailType);
+    assertfailDecl.setPassthroughAttr(
+        ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
+
+    // Split blocks and insert conditional branch.
+    // ^before:
+    //   ...
+    //   cf.cond_br %condition, ^after, ^assert
+    // ^assert:
+    //   cf.assert
+    //   cf.br ^after
+    // ^after:
+    //   ...
+    Block *beforeBlock = assertOp->getBlock();
+    Block *assertBlock =
+        rewriter.splitBlock(beforeBlock, assertOp->getIterator());
+    Block *afterBlock =
+        rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
+    rewriter.setInsertionPointToEnd(beforeBlock);
+    rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
+                                      assertBlock);
+    rewriter.setInsertionPointToEnd(assertBlock);
+    rewriter.create<cf::BranchOp>(loc, afterBlock);
+
+    // Continue cf.assert lowering.
+    rewriter.setInsertionPoint(assertOp);
+
+    // Populate file name, file number and function name from the location of
+    // the AssertOp.
+    StringRef fileName = "(unknown)";
+    StringRef funcName = "(unknown)";
+    int32_t fileLine = 0;
+    if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
+      fileName = fileLineColLoc.getFilename().strref();
+      fileLine = fileLineColLoc.getStartLine();
+    } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
+      funcName = nameLoc.getName().strref();
+      if (auto fileLineColLoc =
+              dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
+        fileName = fileLineColLoc.getFilename().strref();
+        fileLine = fileLineColLoc.getStartLine();
+      }
+    }
+
+    // Create constants.
+    auto getGlobal = [&](LLVM::GlobalOp global) {
+      // Get a pointer to the format string's first element.
+      Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
+          loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
+          global.getSymNameAttr());
+      Value start =
+          rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+                                       globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+      return start;
+    };
+    Value assertMessage = getGlobal(getOrCreateStringConstant(
+        rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
+    Value assertFile = getGlobal(getOrCreateStringConstant(
+        rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
+    Value assertFunc = getGlobal(getOrCreateStringConstant(
+        rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
+    Value assertLine =
+        rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
+    Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
+
+    // Insert function call to __assertfail.
+    SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
+                                 assertFunc, c1};
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
+                                              arguments);
+    return success();
+  }
+};
+
 /// Import the GPU Ops to NVVM Patterns.
 #include "GPUToNVVM.cpp.inc"
 
@@ -358,7 +454,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
   using gpu::index_lowering::IndexKind;
   using gpu::index_lowering::IntrType;
   populateWithGenerated(patterns);
-  patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
+  patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
+      converter);
   patterns.add<
       gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
                                       NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index aa4d3b70329fba..aaf00e51f49416 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -297,6 +297,7 @@ struct LowerGpuOpsToROCDLOpsPass
     populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
     populateMathToLLVMConversionPatterns(converter, llvmPatterns);
     cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
+    cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
     populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
     populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
     populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 58fd3d565fce50..5d0003911bca87 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -304,6 +304,7 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
   LLVMTypeConverter converter(&getContext());
   arith::populateArithToLLVMConversionPatterns(converter, patterns);
   cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
+  cf::populateAssertToLLVMConversionPattern(converter, patterns);
   populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
   populateFuncToLLVMConversionPatterns(converter, patterns);
   populateOpenMPToLLVMConversionPatterns(converter, patterns);
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 748dfe8c68fc7e..a10f9fb51a3b8c 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -969,6 +969,35 @@ gpu.module @test_module_50 {
   }
 }
 
+// CHECK-LABEL: gpu.module @test_module_51
+//       CHECK:   llvm.mlir.global internal constant @[[func_name:.*]]("(unknown)\00") {addr_space = 0 : i32}
+//       CHECK:   llvm.mlir.global internal constant @[[file_name:.*]]("within split at {{.*}}gpu-to-nvvm.mlir:1 offset \00") {addr_space = 0 : i32}
+//       CHECK:   llvm.mlir.global internal constant @[[message:.*]]("assert message\00") {addr_space = 0 : i32}
+//       CHECK:   llvm.func @__assertfail(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) attributes {passthrough = ["noreturn"]}
+//       CHECK:   llvm.func @test_assert(%[[cond:.*]]: i1) attributes {gpu.kernel, nvvm.kernel} {
+//       CHECK:     llvm.cond_br %[[cond]], ^[[after_block:.*]], ^[[assert_block:.*]]
+//       CHECK:   ^[[assert_block]]:
+//       CHECK:     %[[message_ptr:.*]] = llvm.mlir.addressof @[[message]] : !llvm.ptr
+//       CHECK:     %[[message_start:.*]] = llvm.getelementptr %[[message_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<15 x i8>
+//       CHECK:     %[[file_ptr:.*]] = llvm.mlir.addressof @[[file_name]] : !llvm.ptr
+//       CHECK:     %[[file_start:.*]] = llvm.getelementptr %[[file_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{.*}} x i8>
+//       CHECK:     %[[func_ptr:.*]] = llvm.mlir.addressof @[[func_name]] : !llvm.ptr
+//       CHECK:     %[[func_start:.*]] = llvm.getelementptr %[[func_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{.*}} x i8>
+//       CHECK:     %[[line_num:.*]] = llvm.mlir.constant({{.*}} : i32) : i32
+//       CHECK:     %[[ptr:.*]] = llvm.mlir.constant(1 : i64) : i64
+//       CHECK:     llvm.call @__assertfail(%[[message_start]], %[[file_start]], %[[line_num]], %[[func_start]], %[[ptr]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) -> ()
+//       CHECK:     llvm.br ^[[after_block]]
+//       CHECK:   ^[[after_block]]:
+//       CHECK:     llvm.return
+//       CHECK:   }
+
+gpu.module @test_module_51 {
+  gpu.func @test_assert(%arg0: i1) kernel {
+    cf.assert %arg0, "assert message"
+    gpu.return
+  }
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
     %gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module
diff --git a/mlir/test/Integration/GPU/CUDA/assert.mlir b/mlir/test/Integration/GPU/CUDA/assert.mlir
new file mode 100644
index 00000000000000..c8e8e2cbb60576
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/assert.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void 2>&1 \
+// RUN: | FileCheck %s
+
+// CHECK-DAG: thread 0: print after passing assertion
+// CHECK-DAG: thread 1: print after passing assertion
+// CHECK-DAG: mlir/test/Integration/GPU/CUDA/assert.mlir:{{.*}}: (unknown): block: [0,0,0], thread: [0,0,0] Assertion `failing assertion` failed.
+// CHECK-DAG: mlir/test/Integration/GPU/CUDA/assert.mlir:{{.*}}: (unknown): block: [0,0,0], thread: [1,0,0] Assertion `failing assertion` failed.
+// CHECK-NOT: print after failing assertion
+
+module attributes {gpu.container_module} {
+gpu.module @kernels {
+gpu.func @test_assert(%c0: i1, %c1: i1) kernel {
+  %0 = gpu.thread_id x
+  cf.assert %c1, "passing assertion"
+  gpu.printf "thread %lld: print after passing assertion\n" %0 : index
+  cf.assert %c0, "failing assertion"
+  gpu.printf "thread %lld: print after failing assertion\n" %0 : index
+  gpu.return
+}
+}
+
+func.func @main() {
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+  %c0_i1 = arith.constant 0 : i1
+  %c1_i1 = arith.constant 1 : i1
+  gpu.launch_func @kernels::@test_assert
+      blocks in (%c1, %c1, %c1)
+      threads in (%c2, %c1, %c1)
+      args(%c0_i1 : i1, %c1_i1 : i1)
+  return
+}
+}



More information about the flang-commits mailing list