[Mlir-commits] [mlir] [mlir][spirv] Add spirv-to-llvm conversion for OpControlBarrier (PR #111864)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 11 08:28:56 PDT 2024


https://github.com/FMarno updated https://github.com/llvm/llvm-project/pull/111864

>From b24f03378cf1e741fa3f31a0f8e63b80c11cc79d Mon Sep 17 00:00:00 2001
From: Finlay Marno <finlay.marno at codeplay.com>
Date: Mon, 7 Oct 2024 11:45:52 +0100
Subject: [PATCH 1/2] [mlir] Add spirv-to-llvm translation for OpControlBarrier

The translation is based on the expected llvm function from the
LLVM/SPIRV translation tool
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td  |  2 +-
 .../mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td     |  2 +-
 .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp    | 70 ++++++++++++++++++-
 .../SPIRVToLLVM/barrier-ops-to-llvm.mlir      | 23 ++++++
 4 files changed, 94 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
index 1ebea94fced0a3..14593305490661 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
@@ -54,7 +54,7 @@ def SPIRV_ControlBarrierOp : SPIRV_Op<"ControlBarrier", []> {
     #### Example:
 
     ```mlir
-    spirv.ControlBarrier "Workgroup", "Device", "Acquire|UniformMemory"
+    spirv.ControlBarrier <Workgroup>, <Device>, <Acquire|UniformMemory>
     ```
   }];
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
index 71ecabfb444bd0..022cbbbb6720fb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
@@ -1,4 +1,4 @@
-//===-- SPIRVBarrierOps.td - MLIR SPIR-V Barrier Ops -------*- tablegen -*-===//
+//===-- SPIRVMiscOps.td - MLIR SPIR-V Misc Ops -------------*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 74c169c9a7e76a..50d090ddad901f 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1024,6 +1024,71 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
   }
 };
 
+static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
+                                              StringRef name,
+                                              ArrayRef<Type> paramTypes,
+                                              Type resultType) {
+  auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
+      SymbolTable::lookupSymbolIn(symbolTable, name));
+  if (!func) {
+    OpBuilder b(symbolTable->getRegion(0));
+    func = b.create<LLVM::LLVMFuncOp>(
+        symbolTable->getLoc(), name,
+        LLVM::LLVMFunctionType::get(resultType, paramTypes));
+    func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
+    func.setConvergent(true);
+    func.setNoUnwind(true);
+    func.setWillReturn(true);
+  }
+  return func;
+}
+
+static LLVM::CallOp createSPIRVBuiltinCall(Location loc,
+                                           ConversionPatternRewriter &rewriter,
+                                           LLVM::LLVMFuncOp func,
+                                           ValueRange args) {
+  auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
+  call.setCConv(func.getCConv());
+  call.setConvergentAttr(func.getConvergentAttr());
+  call.setNoUnwindAttr(func.getNoUnwindAttr());
+  call.setWillReturnAttr(func.getWillReturnAttr());
+  return call;
+}
+
+class ControlBarrierPattern
+    : public SPIRVToLLVMConversion<spirv::ControlBarrierOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::ControlBarrierOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    constexpr StringRef funcName = "_Z22__spirv_ControlBarrieriii";
+    Operation *symbolTable =
+        controlBarrierOp->getParentWithTrait<OpTrait::SymbolTable>();
+
+    Type i32 = rewriter.getI32Type();
+
+    Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
+    LLVM::LLVMFuncOp func =
+        lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
+
+    auto loc = controlBarrierOp->getLoc();
+    Value execution = rewriter.create<LLVM::ConstantOp>(
+        loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
+    Value memory = rewriter.create<LLVM::ConstantOp>(
+        loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
+    Value semantics = rewriter.create<LLVM::ConstantOp>(
+        loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
+
+    auto call = createSPIRVBuiltinCall(loc, rewriter, func,
+                                       {execution, memory, semantics});
+
+    rewriter.replaceOp(controlBarrierOp, call);
+    return success();
+  }
+};
+
 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
 /// should be reachable for conversion to succeed. The structure of the loop in
 /// LLVM dialect will be the following:
@@ -1648,7 +1713,10 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
 
       // Return ops
-      ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
+      ReturnPattern, ReturnValuePattern,
+
+      // Barrier ops
+      ControlBarrierPattern>(patterns.getContext(), typeConverter);
 
   patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
                                       typeConverter);
diff --git a/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir
new file mode 100644
index 00000000000000..d53afeeea15d10
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.ControlBarrierOp
+//===----------------------------------------------------------------------===//
+
+// CHECK:           llvm.func spir_funccc @_Z22__spirv_ControlBarrieriii(i32, i32, i32) attributes {convergent, no_unwind, will_return}
+
+// CHECK-LABEL: @control_barrier
+spirv.func @control_barrier() "None" {
+  // CHECK:         [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK:         [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK:         [[SEMANTICS:%.*]] = llvm.mlir.constant(768 : i32) : i32
+  // CHECK:         llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
+  spirv.ControlBarrier <Workgroup>, <Workgroup>, <CrossWorkgroupMemory|WorkgroupMemory>
+
+  // CHECK:         [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK:         [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK:         [[SEMANTICS:%.*]] = llvm.mlir.constant(256 : i32) : i32
+  // CHECK:         llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
+  spirv.ControlBarrier <Workgroup>, <Workgroup>, <WorkgroupMemory>
+  spirv.Return
+}

>From 36b51fe57e76f213d06df8f420822372337ff632 Mon Sep 17 00:00:00 2001
From: Finlay Marno <finlay.marno at codeplay.com>
Date: Fri, 11 Oct 2024 16:28:37 +0100
Subject: [PATCH 2/2] fixup! [mlir] Add spirv-to-llvm translation for
 OpControlBarrier

---
 .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp    | 28 +++++++++----------
 1 file changed, 14 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 50d090ddad901f..219d7a93b855c7 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1030,21 +1030,21 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
                                               Type resultType) {
   auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
       SymbolTable::lookupSymbolIn(symbolTable, name));
-  if (!func) {
-    OpBuilder b(symbolTable->getRegion(0));
-    func = b.create<LLVM::LLVMFuncOp>(
-        symbolTable->getLoc(), name,
-        LLVM::LLVMFunctionType::get(resultType, paramTypes));
-    func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
-    func.setConvergent(true);
-    func.setNoUnwind(true);
-    func.setWillReturn(true);
-  }
+  if (func)
+    return func;
+
+  OpBuilder b(symbolTable->getRegion(0));
+  func = b.create<LLVM::LLVMFuncOp>(
+      symbolTable->getLoc(), name,
+      LLVM::LLVMFunctionType::get(resultType, paramTypes));
+  func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
+  func.setConvergent(true);
+  func.setNoUnwind(true);
+  func.setWillReturn(true);
   return func;
 }
 
-static LLVM::CallOp createSPIRVBuiltinCall(Location loc,
-                                           ConversionPatternRewriter &rewriter,
+static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &rewriter,
                                            LLVM::LLVMFuncOp func,
                                            ValueRange args) {
   auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
@@ -1063,7 +1063,7 @@ class ControlBarrierPattern
   LogicalResult
   matchAndRewrite(spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    constexpr StringRef funcName = "_Z22__spirv_ControlBarrieriii";
+    constexpr StringLiteral funcName = "_Z22__spirv_ControlBarrieriii";
     Operation *symbolTable =
         controlBarrierOp->getParentWithTrait<OpTrait::SymbolTable>();
 
@@ -1073,7 +1073,7 @@ class ControlBarrierPattern
     LLVM::LLVMFuncOp func =
         lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
 
-    auto loc = controlBarrierOp->getLoc();
+    Location loc = controlBarrierOp->getLoc();
     Value execution = rewriter.create<LLVM::ConstantOp>(
         loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
     Value memory = rewriter.create<LLVM::ConstantOp>(



More information about the Mlir-commits mailing list