[Mlir-commits] [mlir] [mlir][spirv] Replace hardcoded attribute strings with op methods Resolve #77627 (PR #81443)
Lei Zhang
llvmlistbot at llvm.org
Sat Feb 17 22:16:19 PST 2024
https://github.com/antiagainst updated https://github.com/llvm/llvm-project/pull/81443
>From 733dbe7175fc225953dd31edb9b883c8b20d0c14 Mon Sep 17 00:00:00 2001
From: SahilPatidar <patidarsahil at 2001gmail.com>
Date: Mon, 12 Feb 2024 11:44:07 +0530
Subject: [PATCH 1/3] [mlir][spirv] Replace hardcoded attribute strings with op
methods Resolve #77627
---
.../SPIRV/Deserialization/Deserializer.cpp | 7 +++++--
.../SPIRV/Serialization/SerializeOps.cpp | 20 +++++++++++--------
.../Target/SPIRV/Serialization/Serializer.cpp | 7 +++++--
3 files changed, 22 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 02d03b3a0faeee..ee2f8c2fb37308 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -215,11 +215,14 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
if (operands.size() != 2)
return emitError(unknownLoc, "OpMemoryModel must have two operands");
+ StringRef addressing_model = module->getAddressingModelAttrName().strref();
(*module)->setAttr(
- "addressing_model",
+ addressing_model,
opBuilder.getAttr<spirv::AddressingModelAttr>(
static_cast<spirv::AddressingModel>(operands.front())));
- (*module)->setAttr("memory_model",
+
+ StringRef memory_model = module->getMemoryModelAttrName().strref();
+ (*module)->setAttr(memory_model,
opBuilder.getAttr<spirv::MemoryModelAttr>(
static_cast<spirv::MemoryModel>(operands.back())));
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 41d2c0310d0008..7e6f2f0c5bff27 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -708,33 +708,37 @@ Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
operands.push_back(id);
}
- if (auto attr = op->getAttr("memory_access")) {
+ StringRef memory_access = op.getMemoryAccessAttrName().strref();
+ if (auto attr = op->getAttr(memory_access)) {
operands.push_back(
static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
}
- elidedAttrs.push_back("memory_access");
+ elidedAttrs.push_back(memory_access);
- if (auto attr = op->getAttr("alignment")) {
+ StringRef alignment = op.getAlignmentAttrName().strref();
+ if (auto attr = op->getAttr(alignment)) {
operands.push_back(static_cast<uint32_t>(
cast<IntegerAttr>(attr).getValue().getZExtValue()));
}
- elidedAttrs.push_back("alignment");
+ elidedAttrs.push_back(alignment);
- if (auto attr = op->getAttr("source_memory_access")) {
+ StringRef source_memory_access = op.getSourceMemoryAccessAttrName().strref();
+ if (auto attr = op->getAttr(source_memory_access)) {
operands.push_back(
static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
}
- elidedAttrs.push_back("source_memory_access");
+ elidedAttrs.push_back(source_memory_access);
- if (auto attr = op->getAttr("source_alignment")) {
+ StringRef source_alignment = op.getSourceAlignmentAttrName().strref();
+ if (auto attr = op->getAttr(source_alignment)) {
operands.push_back(static_cast<uint32_t>(
cast<IntegerAttr>(attr).getValue().getZExtValue()));
}
- elidedAttrs.push_back("source_alignment");
+ elidedAttrs.push_back(source_alignment);
if (failed(emitDebugLine(functionBody, op.getLoc())))
return failure();
encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 40337e007bbf74..51fab95ad3a847 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -197,10 +197,13 @@ void Serializer::processExtension() {
}
void Serializer::processMemoryModel() {
+ StringRef memory_model = module.getMemoryModelAttrName().strref();
auto mm = static_cast<uint32_t>(
- module->getAttrOfType<spirv::MemoryModelAttr>("memory_model").getValue());
+ module->getAttrOfType<spirv::MemoryModelAttr>(memory_model).getValue());
+
+ StringRef addressing_model = module.getAddressingModelAttrName().strref();
auto am = static_cast<uint32_t>(
- module->getAttrOfType<spirv::AddressingModelAttr>("addressing_model")
+ module->getAttrOfType<spirv::AddressingModelAttr>(addressing_model)
.getValue());
encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
>From 77dd0c85e350abcdcf9e768e59c87bfd066ecba0 Mon Sep 17 00:00:00 2001
From: SahilPatidar <patidarsahil at 2001gmail.com>
Date: Wed, 14 Feb 2024 14:36:03 +0530
Subject: [PATCH 2/3] replace strref to StringAttr
---
.../SPIRV/Deserialization/DeserializeOps.cpp | 3 ++-
.../SPIRV/Deserialization/Deserializer.cpp | 4 ++--
.../Target/SPIRV/Serialization/SerializeOps.cpp | 16 ++++++++--------
.../Target/SPIRV/Serialization/Serializer.cpp | 4 ++--
4 files changed, 14 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index a678124bf48322..5b2903824c9e76 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -489,7 +489,8 @@ Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
auto attrValue = words[wordIndex++];
auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
static_cast<spirv::MemoryAccess>(attrValue));
- attributes.push_back(opBuilder.getNamedAttr("memory_access", attr));
+ attributes.push_back(
+ opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr));
isAlignedAttr = (attrValue == 2);
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index ee2f8c2fb37308..defc17ac4ef600 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -215,13 +215,13 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
if (operands.size() != 2)
return emitError(unknownLoc, "OpMemoryModel must have two operands");
- StringRef addressing_model = module->getAddressingModelAttrName().strref();
+ StringAttr addressing_model = module->getAddressingModelAttrName();
(*module)->setAttr(
addressing_model,
opBuilder.getAttr<spirv::AddressingModelAttr>(
static_cast<spirv::AddressingModel>(operands.front())));
- StringRef memory_model = module->getMemoryModelAttrName().strref();
+ StringAttr memory_model = module->getMemoryModelAttrName();
(*module)->setAttr(memory_model,
opBuilder.getAttr<spirv::MemoryModelAttr>(
static_cast<spirv::MemoryModel>(operands.back())));
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 7e6f2f0c5bff27..08d55845cb290e 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -708,37 +708,37 @@ Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
operands.push_back(id);
}
- StringRef memory_access = op.getMemoryAccessAttrName().strref();
+ StringAttr memory_access = op.getMemoryAccessAttrName();
if (auto attr = op->getAttr(memory_access)) {
operands.push_back(
static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
}
- elidedAttrs.push_back(memory_access);
+ elidedAttrs.push_back(memory_access.strref());
- StringRef alignment = op.getAlignmentAttrName().strref();
+ StringAttr alignment = op.getAlignmentAttrName();
if (auto attr = op->getAttr(alignment)) {
operands.push_back(static_cast<uint32_t>(
cast<IntegerAttr>(attr).getValue().getZExtValue()));
}
- elidedAttrs.push_back(alignment);
+ elidedAttrs.push_back(alignment.strref());
- StringRef source_memory_access = op.getSourceMemoryAccessAttrName().strref();
+ StringAttr source_memory_access = op.getSourceMemoryAccessAttrName();
if (auto attr = op->getAttr(source_memory_access)) {
operands.push_back(
static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
}
- elidedAttrs.push_back(source_memory_access);
+ elidedAttrs.push_back(source_memory_access.strref());
- StringRef source_alignment = op.getSourceAlignmentAttrName().strref();
+ StringAttr source_alignment = op.getSourceAlignmentAttrName();
if (auto attr = op->getAttr(source_alignment)) {
operands.push_back(static_cast<uint32_t>(
cast<IntegerAttr>(attr).getValue().getZExtValue()));
}
- elidedAttrs.push_back(source_alignment);
+ elidedAttrs.push_back(source_alignment.strref());
if (failed(emitDebugLine(functionBody, op.getLoc())))
return failure();
encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 51fab95ad3a847..2508d6628658af 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -197,11 +197,11 @@ void Serializer::processExtension() {
}
void Serializer::processMemoryModel() {
- StringRef memory_model = module.getMemoryModelAttrName().strref();
+ StringAttr memory_model = module.getMemoryModelAttrName();
auto mm = static_cast<uint32_t>(
module->getAttrOfType<spirv::MemoryModelAttr>(memory_model).getValue());
- StringRef addressing_model = module.getAddressingModelAttrName().strref();
+ StringAttr addressing_model = module.getAddressingModelAttrName();
auto am = static_cast<uint32_t>(
module->getAttrOfType<spirv::AddressingModelAttr>(addressing_model)
.getValue());
>From ebf81e6c06bb5659aa3c530e92246033ef007963 Mon Sep 17 00:00:00 2001
From: Lei Zhang <antiagainst at gmail.com>
Date: Sat, 17 Feb 2024 21:56:17 -0800
Subject: [PATCH 3/3] Fix naming style
---
.../SPIRV/Deserialization/Deserializer.cpp | 6 ++----
.../SPIRV/Serialization/SerializeOps.cpp | 18 +++++++++---------
.../Target/SPIRV/Serialization/Serializer.cpp | 9 +++++----
3 files changed, 16 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index defc17ac4ef600..83ef01b4e3a467 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -215,14 +215,12 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
if (operands.size() != 2)
return emitError(unknownLoc, "OpMemoryModel must have two operands");
- StringAttr addressing_model = module->getAddressingModelAttrName();
(*module)->setAttr(
- addressing_model,
+ module->getAddressingModelAttrName(),
opBuilder.getAttr<spirv::AddressingModelAttr>(
static_cast<spirv::AddressingModel>(operands.front())));
- StringAttr memory_model = module->getMemoryModelAttrName();
- (*module)->setAttr(memory_model,
+ (*module)->setAttr(module->getMemoryModelAttrName(),
opBuilder.getAttr<spirv::MemoryModelAttr>(
static_cast<spirv::MemoryModel>(operands.back())));
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 0c2366c8ab12b7..c283e64fa185a0 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -709,13 +709,13 @@ Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
operands.push_back(id);
}
- StringAttr memory_access = op.getMemoryAccessAttrName();
- if (auto attr = op->getAttr(memory_access)) {
+ StringAttr memoryAccess = op.getMemoryAccessAttrName();
+ if (auto attr = op->getAttr(memoryAccess)) {
operands.push_back(
static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
}
- elidedAttrs.push_back(memory_access.strref());
+ elidedAttrs.push_back(memoryAccess.strref());
StringAttr alignment = op.getAlignmentAttrName();
if (auto attr = op->getAttr(alignment)) {
@@ -725,21 +725,21 @@ Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
elidedAttrs.push_back(alignment.strref());
- StringAttr source_memory_access = op.getSourceMemoryAccessAttrName();
- if (auto attr = op->getAttr(source_memory_access)) {
+ StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
+ if (auto attr = op->getAttr(sourceMemoryAccess)) {
operands.push_back(
static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
}
- elidedAttrs.push_back(source_memory_access.strref());
+ elidedAttrs.push_back(sourceMemoryAccess.strref());
- StringAttr source_alignment = op.getSourceAlignmentAttrName();
- if (auto attr = op->getAttr(source_alignment)) {
+ StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
+ if (auto attr = op->getAttr(sourceAlignment)) {
operands.push_back(static_cast<uint32_t>(
cast<IntegerAttr>(attr).getValue().getZExtValue()));
}
- elidedAttrs.push_back(source_alignment.strref());
+ elidedAttrs.push_back(sourceAlignment.strref());
if (failed(emitDebugLine(functionBody, op.getLoc())))
return failure();
encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 2508d6628658af..4a4e878d8af915 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -197,13 +197,14 @@ void Serializer::processExtension() {
}
void Serializer::processMemoryModel() {
- StringAttr memory_model = module.getMemoryModelAttrName();
+ StringAttr memoryModelName = module.getMemoryModelAttrName();
auto mm = static_cast<uint32_t>(
- module->getAttrOfType<spirv::MemoryModelAttr>(memory_model).getValue());
+ module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
+ .getValue());
- StringAttr addressing_model = module.getAddressingModelAttrName();
+ StringAttr addressingModelName = module.getAddressingModelAttrName();
auto am = static_cast<uint32_t>(
- module->getAttrOfType<spirv::AddressingModelAttr>(addressing_model)
+ module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
.getValue());
encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
More information about the Mlir-commits
mailing list