[clang] [mlir] [llvm] [clang-tools-extra] [mlir][spirv] Fix spirv dialect to support Specialization constants as GlobalVar initializer (PR #75660)
Lei Zhang via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 5 16:13:58 PST 2024
https://github.com/antiagainst updated https://github.com/llvm/llvm-project/pull/75660
>From fd8c637f2b146ffce657307841f84a4123e351af Mon Sep 17 00:00:00 2001
From: Dimple Prajapati <dimpalben.r.prajapati at intel.com>
Date: Wed, 13 Dec 2023 22:33:23 +0000
Subject: [PATCH 1/6] [mlir][spirv] Fix spirv dialect to support Specialization
constants in GlobalVar initializer
Changes include:
- spirv serialization and deserialization needs handling in cases when GlobalVariableOp
initializer is defined using spirv SpecConstant or SpecConstantComposite op, currently
even though it allows SpecConst, it only looked up in for GlobalVariable Map to find
initializer symbol reference, change is fixing this and extending the support to
SpecConstantComposite as an initializer.
- Adds tests to make sure GlobalVariable can be initialzed using specialized constants.
---
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 5 ++--
.../SPIRV/Deserialization/Deserializer.cpp | 19 ++++++++++-----
.../SPIRV/Serialization/SerializeOps.cpp | 22 ++++++++++++-----
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 15 +++++++++++-
mlir/test/Target/SPIRV/global-variable.mlir | 24 +++++++++++++++++++
5 files changed, 70 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2a1d083308282a..66f1d6b2e12206 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1163,9 +1163,10 @@ LogicalResult spirv::GlobalVariableOp::verify() {
// constants and other variables is supported. They could be normal
// constants in the module scope as well.
if (!initOp ||
- !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
+ !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
return emitOpError("initializer must be result of a "
- "spirv.SpecConstant or spirv.GlobalVariable op");
+ "spirv.SpecConstant or spirv.GlobalVariable or "
+ "spirv.SpecConstantCompositeOp op");
}
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 89e2e7ad52fa7d..ccea690a7c3ded 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -637,14 +637,21 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
// Initializer.
FlatSymbolRefAttr initializer = nullptr;
+
if (wordIndex < operands.size()) {
- auto initializerOp = getGlobalVariable(operands[wordIndex]);
- if (!initializerOp) {
- return emitError(unknownLoc, "unknown <id> ")
- << operands[wordIndex] << "used as initializer";
- }
+ Operation *op = nullptr;
+
+ if((op = getGlobalVariable(operands[wordIndex])))
+ initializer = SymbolRefAttr::get((dyn_cast<spirv::GlobalVariableOp>(op)).getOperation());
+ else if ((op = getSpecConstant(operands[wordIndex])))
+ initializer = SymbolRefAttr::get((dyn_cast<spirv::SpecConstantOp>(op)).getOperation());
+ else if((op = getSpecConstantComposite(operands[wordIndex])))
+ initializer = SymbolRefAttr::get((dyn_cast<spirv::SpecConstantCompositeOp>(op)).getOperation());
+ else
+ return emitError(unknownLoc,
+ "Unknown op used as initializer");
+
wordIndex++;
- initializer = SymbolRefAttr::get(initializerOp.getOperation());
}
if (wordIndex != operands.size()) {
return emitError(unknownLoc,
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 44538c38a41b83..bd1fc7a84fbd6a 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -383,12 +383,22 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
// Encode initialization.
- if (auto initializer = varOp.getInitializer()) {
- auto initializerID = getVariableID(*initializer);
- if (!initializerID) {
- return emitError(varOp.getLoc(),
- "invalid usage of undefined variable as initializer");
- }
+ if (auto initializerName = varOp.getInitializer()) {
+
+ uint32_t initializerID = 0;
+ auto init = varOp->getAttrOfType<FlatSymbolRefAttr>("initializer");
+ Operation *initOp = SymbolTable::lookupNearestSymbolFrom(varOp->getParentOp(), init.getAttr());
+
+ // Check if initializer is GlobalVariable or SpecConstant/SpecConstantComposite
+ if(isa<spirv::GlobalVariableOp>(initOp))
+ initializerID = getVariableID(*initializerName);
+ else
+ initializerID = getSpecConstID(*initializerName);
+
+ if (!initializerID)
+ return emitError(varOp.getLoc(),
+ "invalid usage of undefined variable as initializer");
+
operands.push_back(initializerID);
elidedAttrs.push_back("initializer");
}
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 722e4434aeaf9f..77b605050e1442 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -349,6 +349,19 @@ spirv.SpecConstant @sc = 4.0 : f32
// CHECK: spirv.GlobalVariable @var initializer(@sc)
spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<f32, Private>
+
+// -----
+// Allow SpecConstantComposite as initializer
+ spirv.module Logical GLSL450 {
+ spirv.SpecConstant @sc1 = 1 : i8
+ spirv.SpecConstant @sc2 = 2 : i8
+ spirv.SpecConstant @sc3 = 3 : i8
+ spirv.SpecConstantComposite @scc (@sc1, @sc2, @sc3) : !spirv.array<3 x i8>
+
+ // CHECK: spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Private>
+ spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Private>
+}
+
// -----
spirv.module Logical GLSL450 {
@@ -410,7 +423,7 @@ spirv.module Logical GLSL450 {
// -----
spirv.module Logical GLSL450 {
- // expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.GlobalVariable op}}
+ // expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.GlobalVariable or spirv.SpecConstantCompositeOp op}}
spirv.GlobalVariable @var0 initializer(@var1) : !spirv.ptr<f32, Private>
}
diff --git a/mlir/test/Target/SPIRV/global-variable.mlir b/mlir/test/Target/SPIRV/global-variable.mlir
index 66d0782c205c7d..f22d2a9b3d14d9 100644
--- a/mlir/test/Target/SPIRV/global-variable.mlir
+++ b/mlir/test/Target/SPIRV/global-variable.mlir
@@ -23,6 +23,30 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// -----
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+ // CHECK: spirv.SpecConstant @sc = 1 : i8
+ // CHECK-NEXT: spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<i8, Uniform>
+ spirv.SpecConstant @sc = 1 : i8
+
+ spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<i8, Uniform>
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+ // CHECK: spirv.SpecConstantComposite @scc (@sc0, @sc1, @sc2) : !spirv.array<3 x i8>
+ // CHECK-NEXT: spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Uniform>
+ spirv.SpecConstant @sc0 = 1 : i8
+ spirv.SpecConstant @sc1 = 2 : i8
+ spirv.SpecConstant @sc2 = 3 : i8
+
+ spirv.SpecConstantComposite @scc (@sc0, @sc1, @sc2) : !spirv.array<3 x i8>
+
+ spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Uniform>
+}
+
+// -----
+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.GlobalVariable @globalInvocationID built_in("GlobalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
spirv.func @foo() "None" {
>From 8c4d24bffdb51e27815550346bb580286d44c151 Mon Sep 17 00:00:00 2001
From: Dimple Prajapati <dimpalben.r.prajapati at intel.com>
Date: Fri, 15 Dec 2023 21:52:39 +0000
Subject: [PATCH 2/6] Use original error message in case of unknown initializer
---
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index ccea690a7c3ded..a66344d6008843 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -648,8 +648,8 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
else if((op = getSpecConstantComposite(operands[wordIndex])))
initializer = SymbolRefAttr::get((dyn_cast<spirv::SpecConstantCompositeOp>(op)).getOperation());
else
- return emitError(unknownLoc,
- "Unknown op used as initializer");
+ return emitError(unknownLoc, "unknown <id> ")
+ << operands[wordIndex] << "used as initializer";
wordIndex++;
}
>From f154e41d80f3d3b4894fe1f1b90c687c5c3d91cc Mon Sep 17 00:00:00 2001
From: Dimple Prajapati <dimpalben.r.prajapati at intel.com>
Date: Fri, 15 Dec 2023 22:13:27 +0000
Subject: [PATCH 3/6] fix coding standard errors
---
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 4 ++--
.../SPIRV/Deserialization/Deserializer.cpp | 17 ++++++++++-------
.../Target/SPIRV/Serialization/SerializeOps.cpp | 14 ++++++++------
3 files changed, 20 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 66f1d6b2e12206..5343a12132a912 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1162,8 +1162,8 @@ LogicalResult spirv::GlobalVariableOp::verify() {
// TODO: Currently only variable initialization with specialization
// constants and other variables is supported. They could be normal
// constants in the module scope as well.
- if (!initOp ||
- !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
+ if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
+ spirv::SpecConstantCompositeOp>(initOp)) {
return emitOpError("initializer must be result of a "
"spirv.SpecConstant or spirv.GlobalVariable or "
"spirv.SpecConstantCompositeOp op");
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index a66344d6008843..98d867facae8c9 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -637,16 +637,19 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
// Initializer.
FlatSymbolRefAttr initializer = nullptr;
-
+
if (wordIndex < operands.size()) {
Operation *op = nullptr;
- if((op = getGlobalVariable(operands[wordIndex])))
- initializer = SymbolRefAttr::get((dyn_cast<spirv::GlobalVariableOp>(op)).getOperation());
- else if ((op = getSpecConstant(operands[wordIndex])))
- initializer = SymbolRefAttr::get((dyn_cast<spirv::SpecConstantOp>(op)).getOperation());
- else if((op = getSpecConstantComposite(operands[wordIndex])))
- initializer = SymbolRefAttr::get((dyn_cast<spirv::SpecConstantCompositeOp>(op)).getOperation());
+ if ((op = getGlobalVariable(operands[wordIndex])))
+ initializer = SymbolRefAttr::get(
+ (dyn_cast<spirv::GlobalVariableOp>(op)).getOperation());
+ else if ((op = getSpecConstant(operands[wordIndex])))
+ initializer = SymbolRefAttr::get(
+ (dyn_cast<spirv::SpecConstantOp>(op)).getOperation());
+ else if ((op = getSpecConstantComposite(operands[wordIndex])))
+ initializer = SymbolRefAttr::get(
+ (dyn_cast<spirv::SpecConstantCompositeOp>(op)).getOperation());
else
return emitError(unknownLoc, "unknown <id> ")
<< operands[wordIndex] << "used as initializer";
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index bd1fc7a84fbd6a..cbede7d8a39b29 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -387,18 +387,20 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
uint32_t initializerID = 0;
auto init = varOp->getAttrOfType<FlatSymbolRefAttr>("initializer");
- Operation *initOp = SymbolTable::lookupNearestSymbolFrom(varOp->getParentOp(), init.getAttr());
+ Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
+ varOp->getParentOp(), init.getAttr());
- // Check if initializer is GlobalVariable or SpecConstant/SpecConstantComposite
- if(isa<spirv::GlobalVariableOp>(initOp))
+ // Check if initializer is GlobalVariable or
+ // SpecConstant/SpecConstantComposite
+ if (isa<spirv::GlobalVariableOp>(initOp))
initializerID = getVariableID(*initializerName);
else
initializerID = getSpecConstID(*initializerName);
if (!initializerID)
- return emitError(varOp.getLoc(),
- "invalid usage of undefined variable as initializer");
-
+ return emitError(varOp.getLoc(),
+ "invalid usage of undefined variable as initializer");
+
operands.push_back(initializerID);
elidedAttrs.push_back("initializer");
}
>From b70eb93658b50da9bd1a72f4a6dfe44a7f79c8ab Mon Sep 17 00:00:00 2001
From: Dimple Prajapati <dimpalben.r.prajapati at intel.com>
Date: Mon, 18 Dec 2023 18:45:33 +0000
Subject: [PATCH 4/6] change casting
---
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 98d867facae8c9..eb9a5f2b5cd3f0 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -643,13 +643,13 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
if ((op = getGlobalVariable(operands[wordIndex])))
initializer = SymbolRefAttr::get(
- (dyn_cast<spirv::GlobalVariableOp>(op)).getOperation());
+ (cast<spirv::GlobalVariableOp>(op)).getOperation());
else if ((op = getSpecConstant(operands[wordIndex])))
initializer = SymbolRefAttr::get(
- (dyn_cast<spirv::SpecConstantOp>(op)).getOperation());
+ (cast<spirv::SpecConstantOp>(op)).getOperation());
else if ((op = getSpecConstantComposite(operands[wordIndex])))
initializer = SymbolRefAttr::get(
- (dyn_cast<spirv::SpecConstantCompositeOp>(op)).getOperation());
+ (cast<spirv::SpecConstantCompositeOp>(op)).getOperation());
else
return emitError(unknownLoc, "unknown <id> ")
<< operands[wordIndex] << "used as initializer";
>From a3eb78f4d8f6f7f539f17656f7c15dbed7c6428b Mon Sep 17 00:00:00 2001
From: Dimple Prajapati <dimpalben.r.prajapati at intel.com>
Date: Mon, 18 Dec 2023 19:00:50 +0000
Subject: [PATCH 5/6] formatting fix
---
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index eb9a5f2b5cd3f0..2b374d01557452 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -645,8 +645,8 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
initializer = SymbolRefAttr::get(
(cast<spirv::GlobalVariableOp>(op)).getOperation());
else if ((op = getSpecConstant(operands[wordIndex])))
- initializer = SymbolRefAttr::get(
- (cast<spirv::SpecConstantOp>(op)).getOperation());
+ initializer =
+ SymbolRefAttr::get((cast<spirv::SpecConstantOp>(op)).getOperation());
else if ((op = getSpecConstantComposite(operands[wordIndex])))
initializer = SymbolRefAttr::get(
(cast<spirv::SpecConstantCompositeOp>(op)).getOperation());
>From 80c88a6c3527e8519363556cf95d5604962a449c Mon Sep 17 00:00:00 2001
From: Lei Zhang <antiagainst at gmail.com>
Date: Fri, 5 Jan 2024 16:11:05 -0800
Subject: [PATCH 6/6] Light style fix
---
.../SPIRV/Deserialization/Deserializer.cpp | 16 +++++++---------
.../SPIRV/Serialization/SerializeOps.cpp | 19 +++++++++----------
2 files changed, 16 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 2b374d01557452..00645d2c45519e 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -641,19 +641,17 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
if (wordIndex < operands.size()) {
Operation *op = nullptr;
- if ((op = getGlobalVariable(operands[wordIndex])))
- initializer = SymbolRefAttr::get(
- (cast<spirv::GlobalVariableOp>(op)).getOperation());
- else if ((op = getSpecConstant(operands[wordIndex])))
- initializer =
- SymbolRefAttr::get((cast<spirv::SpecConstantOp>(op)).getOperation());
- else if ((op = getSpecConstantComposite(operands[wordIndex])))
- initializer = SymbolRefAttr::get(
- (cast<spirv::SpecConstantCompositeOp>(op)).getOperation());
+ if (auto initOp = getGlobalVariable(operands[wordIndex]))
+ op = initOp;
+ else if (auto initOp = getSpecConstant(operands[wordIndex]))
+ op = initOp;
+ else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
+ op = initOp;
else
return emitError(unknownLoc, "unknown <id> ")
<< operands[wordIndex] << "used as initializer";
+ initializer = SymbolRefAttr::get(op);
wordIndex++;
}
if (wordIndex != operands.size()) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index cbede7d8a39b29..7bfcca5b4dcdca 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -383,32 +383,31 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
// Encode initialization.
- if (auto initializerName = varOp.getInitializer()) {
-
+ StringRef initAttrName = varOp.getInitializerAttrName().getValue();
+ if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
uint32_t initializerID = 0;
- auto init = varOp->getAttrOfType<FlatSymbolRefAttr>("initializer");
+ auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
- varOp->getParentOp(), init.getAttr());
+ varOp->getParentOp(), initRef.getAttr());
- // Check if initializer is GlobalVariable or
- // SpecConstant/SpecConstantComposite
+ // Check if initializer is GlobalVariable or SpecConstant* cases.
if (isa<spirv::GlobalVariableOp>(initOp))
- initializerID = getVariableID(*initializerName);
+ initializerID = getVariableID(*initSymbolName);
else
- initializerID = getSpecConstID(*initializerName);
+ initializerID = getSpecConstID(*initSymbolName);
if (!initializerID)
return emitError(varOp.getLoc(),
"invalid usage of undefined variable as initializer");
operands.push_back(initializerID);
- elidedAttrs.push_back("initializer");
+ elidedAttrs.push_back(initAttrName);
}
if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
return failure();
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
- elidedAttrs.push_back("initializer");
+ elidedAttrs.push_back(initAttrName);
// Encode decorations.
for (auto attr : varOp->getAttrs()) {
More information about the llvm-commits
mailing list