[Mlir-commits] [mlir] e850558 - [MLIR][SPIRVToLLVM] Added a hook for descriptor set / binding encoding
George Mitenkov
llvmlistbot at llvm.org
Wed Aug 26 22:50:30 PDT 2020
Author: George Mitenkov
Date: 2020-08-27T08:27:42+03:00
New Revision: e850558cdc673edc82d13e602d1c819141ce9b3f
URL: https://github.com/llvm/llvm-project/commit/e850558cdc673edc82d13e602d1c819141ce9b3f
DIFF: https://github.com/llvm/llvm-project/commit/e850558cdc673edc82d13e602d1c819141ce9b3f.diff
LOG: [MLIR][SPIRVToLLVM] Added a hook for descriptor set / binding encoding
This patch introduces a hook to encode descriptor set
and binding number into `spv.globalVariable`'s symbolic name. This
allows to preserve this information, and at the same time legalize
the global variable for the conversion to LLVM dialect.
This is required for `mlir-spirv-cpu-runner` to convert kernel
arguments into LLVM.
Also, a couple of some nits added:
- removed unused comment
- changed to a capital letter in the comment
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D86515
Added:
Modified:
mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
index 4092083f2607..dd9fb68fa80b 100644
--- a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
+++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
@@ -32,6 +32,10 @@ class SPIRVToLLVMConversion : public OpConversionPattern<SPIRVOp> {
LLVMTypeConverter &typeConverter;
};
+/// Encodes global variable's descriptor set and binding into its name if they
+/// both exist.
+void encodeBindAttribute(ModuleOp module);
+
/// Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter);
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 9c2ba26274e9..8dcdaab2d7ec 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -23,6 +23,7 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
#define DEBUG_TYPE "spirv-to-llvm-pattern"
@@ -1332,8 +1333,6 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
// TODO: Support EntryPoint/ExecutionMode properly.
ErasePattern<spirv::EntryPointOp>, ErasePattern<spirv::ExecutionModeOp>,
- // Function Call op
-
// GLSL extended instruction set ops
DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>,
@@ -1386,3 +1385,42 @@ void mlir::populateSPIRVToLLVMModuleConversionPatterns(
patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
context, typeConverter);
}
+
+//===----------------------------------------------------------------------===//
+// Pre-conversion hooks
+//===----------------------------------------------------------------------===//
+
+/// Hook for descriptor set and binding number encoding.
+static constexpr StringRef kBinding = "binding";
+static constexpr StringRef kDescriptorSet = "descriptor_set";
+void mlir::encodeBindAttribute(ModuleOp module) {
+ auto spvModules = module.getOps<spirv::ModuleOp>();
+ for (auto spvModule : spvModules) {
+ spvModule.walk([&](spirv::GlobalVariableOp op) {
+ IntegerAttr descriptorSet = op.getAttrOfType<IntegerAttr>(kDescriptorSet);
+ IntegerAttr binding = op.getAttrOfType<IntegerAttr>(kBinding);
+ // For every global variable in the module, get the ones with descriptor
+ // set and binding numbers.
+ if (descriptorSet && binding) {
+ // Encode these numbers into the variable's symbolic name. If the
+ // SPIR-V module has a name, add it at the beginning.
+ auto moduleAndName = spvModule.getName().hasValue()
+ ? spvModule.getName().getValue().str() + "_" +
+ op.sym_name().str()
+ : op.sym_name().str();
+ std::string name =
+ llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
+ std::to_string(descriptorSet.getInt()),
+ std::to_string(binding.getInt()));
+
+ // Replace all symbol uses and set the new symbol name. Finally, remove
+ // descriptor set and binding attributes.
+ if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule)))
+ op.emitError("unable to replace all symbol uses for ") << name;
+ SymbolTable::setSymbolName(op, name);
+ op.removeAttr(kDescriptorSet);
+ op.removeAttr(kBinding);
+ }
+ });
+ }
+}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
index 73d64b01df5d..329989b9e795 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
@@ -33,6 +33,9 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
ModuleOp module = getOperation();
LLVMTypeConverter converter(&getContext());
+ // Encode global variable's descriptor set and binding if they exist.
+ encodeBindAttribute(module);
+
OwningRewritePatternList patterns;
populateSPIRVToLLVMTypeConversion(converter);
@@ -45,7 +48,7 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
target.addIllegalDialect<spirv::SPIRVDialect>();
target.addLegalDialect<LLVM::LLVMDialect>();
- // set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module`
+ // Set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module`
// conversion.
target.addLegalOp<ModuleOp>();
target.addLegalOp<ModuleTerminatorOp>();
diff --git a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
index 68d59393cfb7..c5e498a06f56 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
@@ -37,7 +37,7 @@ spv.module Logical GLSL450 {
spv.module Logical GLSL450 {
// CHECK: llvm.mlir.global private @struct() : !llvm.struct<packed (float, array<10 x float>)>
// CHECK-LABEL: @func
- // CHECK: llvm.mlir.addressof @struct : !llvm.ptr<struct<packed (float, array<10 x float>)>>
+ // CHECK: llvm.mlir.addressof @struct : !llvm.ptr<struct<packed (float, array<10 x float>)>>
spv.globalVariable @struct : !spv.ptr<!spv.struct<f32, !spv.array<10xf32>>, Private>
spv.func @func() "None" {
%0 = spv._address_of @struct : !spv.ptr<!spv.struct<f32, !spv.array<10xf32>>, Private>
@@ -45,6 +45,28 @@ spv.module Logical GLSL450 {
}
}
+spv.module Logical GLSL450 {
+ // CHECK: llvm.mlir.global external @bar_descriptor_set0_binding0() : !llvm.i32
+ // CHECK-LABEL: @foo
+ // CHECK: llvm.mlir.addressof @bar_descriptor_set0_binding0 : !llvm.ptr<i32>
+ spv.globalVariable @bar bind(0, 0) : !spv.ptr<i32, StorageBuffer>
+ spv.func @foo() "None" {
+ %0 = spv._address_of @bar : !spv.ptr<i32, StorageBuffer>
+ spv.Return
+ }
+}
+
+spv.module @name Logical GLSL450 {
+ // CHECK: llvm.mlir.global external @name_bar_descriptor_set0_binding0() : !llvm.i32
+ // CHECK-LABEL: @foo
+ // CHECK: llvm.mlir.addressof @name_bar_descriptor_set0_binding0 : !llvm.ptr<i32>
+ spv.globalVariable @bar bind(0, 0) : !spv.ptr<i32, StorageBuffer>
+ spv.func @foo() "None" {
+ %0 = spv._address_of @bar : !spv.ptr<i32, StorageBuffer>
+ spv.Return
+ }
+}
+
//===----------------------------------------------------------------------===//
// spv.Load
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list