[Mlir-commits] [mlir] b805087 - [mlir][spirv] Introduce OwningSPIRVModuleRef for ownership
Lei Zhang
llvmlistbot at llvm.org
Tue Jul 7 05:32:49 PDT 2020
Author: Lei Zhang
Date: 2020-07-07T08:29:27-04:00
New Revision: b80508703fd7f88a7922c9c8f02c696be1db8034
URL: https://github.com/llvm/llvm-project/commit/b80508703fd7f88a7922c9c8f02c696be1db8034
DIFF: https://github.com/llvm/llvm-project/commit/b80508703fd7f88a7922c9c8f02c696be1db8034.diff
LOG: [mlir][spirv] Introduce OwningSPIRVModuleRef for ownership
Similar to OwningModuleRef, OwningSPIRVModuleRef signals ownership
transfer clearly. This is useful for APIs like spirv::deserialize,
where a spirv::ModuleOp is returned by deserializing SPIR-V binary
module.
This addresses the ASAN error as reported in
https://bugs.llvm.org/show_bug.cgi?id=46272
Differential Revision: https://reviews.llvm.org/D81652
Added:
mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h
mlir/include/mlir/IR/OwningOpRefBase.h
Modified:
mlir/include/mlir/Dialect/SPIRV/Serialization.h
mlir/include/mlir/IR/Module.h
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h
new file mode 100644
index 000000000000..a53331eda4fa
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h
@@ -0,0 +1,29 @@
+//===- SPIRVModule.h - SPIR-V Module Utilities ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_SPIRVMODULE_H
+#define MLIR_DIALECT_SPIRV_SPIRVMODULE_H
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/OwningOpRefBase.h"
+
+namespace mlir {
+namespace spirv {
+
+/// This class acts as an owning reference to a SPIR-V module, and will
+/// automatically destroy the held module on destruction if the held module
+/// is valid.
+class OwningSPIRVModuleRef : public OwningOpRefBase<spirv::ModuleOp> {
+public:
+ using OwningOpRefBase<spirv::ModuleOp>::OwningOpRefBase;
+};
+
+} // end namespace spirv
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_SPIRVMODULE_H
diff --git a/mlir/include/mlir/Dialect/SPIRV/Serialization.h b/mlir/include/mlir/Dialect/SPIRV/Serialization.h
index f6370a1b5ec2..2c91286ca158 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Serialization.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Serialization.h
@@ -22,6 +22,7 @@ class MLIRContext;
namespace spirv {
class ModuleOp;
+class OwningSPIRVModuleRef;
/// Serializes the given SPIR-V `module` and writes to `binary`. On failure,
/// reports errors to the error handler registered with the MLIR context for
@@ -31,9 +32,10 @@ LogicalResult serialize(ModuleOp module, SmallVectorImpl<uint32_t> &binary,
/// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp
/// in the given `context`. Returns the ModuleOp on success; otherwise, reports
-/// errors to the error handler registered with `context` and returns
-/// llvm::None.
-Optional<ModuleOp> deserialize(ArrayRef<uint32_t> binary, MLIRContext *context);
+/// errors to the error handler registered with `context` and returns a null
+/// module.
+OwningSPIRVModuleRef deserialize(ArrayRef<uint32_t> binary,
+ MLIRContext *context);
} // end namespace spirv
} // end namespace mlir
diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h
index 3c61574f2b99..8a5101337586 100644
--- a/mlir/include/mlir/IR/Module.h
+++ b/mlir/include/mlir/IR/Module.h
@@ -13,6 +13,7 @@
#ifndef MLIR_IR_MODULE_H
#define MLIR_IR_MODULE_H
+#include "mlir/IR/OwningOpRefBase.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
@@ -122,40 +123,10 @@ class ModuleTerminatorOp
};
/// This class acts as an owning reference to a module, and will automatically
-/// destroy the held module if valid.
-class OwningModuleRef {
+/// destroy the held module on destruction if the held module is valid.
+class OwningModuleRef : public OwningOpRefBase<ModuleOp> {
public:
- OwningModuleRef(std::nullptr_t = nullptr) {}
- OwningModuleRef(ModuleOp module) : module(module) {}
- OwningModuleRef(OwningModuleRef &&other) : module(other.release()) {}
- ~OwningModuleRef() {
- if (module)
- module.erase();
- }
-
- // Assign from another module reference.
- OwningModuleRef &operator=(OwningModuleRef &&other) {
- if (module)
- module.erase();
- module = other.release();
- return *this;
- }
-
- /// Allow accessing the internal module.
- ModuleOp get() const { return module; }
- ModuleOp operator*() const { return module; }
- ModuleOp *operator->() { return &module; }
- explicit operator bool() const { return module; }
-
- /// Release the referenced module.
- ModuleOp release() {
- ModuleOp released;
- std::swap(released, module);
- return released;
- }
-
-private:
- ModuleOp module;
+ using OwningOpRefBase<ModuleOp>::OwningOpRefBase;
};
} // end namespace mlir
diff --git a/mlir/include/mlir/IR/OwningOpRefBase.h b/mlir/include/mlir/IR/OwningOpRefBase.h
new file mode 100644
index 000000000000..bfdf98f6f37e
--- /dev/null
+++ b/mlir/include/mlir/IR/OwningOpRefBase.h
@@ -0,0 +1,64 @@
+//===- OwningOpRefBase.h - MLIR OwningOpRefBase -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides a base class for owning op refs.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_OWNINGOPREFBASE_H
+#define MLIR_IR_OWNINGOPREFBASE_H
+
+#include <utility>
+
+namespace mlir {
+
+/// This class acts as an owning reference to an op, and will automatically
+/// destroy the held op on destruction if the held op is valid.
+///
+/// Note that OpBuilder and related functionality should be highly preferred
+/// instead, and this should only be used in situations where existing solutions
+/// are not viable.
+template <typename OpTy>
+class OwningOpRefBase {
+public:
+ OwningOpRefBase(std::nullptr_t = nullptr) {}
+ OwningOpRefBase(OpTy op) : op(op) {}
+ OwningOpRefBase(OwningOpRefBase &&other) : op(other.release()) {}
+ ~OwningOpRefBase() {
+ if (op)
+ op.erase();
+ }
+
+ // Assign from another op reference.
+ OwningOpRefBase &operator=(OwningOpRefBase &&other) {
+ if (op)
+ op.erase();
+ op = other.release();
+ return *this;
+ }
+
+ /// Allow accessing the internal op.
+ OpTy get() const { return op; }
+ OpTy operator*() const { return op; }
+ OpTy *operator->() { return &op; }
+ explicit operator bool() const { return op; }
+
+ /// Release the referenced op.
+ OpTy release() {
+ OpTy released;
+ std::swap(released, op);
+ return released;
+ }
+
+private:
+ OpTy op;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_IR_OWNINGOPREFBASE_H
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index b5fef1477870..92f2a015930f 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
+#include "mlir/Dialect/SPIRV/SPIRVModule.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/BlockAndValueMapping.h"
@@ -2516,12 +2517,12 @@ Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
} // namespace
-Optional<spirv::ModuleOp> spirv::deserialize(ArrayRef<uint32_t> binary,
- MLIRContext *context) {
+spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef<uint32_t> binary,
+ MLIRContext *context) {
Deserializer deserializer(binary, context);
if (failed(deserializer.deserialize()))
- return llvm::None;
+ return nullptr;
- return deserializer.collect();
+ return deserializer.collect().getValueOr(nullptr);
}
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
index 4c3fb1e8d422..42b458d314ca 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/SPIRV/SPIRVModule.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/IR/Builders.h"
@@ -49,13 +50,13 @@ static OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
auto binary = llvm::makeArrayRef(reinterpret_cast<const uint32_t *>(start),
size / sizeof(uint32_t));
- auto spirvModule = spirv::deserialize(binary, context);
+ spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context);
if (!spirvModule)
return {};
OwningModuleRef module(ModuleOp::create(FileLineColLoc::get(
input->getBufferIdentifier(), /*line=*/0, /*column=*/0, context)));
- module->getBody()->push_front(spirvModule->getOperation());
+ module->getBody()->push_front(spirvModule.release());
return module;
}
@@ -136,14 +137,14 @@ static LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr,
return failure();
// Then deserialize to get back a SPIR-V module.
- auto spirvModule = spirv::deserialize(binary, context);
+ spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context);
if (!spirvModule)
return failure();
// Wrap around in a new MLIR module.
OwningModuleRef dstModule(ModuleOp::create(FileLineColLoc::get(
/*filename=*/"", /*line=*/0, /*column=*/0, context)));
- dstModule->getBody()->push_front(spirvModule->getOperation());
+ dstModule->getBody()->push_front(spirvModule.release());
dstModule->print(output);
return mlir::success();
diff --git a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
index 31fc0e426e24..a81b7741deea 100644
--- a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVModule.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/IR/Diagnostics.h"
@@ -46,7 +47,7 @@ class DeserializationTest : public ::testing::Test {
}
/// Performs deserialization and returns the constructed spv.module op.
- Optional<spirv::ModuleOp> deserialize() {
+ spirv::OwningSPIRVModuleRef deserialize() {
return spirv::deserialize(binary, &context);
}
@@ -130,27 +131,27 @@ class DeserializationTest : public ::testing::Test {
//===----------------------------------------------------------------------===//
TEST_F(DeserializationTest, EmptyModuleFailure) {
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("SPIR-V binary module must have a 5-word header");
}
TEST_F(DeserializationTest, WrongMagicNumberFailure) {
addHeader();
binary.front() = 0xdeadbeef; // Change to a wrong magic number
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("incorrect magic number");
}
TEST_F(DeserializationTest, OnlyHeaderSuccess) {
addHeader();
- EXPECT_NE(llvm::None, deserialize());
+ EXPECT_TRUE(deserialize());
}
TEST_F(DeserializationTest, ZeroWordCountFailure) {
addHeader();
binary.push_back(0); // OpNop with zero word count
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("word count cannot be zero");
}
@@ -160,7 +161,7 @@ TEST_F(DeserializationTest, InsufficientWordFailure) {
static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
// Missing word for type <id>
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("insufficient words for the last instruction");
}
@@ -172,7 +173,7 @@ TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
addHeader();
addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
}
@@ -198,7 +199,7 @@ TEST_F(DeserializationTest, OpMemberNameSuccess) {
addInstruction(spirv::Opcode::OpMemberName, operands2);
binary.append(typeDecl.begin(), typeDecl.end());
- EXPECT_NE(llvm::None, deserialize());
+ EXPECT_TRUE(deserialize());
}
TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
@@ -215,7 +216,7 @@ TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
addInstruction(spirv::Opcode::OpMemberName, operands1);
binary.append(typeDecl.begin(), typeDecl.end());
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("OpMemberName must have at least 3 operands");
}
@@ -234,7 +235,7 @@ TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
addInstruction(spirv::Opcode::OpMemberName, operands);
binary.append(typeDecl.begin(), typeDecl.end());
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("unexpected trailing words in OpMemberName instruction");
}
@@ -249,7 +250,7 @@ TEST_F(DeserializationTest, FunctionMissingEndFailure) {
addFunction(voidType, fnType);
// Missing OpFunctionEnd
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("expected OpFunctionEnd instruction");
}
@@ -261,7 +262,7 @@ TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
addFunction(voidType, fnType);
// Missing OpFunctionParameter
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("expected OpFunctionParameter instruction");
}
@@ -274,7 +275,7 @@ TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
addReturn();
addFunctionEnd();
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("a basic block must start with OpLabel");
}
@@ -287,6 +288,6 @@ TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
addReturn();
addFunctionEnd();
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("OpLabel should only have result <id>");
}
More information about the Mlir-commits
mailing list