[Mlir-commits] [mlir] [mlir][spirv] Add spirv-to-llvm translation for OpControlBarrier (PR #111864)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 10 09:24:21 PDT 2024


https://github.com/FMarno created https://github.com/llvm/llvm-project/pull/111864

The translation is based on the expected llvm function from the LLVM/SPIRV translation tool

>From b24f03378cf1e741fa3f31a0f8e63b80c11cc79d Mon Sep 17 00:00:00 2001
From: Finlay Marno <finlay.marno at codeplay.com>
Date: Mon, 7 Oct 2024 11:45:52 +0100
Subject: [PATCH] [mlir] Add spirv-to-llvm translation for OpControlBarrier

The translation is based on the expected llvm function from the
LLVM/SPIRV translation tool
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td  |  2 +-
 .../mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td     |  2 +-
 .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp    | 70 ++++++++++++++++++-
 .../SPIRVToLLVM/barrier-ops-to-llvm.mlir      | 23 ++++++
 4 files changed, 94 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
index 1ebea94fced0a3..14593305490661 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
@@ -54,7 +54,7 @@ def SPIRV_ControlBarrierOp : SPIRV_Op<"ControlBarrier", []> {
     #### Example:
 
     ```mlir
-    spirv.ControlBarrier "Workgroup", "Device", "Acquire|UniformMemory"
+    spirv.ControlBarrier <Workgroup>, <Device>, <Acquire|UniformMemory>
     ```
   }];
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
index 71ecabfb444bd0..022cbbbb6720fb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
@@ -1,4 +1,4 @@
-//===-- SPIRVBarrierOps.td - MLIR SPIR-V Barrier Ops -------*- tablegen -*-===//
+//===-- SPIRVMiscOps.td - MLIR SPIR-V Misc Ops -------------*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 74c169c9a7e76a..50d090ddad901f 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1024,6 +1024,71 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
   }
 };
 
+static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
+                                              StringRef name,
+                                              ArrayRef<Type> paramTypes,
+                                              Type resultType) {
+  auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
+      SymbolTable::lookupSymbolIn(symbolTable, name));
+  if (!func) {
+    OpBuilder b(symbolTable->getRegion(0));
+    func = b.create<LLVM::LLVMFuncOp>(
+        symbolTable->getLoc(), name,
+        LLVM::LLVMFunctionType::get(resultType, paramTypes));
+    func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
+    func.setConvergent(true);
+    func.setNoUnwind(true);
+    func.setWillReturn(true);
+  }
+  return func;
+}
+
+static LLVM::CallOp createSPIRVBuiltinCall(Location loc,
+                                           ConversionPatternRewriter &rewriter,
+                                           LLVM::LLVMFuncOp func,
+                                           ValueRange args) {
+  auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
+  call.setCConv(func.getCConv());
+  call.setConvergentAttr(func.getConvergentAttr());
+  call.setNoUnwindAttr(func.getNoUnwindAttr());
+  call.setWillReturnAttr(func.getWillReturnAttr());
+  return call;
+}
+
+class ControlBarrierPattern
+    : public SPIRVToLLVMConversion<spirv::ControlBarrierOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::ControlBarrierOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    constexpr StringRef funcName = "_Z22__spirv_ControlBarrieriii";
+    Operation *symbolTable =
+        controlBarrierOp->getParentWithTrait<OpTrait::SymbolTable>();
+
+    Type i32 = rewriter.getI32Type();
+
+    Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
+    LLVM::LLVMFuncOp func =
+        lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
+
+    auto loc = controlBarrierOp->getLoc();
+    Value execution = rewriter.create<LLVM::ConstantOp>(
+        loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
+    Value memory = rewriter.create<LLVM::ConstantOp>(
+        loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
+    Value semantics = rewriter.create<LLVM::ConstantOp>(
+        loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
+
+    auto call = createSPIRVBuiltinCall(loc, rewriter, func,
+                                       {execution, memory, semantics});
+
+    rewriter.replaceOp(controlBarrierOp, call);
+    return success();
+  }
+};
+
 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
 /// should be reachable for conversion to succeed. The structure of the loop in
 /// LLVM dialect will be the following:
@@ -1648,7 +1713,10 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
 
       // Return ops
-      ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
+      ReturnPattern, ReturnValuePattern,
+
+      // Barrier ops
+      ControlBarrierPattern>(patterns.getContext(), typeConverter);
 
   patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
                                       typeConverter);
diff --git a/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir
new file mode 100644
index 00000000000000..d53afeeea15d10
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.ControlBarrierOp
+//===----------------------------------------------------------------------===//
+
+// CHECK:           llvm.func spir_funccc @_Z22__spirv_ControlBarrieriii(i32, i32, i32) attributes {convergent, no_unwind, will_return}
+
+// CHECK-LABEL: @control_barrier
+spirv.func @control_barrier() "None" {
+  // CHECK:         [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK:         [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK:         [[SEMANTICS:%.*]] = llvm.mlir.constant(768 : i32) : i32
+  // CHECK:         llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
+  spirv.ControlBarrier <Workgroup>, <Workgroup>, <CrossWorkgroupMemory|WorkgroupMemory>
+
+  // CHECK:         [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK:         [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK:         [[SEMANTICS:%.*]] = llvm.mlir.constant(256 : i32) : i32
+  // CHECK:         llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
+  spirv.ControlBarrier <Workgroup>, <Workgroup>, <WorkgroupMemory>
+  spirv.Return
+}



More information about the Mlir-commits mailing list