[Mlir-commits] [mlir] 1e4cfe5 - [mlir][SPIRVToLLVM] Propagate location attribute from spv.GlobalVariable to llvm.mlir.global

Weiwei Li llvmlistbot at llvm.org
Mon Oct 4 09:15:22 PDT 2021


Author: Weiwei Li
Date: 2021-10-05T00:09:09+08:00
New Revision: 1e4cfe5e4f21a0667cab4baecb65e3ea791ee695

URL: https://github.com/llvm/llvm-project/commit/1e4cfe5e4f21a0667cab4baecb65e3ea791ee695
DIFF: https://github.com/llvm/llvm-project/commit/1e4cfe5e4f21a0667cab4baecb65e3ea791ee695.diff

LOG: [mlir][SPIRVToLLVM] Propagate location attribute from spv.GlobalVariable to llvm.mlir.global

This patch is mainly to propogate location attribute from spv.GlobalVariable to llvm.mlir.global.

It also contains three small changes.

1. Remove the restriction on UniformConstant In SPIRVToLLVM.cpp;
2. Remove the errorCheck on relaxedPrecision when deserializering SPIR-V in Deserializer.cpp
3. In SPIRVOps.cpp, let ConstantOp take signedInteger too.

Co-authered: Alan Liu <alanliu.yf at gmail.com> and Xinyi Liu <xyliuhelen at gmail.com>

Reviewed by:antiagainst

Differential revision: https://reviews.llvm.org/D110207

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
    mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
    mlir/test/Target/SPIRV/decorations.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 23d981a589010..a057bf784693e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -379,12 +379,25 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
   let arguments = (ins
     TypeAttr:$type,
     StrAttr:$sym_name,
-    OptionalAttr<FlatSymbolRefAttr>:$initializer
+    OptionalAttr<FlatSymbolRefAttr>:$initializer,
+    OptionalAttr<I32Attr>:$location,
+    OptionalAttr<I32Attr>:$binding,
+    OptionalAttr<I32Attr>:$descriptorSet,
+    OptionalAttr<StrAttr>:$builtin
   );
 
   let results = (outs);
 
   let builders = [
+    OpBuilder<(ins "TypeAttr":$type,
+                   "StringAttr":$sym_name,
+                   CArg<"FlatSymbolRefAttr", "nullptr">:$initializer),
+    [{
+      $_state.addAttribute("type", type);
+      $_state.addAttribute(sym_nameAttrName($_state.name), sym_name);
+      if (initializer)
+        $_state.addAttribute(initializerAttrName($_state.name), initializer);
+    }]>,
     OpBuilder<(ins "TypeAttr":$type, "ArrayRef<NamedAttribute>":$namedAttrs),
     [{
       $_state.addAttribute("type", type);
@@ -393,7 +406,16 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
     OpBuilder<(ins "Type":$type, "StringRef":$name,
       "unsigned":$descriptorSet, "unsigned":$binding)>,
     OpBuilder<(ins "Type":$type, "StringRef":$name,
-      "spirv::BuiltIn":$builtin)>
+      "spirv::BuiltIn":$builtin)>,
+    OpBuilder<(ins "Type":$type,
+                   "StringRef":$sym_name,
+                    CArg<"FlatSymbolRefAttr", "{}">:$initializer),
+    [{
+      $_state.addAttribute("type", TypeAttr::get(type));
+      $_state.addAttribute(sym_nameAttrName($_state.name), $_builder.getStringAttr(sym_name));
+      if (initializer)
+        $_state.addAttribute(initializerAttrName($_state.name), initializer);
+    }]>
   ];
 
   let hasOpcode = 0;

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 348d8ad714f0d..abec151145850 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -733,17 +733,22 @@ class GlobalVariablePattern
     // required by SPIR-V runner.
     // This is okay because multiple invocations are not supported yet.
     auto storageClass = srcType.getStorageClass();
-    if (storageClass != spirv::StorageClass::Input &&
-        storageClass != spirv::StorageClass::Private &&
-        storageClass != spirv::StorageClass::Output &&
-        storageClass != spirv::StorageClass::StorageBuffer) {
+    switch (storageClass) {
+    case spirv::StorageClass::Input:
+    case spirv::StorageClass::Private:
+    case spirv::StorageClass::Output:
+    case spirv::StorageClass::StorageBuffer:
+    case spirv::StorageClass::UniformConstant:
+      break;
+    default:
       return failure();
     }
 
     // LLVM dialect spec: "If the global value is a constant, storing into it is
-    // not allowed.". This corresponds to SPIR-V 'Input' storage class that is
-    // read-only.
-    bool isConstant = storageClass == spirv::StorageClass::Input;
+    // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
+    // storage class that is read-only.
+    bool isConstant = (storageClass == spirv::StorageClass::Input) ||
+                      (storageClass == spirv::StorageClass::UniformConstant);
     // SPIR-V spec: "By default, functions and global variables are private to a
     // module and cannot be accessed by other modules. However, a module may be
     // written to export or import functions and global (module scope)
@@ -752,9 +757,14 @@ class GlobalVariablePattern
     auto linkage = storageClass == spirv::StorageClass::Private
                        ? LLVM::Linkage::Private
                        : LLVM::Linkage::External;
-    rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
+    auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
         op, dstType, isConstant, linkage, op.sym_name(), Attribute(),
         /*alignment=*/0);
+
+    // Attach location attribute if applicable
+    if (op.locationAttr())
+      newGlobalOp->setAttr(op.locationAttrName(), op.locationAttr());
+
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 898d0637c6225..f711c183f122a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -92,7 +92,12 @@ static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
   if (!integerValueAttr) {
     return failure();
   }
-  value = integerValueAttr.getInt();
+
+  if (integerValueAttr.getType().isSignlessInteger())
+    value = integerValueAttr.getInt();
+  else
+    value = integerValueAttr.getSInt();
+
   return success();
 }
 
@@ -2066,8 +2071,7 @@ Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
                                     Type type, StringRef name,
                                     unsigned descriptorSet, unsigned binding) {
-  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name),
-        nullptr);
+  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
   state.addAttribute(
       spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
       builder.getI32IntegerAttr(descriptorSet));
@@ -2079,8 +2083,7 @@ void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
                                     Type type, StringRef name,
                                     spirv::BuiltIn builtin) {
-  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name),
-        nullptr);
+  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
   state.addAttribute(
       spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
       builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 1fbfbf09c066e..b74b4873442bd 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -262,6 +262,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   case spirv::Decoration::NonWritable:
   case spirv::Decoration::NoPerspective:
   case spirv::Decoration::Restrict:
+  case spirv::Decoration::RelaxedPrecision:
     if (words.size() != 2) {
       return emitError(unknownLoc, "OpDecoration with ")
              << decorationName << "needs a single target <id>";

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 7d83f6a39254e..24eea6c317116 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -241,6 +241,7 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
   case spirv::Decoration::NonWritable:
   case spirv::Decoration::NoPerspective:
   case spirv::Decoration::Restrict:
+  case spirv::Decoration::RelaxedPrecision:
     // For unit attributes, the args list has no values so we do nothing
     if (auto unitAttr = attr.second.dyn_cast<UnitAttr>())
       break;

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
index e2962dfd57a6a..effc9befb2889 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
@@ -67,6 +67,26 @@ spv.module @name Logical GLSL450 {
   }
 }
 
+spv.module Logical GLSL450 {
+  // CHECK: llvm.mlir.global external @bar() {location = 1 : i32} : i32
+  // CHECK-LABEL: @foo
+  spv.GlobalVariable @bar {location = 1 : i32} : !spv.ptr<i32, Output>
+  spv.func @foo() "None" {
+    %0 = spv.mlir.addressof @bar : !spv.ptr<i32, Output>
+    spv.Return
+  }
+}
+
+spv.module Logical GLSL450 {
+  // CHECK: llvm.mlir.global external constant @bar() {location = 3 : i32} : f32
+  // CHECK-LABEL: @foo
+  spv.GlobalVariable @bar {descriptor_set = 0 : i32, location = 3 : i32} : !spv.ptr<f32, UniformConstant>
+  spv.func @foo() "None" {
+    %0 = spv.mlir.addressof @bar : !spv.ptr<f32, UniformConstant>
+    spv.Return
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // spv.Load
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index 689d527c13c26..02109efaafd3e 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -49,3 +49,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   spv.GlobalVariable @var bind(0, 0) {restrict} : !spv.ptr<!spv.struct<(!spv.array<4xf32, stride=4>[0])>, StorageBuffer>
 }
 
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+  // CHECK: relaxed_precision
+  spv.GlobalVariable @var {location = 0 : i32, relaxed_precision} : !spv.ptr<vector<4xf32>, Output>
+}
+


        


More information about the Mlir-commits mailing list