[Mlir-commits] [mlir] [mlir][spirv] Add spirv-to-llvm conversion for OpControlBarrier (PR #111864)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 11 08:28:56 PDT 2024
https://github.com/FMarno updated https://github.com/llvm/llvm-project/pull/111864
>From b24f03378cf1e741fa3f31a0f8e63b80c11cc79d Mon Sep 17 00:00:00 2001
From: Finlay Marno <finlay.marno at codeplay.com>
Date: Mon, 7 Oct 2024 11:45:52 +0100
Subject: [PATCH 1/2] [mlir] Add spirv-to-llvm translation for OpControlBarrier
The translation is based on the expected llvm function from the
LLVM/SPIRV translation tool
---
.../mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td | 2 +-
.../mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td | 2 +-
.../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 70 ++++++++++++++++++-
.../SPIRVToLLVM/barrier-ops-to-llvm.mlir | 23 ++++++
4 files changed, 94 insertions(+), 3 deletions(-)
create mode 100644 mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
index 1ebea94fced0a3..14593305490661 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
@@ -54,7 +54,7 @@ def SPIRV_ControlBarrierOp : SPIRV_Op<"ControlBarrier", []> {
#### Example:
```mlir
- spirv.ControlBarrier "Workgroup", "Device", "Acquire|UniformMemory"
+ spirv.ControlBarrier <Workgroup>, <Device>, <Acquire|UniformMemory>
```
}];
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
index 71ecabfb444bd0..022cbbbb6720fb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
@@ -1,4 +1,4 @@
-//===-- SPIRVBarrierOps.td - MLIR SPIR-V Barrier Ops -------*- tablegen -*-===//
+//===-- SPIRVMiscOps.td - MLIR SPIR-V Misc Ops -------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 74c169c9a7e76a..50d090ddad901f 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1024,6 +1024,71 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
}
};
+static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
+ StringRef name,
+ ArrayRef<Type> paramTypes,
+ Type resultType) {
+ auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
+ SymbolTable::lookupSymbolIn(symbolTable, name));
+ if (!func) {
+ OpBuilder b(symbolTable->getRegion(0));
+ func = b.create<LLVM::LLVMFuncOp>(
+ symbolTable->getLoc(), name,
+ LLVM::LLVMFunctionType::get(resultType, paramTypes));
+ func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
+ func.setConvergent(true);
+ func.setNoUnwind(true);
+ func.setWillReturn(true);
+ }
+ return func;
+}
+
+static LLVM::CallOp createSPIRVBuiltinCall(Location loc,
+ ConversionPatternRewriter &rewriter,
+ LLVM::LLVMFuncOp func,
+ ValueRange args) {
+ auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
+ call.setCConv(func.getCConv());
+ call.setConvergentAttr(func.getConvergentAttr());
+ call.setNoUnwindAttr(func.getNoUnwindAttr());
+ call.setWillReturnAttr(func.getWillReturnAttr());
+ return call;
+}
+
+class ControlBarrierPattern
+ : public SPIRVToLLVMConversion<spirv::ControlBarrierOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::ControlBarrierOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ constexpr StringRef funcName = "_Z22__spirv_ControlBarrieriii";
+ Operation *symbolTable =
+ controlBarrierOp->getParentWithTrait<OpTrait::SymbolTable>();
+
+ Type i32 = rewriter.getI32Type();
+
+ Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
+ LLVM::LLVMFuncOp func =
+ lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
+
+ auto loc = controlBarrierOp->getLoc();
+ Value execution = rewriter.create<LLVM::ConstantOp>(
+ loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
+ Value memory = rewriter.create<LLVM::ConstantOp>(
+ loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
+ Value semantics = rewriter.create<LLVM::ConstantOp>(
+ loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
+
+ auto call = createSPIRVBuiltinCall(loc, rewriter, func,
+ {execution, memory, semantics});
+
+ rewriter.replaceOp(controlBarrierOp, call);
+ return success();
+ }
+};
+
/// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
/// should be reachable for conversion to succeed. The structure of the loop in
/// LLVM dialect will be the following:
@@ -1648,7 +1713,10 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
// Return ops
- ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
+ ReturnPattern, ReturnValuePattern,
+
+ // Barrier ops
+ ControlBarrierPattern>(patterns.getContext(), typeConverter);
patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
typeConverter);
diff --git a/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir
new file mode 100644
index 00000000000000..d53afeeea15d10
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.ControlBarrierOp
+//===----------------------------------------------------------------------===//
+
+// CHECK: llvm.func spir_funccc @_Z22__spirv_ControlBarrieriii(i32, i32, i32) attributes {convergent, no_unwind, will_return}
+
+// CHECK-LABEL: @control_barrier
+spirv.func @control_barrier() "None" {
+ // CHECK: [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: [[SEMANTICS:%.*]] = llvm.mlir.constant(768 : i32) : i32
+ // CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
+ spirv.ControlBarrier <Workgroup>, <Workgroup>, <CrossWorkgroupMemory|WorkgroupMemory>
+
+ // CHECK: [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: [[SEMANTICS:%.*]] = llvm.mlir.constant(256 : i32) : i32
+ // CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
+ spirv.ControlBarrier <Workgroup>, <Workgroup>, <WorkgroupMemory>
+ spirv.Return
+}
>From 36b51fe57e76f213d06df8f420822372337ff632 Mon Sep 17 00:00:00 2001
From: Finlay Marno <finlay.marno at codeplay.com>
Date: Fri, 11 Oct 2024 16:28:37 +0100
Subject: [PATCH 2/2] fixup! [mlir] Add spirv-to-llvm translation for
OpControlBarrier
---
.../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 28 +++++++++----------
1 file changed, 14 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 50d090ddad901f..219d7a93b855c7 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1030,21 +1030,21 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
Type resultType) {
auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
SymbolTable::lookupSymbolIn(symbolTable, name));
- if (!func) {
- OpBuilder b(symbolTable->getRegion(0));
- func = b.create<LLVM::LLVMFuncOp>(
- symbolTable->getLoc(), name,
- LLVM::LLVMFunctionType::get(resultType, paramTypes));
- func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
- func.setConvergent(true);
- func.setNoUnwind(true);
- func.setWillReturn(true);
- }
+ if (func)
+ return func;
+
+ OpBuilder b(symbolTable->getRegion(0));
+ func = b.create<LLVM::LLVMFuncOp>(
+ symbolTable->getLoc(), name,
+ LLVM::LLVMFunctionType::get(resultType, paramTypes));
+ func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
+ func.setConvergent(true);
+ func.setNoUnwind(true);
+ func.setWillReturn(true);
return func;
}
-static LLVM::CallOp createSPIRVBuiltinCall(Location loc,
- ConversionPatternRewriter &rewriter,
+static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &rewriter,
LLVM::LLVMFuncOp func,
ValueRange args) {
auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
@@ -1063,7 +1063,7 @@ class ControlBarrierPattern
LogicalResult
matchAndRewrite(spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- constexpr StringRef funcName = "_Z22__spirv_ControlBarrieriii";
+ constexpr StringLiteral funcName = "_Z22__spirv_ControlBarrieriii";
Operation *symbolTable =
controlBarrierOp->getParentWithTrait<OpTrait::SymbolTable>();
@@ -1073,7 +1073,7 @@ class ControlBarrierPattern
LLVM::LLVMFuncOp func =
lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
- auto loc = controlBarrierOp->getLoc();
+ Location loc = controlBarrierOp->getLoc();
Value execution = rewriter.create<LLVM::ConstantOp>(
loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
Value memory = rewriter.create<LLVM::ConstantOp>(
More information about the Mlir-commits
mailing list