[Mlir-commits] [mlir] 341f3c1 - [MLIR][SPIRV] ModuleCombiner: deduplicate global vars, spec consts, and funcs.

Lei Zhang llvmlistbot at llvm.org
Thu Nov 19 07:09:37 PST 2020


Author: ergawy
Date: 2020-11-19T10:06:04-05:00
New Revision: 341f3c1120dfa8879e5f714a07fc8b16c8887a7f

URL: https://github.com/llvm/llvm-project/commit/341f3c1120dfa8879e5f714a07fc8b16c8887a7f
DIFF: https://github.com/llvm/llvm-project/commit/341f3c1120dfa8879e5f714a07fc8b16c8887a7f.diff

LOG: [MLIR][SPIRV] ModuleCombiner: deduplicate global vars, spec consts, and funcs.

This commit extends the functionality of the SPIR-V module combiner
library by adding new deduplication capabilities. In particular,
implementation of deduplication of global variables and specialization
constants, and functions is introduced.

For global variables, 2 variables are considered duplicate if they either
have the same descriptor set + binding or the same built_in attribute.

For specialization constants, 2 spec constants are considered duplicate if
they have the same spec_id attribute.

2 functions are deduplicated if they are identical. 2 functions are
identical if they have the same prototype, attributes, and body.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D90951

Added: 
    mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h
    mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
    mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h
index b7ecd57d103d..36071f00b94d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h
+++ b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h
@@ -28,7 +28,7 @@ class ModuleOp;
 /// combination process proceeds in 2 phases:
 ///
 ///   (1) resolve conflicts between pairs of ops from 
diff erent modules
-///   (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO)
+///   (2) deduplicate equivalent ops/sub-ops in the merged module.
 ///
 /// For the conflict resolution phase, the following rules are employed to
 /// resolve such conflicts:
@@ -39,13 +39,22 @@ class ModuleOp;
 ///   other symbol.
 ///   - If none of the 2 conflicting ops are spv.func, then rename either.
 ///
-/// In all cases, the references to the updated symbol are also updated to
-/// reflect the change.
+/// For deduplication, the following 3 cases are taken into consideration:
+///
+///   - If 2 spv.globalVariable's have either the same descriptor set + binding
+///   or the same build_in attribute value, then replace one of them using the
+///   other.
+///   - If 2 spv.specConstant's have the same spec_id attribute value, then
+///   replace one of them using the other.
+///   - If 2 spv.func's are identical replace one of them using the other.
+///
+/// In all cases, the references to the updated symbol (whether renamed or
+/// deduplicated) are also updated to reflect the change.
 ///
 /// \param modules the list of modules to combine. Input modules are not
 /// modified.
 /// \param combinedMdouleBuilder an OpBuilder to be used for
-/// building up the combined module.
+//                               building up the combined module.
 /// \param symbRenameListener a listener that gets called everytime a symbol in
 ///                           one of the input modules is renamed. The arguments
 ///                           passed to the listener are: the input

diff  --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
index 7687ab27e753..2df9e56a6940 100644
--- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
+++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
@@ -12,10 +12,12 @@
 
 #include "mlir/Dialect/SPIRV/ModuleCombiner.h"
 
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/SymbolTable.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/StringExtras.h"
 
 using namespace mlir;
@@ -59,6 +61,59 @@ static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
   return success();
 }
 
+template <typename KeyTy, typename SymbolOpTy>
+static SymbolOpTy
+emplaceOrGetReplacementSymbol(KeyTy key, SymbolOpTy symbolOp,
+                              DenseMap<KeyTy, SymbolOpTy> &deduplicationMap) {
+  auto result = deduplicationMap.try_emplace(key, symbolOp);
+
+  if (result.second)
+    return SymbolOpTy();
+
+  return result.first->second;
+}
+
+/// Computes a hash code to represent the argument SymbolOpInterface based on
+/// all the Op's attributes except for the symbol name.
+///
+/// \return the hash code computed from the Op's attributes as described above.
+///
+/// Note: We use the operation's name (not the symbol name) as part of the hash
+/// computation. This prevents, for example, mistakenly considering a global
+/// variable and a spec constant as duplicates because their descriptor set +
+/// binding and spec_id, repectively, happen to hash to the same value.
+static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
+  llvm::hash_code hashCode(0);
+  hashCode = llvm::hash_combine(symbolOp.getOperation()->getName());
+
+  for (auto attr : symbolOp.getOperation()->getAttrs()) {
+    if (attr.first == SymbolTable::getSymbolAttrName())
+      continue;
+    hashCode = llvm::hash_combine(hashCode, attr);
+  }
+
+  return hashCode;
+}
+
+/// Computes a hash code from the argument Block.
+llvm::hash_code computeHash(Block *block) {
+  // TODO: Consider extracting BlockEquivalenceData into a common header and
+  // re-using it here.
+  llvm::hash_code hash(0);
+
+  for (Operation &op : *block) {
+    // TODO: Properly handle operations with regions.
+    if (op.getNumRegions() > 0)
+      return 0;
+
+    hash = llvm::hash_combine(
+        hash, OperationEquivalence::computeHash(
+                  &op, OperationEquivalence::Flags::IgnoreOperands));
+  }
+
+  return hash;
+}
+
 namespace mlir {
 namespace spirv {
 
@@ -174,6 +229,48 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
       combinedModuleBuilder.insert(op.clone());
   }
 
+  // Deduplicate identical global variables, spec constants, and functions.
+  DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
+  SmallVector<SymbolOpInterface, 0> eraseList;
+
+  for (auto &op : combinedModule.getBlock().without_terminator()) {
+    llvm::hash_code hashCode(0);
+    SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
+
+    if (!symbolOp)
+      continue;
+
+    hashCode = computeHash(symbolOp);
+
+    // A 0 hash code means the op is not suitable for deduplication and should
+    // be skipped. An example of this is when a function has ops with regions
+    // which are not properly supported yet.
+    if (!hashCode)
+      continue;
+
+    if (auto funcOp = dyn_cast<FuncOp>(op))
+      for (auto &blk : funcOp)
+        hashCode = llvm::hash_combine(hashCode, computeHash(&blk));
+
+    SymbolOpInterface replacementSymOp =
+        emplaceOrGetReplacementSymbol(hashCode, symbolOp, hashToSymbolOp);
+
+    if (!replacementSymOp)
+      continue;
+
+    if (failed(SymbolTable::replaceAllSymbolUses(
+            symbolOp, replacementSymOp.getName(), combinedModule))) {
+      symbolOp.emitError("unable to update all symbol uses for ")
+          << symbolOp.getName() << " to " << replacementSymOp.getName();
+      return nullptr;
+    }
+
+    eraseList.push_back(symbolOp);
+  }
+
+  for (auto symbolOp : eraseList)
+    symbolOp.erase();
+
   return combinedModule;
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir
index 4d9727cc11b3..c41e7b9486e1 100644
--- a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir
+++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir
@@ -39,10 +39,12 @@ spv.module Logical GLSL450 {
 // CHECK-NEXT:     }
 
 // CHECK-NEXT:     spv.func @foo_1
+// CHECK-NEXT:       spv.FAdd
 // CHECK-NEXT:       spv.ReturnValue
 // CHECK-NEXT:     }
 
 // CHECK-NEXT:     spv.func @foo_2
+// CHECK-NEXT:       spv.ISub
 // CHECK-NEXT:       spv.ReturnValue
 // CHECK-NEXT:     }
 // CHECK-NEXT:   }
@@ -57,13 +59,15 @@ spv.module Logical GLSL450 {
 
 spv.module Logical GLSL450 {
   spv.func @foo(%arg0 : f32) -> f32 "None" {
-    spv.ReturnValue %arg0 : f32
+    %0 = spv.FAdd %arg0, %arg0 : f32
+    spv.ReturnValue %0 : f32
   }
 }
 
 spv.module Logical GLSL450 {
   spv.func @foo(%arg0 : i32) -> i32 "None" {
-    spv.ReturnValue %arg0 : i32
+    %0 = spv.ISub %arg0, %arg0 : i32
+    spv.ReturnValue %0 : i32
   }
 }
 }
@@ -578,9 +582,9 @@ spv.module Logical GLSL450 {
 
 // CHECK:      module {
 // CHECK-NEXT:   spv.module Logical GLSL450 {
-// CHECK-NEXT:     spv.globalVariable @foo_1
+// CHECK-NEXT:     spv.globalVariable @foo_1 bind(1, 0)
 
-// CHECK-NEXT:     spv.globalVariable @foo
+// CHECK-NEXT:     spv.globalVariable @foo bind(2, 0)
 // CHECK-NEXT: }
 
 module {
@@ -589,7 +593,26 @@ spv.module Logical GLSL450 {
 }
 
 spv.module Logical GLSL450 {
-  spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+  spv.globalVariable @foo bind(2, 0) : !spv.ptr<f32, Input>
+}
+}
+
+// -----
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.globalVariable @foo_1 built_in("GlobalInvocationId")
+
+// CHECK-NEXT:     spv.globalVariable @foo built_in("LocalInvocationId")
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.globalVariable @foo built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
+}
+
+spv.module Logical GLSL450 {
+  spv.globalVariable @foo built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
 }
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir
new file mode 100644
index 000000000000..e919bbaf17b7
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir
@@ -0,0 +1,244 @@
+// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// Deduplicate 2 global variables with the same descriptor set and binding.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.globalVariable @foo
+
+// CHECK-NEXT:     spv.func @use_foo
+// CHECK-NEXT:       spv.mlir.addressof @foo
+// CHECK-NEXT:       spv.Load
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @use_bar
+// CHECK-NEXT:       spv.mlir.addressof @foo
+// CHECK-NEXT:       spv.Load
+// CHECK-NEXT:       spv.FAdd
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+
+  spv.func @use_foo() -> f32 "None" {
+    %0 = spv.mlir.addressof @foo : !spv.ptr<f32, Input>
+    %1 = spv.Load "Input" %0 : f32
+    spv.ReturnValue %1 : f32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.globalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
+
+  spv.func @use_bar() -> f32 "None" {
+    %0 = spv.mlir.addressof @bar : !spv.ptr<f32, Input>
+    %1 = spv.Load "Input" %0 : f32
+    %2 = spv.FAdd %1, %1 : f32
+    spv.ReturnValue %2 : f32
+  }
+}
+}
+
+// -----
+
+// Deduplicate 2 global variables with the same descriptor set and binding but 
diff erent types.
+
+// CHECK:      module {
+// CHECK-NEXT: spv.module Logical GLSL450 {
+// CHECK-NEXT:   spv.globalVariable @foo bind(1, 0)
+
+// CHECK-NEXT:   spv.globalVariable @bar bind(1, 0)
+
+// CHECK-NEXT:   spv.func @use_bar
+// CHECK-NEXT:     spv.mlir.addressof @bar
+// CHECK-NEXT:     spv.Load
+// CHECK-NEXT:     spv.ReturnValue
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.globalVariable @foo bind(1, 0) : !spv.ptr<i32, Input>
+}
+
+spv.module Logical GLSL450 {
+  spv.globalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
+
+  spv.func @use_bar() -> f32 "None" {
+    %0 = spv.mlir.addressof @bar : !spv.ptr<f32, Input>
+    %1 = spv.Load "Input" %0 : f32
+    spv.ReturnValue %1 : f32
+  }
+}
+}
+
+// -----
+
+// Deduplicate 2 global variables with the same built-in attribute.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.globalVariable @foo built_in("GlobalInvocationId")
+// CHECK-NEXT:     spv.func @use_bar
+// CHECK-NEXT:       spv.mlir.addressof @foo
+// CHECK-NEXT:       spv.Load
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.globalVariable @foo built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
+}
+
+spv.module Logical GLSL450 {
+  spv.globalVariable @bar built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
+
+  spv.func @use_bar() -> vector<3xi32> "None" {
+    %0 = spv.mlir.addressof @bar : !spv.ptr<vector<3xi32>, Input>
+    %1 = spv.Load "Input" %0 : vector<3xi32>
+    spv.ReturnValue %1 : vector<3xi32>
+  }
+}
+}
+
+// -----
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.specConstant @foo spec_id(5)
+
+// CHECK-NEXT:     spv.func @use_foo()
+// CHECK-NEXT:       %0 = spv.mlir.referenceof @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @use_bar()
+// CHECK-NEXT:       %0 = spv.mlir.referenceof @foo
+// CHECK-NEXT:       spv.FAdd
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.specConstant @foo spec_id(5) = 1. : f32
+
+  spv.func @use_foo() -> (f32) "None" {
+    %0 = spv.mlir.referenceof @foo : f32
+    spv.ReturnValue %0 : f32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.specConstant @bar spec_id(5) = 1. : f32
+
+  spv.func @use_bar() -> (f32) "None" {
+    %0 = spv.mlir.referenceof @bar : f32
+    %1 = spv.FAdd %0, %0 : f32
+    spv.ReturnValue %1 : f32
+  }
+}
+}
+
+// -----
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.specConstant @bar spec_id(5)
+
+// CHECK-NEXT:     spv.func @foo(%arg0: f32)
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @foo_
diff erent_body(%arg0: f32)
+// CHECK-NEXT:       spv.mlir.referenceof
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @baz(%arg0: i32)
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @baz_no_return(%arg0: i32)
+// CHECK-NEXT:       spv.Return
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @baz_no_return_
diff erent_control
+// CHECK-NEXT:       spv.Return
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @baz_no_return_another_control
+// CHECK-NEXT:       spv.Return
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @kernel
+// CHECK-NEXT:       spv.Return
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @kernel_
diff erent_attr
+// CHECK-NEXT:       spv.Return
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT:   }
+
+module {
+spv.module Logical GLSL450 {
+  spv.specConstant @bar spec_id(5) = 1. : f32
+
+  spv.func @foo(%arg0: f32) -> (f32) "None" {
+    spv.ReturnValue %arg0 : f32
+  }
+
+  spv.func @foo_duplicate(%arg0: f32) -> (f32) "None" {
+    spv.ReturnValue %arg0 : f32
+  }
+
+  spv.func @foo_
diff erent_body(%arg0: f32) -> (f32) "None" {
+    %0 = spv.mlir.referenceof @bar : f32
+    spv.ReturnValue %arg0 : f32
+  }
+
+  spv.func @baz(%arg0: i32) -> (i32) "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+
+  spv.func @baz_no_return(%arg0: i32) "None" {
+    spv.Return
+  }
+
+  spv.func @baz_no_return_duplicate(%arg0: i32) -> () "None" {
+    spv.Return
+  }
+
+  spv.func @baz_no_return_
diff erent_control(%arg0: i32) -> () "Inline" {
+    spv.Return
+  }
+
+  spv.func @baz_no_return_another_control(%arg0: i32) -> () "Inline|Pure" {
+    spv.Return
+  }
+
+  spv.func @kernel(
+    %arg0: f32,
+    %arg1: !spv.ptr<!spv.struct<(!spv.array<12 x f32>)>, CrossWorkgroup>) "None"
+  attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
+    spv.Return
+  }
+
+  spv.func @kernel_
diff erent_attr(
+    %arg0: f32,
+    %arg1: !spv.ptr<!spv.struct<(!spv.array<12 x f32>)>, CrossWorkgroup>) "None"
+  attributes {spv.entry_point_abi = {local_size = dense<[64, 1, 1]> : vector<3xi32>}} {
+    spv.Return
+  }
+}
+}


        


More information about the Mlir-commits mailing list