[Mlir-commits] [mlir] e850558 - [MLIR][SPIRVToLLVM] Added a hook for descriptor set / binding encoding

George Mitenkov llvmlistbot at llvm.org
Wed Aug 26 22:50:30 PDT 2020


Author: George Mitenkov
Date: 2020-08-27T08:27:42+03:00
New Revision: e850558cdc673edc82d13e602d1c819141ce9b3f

URL: https://github.com/llvm/llvm-project/commit/e850558cdc673edc82d13e602d1c819141ce9b3f
DIFF: https://github.com/llvm/llvm-project/commit/e850558cdc673edc82d13e602d1c819141ce9b3f.diff

LOG: [MLIR][SPIRVToLLVM] Added a hook for descriptor set / binding encoding

This patch introduces a hook to encode descriptor set
and binding number into `spv.globalVariable`'s symbolic name. This
allows to preserve this information, and at the same time legalize
the global variable for the conversion to LLVM dialect.

This is required for `mlir-spirv-cpu-runner` to convert kernel
arguments into LLVM.

Also, a couple of some nits added:
- removed unused comment
- changed to a capital letter in the comment

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D86515

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
    mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
    mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
    mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
index 4092083f2607..dd9fb68fa80b 100644
--- a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
+++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h
@@ -32,6 +32,10 @@ class SPIRVToLLVMConversion : public OpConversionPattern<SPIRVOp> {
   LLVMTypeConverter &typeConverter;
 };
 
+/// Encodes global variable's descriptor set and binding into its name if they
+/// both exist.
+void encodeBindAttribute(ModuleOp module);
+
 /// Populates type conversions with additional SPIR-V types.
 void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter);
 

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 9c2ba26274e9..8dcdaab2d7ec 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
 
 #define DEBUG_TYPE "spirv-to-llvm-pattern"
 
@@ -1332,8 +1333,6 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       // TODO: Support EntryPoint/ExecutionMode properly.
       ErasePattern<spirv::EntryPointOp>, ErasePattern<spirv::ExecutionModeOp>,
 
-      // Function Call op
-
       // GLSL extended instruction set ops
       DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
       DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>,
@@ -1386,3 +1385,42 @@ void mlir::populateSPIRVToLLVMModuleConversionPatterns(
   patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
       context, typeConverter);
 }
+
+//===----------------------------------------------------------------------===//
+// Pre-conversion hooks
+//===----------------------------------------------------------------------===//
+
+/// Hook for descriptor set and binding number encoding.
+static constexpr StringRef kBinding = "binding";
+static constexpr StringRef kDescriptorSet = "descriptor_set";
+void mlir::encodeBindAttribute(ModuleOp module) {
+  auto spvModules = module.getOps<spirv::ModuleOp>();
+  for (auto spvModule : spvModules) {
+    spvModule.walk([&](spirv::GlobalVariableOp op) {
+      IntegerAttr descriptorSet = op.getAttrOfType<IntegerAttr>(kDescriptorSet);
+      IntegerAttr binding = op.getAttrOfType<IntegerAttr>(kBinding);
+      // For every global variable in the module, get the ones with descriptor
+      // set and binding numbers.
+      if (descriptorSet && binding) {
+        // Encode these numbers into the variable's symbolic name. If the
+        // SPIR-V module has a name, add it at the beginning.
+        auto moduleAndName = spvModule.getName().hasValue()
+                                 ? spvModule.getName().getValue().str() + "_" +
+                                       op.sym_name().str()
+                                 : op.sym_name().str();
+        std::string name =
+            llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
+                          std::to_string(descriptorSet.getInt()),
+                          std::to_string(binding.getInt()));
+
+        // Replace all symbol uses and set the new symbol name. Finally, remove
+        // descriptor set and binding attributes.
+        if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule)))
+          op.emitError("unable to replace all symbol uses for ") << name;
+        SymbolTable::setSymbolName(op, name);
+        op.removeAttr(kDescriptorSet);
+        op.removeAttr(kBinding);
+      }
+    });
+  }
+}

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
index 73d64b01df5d..329989b9e795 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
@@ -33,6 +33,9 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
   ModuleOp module = getOperation();
   LLVMTypeConverter converter(&getContext());
 
+  // Encode global variable's descriptor set and binding if they exist.
+  encodeBindAttribute(module);
+
   OwningRewritePatternList patterns;
 
   populateSPIRVToLLVMTypeConversion(converter);
@@ -45,7 +48,7 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
   target.addIllegalDialect<spirv::SPIRVDialect>();
   target.addLegalDialect<LLVM::LLVMDialect>();
 
-  // set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module`
+  // Set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module`
   // conversion.
   target.addLegalOp<ModuleOp>();
   target.addLegalOp<ModuleTerminatorOp>();

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
index 68d59393cfb7..c5e498a06f56 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
@@ -37,7 +37,7 @@ spv.module Logical GLSL450 {
 spv.module Logical GLSL450 {
   //       CHECK: llvm.mlir.global private @struct() : !llvm.struct<packed (float, array<10 x float>)>
   // CHECK-LABEL: @func
-  //       CHECK: llvm.mlir.addressof @struct : !llvm.ptr<struct<packed (float, array<10 x float>)>>
+  //       CHECK:   llvm.mlir.addressof @struct : !llvm.ptr<struct<packed (float, array<10 x float>)>>
   spv.globalVariable @struct : !spv.ptr<!spv.struct<f32, !spv.array<10xf32>>, Private>
   spv.func @func() "None" {
     %0 = spv._address_of @struct : !spv.ptr<!spv.struct<f32, !spv.array<10xf32>>, Private>
@@ -45,6 +45,28 @@ spv.module Logical GLSL450 {
   }
 }
 
+spv.module Logical GLSL450 {
+  //       CHECK: llvm.mlir.global external @bar_descriptor_set0_binding0() : !llvm.i32
+  // CHECK-LABEL: @foo
+  //       CHECK:   llvm.mlir.addressof @bar_descriptor_set0_binding0 : !llvm.ptr<i32>
+  spv.globalVariable @bar bind(0, 0) : !spv.ptr<i32, StorageBuffer>
+  spv.func @foo() "None" {
+    %0 = spv._address_of @bar : !spv.ptr<i32, StorageBuffer>
+    spv.Return
+  }
+}
+
+spv.module @name Logical GLSL450 {
+  //       CHECK: llvm.mlir.global external @name_bar_descriptor_set0_binding0() : !llvm.i32
+  // CHECK-LABEL: @foo
+  //       CHECK:   llvm.mlir.addressof @name_bar_descriptor_set0_binding0 : !llvm.ptr<i32>
+  spv.globalVariable @bar bind(0, 0) : !spv.ptr<i32, StorageBuffer>
+  spv.func @foo() "None" {
+    %0 = spv._address_of @bar : !spv.ptr<i32, StorageBuffer>
+    spv.Return
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // spv.Load
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list