[llvm-branch-commits] [mlir] [mlir][rocdl] Add AMDGPU-specific `cf.assert` lowering (PR #121067)

Maksim Levental via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Dec 24 12:50:14 PST 2024


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/121067

>From 79c96cb4d396e2d43ce3c9f9188a18fff4b4d227 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Tue, 24 Dec 2024 15:27:31 -0500
Subject: [PATCH] [mlir][rocdl] Add AMDGPU-specific `cf.assert` lowering

---
 .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp      | 72 ++++++++++++++++++-
 mlir/test/Integration/GPU/ROCM/assert.mlir    | 39 ++++++++++
 2 files changed, 110 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Integration/GPU/ROCM/assert.mlir

diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index aaf00e51f49416..b68c58937cb2b9 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -195,6 +196,75 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
   }
 };
 
+/// Based on
+/// mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp#AssertOpToAssertfailLowering
+/// Lowering of cf.assert into a conditional llvm.intr.trap plus gpu.printf with
+/// the metadata (filename, fileline, assert msg).
+struct AssertOpToBuiltinTrapLowering
+    : public ConvertOpToLLVMPattern<cf::AssertOp> {
+  using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = assertOp.getLoc();
+
+    // Split blocks and insert conditional branch.
+    // ^before:
+    //   ...
+    //   cf.cond_br %condition, ^after, ^assert
+    // ^assert:
+    //   cf.assert
+    //   cf.br ^after
+    // ^after:
+    //   ...
+    Block *beforeBlock = assertOp->getBlock();
+    Block *assertBlock =
+        rewriter.splitBlock(beforeBlock, assertOp->getIterator());
+    Block *afterBlock =
+        rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
+    rewriter.setInsertionPointToEnd(beforeBlock);
+    rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
+                                      assertBlock);
+    rewriter.setInsertionPointToEnd(assertBlock);
+    rewriter.create<cf::BranchOp>(loc, afterBlock);
+
+    // Continue cf.assert lowering.
+    rewriter.setInsertionPoint(assertOp);
+
+    // Populate file name, file number and function name from the location of
+    // the AssertOp.
+    StringRef fileName = "(unknown)";
+    StringRef funcName = "(unknown)";
+    int32_t fileLine = 0;
+    if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
+      fileName = fileLineColLoc.getFilename().strref();
+      fileLine = fileLineColLoc.getStartLine();
+    } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
+      funcName = nameLoc.getName().strref();
+      if (auto fileLineColLoc =
+              dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
+        fileName = fileLineColLoc.getFilename().strref();
+        fileLine = fileLineColLoc.getStartLine();
+      }
+    }
+
+    Value assertLine =
+        rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), fileLine);
+    // Interpolate the fmt str AOT because current gpu.printf lowering doesn't
+    // handle %s.
+    llvm::Twine fmtStr = fileName + ":%u: " + funcName +
+                         " Device-side assertion `" + assertOp.getMsg() +
+                         "' failed.\n";
+    rewriter.create<gpu::PrintfOp>(assertOp.getLoc(),
+                                   rewriter.getStringAttr(fmtStr),
+                                   ValueRange{assertLine});
+    rewriter.replaceOpWithNewOp<LLVM::Trap>(assertOp);
+
+    return success();
+  }
+};
+
 /// Import the GPU Ops to ROCDL Patterns.
 #include "GPUToROCDL.cpp.inc"
 
@@ -297,7 +367,7 @@ struct LowerGpuOpsToROCDLOpsPass
     populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
     populateMathToLLVMConversionPatterns(converter, llvmPatterns);
     cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
-    cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
+    llvmPatterns.add<AssertOpToBuiltinTrapLowering>(converter);
     populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
     populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
     populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
diff --git a/mlir/test/Integration/GPU/ROCM/assert.mlir b/mlir/test/Integration/GPU/ROCM/assert.mlir
new file mode 100644
index 00000000000000..e1b07d454e61ac
--- /dev/null
+++ b/mlir/test/Integration/GPU/ROCM/assert.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-rocdl{index-bitwidth=32 runtime=HIP}),rocdl-attach-target{chip=%chip})' \
+// RUN: | mlir-opt -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary \
+// RUN: | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_rocm_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void 2>&1 \
+// RUN: | FileCheck %s
+
+// CHECK-DAG: thread 0: print after passing assertion
+// CHECK-DAG: thread 1: print after passing assertion
+// CHECK-DAG: mlir/test/Integration/GPU/ROCM/assert.mlir:{{.*}}: (unknown) Device-side assertion `failing assertion' failed.
+// CHECK-DAG: mlir/test/Integration/GPU/ROCM/assert.mlir:{{.*}}: (unknown) Device-side assertion `failing assertion' failed.
+// CHECK-NOT: print after failing assertion
+
+module attributes {gpu.container_module} {
+gpu.module @kernels {
+gpu.func @test_assert(%c0: i1, %c1: i1) kernel {
+  %0 = gpu.thread_id x
+  cf.assert %c1, "passing assertion"
+  gpu.printf "thread %lld: print after passing assertion\n" %0 : index
+  cf.assert %c0, "failing assertion"
+  gpu.printf "thread %lld: print after failing assertion\n" %0 : index
+  gpu.return
+}
+}
+
+func.func @main() {
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+  %c0_i1 = arith.constant 0 : i1
+  %c1_i1 = arith.constant 1 : i1
+  gpu.launch_func @kernels::@test_assert
+      blocks in (%c1, %c1, %c1)
+      threads in (%c2, %c1, %c1)
+      args(%c0_i1 : i1, %c1_i1 : i1)
+  return
+}
+}



More information about the llvm-branch-commits mailing list