[llvm-branch-commits] [mlir] [mlir][rocdl] Add AMDGPU-specific `cf.assert` lowering (PR #121067)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Dec 24 12:32:13 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

This commit adds an AMD-specific lowering of `cf.assert` to `llvm.intr.trap`. It depends on (and is therefore merging into) https://github.com/llvm/llvm-project/pull/120431.

Note, the reason for lowering to `llvm.intr.trap` instead of `__assertfail` / `__assert_fail` is because these are [header only functions in HIP](https://github.com/ROCm/clr/blob/dff8197b1dc06913080c54f798eb1379914192b8/hipamd/include/hip/amd_detail/hip_assert.h#L44-L83). 



---
Full diff: https://github.com/llvm/llvm-project/pull/121067.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+71-1) 
- (added) mlir/test/Integration/GPU/ROCM/assert.mlir (+37) 


``````````diff
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index aaf00e51f49416..2b0d16a5defb1f 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -195,6 +196,75 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
   }
 };
 
+/// Based on
+/// mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp#AssertOpToAssertfailLowering
+/// Lowering of cf.assert into a conditional llvm.intr.trap plus gpu.printf with
+/// the metadata (filename, fileline, assert msg).
+struct AssertOpToBuiltinTrapLowering
+    : public ConvertOpToLLVMPattern<cf::AssertOp> {
+  using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = assertOp.getLoc();
+
+    // Split blocks and insert conditional branch.
+    // ^before:
+    //   ...
+    //   cf.cond_br %condition, ^after, ^assert
+    // ^assert:
+    //   cf.assert
+    //   cf.br ^after
+    // ^after:
+    //   ...
+    Block *beforeBlock = assertOp->getBlock();
+    Block *assertBlock =
+        rewriter.splitBlock(beforeBlock, assertOp->getIterator());
+    Block *afterBlock =
+        rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
+    rewriter.setInsertionPointToEnd(beforeBlock);
+    rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
+                                      assertBlock);
+    rewriter.setInsertionPointToEnd(assertBlock);
+    rewriter.create<cf::BranchOp>(loc, afterBlock);
+
+    // Continue cf.assert lowering.
+    rewriter.setInsertionPoint(assertOp);
+
+    // Populate file name, file number and function name from the location of
+    // the AssertOp.
+    StringRef fileName = "(unknown)";
+    StringRef funcName = "(unknown)";
+    int32_t fileLine = 0;
+    if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
+      fileName = fileLineColLoc.getFilename().strref();
+      fileLine = fileLineColLoc.getStartLine();
+    } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
+      funcName = nameLoc.getName().strref();
+      if (auto fileLineColLoc =
+              dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
+        fileName = fileLineColLoc.getFilename().strref();
+        fileLine = fileLineColLoc.getStartLine();
+      }
+    }
+
+    Value assertLine =
+        rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), fileLine);
+    // interpolate the fmt str AOT because current gpu.printf lowering doesn't
+    // handle %s
+    llvm::Twine fmtStr = fileName + ":%u: " + funcName +
+                         " Device-side assertion `" + assertOp.getMsg() +
+                         "' failed.\n";
+    rewriter.create<gpu::PrintfOp>(assertOp.getLoc(),
+                                   rewriter.getStringAttr(fmtStr),
+                                   ValueRange{assertLine});
+    rewriter.replaceOpWithNewOp<LLVM::Trap>(assertOp);
+
+    return success();
+  }
+};
+
 /// Import the GPU Ops to ROCDL Patterns.
 #include "GPUToROCDL.cpp.inc"
 
@@ -297,7 +367,7 @@ struct LowerGpuOpsToROCDLOpsPass
     populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
     populateMathToLLVMConversionPatterns(converter, llvmPatterns);
     cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
-    cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
+    llvmPatterns.add<AssertOpToBuiltinTrapLowering>(converter);
     populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
     populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
     populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
diff --git a/mlir/test/Integration/GPU/ROCM/assert.mlir b/mlir/test/Integration/GPU/ROCM/assert.mlir
new file mode 100644
index 00000000000000..0c292d1b02473e
--- /dev/null
+++ b/mlir/test/Integration/GPU/ROCM/assert.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void 2>&1 \
+// RUN: | FileCheck %s
+
+// CHECK-DAG: thread 0: print after passing assertion
+// CHECK-DAG: thread 1: print after passing assertion
+// CHECK-DAG: mlir/test/Integration/GPU/ROCM/assert.mlir:{{.*}}: (unknown) Device-side assertion `failing assertion' failed.
+// CHECK-DAG: mlir/test/Integration/GPU/ROCM/assert.mlir:{{.*}}: (unknown) Device-side assertion `failing assertion' failed.
+// CHECK-NOT: print after failing assertion
+
+module attributes {gpu.container_module} {
+gpu.module @kernels {
+gpu.func @test_assert(%c0: i1, %c1: i1) kernel {
+  %0 = gpu.thread_id x
+  cf.assert %c1, "passing assertion"
+  gpu.printf "thread %lld: print after passing assertion\n" %0 : index
+  cf.assert %c0, "failing assertion"
+  gpu.printf "thread %lld: print after failing assertion\n" %0 : index
+  gpu.return
+}
+}
+
+func.func @main() {
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+  %c0_i1 = arith.constant 0 : i1
+  %c1_i1 = arith.constant 1 : i1
+  gpu.launch_func @kernels::@test_assert
+      blocks in (%c1, %c1, %c1)
+      threads in (%c2, %c1, %c1)
+      args(%c0_i1 : i1, %c1_i1 : i1)
+  return
+}
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/121067


More information about the llvm-branch-commits mailing list