[Mlir-commits] [mlir] 2dd9e43 - [spirv] Use owning module ref to avoid leaks and fix ASAN tests

Lei Zhang llvmlistbot at llvm.org
Thu Jul 16 14:31:18 PDT 2020


Author: Lei Zhang
Date: 2020-07-16T17:30:59-04:00
New Revision: 2dd9e43579b341e5de238de924cc910042b0194e

URL: https://github.com/llvm/llvm-project/commit/2dd9e43579b341e5de238de924cc910042b0194e
DIFF: https://github.com/llvm/llvm-project/commit/2dd9e43579b341e5de238de924cc910042b0194e.diff

LOG: [spirv] Use owning module ref to avoid leaks and fix ASAN tests

Differential Revision: https://reviews.llvm.org/D83982

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
    mlir/unittests/Dialect/SPIRV/SerializationTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 4e6b294beaae..4ba3f16feef0 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -103,7 +103,7 @@ class Deserializer {
   LogicalResult deserialize();
 
   /// Collects the final SPIR-V ModuleOp.
-  Optional<spirv::ModuleOp> collect();
+  spirv::OwningSPIRVModuleRef collect();
 
 private:
   //===--------------------------------------------------------------------===//
@@ -111,7 +111,7 @@ class Deserializer {
   //===--------------------------------------------------------------------===//
 
   /// Initializes the `module` ModuleOp in this deserializer instance.
-  spirv::ModuleOp createModuleOp();
+  spirv::OwningSPIRVModuleRef createModuleOp();
 
   /// Processes SPIR-V module header in `binary`.
   LogicalResult processHeader();
@@ -425,7 +425,7 @@ class Deserializer {
   Location unknownLoc;
 
   /// The SPIR-V ModuleOp.
-  Optional<spirv::ModuleOp> module;
+  spirv::OwningSPIRVModuleRef module;
 
   /// The current function under construction.
   Optional<spirv::FuncOp> curFunction;
@@ -556,13 +556,15 @@ LogicalResult Deserializer::deserialize() {
   return success();
 }
 
-Optional<spirv::ModuleOp> Deserializer::collect() { return module; }
+spirv::OwningSPIRVModuleRef Deserializer::collect() {
+  return std::move(module);
+}
 
 //===----------------------------------------------------------------------===//
 // Module structure
 //===----------------------------------------------------------------------===//
 
-spirv::ModuleOp Deserializer::createModuleOp() {
+spirv::OwningSPIRVModuleRef Deserializer::createModuleOp() {
   OpBuilder builder(context);
   OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
   spirv::ModuleOp::build(builder, state);
@@ -1912,10 +1914,10 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
   // Go through all ops and remap the operands.
   auto remapOperands = [&](Operation *op) {
     for (auto &operand : op->getOpOperands())
-      if (auto mappedOp = mapper.lookupOrNull(operand.get()))
+      if (Value mappedOp = mapper.lookupOrNull(operand.get()))
         operand.set(mappedOp);
     for (auto &succOp : op->getBlockOperands())
-      if (auto mappedOp = mapper.lookupOrNull(succOp.get()))
+      if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
         succOp.set(mappedOp);
   };
   for (auto &block : body) {
@@ -2354,7 +2356,7 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
     return emitError(unknownLoc,
                      "missing Execution Model specification in OpEntryPoint");
   }
-  auto exec_model = opBuilder.getI32IntegerAttr(words[wordIndex++]);
+  auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]);
   if (wordIndex >= words.size()) {
     return emitError(unknownLoc, "missing <id> in OpEntryPoint");
   }
@@ -2382,7 +2384,7 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
     interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
     wordIndex++;
   }
-  opBuilder.create<spirv::EntryPointOp>(unknownLoc, exec_model,
+  opBuilder.create<spirv::EntryPointOp>(unknownLoc, execModel,
                                         opBuilder.getSymbolRefAttr(fnName),
                                         opBuilder.getArrayAttr(interface));
   return success();
@@ -2594,5 +2596,5 @@ spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef<uint32_t> binary,
   if (failed(deserializer.deserialize()))
     return nullptr;
 
-  return deserializer.collect().getValueOr(nullptr);
+  return deserializer.collect();
 }

diff  --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index 340bfd939e1b..3d57e559ca5e 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVModule.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/Builders.h"
@@ -56,7 +57,7 @@ class SerializationTest : public ::testing::Test {
   }
 
   Type getFloatStructType() {
-    OpBuilder opBuilder(module.body());
+    OpBuilder opBuilder(module->body());
     llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
     llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
     auto structType = spirv::StructType::get(elementTypes, offsetInfo);
@@ -64,7 +65,7 @@ class SerializationTest : public ::testing::Test {
   }
 
   void addGlobalVar(Type type, llvm::StringRef name) {
-    OpBuilder opBuilder(module.body());
+    OpBuilder opBuilder(module->body());
     auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
     opBuilder.create<spirv::GlobalVariableOp>(
         UnknownLoc::get(&context), TypeAttr::get(ptrType),
@@ -98,7 +99,7 @@ class SerializationTest : public ::testing::Test {
 
 protected:
   MLIRContext context;
-  spirv::ModuleOp module;
+  spirv::OwningSPIRVModuleRef module;
   SmallVector<uint32_t, 0> binary;
 };
 
@@ -109,7 +110,7 @@ class SerializationTest : public ::testing::Test {
 TEST_F(SerializationTest, BlockDecorationTest) {
   auto structType = getFloatStructType();
   addGlobalVar(structType, "var0");
-  ASSERT_TRUE(succeeded(spirv::serialize(module, binary)));
+  ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
   auto hasBlockDecoration = [](spirv::Opcode opcode,
                                ArrayRef<uint32_t> operands) -> bool {
     if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)


        


More information about the Mlir-commits mailing list