[Mlir-commits] [mlir] 341f3c1 - [MLIR][SPIRV] ModuleCombiner: deduplicate global vars, spec consts, and funcs.
Lei Zhang
llvmlistbot at llvm.org
Thu Nov 19 07:09:37 PST 2020
Author: ergawy
Date: 2020-11-19T10:06:04-05:00
New Revision: 341f3c1120dfa8879e5f714a07fc8b16c8887a7f
URL: https://github.com/llvm/llvm-project/commit/341f3c1120dfa8879e5f714a07fc8b16c8887a7f
DIFF: https://github.com/llvm/llvm-project/commit/341f3c1120dfa8879e5f714a07fc8b16c8887a7f.diff
LOG: [MLIR][SPIRV] ModuleCombiner: deduplicate global vars, spec consts, and funcs.
This commit extends the functionality of the SPIR-V module combiner
library by adding new deduplication capabilities. In particular,
implementation of deduplication of global variables and specialization
constants, and functions is introduced.
For global variables, 2 variables are considered duplicate if they either
have the same descriptor set + binding or the same built_in attribute.
For specialization constants, 2 spec constants are considered duplicate if
they have the same spec_id attribute.
2 functions are deduplicated if they are identical. 2 functions are
identical if they have the same prototype, attributes, and body.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D90951
Added:
mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h
mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h
index b7ecd57d103d..36071f00b94d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h
+++ b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h
@@ -28,7 +28,7 @@ class ModuleOp;
/// combination process proceeds in 2 phases:
///
/// (1) resolve conflicts between pairs of ops from
diff erent modules
-/// (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO)
+/// (2) deduplicate equivalent ops/sub-ops in the merged module.
///
/// For the conflict resolution phase, the following rules are employed to
/// resolve such conflicts:
@@ -39,13 +39,22 @@ class ModuleOp;
/// other symbol.
/// - If none of the 2 conflicting ops are spv.func, then rename either.
///
-/// In all cases, the references to the updated symbol are also updated to
-/// reflect the change.
+/// For deduplication, the following 3 cases are taken into consideration:
+///
+/// - If 2 spv.globalVariable's have either the same descriptor set + binding
+/// or the same build_in attribute value, then replace one of them using the
+/// other.
+/// - If 2 spv.specConstant's have the same spec_id attribute value, then
+/// replace one of them using the other.
+/// - If 2 spv.func's are identical replace one of them using the other.
+///
+/// In all cases, the references to the updated symbol (whether renamed or
+/// deduplicated) are also updated to reflect the change.
///
/// \param modules the list of modules to combine. Input modules are not
/// modified.
/// \param combinedMdouleBuilder an OpBuilder to be used for
-/// building up the combined module.
+// building up the combined module.
/// \param symbRenameListener a listener that gets called everytime a symbol in
/// one of the input modules is renamed. The arguments
/// passed to the listener are: the input
diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
index 7687ab27e753..2df9e56a6940 100644
--- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
+++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
@@ -12,10 +12,12 @@
#include "mlir/Dialect/SPIRV/ModuleCombiner.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/StringExtras.h"
using namespace mlir;
@@ -59,6 +61,59 @@ static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
return success();
}
+template <typename KeyTy, typename SymbolOpTy>
+static SymbolOpTy
+emplaceOrGetReplacementSymbol(KeyTy key, SymbolOpTy symbolOp,
+ DenseMap<KeyTy, SymbolOpTy> &deduplicationMap) {
+ auto result = deduplicationMap.try_emplace(key, symbolOp);
+
+ if (result.second)
+ return SymbolOpTy();
+
+ return result.first->second;
+}
+
+/// Computes a hash code to represent the argument SymbolOpInterface based on
+/// all the Op's attributes except for the symbol name.
+///
+/// \return the hash code computed from the Op's attributes as described above.
+///
+/// Note: We use the operation's name (not the symbol name) as part of the hash
+/// computation. This prevents, for example, mistakenly considering a global
+/// variable and a spec constant as duplicates because their descriptor set +
+/// binding and spec_id, repectively, happen to hash to the same value.
+static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
+ llvm::hash_code hashCode(0);
+ hashCode = llvm::hash_combine(symbolOp.getOperation()->getName());
+
+ for (auto attr : symbolOp.getOperation()->getAttrs()) {
+ if (attr.first == SymbolTable::getSymbolAttrName())
+ continue;
+ hashCode = llvm::hash_combine(hashCode, attr);
+ }
+
+ return hashCode;
+}
+
+/// Computes a hash code from the argument Block.
+llvm::hash_code computeHash(Block *block) {
+ // TODO: Consider extracting BlockEquivalenceData into a common header and
+ // re-using it here.
+ llvm::hash_code hash(0);
+
+ for (Operation &op : *block) {
+ // TODO: Properly handle operations with regions.
+ if (op.getNumRegions() > 0)
+ return 0;
+
+ hash = llvm::hash_combine(
+ hash, OperationEquivalence::computeHash(
+ &op, OperationEquivalence::Flags::IgnoreOperands));
+ }
+
+ return hash;
+}
+
namespace mlir {
namespace spirv {
@@ -174,6 +229,48 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
combinedModuleBuilder.insert(op.clone());
}
+ // Deduplicate identical global variables, spec constants, and functions.
+ DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
+ SmallVector<SymbolOpInterface, 0> eraseList;
+
+ for (auto &op : combinedModule.getBlock().without_terminator()) {
+ llvm::hash_code hashCode(0);
+ SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
+
+ if (!symbolOp)
+ continue;
+
+ hashCode = computeHash(symbolOp);
+
+ // A 0 hash code means the op is not suitable for deduplication and should
+ // be skipped. An example of this is when a function has ops with regions
+ // which are not properly supported yet.
+ if (!hashCode)
+ continue;
+
+ if (auto funcOp = dyn_cast<FuncOp>(op))
+ for (auto &blk : funcOp)
+ hashCode = llvm::hash_combine(hashCode, computeHash(&blk));
+
+ SymbolOpInterface replacementSymOp =
+ emplaceOrGetReplacementSymbol(hashCode, symbolOp, hashToSymbolOp);
+
+ if (!replacementSymOp)
+ continue;
+
+ if (failed(SymbolTable::replaceAllSymbolUses(
+ symbolOp, replacementSymOp.getName(), combinedModule))) {
+ symbolOp.emitError("unable to update all symbol uses for ")
+ << symbolOp.getName() << " to " << replacementSymOp.getName();
+ return nullptr;
+ }
+
+ eraseList.push_back(symbolOp);
+ }
+
+ for (auto symbolOp : eraseList)
+ symbolOp.erase();
+
return combinedModule;
}
diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir
index 4d9727cc11b3..c41e7b9486e1 100644
--- a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir
+++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir
@@ -39,10 +39,12 @@ spv.module Logical GLSL450 {
// CHECK-NEXT: }
// CHECK-NEXT: spv.func @foo_1
+// CHECK-NEXT: spv.FAdd
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// CHECK-NEXT: spv.func @foo_2
+// CHECK-NEXT: spv.ISub
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// CHECK-NEXT: }
@@ -57,13 +59,15 @@ spv.module Logical GLSL450 {
spv.module Logical GLSL450 {
spv.func @foo(%arg0 : f32) -> f32 "None" {
- spv.ReturnValue %arg0 : f32
+ %0 = spv.FAdd %arg0, %arg0 : f32
+ spv.ReturnValue %0 : f32
}
}
spv.module Logical GLSL450 {
spv.func @foo(%arg0 : i32) -> i32 "None" {
- spv.ReturnValue %arg0 : i32
+ %0 = spv.ISub %arg0, %arg0 : i32
+ spv.ReturnValue %0 : i32
}
}
}
@@ -578,9 +582,9 @@ spv.module Logical GLSL450 {
// CHECK: module {
// CHECK-NEXT: spv.module Logical GLSL450 {
-// CHECK-NEXT: spv.globalVariable @foo_1
+// CHECK-NEXT: spv.globalVariable @foo_1 bind(1, 0)
-// CHECK-NEXT: spv.globalVariable @foo
+// CHECK-NEXT: spv.globalVariable @foo bind(2, 0)
// CHECK-NEXT: }
module {
@@ -589,7 +593,26 @@ spv.module Logical GLSL450 {
}
spv.module Logical GLSL450 {
- spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+ spv.globalVariable @foo bind(2, 0) : !spv.ptr<f32, Input>
+}
+}
+
+// -----
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.globalVariable @foo_1 built_in("GlobalInvocationId")
+
+// CHECK-NEXT: spv.globalVariable @foo built_in("LocalInvocationId")
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.globalVariable @foo built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
+}
+
+spv.module Logical GLSL450 {
+ spv.globalVariable @foo built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
}
}
diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir
new file mode 100644
index 000000000000..e919bbaf17b7
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir
@@ -0,0 +1,244 @@
+// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// Deduplicate 2 global variables with the same descriptor set and binding.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.globalVariable @foo
+
+// CHECK-NEXT: spv.func @use_foo
+// CHECK-NEXT: spv.mlir.addressof @foo
+// CHECK-NEXT: spv.Load
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @use_bar
+// CHECK-NEXT: spv.mlir.addressof @foo
+// CHECK-NEXT: spv.Load
+// CHECK-NEXT: spv.FAdd
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+
+ spv.func @use_foo() -> f32 "None" {
+ %0 = spv.mlir.addressof @foo : !spv.ptr<f32, Input>
+ %1 = spv.Load "Input" %0 : f32
+ spv.ReturnValue %1 : f32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.globalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
+
+ spv.func @use_bar() -> f32 "None" {
+ %0 = spv.mlir.addressof @bar : !spv.ptr<f32, Input>
+ %1 = spv.Load "Input" %0 : f32
+ %2 = spv.FAdd %1, %1 : f32
+ spv.ReturnValue %2 : f32
+ }
+}
+}
+
+// -----
+
+// Deduplicate 2 global variables with the same descriptor set and binding but
diff erent types.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.globalVariable @foo bind(1, 0)
+
+// CHECK-NEXT: spv.globalVariable @bar bind(1, 0)
+
+// CHECK-NEXT: spv.func @use_bar
+// CHECK-NEXT: spv.mlir.addressof @bar
+// CHECK-NEXT: spv.Load
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.globalVariable @foo bind(1, 0) : !spv.ptr<i32, Input>
+}
+
+spv.module Logical GLSL450 {
+ spv.globalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
+
+ spv.func @use_bar() -> f32 "None" {
+ %0 = spv.mlir.addressof @bar : !spv.ptr<f32, Input>
+ %1 = spv.Load "Input" %0 : f32
+ spv.ReturnValue %1 : f32
+ }
+}
+}
+
+// -----
+
+// Deduplicate 2 global variables with the same built-in attribute.
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.globalVariable @foo built_in("GlobalInvocationId")
+// CHECK-NEXT: spv.func @use_bar
+// CHECK-NEXT: spv.mlir.addressof @foo
+// CHECK-NEXT: spv.Load
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.globalVariable @foo built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
+}
+
+spv.module Logical GLSL450 {
+ spv.globalVariable @bar built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
+
+ spv.func @use_bar() -> vector<3xi32> "None" {
+ %0 = spv.mlir.addressof @bar : !spv.ptr<vector<3xi32>, Input>
+ %1 = spv.Load "Input" %0 : vector<3xi32>
+ spv.ReturnValue %1 : vector<3xi32>
+ }
+}
+}
+
+// -----
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.specConstant @foo spec_id(5)
+
+// CHECK-NEXT: spv.func @use_foo()
+// CHECK-NEXT: %0 = spv.mlir.referenceof @foo
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @use_bar()
+// CHECK-NEXT: %0 = spv.mlir.referenceof @foo
+// CHECK-NEXT: spv.FAdd
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.specConstant @foo spec_id(5) = 1. : f32
+
+ spv.func @use_foo() -> (f32) "None" {
+ %0 = spv.mlir.referenceof @foo : f32
+ spv.ReturnValue %0 : f32
+ }
+}
+
+spv.module Logical GLSL450 {
+ spv.specConstant @bar spec_id(5) = 1. : f32
+
+ spv.func @use_bar() -> (f32) "None" {
+ %0 = spv.mlir.referenceof @bar : f32
+ %1 = spv.FAdd %0, %0 : f32
+ spv.ReturnValue %1 : f32
+ }
+}
+}
+
+// -----
+
+// CHECK: module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT: spv.specConstant @bar spec_id(5)
+
+// CHECK-NEXT: spv.func @foo(%arg0: f32)
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @foo_
diff erent_body(%arg0: f32)
+// CHECK-NEXT: spv.mlir.referenceof
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @baz(%arg0: i32)
+// CHECK-NEXT: spv.ReturnValue
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @baz_no_return(%arg0: i32)
+// CHECK-NEXT: spv.Return
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @baz_no_return_
diff erent_control
+// CHECK-NEXT: spv.Return
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @baz_no_return_another_control
+// CHECK-NEXT: spv.Return
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @kernel
+// CHECK-NEXT: spv.Return
+// CHECK-NEXT: }
+
+// CHECK-NEXT: spv.func @kernel_
diff erent_attr
+// CHECK-NEXT: spv.Return
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+ spv.specConstant @bar spec_id(5) = 1. : f32
+
+ spv.func @foo(%arg0: f32) -> (f32) "None" {
+ spv.ReturnValue %arg0 : f32
+ }
+
+ spv.func @foo_duplicate(%arg0: f32) -> (f32) "None" {
+ spv.ReturnValue %arg0 : f32
+ }
+
+ spv.func @foo_
diff erent_body(%arg0: f32) -> (f32) "None" {
+ %0 = spv.mlir.referenceof @bar : f32
+ spv.ReturnValue %arg0 : f32
+ }
+
+ spv.func @baz(%arg0: i32) -> (i32) "None" {
+ spv.ReturnValue %arg0 : i32
+ }
+
+ spv.func @baz_no_return(%arg0: i32) "None" {
+ spv.Return
+ }
+
+ spv.func @baz_no_return_duplicate(%arg0: i32) -> () "None" {
+ spv.Return
+ }
+
+ spv.func @baz_no_return_
diff erent_control(%arg0: i32) -> () "Inline" {
+ spv.Return
+ }
+
+ spv.func @baz_no_return_another_control(%arg0: i32) -> () "Inline|Pure" {
+ spv.Return
+ }
+
+ spv.func @kernel(
+ %arg0: f32,
+ %arg1: !spv.ptr<!spv.struct<(!spv.array<12 x f32>)>, CrossWorkgroup>) "None"
+ attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
+ spv.Return
+ }
+
+ spv.func @kernel_
diff erent_attr(
+ %arg0: f32,
+ %arg1: !spv.ptr<!spv.struct<(!spv.array<12 x f32>)>, CrossWorkgroup>) "None"
+ attributes {spv.entry_point_abi = {local_size = dense<[64, 1, 1]> : vector<3xi32>}} {
+ spv.Return
+ }
+}
+}
More information about the Mlir-commits
mailing list