[Mlir-commits] [mlir] ead0a97 - [mlir][spirv] Replace hardcoded strings with op methods (#81443)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Feb 17 22:21:39 PST 2024


Author: SahilPatidar
Date: 2024-02-17T22:21:35-08:00
New Revision: ead0a9777f8ccb5c26d50d96bade6cd5b47f496b

URL: https://github.com/llvm/llvm-project/commit/ead0a9777f8ccb5c26d50d96bade6cd5b47f496b
DIFF: https://github.com/llvm/llvm-project/commit/ead0a9777f8ccb5c26d50d96bade6cd5b47f496b.diff

LOG: [mlir][spirv] Replace hardcoded strings with op methods (#81443)

Progress towards #77627

---------

Co-authored-by: SahilPatidar <patidarsahil at 2001gmail.com>
Co-authored-by: Lei Zhang <antiagainst at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index a678124bf48322..5b2903824c9e76 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -489,7 +489,8 @@ Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
     auto attrValue = words[wordIndex++];
     auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
         static_cast<spirv::MemoryAccess>(attrValue));
-    attributes.push_back(opBuilder.getNamedAttr("memory_access", attr));
+    attributes.push_back(
+        opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr));
     isAlignedAttr = (attrValue == 2);
   }
 

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 02d03b3a0faeee..83ef01b4e3a467 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -216,10 +216,11 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
     return emitError(unknownLoc, "OpMemoryModel must have two operands");
 
   (*module)->setAttr(
-      "addressing_model",
+      module->getAddressingModelAttrName(),
       opBuilder.getAttr<spirv::AddressingModelAttr>(
           static_cast<spirv::AddressingModel>(operands.front())));
-  (*module)->setAttr("memory_model",
+
+  (*module)->setAttr(module->getMemoryModelAttrName(),
                      opBuilder.getAttr<spirv::MemoryModelAttr>(
                          static_cast<spirv::MemoryModel>(operands.back())));
 

diff  --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index e68ed5efaca746..c283e64fa185a0 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -709,33 +709,37 @@ Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
     operands.push_back(id);
   }
 
-  if (auto attr = op->getAttr("memory_access")) {
+  StringAttr memoryAccess = op.getMemoryAccessAttrName();
+  if (auto attr = op->getAttr(memoryAccess)) {
     operands.push_back(
         static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
   }
 
-  elidedAttrs.push_back("memory_access");
+  elidedAttrs.push_back(memoryAccess.strref());
 
-  if (auto attr = op->getAttr("alignment")) {
+  StringAttr alignment = op.getAlignmentAttrName();
+  if (auto attr = op->getAttr(alignment)) {
     operands.push_back(static_cast<uint32_t>(
         cast<IntegerAttr>(attr).getValue().getZExtValue()));
   }
 
-  elidedAttrs.push_back("alignment");
+  elidedAttrs.push_back(alignment.strref());
 
-  if (auto attr = op->getAttr("source_memory_access")) {
+  StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
+  if (auto attr = op->getAttr(sourceMemoryAccess)) {
     operands.push_back(
         static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
   }
 
-  elidedAttrs.push_back("source_memory_access");
+  elidedAttrs.push_back(sourceMemoryAccess.strref());
 
-  if (auto attr = op->getAttr("source_alignment")) {
+  StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
+  if (auto attr = op->getAttr(sourceAlignment)) {
     operands.push_back(static_cast<uint32_t>(
         cast<IntegerAttr>(attr).getValue().getZExtValue()));
   }
 
-  elidedAttrs.push_back("source_alignment");
+  elidedAttrs.push_back(sourceAlignment.strref());
   if (failed(emitDebugLine(functionBody, op.getLoc())))
     return failure();
   encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 40337e007bbf74..4a4e878d8af915 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -197,10 +197,14 @@ void Serializer::processExtension() {
 }
 
 void Serializer::processMemoryModel() {
+  StringAttr memoryModelName = module.getMemoryModelAttrName();
   auto mm = static_cast<uint32_t>(
-      module->getAttrOfType<spirv::MemoryModelAttr>("memory_model").getValue());
+      module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
+          .getValue());
+
+  StringAttr addressingModelName = module.getAddressingModelAttrName();
   auto am = static_cast<uint32_t>(
-      module->getAttrOfType<spirv::AddressingModelAttr>("addressing_model")
+      module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
           .getValue());
 
   encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});


        


More information about the Mlir-commits mailing list