[Mlir-commits] [mlir] 222d7fc - [mlir][spirv] Avoid duplicated Block decoration during serialization

Lei Zhang llvmlistbot at llvm.org
Fri Dec 10 16:21:13 PST 2021


Author: Lei Zhang
Date: 2021-12-10T19:20:49-05:00
New Revision: 222d7fc7f81936edf488bc85acf3ea5d5300bca3

URL: https://github.com/llvm/llvm-project/commit/222d7fc7f81936edf488bc85acf3ea5d5300bca3
DIFF: https://github.com/llvm/llvm-project/commit/222d7fc7f81936edf488bc85acf3ea5d5300bca3.diff

LOG: [mlir][spirv] Avoid duplicated Block decoration during serialization

It's legal per the Vulkan / SPIR-V spec; still it's better to avoid
such duplication to have cleaner blob and reduce the binary size.

Reviewed By: scotttodd

Differential Revision: https://reviews.llvm.org/D115532

Added: 
    

Modified: 
    mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/unittests/Dialect/SPIRV/SerializationTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index d877c1b70311e..ef558aa6c3929 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -316,18 +316,6 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
     return failure();
   }
 
-  if (isInterfaceStructPtrType(varOp.type())) {
-    auto structType = varOp.type()
-                          .cast<spirv::PointerType>()
-                          .getPointeeType()
-                          .cast<spirv::StructType>();
-    if (failed(
-            emitDecoration(getTypeID(structType), spirv::Decoration::Block))) {
-      return varOp.emitError("cannot decorate ")
-             << structType << " with Block decoration";
-    }
-  }
-
   elidedAttrs.push_back("type");
   SmallVector<uint32_t, 4> operands;
   operands.push_back(resultTypeID);

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index bcead6e527d5e..bd618ec4884b8 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -331,9 +331,9 @@ LogicalResult
 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
                             SetVector<StringRef> &serializationCtx) {
   typeID = getTypeID(type);
-  if (typeID) {
+  if (typeID)
     return success();
-  }
+
   typeID = getNextID();
   SmallVector<uint32_t, 4> operands;
 
@@ -499,6 +499,14 @@ LogicalResult Serializer::prepareBasicType(
     typeEnum = spirv::Opcode::OpTypePointer;
     operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
     operands.push_back(pointeeTypeID);
+
+    if (isInterfaceStructPtrType(ptrType)) {
+      if (failed(emitDecoration(getTypeID(pointeeStruct),
+                                spirv::Decoration::Block)))
+        return emitError(loc, "cannot decorate ")
+               << pointeeStruct << " with Block decoration";
+    }
+
     return success();
   }
 

diff  --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index 9222b0cd3654a..f17bc53c5bd8b 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -76,27 +76,29 @@ class SerializationTest : public ::testing::Test {
         builder.getStringAttr(name), nullptr);
   }
 
+  /// Handles a SPIR-V instruction with the given `opcode` and `operand`.
+  /// Returns true to interrupt.
+  using HandleFn = llvm::function_ref<bool(spirv::Opcode opcode,
+                                           ArrayRef<uint32_t> operands)>;
+
   /// Returns true if we can find a matching instruction in the SPIR-V blob.
-  bool findInstruction(llvm::function_ref<bool(spirv::Opcode opcode,
-                                               ArrayRef<uint32_t> operands)>
-                           matchFn) {
+  bool scanInstruction(HandleFn handleFn) {
     auto binarySize = binary.size();
     auto *begin = binary.begin();
     auto currOffset = spirv::kHeaderWordCount;
 
     while (currOffset < binarySize) {
       auto wordCount = binary[currOffset] >> 16;
-      if (!wordCount || (currOffset + wordCount > binarySize)) {
+      if (!wordCount || (currOffset + wordCount > binarySize))
         return false;
-      }
+
       spirv::Opcode opcode =
           static_cast<spirv::Opcode>(binary[currOffset] & 0xffff);
-
-      if (matchFn(opcode,
-                  llvm::ArrayRef<uint32_t>(begin + currOffset + 1,
-                                           begin + currOffset + wordCount))) {
+      llvm::ArrayRef<uint32_t> operands(begin + currOffset + 1,
+                                        begin + currOffset + wordCount);
+      if (handleFn(opcode, operands))
         return true;
-      }
+
       currOffset += wordCount;
     }
     return false;
@@ -119,12 +121,32 @@ TEST_F(SerializationTest, ContainsBlockDecoration) {
   ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
 
   auto hasBlockDecoration = [](spirv::Opcode opcode,
-                               ArrayRef<uint32_t> operands) -> bool {
-    if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)
-      return false;
-    return operands[1] == static_cast<uint32_t>(spirv::Decoration::Block);
+                               ArrayRef<uint32_t> operands) {
+    return opcode == spirv::Opcode::OpDecorate && operands.size() == 2 &&
+           operands[1] == static_cast<uint32_t>(spirv::Decoration::Block);
+  };
+  EXPECT_TRUE(scanInstruction(hasBlockDecoration));
+}
+
+TEST_F(SerializationTest, ContainsNoDuplicatedBlockDecoration) {
+  auto structType = getFloatStructType();
+  // Two global variables using the same type should not decorate the type with
+  // duplicated `Block` decorations.
+  addGlobalVar(structType, "var0");
+  addGlobalVar(structType, "var1");
+
+  ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+
+  unsigned count = 0;
+  auto countBlockDecoration = [&count](spirv::Opcode opcode,
+                                       ArrayRef<uint32_t> operands) {
+    if (opcode == spirv::Opcode::OpDecorate && operands.size() == 2 &&
+        operands[1] == static_cast<uint32_t>(spirv::Decoration::Block))
+      ++count;
+    return false;
   };
-  EXPECT_TRUE(findInstruction(hasBlockDecoration));
+  ASSERT_FALSE(scanInstruction(countBlockDecoration));
+  EXPECT_EQ(count, 1u);
 }
 
 TEST_F(SerializationTest, ContainsSymbolName) {
@@ -140,7 +162,7 @@ TEST_F(SerializationTest, ContainsSymbolName) {
     return opcode == spirv::Opcode::OpName &&
            spirv::decodeStringLiteral(operands, index) == "var0";
   };
-  EXPECT_TRUE(findInstruction(hasVarName));
+  EXPECT_TRUE(scanInstruction(hasVarName));
 }
 
 TEST_F(SerializationTest, DoesNotContainSymbolName) {
@@ -156,5 +178,5 @@ TEST_F(SerializationTest, DoesNotContainSymbolName) {
     return opcode == spirv::Opcode::OpName &&
            spirv::decodeStringLiteral(operands, index) == "var0";
   };
-  EXPECT_FALSE(findInstruction(hasVarName));
+  EXPECT_FALSE(scanInstruction(hasVarName));
 }


        


More information about the Mlir-commits mailing list