[Mlir-commits] [mlir] [mlir][GPU] Add `gpu.assert` op (PR #120431)

Matthias Springer llvmlistbot at llvm.org
Wed Dec 18 06:54:26 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/120431

>From 09c1daa54c7d3ac3c17a96f9d75e54de553b3bbf Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 18 Dec 2024 15:52:13 +0100
Subject: [PATCH 1/2] [mlir][GPU] Do not strip debug info when lowering to NVVM

---
 mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
index fb440756e0c1d5..20d7372eef85d5 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToNVVMPipeline.cpp
@@ -73,7 +73,6 @@ void buildCommonPassPipeline(
 //===----------------------------------------------------------------------===//
 void buildGpuPassPipeline(OpPassManager &pm,
                           const mlir::gpu::GPUToNVVMPipelineOptions &options) {
-  pm.addNestedPass<gpu::GPUModuleOp>(createStripDebugInfoPass());
   ConvertGpuOpsToNVVMOpsOptions opt;
   opt.useBarePtrCallConv = options.kernelUseBarePtrCallConv;
   opt.indexBitwidth = options.indexBitWidth;

>From f9f58669aa76c9c8d37ac88de611b85fcf6ec2b9 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 18 Dec 2024 15:42:59 +0100
Subject: [PATCH 2/2] [mlir][GPU] Add `gpu.assert` op

---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    |  11 ++
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   | 122 +++++++++---------
 .../lib/Conversion/GPUCommon/GPUOpsLowering.h |  21 +++
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        | 103 ++++++++++++++-
 .../Conversion/GPUToNVVM/gpu-to-nvvm.mlir     |  29 +++++
 mlir/test/Dialect/GPU/ops.mlir                |   9 ++
 mlir/test/Integration/GPU/CUDA/assert.mlir    |  38 ++++++
 7 files changed, 270 insertions(+), 63 deletions(-)
 create mode 100644 mlir/test/Integration/GPU/CUDA/assert.mlir

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 42a017db300af6..793d663c0322a5 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -38,6 +38,17 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 class GPU_Op<string mnemonic, list<Trait> traits = []> :
     Op<GPU_Dialect, mnemonic, traits>;
 
+def GPU_AssertOp : GPU_Op<"assert"> {
+  let summary = "Device-side assertion";
+  let description = [{
+    The `gpu.assert` op is a device-side assertion. If the given `condition`
+    is 0, the kernel execution is aborted, optionally with the given error
+    message. This op is useful for debugging and verifying invariants.
+  }];
+  let arguments = (ins I1:$condition, OptionalAttr<StrAttr>:$message);
+  let assemblyFormat = "$condition (`,` $message^)? attr-dict";
+}
+
 def GPU_Dimension : I32EnumAttr<"Dimension",
     "a dimension, either 'x', 'y', or 'z'",
     [
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index b3c3fd4956d0bb..544fc57949e24d 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -19,6 +19,59 @@
 
 using namespace mlir;
 
+LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
+                                           Location loc, OpBuilder &b,
+                                           StringRef name,
+                                           LLVM::LLVMFunctionType type) {
+  LLVM::LLVMFuncOp ret;
+  if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
+    OpBuilder::InsertionGuard guard(b);
+    b.setInsertionPointToStart(moduleOp.getBody());
+    ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External);
+  }
+  return ret;
+}
+
+static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
+                                           StringRef prefix) {
+  // Get a unique global name.
+  unsigned stringNumber = 0;
+  SmallString<16> stringConstName;
+  do {
+    stringConstName.clear();
+    (prefix + Twine(stringNumber++)).toStringRef(stringConstName);
+  } while (moduleOp.lookupSymbol(stringConstName));
+  return stringConstName;
+}
+
+LLVM::GlobalOp
+mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
+                                gpu::GPUModuleOp moduleOp, Type llvmI8,
+                                StringRef namePrefix, StringRef str,
+                                uint64_t alignment, unsigned addrSpace) {
+  llvm::SmallString<20> nullTermStr(str);
+  nullTermStr.push_back('\0'); // Null terminate for C
+  auto globalType =
+      LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
+  StringAttr attr = b.getStringAttr(nullTermStr);
+
+  // Try to find existing global.
+  for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
+    if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
+        globalOp.getValueAttr() == attr &&
+        globalOp.getAlignment().value_or(0) == alignment &&
+        globalOp.getAddrSpace() == addrSpace)
+      return globalOp;
+
+  // Not found: create new global.
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPointToStart(moduleOp.getBody());
+  SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
+  return b.create<LLVM::GlobalOp>(loc, globalType,
+                                  /*isConstant=*/true, LLVM::Linkage::Internal,
+                                  name, attr, alignment, addrSpace);
+}
+
 LogicalResult
 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
                                    ConversionPatternRewriter &rewriter) const {
@@ -328,61 +381,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
   return success();
 }
 
-static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
-  const char formatStringPrefix[] = "printfFormat_";
-  // Get a unique global name.
-  unsigned stringNumber = 0;
-  SmallString<16> stringConstName;
-  do {
-    stringConstName.clear();
-    (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
-  } while (moduleOp.lookupSymbol(stringConstName));
-  return stringConstName;
-}
-
-/// Create an global that contains the given format string. If a global with
-/// the same format string exists already in the module, return that global.
-static LLVM::GlobalOp getOrCreateFormatStringConstant(
-    OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
-    StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
-  llvm::SmallString<20> formatString(str);
-  formatString.push_back('\0'); // Null terminate for C
-  auto globalType =
-      LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
-  StringAttr attr = b.getStringAttr(formatString);
-
-  // Try to find existing global.
-  for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
-    if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
-        globalOp.getValueAttr() == attr &&
-        globalOp.getAlignment().value_or(0) == alignment &&
-        globalOp.getAddrSpace() == addrSpace)
-      return globalOp;
-
-  // Not found: create new global.
-  OpBuilder::InsertionGuard guard(b);
-  b.setInsertionPointToStart(moduleOp.getBody());
-  SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
-  return b.create<LLVM::GlobalOp>(loc, globalType,
-                                  /*isConstant=*/true, LLVM::Linkage::Internal,
-                                  name, attr, alignment, addrSpace);
-}
-
-template <typename T>
-static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
-                                            ConversionPatternRewriter &rewriter,
-                                            StringRef name,
-                                            LLVM::LLVMFunctionType type) {
-  LLVM::LLVMFuncOp ret;
-  if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
-    ConversionPatternRewriter::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(moduleOp.getBody());
-    ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
-                                            LLVM::Linkage::External);
-  }
-  return ret;
-}
-
 LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
     gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -420,8 +418,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
   Value printfDesc = printfBeginCall.getResult();
 
   // Create the global op or find an existing one.
-  LLVM::GlobalOp global = getOrCreateFormatStringConstant(
-      rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
+  LLVM::GlobalOp global = getOrCreateStringConstant(
+      rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
 
   // Get a pointer to the format string's first element and pass it to printf()
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -502,9 +500,9 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
       getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
 
   // Create the global op or find an existing one.
-  LLVM::GlobalOp global = getOrCreateFormatStringConstant(
-      rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
-      addressSpace);
+  LLVM::GlobalOp global = getOrCreateStringConstant(
+      rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
+      /*alignment=*/0, addressSpace);
 
   // Get a pointer to the format string's first element
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
@@ -546,8 +544,8 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
       getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
 
   // Create the global op or find an existing one.
-  LLVM::GlobalOp global = getOrCreateFormatStringConstant(
-      rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
+  LLVM::GlobalOp global = getOrCreateStringConstant(
+      rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());
 
   // Get a pointer to the format string's first element
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 444a07a93ca36e..e73a74845d2b66 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -14,6 +14,27 @@
 
 namespace mlir {
 
+//===----------------------------------------------------------------------===//
+// Helper Functions
+//===----------------------------------------------------------------------===//
+
+/// Find or create an external function declaration in the given module.
+LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc,
+                                     OpBuilder &b, StringRef name,
+                                     LLVM::LLVMFunctionType type);
+
+/// Create a global that contains the given string. If a global with the same
+/// string already exists in the module, return that global.
+LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc,
+                                         gpu::GPUModuleOp moduleOp, Type llvmI8,
+                                         StringRef namePrefix, StringRef str,
+                                         uint64_t alignment = 0,
+                                         unsigned addrSpace = 0);
+
+//===----------------------------------------------------------------------===//
+// Lowering Patterns
+//===----------------------------------------------------------------------===//
+
 /// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
 /// create a 0-sized global array symbol similar as LLVM expects. It constructs
 /// a memref descriptor with these values and return it.
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index b343cf71e3a2e7..f95f6ea0ff41c4 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -25,6 +25,7 @@
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -236,6 +237,105 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
   }
 };
 
+/// Lowering of gpu.assert into a conditional __assertfail.
+struct GPUAssertOpToAssertfailLowering
+    : public ConvertOpToLLVMPattern<gpu::AssertOp> {
+  using ConvertOpToLLVMPattern<gpu::AssertOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::AssertOp assertOp, gpu::AssertOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    MLIRContext *ctx = rewriter.getContext();
+    Location loc = assertOp.getLoc();
+    Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
+    Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
+    Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
+    Type ptrType = LLVM::LLVMPointerType::get(ctx);
+    Type voidType = LLVM::LLVMVoidType::get(ctx);
+
+    // Find or create __assertfail function declaration.
+    auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
+    auto assertfailType = LLVM::LLVMFunctionType::get(
+        voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
+    LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
+        moduleOp, loc, rewriter, "__assertfail", assertfailType);
+    assertfailDecl.setPassthroughAttr(
+        ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
+
+    // Split blocks and insert conditional branch.
+    // ^before:
+    //   ...
+    //   cf.cond_br %condition, ^after, ^assert
+    // ^assert:
+    //   gpu.assert
+    //   cf.br ^after
+    // ^after:
+    //   ...
+    Block *beforeBlock = assertOp->getBlock();
+    Block *assertBlock =
+        rewriter.splitBlock(beforeBlock, assertOp->getIterator());
+    Block *afterBlock =
+        rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
+    rewriter.setInsertionPointToEnd(beforeBlock);
+    rewriter.create<cf::CondBranchOp>(loc, adaptor.getCondition(), afterBlock,
+                                      assertBlock);
+    rewriter.setInsertionPointToEnd(assertBlock);
+    rewriter.create<cf::BranchOp>(loc, afterBlock);
+
+    // Continue gpu.assert lowering.
+    rewriter.setInsertionPoint(assertOp);
+
+    // Populate file name, file number and function name from the location of
+    // the AssertOp.
+    StringRef fileName = "(unknown)";
+    StringRef funcName = "(unknown)";
+    int32_t fileLine = 0;
+    if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
+      fileName = fileLineColLoc.getFilename().strref();
+      fileLine = fileLineColLoc.getStartLine();
+    } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
+      funcName = nameLoc.getName().strref();
+      if (auto fileLineColLoc =
+              dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
+        fileName = fileLineColLoc.getFilename().strref();
+        fileLine = fileLineColLoc.getStartLine();
+      }
+    }
+    // Extract message string.
+    StringRef message = "";
+    if (assertOp.getMessage().has_value())
+      message = *assertOp.getMessage();
+
+    // Create constants.
+    auto getGlobal = [&](LLVM::GlobalOp global) {
+      // Get a pointer to the format string's first element.
+      Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
+          loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
+          global.getSymNameAttr());
+      Value start =
+          rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+                                       globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+      return start;
+    };
+    Value assertMessage = getGlobal(getOrCreateStringConstant(
+        rewriter, loc, moduleOp, i8Type, "assert_message_", message));
+    Value assertFile = getGlobal(getOrCreateStringConstant(
+        rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
+    Value assertFunc = getGlobal(getOrCreateStringConstant(
+        rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
+    Value assertLine =
+        rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
+    Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
+
+    // Insert function call to __assertfail.
+    SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
+                                 assertFunc, c1};
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
+                                              arguments);
+    return success();
+  }
+};
+
 /// Import the GPU Ops to NVVM Patterns.
 #include "GPUToNVVM.cpp.inc"
 
@@ -358,7 +458,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
   using gpu::index_lowering::IndexKind;
   using gpu::index_lowering::IntrType;
   populateWithGenerated(patterns);
-  patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
+  patterns.add<GPUPrintfOpToVPrintfLowering, GPUAssertOpToAssertfailLowering>(
+      converter);
   patterns.add<
       gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
                                       NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 748dfe8c68fc7e..7e50734350be2a 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -969,6 +969,35 @@ gpu.module @test_module_50 {
   }
 }
 
+// CHECK-LABEL: gpu.module @test_module_51
+//       CHECK:   llvm.mlir.global internal constant @[[func_name:.*]]("(unknown)\00") {addr_space = 0 : i32}
+//       CHECK:   llvm.mlir.global internal constant @[[file_name:.*]]("within split at mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir:1 offset \00") {addr_space = 0 : i32}
+//       CHECK:   llvm.mlir.global internal constant @[[message:.*]]("assert message\00") {addr_space = 0 : i32}
+//       CHECK:   llvm.func @__assertfail(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) attributes {passthrough = ["noreturn"]}
+//       CHECK:   llvm.func @test_assert(%[[cond:.*]]: i1) attributes {gpu.kernel, nvvm.kernel} {
+//       CHECK:     llvm.cond_br %[[cond]], ^[[assert_block:.*]], ^[[after_block:.*]]
+//       CHECK:   ^[[assert_block]]:  // pred: ^bb0
+//       CHECK:     %[[message_ptr:.*]] = llvm.mlir.addressof @[[message]] : !llvm.ptr
+//       CHECK:     %[[message_start:.*]] = llvm.getelementptr %[[message_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<15 x i8>
+//       CHECK:     %[[file_ptr:.*]] = llvm.mlir.addressof @[[file_name]] : !llvm.ptr
+//       CHECK:     %[[file_start:.*]] = llvm.getelementptr %[[file_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<74 x i8>
+//       CHECK:     %[[func_ptr:.*]] = llvm.mlir.addressof @[[func_name]] : !llvm.ptr
+//       CHECK:     %[[func_start:.*]] = llvm.getelementptr %[[func_ptr]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i8>
+//       CHECK:     %[[line_num:.*]] = llvm.mlir.constant({{.*}} : i32) : i32
+//       CHECK:     %[[ptr:.*]] = llvm.mlir.constant(1 : i64) : i64
+//       CHECK:     llvm.call @__assertfail(%[[message_start]], %[[file_start]], %[[line_num]], %[[func_start]], %[[ptr]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, i64) -> ()
+//       CHECK:     llvm.br ^[[after_block]]
+//       CHECK:   ^[[after_block]]:
+//       CHECK:     llvm.return
+//       CHECK:   }
+
+gpu.module @test_module_51 {
+  gpu.func @test_assert(%arg0: i1) kernel {
+    gpu.assert %arg0, "assert message"
+    gpu.return
+  }
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
     %gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index c0ff2044b76c40..d74ff9482b911d 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -500,3 +500,12 @@ func.func @warp_operand_result(%laneid: index, %v0 : vector<4xi32>) -> (vector<4
   }
   return %2 : vector<4xi32>
 }
+
+// CHECK-LABEL: func @test_assert(
+func.func @test_assert(%cond : i1) {
+  // CHECK: gpu.assert %{{.*}}, "message"
+  gpu.assert %cond, "message"
+  // CHECK: gpu.assert %{{.*}}
+  gpu.assert %cond
+  return
+}
diff --git a/mlir/test/Integration/GPU/CUDA/assert.mlir b/mlir/test/Integration/GPU/CUDA/assert.mlir
new file mode 100644
index 00000000000000..40745f26af16cf
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/assert.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void 2>&1 \
+// RUN: | FileCheck %s
+
+// CHECK-DAG: thread 0: print after passing assertion
+// CHECK-DAG: thread 1: print after passing assertion
+// CHECK-DAG: mlir/test/Integration/GPU/CUDA/assert.mlir:{{.*}}: (unknown): block: [0,0,0], thread: [0,0,0] Assertion `failing assertion` failed.
+// CHECK-DAG: mlir/test/Integration/GPU/CUDA/assert.mlir:{{.*}}: (unknown): block: [0,0,0], thread: [1,0,0] Assertion `failing assertion` failed.
+// CHECK-NOT: print after failing assertion
+
+module attributes {gpu.container_module} {
+gpu.module @kernels {
+gpu.func @hello(%c0: i1, %c1: i1) kernel {
+  %0 = gpu.thread_id x
+  gpu.assert %c1, "passing assertion"
+  gpu.printf "thread %lld: print after passing assertion\n" %0 : index
+  gpu.assert %c0, "failing assertion"
+  gpu.printf "thread %lld: print after failing assertion\n" %0 : index
+  gpu.return
+}
+}
+
+func.func @main() {
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+  %c0_i1 = arith.constant 0 : i1
+  %c1_i1 = arith.constant 1 : i1
+  gpu.launch_func @kernels::@hello
+      blocks in (%c1, %c1, %c1)
+      threads in (%c2, %c1, %c1)
+      args(%c0_i1 : i1, %c1_i1 : i1)
+  return
+}
+}



More information about the Mlir-commits mailing list