[Mlir-commits] [mlir] 5e54319 - [mlir][spirv] Support spec constants as GlobalVar initializer (#75660)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 5 16:27:34 PST 2024


Author: Dimple Prajapati
Date: 2024-01-05T16:27:30-08:00
New Revision: 5e54319b7be3e8aa035836098e0a9defc0a41c3a

URL: https://github.com/llvm/llvm-project/commit/5e54319b7be3e8aa035836098e0a9defc0a41c3a
DIFF: https://github.com/llvm/llvm-project/commit/5e54319b7be3e8aa035836098e0a9defc0a41c3a.diff

LOG: [mlir][spirv] Support spec constants as GlobalVar initializer (#75660)

Changes include:

- spirv serialization and deserialization needs handling in cases when
GlobalVariableOp initializer is defined using spirv SpecConstant or
SpecConstantComposite op, currently even though it allows SpecConst, it
only looked up in for GlobalVariable Map to find initializer symbol
reference, change is fixing this and extending the support to
SpecConstantComposite as an initializer.
- Adds tests to make sure GlobalVariable can be initialized using
specialized constants.

---------

Co-authored-by: Lei Zhang <antiagainst at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
    mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
    mlir/test/Target/SPIRV/global-variable.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2a1d083308282a..5343a12132a912 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1162,10 +1162,11 @@ LogicalResult spirv::GlobalVariableOp::verify() {
     // TODO: Currently only variable initialization with specialization
     // constants and other variables is supported. They could be normal
     // constants in the module scope as well.
-    if (!initOp ||
-        !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
+    if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
+                        spirv::SpecConstantCompositeOp>(initOp)) {
       return emitOpError("initializer must be result of a "
-                         "spirv.SpecConstant or spirv.GlobalVariable op");
+                         "spirv.SpecConstant or spirv.GlobalVariable or "
+                         "spirv.SpecConstantCompositeOp op");
     }
   }
 

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 89e2e7ad52fa7d..00645d2c45519e 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -637,14 +637,22 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
 
   // Initializer.
   FlatSymbolRefAttr initializer = nullptr;
+
   if (wordIndex < operands.size()) {
-    auto initializerOp = getGlobalVariable(operands[wordIndex]);
-    if (!initializerOp) {
+    Operation *op = nullptr;
+
+    if (auto initOp = getGlobalVariable(operands[wordIndex]))
+      op = initOp;
+    else if (auto initOp = getSpecConstant(operands[wordIndex]))
+      op = initOp;
+    else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
+      op = initOp;
+    else
       return emitError(unknownLoc, "unknown <id> ")
              << operands[wordIndex] << "used as initializer";
-    }
+
+    initializer = SymbolRefAttr::get(op);
     wordIndex++;
-    initializer = SymbolRefAttr::get(initializerOp.getOperation());
   }
   if (wordIndex != operands.size()) {
     return emitError(unknownLoc,

diff  --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 44538c38a41b83..7bfcca5b4dcdca 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -383,20 +383,31 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
   operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
 
   // Encode initialization.
-  if (auto initializer = varOp.getInitializer()) {
-    auto initializerID = getVariableID(*initializer);
-    if (!initializerID) {
+  StringRef initAttrName = varOp.getInitializerAttrName().getValue();
+  if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
+    uint32_t initializerID = 0;
+    auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
+    Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
+        varOp->getParentOp(), initRef.getAttr());
+
+    // Check if initializer is GlobalVariable or SpecConstant* cases.
+    if (isa<spirv::GlobalVariableOp>(initOp))
+      initializerID = getVariableID(*initSymbolName);
+    else
+      initializerID = getSpecConstID(*initSymbolName);
+
+    if (!initializerID)
       return emitError(varOp.getLoc(),
                        "invalid usage of undefined variable as initializer");
-    }
+
     operands.push_back(initializerID);
-    elidedAttrs.push_back("initializer");
+    elidedAttrs.push_back(initAttrName);
   }
 
   if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
     return failure();
   encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
-  elidedAttrs.push_back("initializer");
+  elidedAttrs.push_back(initAttrName);
 
   // Encode decorations.
   for (auto attr : varOp->getAttrs()) {

diff  --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 722e4434aeaf9f..77b605050e1442 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -349,6 +349,19 @@ spirv.SpecConstant @sc = 4.0 : f32
 // CHECK: spirv.GlobalVariable @var initializer(@sc)
 spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<f32, Private>
 
+
+// -----
+// Allow SpecConstantComposite as initializer
+  spirv.module Logical GLSL450 {
+  spirv.SpecConstant @sc1 = 1 : i8
+  spirv.SpecConstant @sc2 = 2 : i8
+  spirv.SpecConstant @sc3 = 3 : i8
+  spirv.SpecConstantComposite @scc (@sc1, @sc2, @sc3) : !spirv.array<3 x i8>
+
+  // CHECK: spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Private>
+  spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Private>
+}
+
 // -----
 
 spirv.module Logical GLSL450 {
@@ -410,7 +423,7 @@ spirv.module Logical GLSL450 {
 // -----
 
 spirv.module Logical GLSL450 {
-  // expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.GlobalVariable op}}
+  // expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.GlobalVariable or spirv.SpecConstantCompositeOp op}}
   spirv.GlobalVariable @var0 initializer(@var1) : !spirv.ptr<f32, Private>
 }
 

diff  --git a/mlir/test/Target/SPIRV/global-variable.mlir b/mlir/test/Target/SPIRV/global-variable.mlir
index 66d0782c205c7d..f22d2a9b3d14d9 100644
--- a/mlir/test/Target/SPIRV/global-variable.mlir
+++ b/mlir/test/Target/SPIRV/global-variable.mlir
@@ -23,6 +23,30 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
 
 // -----
 
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+  // CHECK:         spirv.SpecConstant @sc = 1 : i8
+  // CHECK-NEXT:    spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<i8, Uniform>
+  spirv.SpecConstant @sc = 1 : i8
+
+  spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<i8, Uniform>
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+  // CHECK:         spirv.SpecConstantComposite @scc (@sc0, @sc1, @sc2) : !spirv.array<3 x i8>
+  // CHECK-NEXT:    spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Uniform>
+  spirv.SpecConstant @sc0 = 1 : i8
+  spirv.SpecConstant @sc1 = 2 : i8
+  spirv.SpecConstant @sc2 = 3 : i8
+
+  spirv.SpecConstantComposite @scc (@sc0, @sc1, @sc2) : !spirv.array<3 x i8>
+
+  spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Uniform>
+}
+
+// -----
+
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   spirv.GlobalVariable @globalInvocationID built_in("GlobalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
   spirv.func @foo() "None" {


        


More information about the Mlir-commits mailing list