[Mlir-commits] [mlir] b12bcf3 - [MLIR] Add pass to deduplicate functions

Frederik Gossen llvmlistbot at llvm.org
Mon Feb 27 08:00:13 PST 2023


Author: Frederik Gossen
Date: 2023-02-27T10:59:53-05:00
New Revision: b12bcf3fb7fa7f968e8bd77b466057509ab2e04b

URL: https://github.com/llvm/llvm-project/commit/b12bcf3fb7fa7f968e8bd77b466057509ab2e04b
DIFF: https://github.com/llvm/llvm-project/commit/b12bcf3fb7fa7f968e8bd77b466057509ab2e04b.diff

LOG: [MLIR] Add pass to deduplicate functions

Deduplicate functions that are equivalent in all aspects but their symbol name.
The pass chooses one representative per equivalence class, erases the remainder, and updates function calls accordingly.

Differential Revision: https://reviews.llvm.org/D144738

Added: 
    mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
    mlir/test/Dialect/Func/duplicate-function-elimination.mlir

Modified: 
    mlir/include/mlir/Dialect/Func/Transforms/Passes.h
    mlir/include/mlir/Dialect/Func/Transforms/Passes.td
    mlir/lib/Dialect/Func/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
index 10968d4ad9c33..011ad3e3d0be4 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h
@@ -32,6 +32,9 @@ namespace func {
 /// Creates an instance of func bufferization pass.
 std::unique_ptr<Pass> createFuncBufferizePass();
 
+/// Pass to deduplicate functions.
+std::unique_ptr<Pass> createDuplicateFunctionEliminationPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
index 54fe4fdd6bbe6..8f6dbcb1ee653 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td
@@ -40,4 +40,15 @@ def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
                            "memref::MemRefDialect"];
 }
 
+def DuplicateFunctionEliminationPass : Pass<"duplicate-function-elimination",
+    "ModuleOp"> {
+  let summary = "Deduplicate functions";
+  let description = [{
+    Deduplicate functions that are equivalent in all aspects but their symbol
+    name. The pass chooses one representative per equivalence class, erases
+    the remainder, and updates function calls accordingly.
+  }];
+  let constructor = "mlir::func::createDuplicateFunctionEliminationPass()";
+}
+
 #endif // MLIR_DIALECT_FUNC_TRANSFORMS_PASSES_TD

diff  --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
index 5de2fb8bbe9f2..9a5b38ba6ea2c 100644
--- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRFuncTransforms
   DecomposeCallGraphTypes.cpp
+  DuplicateFunctionElimination.cpp
   FuncBufferize.cpp
   FuncConversions.cpp
 

diff  --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
new file mode 100644
index 0000000000000..b83d67e2ef14a
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
@@ -0,0 +1,124 @@
+//===- DuplicateFunctionElimination.cpp - Duplicate function elimination --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/Passes.h"
+
+namespace mlir {
+namespace {
+
+#define GEN_PASS_DEF_DUPLICATEFUNCTIONELIMINATIONPASS
+#include "mlir/Dialect/Func/Transforms/Passes.h.inc"
+
+// Define a notion of function equivalence that allows for reuse. Ignore the
+// symbol name for this purpose.
+struct DuplicateFuncOpEquivalenceInfo
+    : public llvm::DenseMapInfo<func::FuncOp> {
+
+  static unsigned getHashValue(const func::FuncOp cFunc) {
+    if (!cFunc) {
+      return DenseMapInfo<func::FuncOp>::getHashValue(cFunc);
+    }
+
+    // Aggregate attributes, ignoring the symbol name.
+    llvm::hash_code hash = {};
+    func::FuncOp func = const_cast<func::FuncOp &>(cFunc);
+    StringAttr symNameAttrName = func.getSymNameAttrName();
+    for (NamedAttribute namedAttr : cFunc->getAttrs()) {
+      StringAttr attrName = namedAttr.getName();
+      if (attrName == symNameAttrName)
+        continue;
+      hash = llvm::hash_combine(hash, namedAttr);
+    }
+
+    // Also hash the func body.
+    func.getBody().walk([&](Operation *op) {
+      hash = llvm::hash_combine(
+          hash, OperationEquivalence::computeHash(
+                    op, /*hashOperands=*/OperationEquivalence::ignoreHashValue,
+                    /*hashResults=*/OperationEquivalence::ignoreHashValue,
+                    OperationEquivalence::IgnoreLocations));
+    });
+
+    return hash;
+  }
+
+  static bool isEqual(const func::FuncOp cLhs, const func::FuncOp cRhs) {
+    if (cLhs == cRhs) {
+      return true;
+    }
+    if (cLhs == getTombstoneKey() || cLhs == getEmptyKey() ||
+        cRhs == getTombstoneKey() || cRhs == getEmptyKey()) {
+      return false;
+    }
+
+    // Check attributes equivalence, ignoring the symbol name.
+    if (cLhs->getAttrDictionary().size() != cRhs->getAttrDictionary().size()) {
+      return false;
+    }
+    func::FuncOp lhs = const_cast<func::FuncOp &>(cLhs);
+    StringAttr symNameAttrName = lhs.getSymNameAttrName();
+    for (NamedAttribute namedAttr : cLhs->getAttrs()) {
+      StringAttr attrName = namedAttr.getName();
+      if (attrName == symNameAttrName) {
+        continue;
+      }
+      if (namedAttr.getValue() != cRhs->getAttr(attrName)) {
+        return false;
+      }
+    }
+
+    // Compare inner workings.
+    func::FuncOp rhs = const_cast<func::FuncOp &>(cRhs);
+    return OperationEquivalence::isRegionEquivalentTo(
+        &lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations);
+  }
+};
+
+struct DuplicateFunctionEliminationPass
+    : public impl::DuplicateFunctionEliminationPassBase<
+          DuplicateFunctionEliminationPass> {
+
+  using DuplicateFunctionEliminationPassBase<
+      DuplicateFunctionEliminationPass>::DuplicateFunctionEliminationPassBase;
+
+  void runOnOperation() override {
+    auto module = getOperation();
+
+    // Find unique representant per equivalent func ops.
+    DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps;
+    DenseMap<StringAttr, func::FuncOp> getRepresentant;
+    DenseSet<func::FuncOp> toBeErased;
+    module.walk([&](func::FuncOp f) {
+      auto [repr, inserted] = uniqueFuncOps.insert(f);
+      getRepresentant[f.getSymNameAttr()] = *repr;
+      if (!inserted) {
+        toBeErased.insert(f);
+      }
+    });
+
+    // Update call ops to call unique func op representants.
+    module.walk([&](func::CallOp callOp) {
+      func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()];
+      callOp.setCallee(callee.getSymName());
+    });
+
+    // Erase redundant func ops.
+    for (auto it : toBeErased) {
+      it.erase();
+    }
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::func::createDuplicateFunctionEliminationPass() {
+  return std::make_unique<DuplicateFunctionEliminationPass>();
+}
+
+} // namespace mlir

diff  --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
new file mode 100644
index 0000000000000..acf2bfb97cdb9
--- /dev/null
+++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
@@ -0,0 +1,367 @@
+// RUN: mlir-opt %s --split-input-file --duplicate-function-elimination | \
+// RUN: FileCheck %s
+
+func.func @identity(%arg0: tensor<f32>) -> tensor<f32> {
+  return %arg0 : tensor<f32>
+}
+
+func.func @also_identity(%arg0: tensor<f32>) -> tensor<f32> {
+  return %arg0 : tensor<f32>
+}
+
+func.func @yet_another_identity(%arg0: tensor<f32>) -> tensor<f32> {
+  return %arg0 : tensor<f32>
+}
+
+func.func @user(%arg0: tensor<f32>) -> tensor<f32> {
+  %0 = call @identity(%arg0) : (tensor<f32>) -> tensor<f32>
+  %1 = call @also_identity(%0) : (tensor<f32>) -> tensor<f32>
+  %2 = call @yet_another_identity(%1) : (tensor<f32>) -> tensor<f32>
+  return %2 : tensor<f32>
+}
+
+// CHECK:     @identity
+// CHECK-NOT: @also_identity
+// CHECK-NOT: @yet_another_identity
+// CHECK:     @user
+// CHECK-3:     call @identity
+
+// -----
+
+func.func @add_lr(%arg0: f32, %arg1: f32) -> f32 {
+  %0 = arith.addf %arg0, %arg1 : f32
+  return %0 : f32
+}
+
+func.func @also_add_lr(%arg0: f32, %arg1: f32) -> f32 {
+  %0 = arith.addf %arg0, %arg1 : f32
+  return %0 : f32
+}
+
+func.func @add_rl(%arg0: f32, %arg1: f32) -> f32 {
+  %0 = arith.addf %arg1, %arg0 : f32
+  return %0 : f32
+}
+
+func.func @also_add_rl(%arg0: f32, %arg1: f32) -> f32 {
+  %0 = arith.addf %arg1, %arg0 : f32
+  return %0 : f32
+}
+
+func.func @user(%arg0: f32, %arg1: f32) -> f32 {
+  %0 = call @add_lr(%arg0, %arg1) : (f32, f32) -> f32
+  %1 = call @also_add_lr(%arg0, %arg1) : (f32, f32) -> f32
+  %2 = call @add_rl(%0, %1) : (f32, f32) -> f32
+  %3 = call @also_add_rl(%arg0, %2) : (f32, f32) -> f32
+ return %3 : f32
+}
+
+// CHECK:     @add_lr
+// CHECK-NOT: @also_add_lr
+// CHECK-NOT: @add_rl
+// CHECK-NOT: @also_add_rl
+// CHECK:     @user
+// CHECK-4:     call @add_lr
+
+// -----
+
+func.func @ite(%pred: i1, %then: f32, %else: f32) -> f32 {
+  %0 = scf.if %pred -> f32 {
+    scf.yield %then : f32
+  } else {
+    scf.yield %else : f32
+  }
+  return %0 : f32
+}
+
+func.func @also_ite(%pred: i1, %then: f32, %else: f32) -> f32 {
+  %0 = scf.if %pred -> f32 {
+    scf.yield %then : f32
+  } else {
+    scf.yield %else : f32
+  }
+  return %0 : f32
+}
+
+func.func @reverse_ite(%pred: i1, %then: f32, %else: f32) -> f32 {
+  %0 = scf.if %pred -> f32 {
+    scf.yield %else : f32
+  } else {
+    scf.yield %then : f32
+  }
+  return %0 : f32
+}
+
+func.func @user(%pred : i1, %arg0: f32, %arg1: f32) -> f32 {
+  %0 = call @also_ite(%pred, %arg0, %arg1) : (i1, f32, f32) -> f32
+  %1 = call @ite(%pred, %arg0, %arg1) : (i1, f32, f32) -> f32
+  %2 = call @reverse_ite(%pred, %0, %1) : (i1, f32, f32) -> f32
+ return %2 : f32
+}
+
+// CHECK:     @ite
+// CHECK-NOT: @also_ite
+// CHECK:     @reverse_ite
+// CHECK:     @user
+// CHECK-2:     call @ite
+// CHECK:       call @reverse_ite
+
+// -----
+
+func.func @deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32, %odd: f32) 
+    -> f32 {
+  %0 = scf.if %p0 -> f32 {
+    %1 = scf.if %p1 -> f32 {
+      %2 = scf.if %p2 -> f32 {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %even : f32
+        } else {
+          scf.yield %odd : f32
+        }
+        scf.yield %3 : f32
+      } else {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %odd : f32
+        } else {
+          scf.yield %even : f32
+        }
+        scf.yield %3 : f32
+      }
+      scf.yield %2 : f32
+    } else {
+      %2 = scf.if %p2 -> f32 {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %odd : f32
+        } else {
+          scf.yield %even : f32
+        }
+        scf.yield %3 : f32
+      } else {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %even : f32
+        } else {
+          scf.yield %odd : f32
+        }
+        scf.yield %3 : f32
+      }
+      scf.yield %2 : f32
+    }
+    scf.yield %1 : f32
+  } else {
+    %1 = scf.if %p1 -> f32 {
+      %2 = scf.if %p2 -> f32 {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %odd : f32
+        } else {
+          scf.yield %even : f32
+        }
+        scf.yield %3 : f32
+      } else {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %even : f32
+        } else {
+          scf.yield %odd : f32
+        }
+        scf.yield %3 : f32
+      }
+      scf.yield %2 : f32
+    } else {
+      %2 = scf.if %p2 -> f32 {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %even : f32
+        } else {
+          scf.yield %odd : f32
+        }
+        scf.yield %3 : f32
+      } else {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %odd : f32
+        } else {
+          scf.yield %even : f32
+        }
+        scf.yield %3 : f32
+      }
+      scf.yield %2 : f32
+    }
+    scf.yield %1 : f32
+  }
+  return %0 : f32
+}
+
+func.func @also_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32, 
+    %odd: f32) -> f32 {
+  %0 = scf.if %p0 -> f32 {
+    %1 = scf.if %p1 -> f32 {
+      %2 = scf.if %p2 -> f32 {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %even : f32
+        } else {
+          scf.yield %odd : f32
+        }
+        scf.yield %3 : f32
+      } else {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %odd : f32
+        } else {
+          scf.yield %even : f32
+        }
+        scf.yield %3 : f32
+      }
+      scf.yield %2 : f32
+    } else {
+      %2 = scf.if %p2 -> f32 {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %odd : f32
+        } else {
+          scf.yield %even : f32
+        }
+        scf.yield %3 : f32
+      } else {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %even : f32
+        } else {
+          scf.yield %odd : f32
+        }
+        scf.yield %3 : f32
+      }
+      scf.yield %2 : f32
+    }
+    scf.yield %1 : f32
+  } else {
+    %1 = scf.if %p1 -> f32 {
+      %2 = scf.if %p2 -> f32 {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %odd : f32
+        } else {
+          scf.yield %even : f32
+        }
+        scf.yield %3 : f32
+      } else {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %even : f32
+        } else {
+          scf.yield %odd : f32
+        }
+        scf.yield %3 : f32
+      }
+      scf.yield %2 : f32
+    } else {
+      %2 = scf.if %p2 -> f32 {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %even : f32
+        } else {
+          scf.yield %odd : f32
+        }
+        scf.yield %3 : f32
+      } else {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %odd : f32
+        } else {
+          scf.yield %even : f32
+        }
+        scf.yield %3 : f32
+      }
+      scf.yield %2 : f32
+    }
+    scf.yield %1 : f32
+  }
+  return %0 : f32
+}
+
+func.func @reverse_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32, 
+    %odd: f32) -> f32 {
+  %0 = scf.if %p0 -> f32 {
+    %1 = scf.if %p1 -> f32 {
+      %2 = scf.if %p2 -> f32 {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %odd : f32
+        } else {
+          scf.yield %even : f32
+        }
+        scf.yield %3 : f32
+      } else {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %even : f32
+        } else {
+          scf.yield %odd : f32
+        }
+        scf.yield %3 : f32
+      }
+      scf.yield %2 : f32
+    } else {
+      %2 = scf.if %p2 -> f32 {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %even : f32
+        } else {
+          scf.yield %odd : f32
+        }
+        scf.yield %3 : f32
+      } else {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %odd : f32
+        } else {
+          scf.yield %even : f32
+        }
+        scf.yield %3 : f32
+      }
+      scf.yield %2 : f32
+    }
+    scf.yield %1 : f32
+  } else {
+    %1 = scf.if %p1 -> f32 {
+      %2 = scf.if %p2 -> f32 {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %even : f32
+        } else {
+          scf.yield %odd : f32
+        }
+        scf.yield %3 : f32
+      } else {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %odd : f32
+        } else {
+          scf.yield %even : f32
+        }
+        scf.yield %3 : f32
+      }
+      scf.yield %2 : f32
+    } else {
+      %2 = scf.if %p2 -> f32 {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %odd : f32
+        } else {
+          scf.yield %even : f32
+        }
+        scf.yield %3 : f32
+      } else {
+        %3 = scf.if %p3 -> f32 {
+          scf.yield %even : f32
+        } else {
+          scf.yield %odd : f32
+        }
+        scf.yield %3 : f32
+      }
+      scf.yield %2 : f32
+    }
+    scf.yield %1 : f32
+  }
+  return %0 : f32
+}
+
+func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32) 
+    -> (f32, f32, f32) {
+  %0 = call @deep_tree(%p0, %p1, %p2, %p3, %odd, %even) 
+      : (i1, i1, i1, i1, f32, f32) -> f32
+  %1 = call @also_deep_tree(%p0, %p1, %p2, %p3, %odd, %even) 
+      : (i1, i1, i1, i1, f32, f32) -> f32
+  %2 = call @reverse_deep_tree(%p0, %p1, %p2, %p3, %odd, %even) 
+      : (i1, i1, i1, i1, f32, f32) -> f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// CHECK:     @deep_tree
+// CHECK-NOT: @also_deep_tree
+// CHECK:     @reverse_deep_tree
+// CHECK:     @user
+// CHECK-2:     call @deep_tree
+// CHECK:       call @reverse_deep_tree


        


More information about the Mlir-commits mailing list