[llvm] [mlir] [mlir] Add Normalize pass (PR #162266)
Shourya Goel via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 28 11:08:15 PDT 2025
================
@@ -0,0 +1,461 @@
+//===- Normalize.cpp - Conversion from MLIR to its canonical form ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/Normalize/Normalize.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <iomanip>
+#include <sstream>
+
+namespace mlir {
+#define GEN_PASS_DEF_NORMALIZE
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "normalize"
+
+namespace {
+/// NormalizePass aims to transform MLIR into it's normal form
+struct NormalizePass : public impl::NormalizeBase<NormalizePass> {
+ NormalizePass() = default;
+
+ void runOnOperation() override;
+
+private:
+ // Random constant for hashing, so the state isn't zero.
+ const uint64_t magicHashConstant = 0x6acaa36bef8325c5ULL;
+ void
+ collectOutputOperations(Block &block,
+ SmallVector<Operation *, 16> &outputs) const noexcept;
+ bool isOutput(Operation &op) const noexcept;
+ void reorderOperations(const SmallVector<Operation *, 16> &outputs);
+ void reorderOperation(Operation *used, Operation *user,
+ llvm::SmallPtrSet<const Operation *, 32> &visited);
+ void renameOperations(const SmallVector<Operation *, 16> &outputs);
+ void renameOperation(Operation *op,
+ SmallPtrSet<const Operation *, 32> &visited);
+ bool isInitialOperation(Operation *const op) const noexcept;
+ void
+ nameAsInitialOperation(Operation *op,
+ llvm::SmallPtrSet<const Operation *, 32> &visited);
+ void
+ nameAsRegularOperation(Operation *op,
+ llvm::SmallPtrSet<const Operation *, 32> &visited);
+ bool hasOnlyImmediateOperands(Operation *const op) const noexcept;
+ llvm::SetVector<int>
+ getOutputFootprint(Operation *op,
+ llvm::SmallPtrSet<const Operation *, 32> &visited) const;
+ void appendRenamedOperands(Operation *op, SmallString<512> &name);
+ void reorderOperationOperandsByName(Operation *op);
+ OpPrintingFlags flags{};
+};
+} // namespace
+
+/// Entry method to the NormalizePass
+void NormalizePass::runOnOperation() {
+ flags.printNameLocAsPrefix(true);
+
+ ModuleOp module = getOperation();
+
+ for (auto &op : module.getOps()) {
+ SmallVector<Operation *, 16> outputs;
+
+ for (auto ®ion : op.getRegions())
+ for (auto &block : region)
+ collectOutputOperations(block, outputs);
+
+ reorderOperations(outputs);
+ renameOperations(outputs);
+ }
+}
+
+void NormalizePass::renameOperations(
+ const SmallVector<Operation *, 16> &outputs) {
+ llvm::SmallPtrSet<const Operation *, 32> visited;
+
+ for (auto *op : outputs)
+ renameOperation(op, visited);
+}
+
+/// Renames operations graphically (recursive) in accordance with the
+/// def-use tree, starting from the initial operations (defs), finishing at
+/// the output (top-most user) operations.
+void NormalizePass::renameOperation(
+ Operation *op, SmallPtrSet<const Operation *, 32> &visited) {
+ if (!visited.count(op)) {
+ visited.insert(op);
+
+ if (isInitialOperation(op)) {
+ nameAsInitialOperation(op, visited);
+ } else {
+ nameAsRegularOperation(op, visited);
+ }
+ if (op->hasTrait<OpTrait::IsCommutative>())
+ reorderOperationOperandsByName(op);
+ }
+}
+
+/// Helper method checking whether a given operation has users and only
+/// immediate operands.
+bool NormalizePass::isInitialOperation(Operation *const op) const noexcept {
+ return !op->use_empty() and hasOnlyImmediateOperands(op);
+}
+
+/// Helper method checking whether all operands of a given operation has a
+/// ConstantLike OpTrait
+bool NormalizePass::hasOnlyImmediateOperands(
+ Operation *const op) const noexcept {
+ for (Value operand : op->getOperands())
+ if (Operation *defOp = operand.getDefiningOp())
+ if (!(defOp->hasTrait<OpTrait::ConstantLike>()))
+ return false;
+ return true;
+}
+
+std::string inline toString(uint64_t const hash) noexcept {
+ std::ostringstream oss;
+ oss << std::hex << std::setw(5) << std::setfill('0') << hash;
+ std::string tmp = oss.str();
+ return tmp.size() > 5 ? tmp.substr(tmp.size() - 5, 5) : tmp;
+}
+
+uint64_t inline strHash(std::string_view data) noexcept {
+ const static uint64_t FNV_OFFSET = 0xcbf29ce484222325ULL;
+ const static uint64_t FNV_PRIME = 0x100000001b3ULL;
+ uint64_t hash = FNV_OFFSET;
+ for (const auto &c : data) {
+ hash ^= static_cast<uint64_t>(c);
+ hash *= FNV_PRIME;
+ }
+ return hash;
+}
+
+std::string inline split(std::string_view str, const char &delimiter,
+ int indx = 0) noexcept {
+ std::stringstream ss{std::string{str}};
+ std::string item;
+ int cnt = 0;
+ while (std::getline(ss, item, delimiter)) {
+ if (cnt == indx) {
+ std::replace(item.begin(), item.end(), ':', '_');
+ return item;
+ } else {
+ cnt++;
+ }
+ }
+ return nullptr;
+}
+
+/// Names operation following the scheme:
+/// vl00000Callee$Operands$
+///
+/// Where 00000 is a hash calculated considering operation's opcode and output
+/// footprint. Callee's name is only included when operations's type is
+/// CallOp. If the operation has operands, the renaming is further handled
+/// in appendRenamedOperands, otherwise if it's a call operation with no
+/// arguments, void is appended, else a hash of the definition of the operation
+/// is appended.
+void NormalizePass::nameAsInitialOperation(
+ Operation *op, llvm::SmallPtrSet<const Operation *, 32> &visited) {
+
+ for (Value operand : op->getOperands())
+ if (Operation *defOp = operand.getDefiningOp())
+ renameOperation(defOp, visited);
+
+ uint64_t hash = magicHashConstant;
+
+ uint64_t opcodeHash = strHash(op->getName().getStringRef().str());
+ hash = llvm::hashing::detail::hash_16_bytes(hash, opcodeHash);
+
+ SmallPtrSet<const Operation *, 32> visitedLocal;
+ SetVector<int> outputFootprint = getOutputFootprint(op, visitedLocal);
+
+ for (const auto &output : outputFootprint)
+ hash = llvm::hashing::detail::hash_16_bytes(hash, output);
+
+ SmallString<512> name;
+ name.append("vl" + std::to_string(hash).substr(0, 5));
+
+ if (auto call = dyn_cast<func::CallOp>(op)) {
+ llvm::StringRef callee = call.getCallee();
+ name.append(callee.str());
+ }
+
+ if (op->getNumOperands() == 0) {
+ name.append("$");
+ if (auto call = dyn_cast<func::CallOp>(op)) {
+ name.append("void");
+ } else {
+ std::string textRepresentation;
+ AsmState state(op, flags);
+ llvm::raw_string_ostream stream(textRepresentation);
+ op->print(stream, state);
+ std::string hashStr = toString(strHash(split(stream.str(), '=', 1)));
+ name.append(hashStr);
+ }
+ name.append("$");
+
+ OpBuilder b(op->getContext());
+ StringAttr sat = b.getStringAttr(name);
+ Location newLoc = NameLoc::get(sat, op->getLoc());
+ op->setLoc(newLoc);
+
+ return;
+ }
+
+ appendRenamedOperands(op, name);
+}
+
+/// Names operation following the scheme:
+/// op00000Callee$Operands$
+///
+/// Where 00000 is a hash calculated considering operation's opcode and its
+/// operands opcode. Callee's name is only included when operations's type is
+/// CallOp. A regular operation must have operands, thus the renaming is further
+/// handled in appendRenamedOperands.
+void NormalizePass::nameAsRegularOperation(
+ Operation *op, llvm::SmallPtrSet<const Operation *, 32> &visited) {
+
+ for (Value operand : op->getOperands())
+ if (Operation *defOp = operand.getDefiningOp())
+ renameOperation(defOp, visited);
+
+ uint64_t hash = magicHashConstant;
+
+ uint64_t opcodeHash = strHash(op->getName().getStringRef().str());
+ hash = llvm::hashing::detail::hash_16_bytes(hash, opcodeHash);
+
+ SmallVector<uint64_t, 4> operandOpcodes;
+
+ for (Value operand : op->getOperands())
+ if (Operation *defOp = operand.getDefiningOp())
+ operandOpcodes.push_back(strHash(defOp->getName().getStringRef().str()));
+
+ if (op->hasTrait<OpTrait::IsCommutative>())
+ llvm::sort(operandOpcodes.begin(), operandOpcodes.end());
+
+ for (const uint64_t code : operandOpcodes)
+ hash = llvm::hashing::detail::hash_16_bytes(hash, code);
+
+ SmallString<512> name;
+ name.append("op" + std::to_string(hash).substr(0, 5));
+
+ if (auto call = dyn_cast<func::CallOp>(op)) {
+ llvm::StringRef callee = call.getCallee();
+ name.append(callee.str());
+ }
+
+ appendRenamedOperands(op, name);
+}
+
+bool inline startsWith(std::string_view base, std::string_view check) noexcept {
+ return base.size() >= check.size() &&
+ std::equal(check.begin(), check.end(), base.begin());
+}
+
+/// This function serves a dual purpose of appending the operands name in the
+/// operation while at the same time shortening it. Because of the recursive
+/// def-use chain traversal, the operands should already have been renamed and
+/// if they were an initial / regular operation, we truncate them by taking the
+/// first 7 characters of the renamed operand. The operand could also have been
+/// a block/function argument which is handled separately.
+void NormalizePass::appendRenamedOperands(Operation *op,
+ SmallString<512> &name) {
+ if (op->getNumOperands() == 0)
+ return;
+
+ SmallVector<std::string, 4> operands;
+
+ for (Value operand : op->getOperands()) {
+ if (Operation *defOp = operand.getDefiningOp()) {
+ std::string textRepresentation;
+ AsmState state(defOp, flags);
+ llvm::raw_string_ostream stream(textRepresentation);
+ defOp->print(stream, state);
+ auto operandName = split(stream.str(), '=', 0);
+
+ bool hasNormalName =
+ (startsWith(operandName, "%op") || startsWith(operandName, "%vl"));
+
+ if (hasNormalName) {
+ operands.push_back(operandName.substr(1, 7));
+ } else {
+ operands.push_back(operandName);
+ }
+ } else if (auto ba = dyn_cast<BlockArgument>(operand)) {
+ Block *ownerBlock = ba.getOwner();
+ unsigned argIndex = ba.getArgNumber();
+ if (auto func = dyn_cast<func::FuncOp>(ownerBlock->getParentOp())) {
+ if (&func.front() == ownerBlock) {
+ operands.push_back(std::string("funcArg" + std::to_string(argIndex)));
+ } else {
+ operands.push_back(
+ std::string("blockArg" + std::to_string(argIndex)));
+ }
+ } else {
+ operands.push_back(std::string("blockArg" + std::to_string(argIndex)));
+ }
+ }
+ }
+
+ if (op->hasTrait<OpTrait::IsCommutative>())
+ llvm::sort(operands.begin(), operands.end());
+
+ name.append("$");
+ for (size_t i = 0, size_ = operands.size(); i < size_; ++i) {
+ name.append(operands[i]);
+
+ if (i < size_ - 1)
+ name.append("-");
+ }
+ name.append("$");
+
+ OpBuilder b(op->getContext());
+ Location newLoc = NameLoc::get(b.getStringAttr(name), op->getLoc());
+ op->setLoc(newLoc);
+}
+
+/// Reorders operation's operands alphabetically. This method assumes
+/// that passed operation is commutative.
+void NormalizePass::reorderOperationOperandsByName(Operation *op) {
+ if (op->getNumOperands() == 0)
+ return;
+
+ SmallVector<std::pair<std::string, Value>, 4> operands;
+
+ for (Value operand : op->getOperands()) {
+ std::string textRepresentation;
+ llvm::raw_string_ostream stream(textRepresentation);
+ operand.printAsOperand(stream, flags);
+ operands.push_back({stream.str(), operand});
+ }
+
+ if (op->hasTrait<OpTrait::IsCommutative>()) {
+ llvm::sort(
+ operands.begin(), operands.end(), [](const auto &a, const auto &b) {
+ return llvm::StringRef(a.first).compare_insensitive(b.first) < 0;
+ });
+ }
+
+ for (size_t i = 0, size_ = operands.size(); i < size_; i++) {
+ op->setOperand(i, operands[i].second);
+ }
+}
+
+/// Reorders operations by walking up the tree from each operand of an output
+/// operation and reducing the def-use distance.
+void NormalizePass::reorderOperations(
+ const SmallVector<Operation *, 16> &outputs) {
+ llvm::SmallPtrSet<const Operation *, 32> visited;
+ for (auto *const op : outputs)
+ for (Value operand : op->getOperands())
+ if (Operation *defOp = operand.getDefiningOp())
+ reorderOperation(defOp, op, visited);
+}
+
+void NormalizePass::reorderOperation(
+ Operation *used, Operation *user,
+ llvm::SmallPtrSet<const Operation *, 32> &visited) {
+ if (!visited.count(used)) {
+ visited.insert(used);
+
+ Block *usedBlock = used->getBlock();
+ Block *userBlock = user->getBlock();
+
+ if (usedBlock == userBlock)
+ used->moveBefore(user);
+ else
+ used->moveBefore(&usedBlock->back());
+
+ for (Value operand : used->getOperands())
+ if (Operation *defOp = operand.getDefiningOp())
+ reorderOperation(defOp, used, visited);
+ }
+}
+
+void NormalizePass::collectOutputOperations(
+ Block &block, SmallVector<Operation *, 16> &outputs) const noexcept {
+ for (auto &innerOp : block)
+ if (isOutput(innerOp))
+ outputs.emplace_back(&innerOp);
+}
+
+/// The following Operations are termed as output:
+/// - Terminator operations are outputs
+/// - Any operation that implements MemoryEffectOpInterface and reports at
+/// least one MemoryEffects::Write effect is an output
+/// - func::CallOp is treated as an output (calls are conservatively assumed to
+/// possibly produce side effects).
+bool NormalizePass::isOutput(Operation &op) const noexcept {
+ if (op.hasTrait<OpTrait::IsTerminator>())
+ return true;
+
+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(&op)) {
+ SmallVector<MemoryEffects::EffectInstance, 4> effects;
+ memOp.getEffects(effects);
+ for (auto &effect : effects)
+ if (isa<MemoryEffects::Write>(effect.getEffect()))
+ return true;
+ }
+
+ if (auto call = dyn_cast<func::CallOp>(op))
+ return true;
+
+ return false;
+}
+
+/// Helper method returning indices (distance from the beginning of the basic
+/// block) of output operations using the given operation. Walks down the
+/// def-use tree recursively
+llvm::SetVector<int> NormalizePass::getOutputFootprint(
+ Operation *op, llvm::SmallPtrSet<const Operation *, 32> &visited) const {
+ llvm::SetVector<int> outputsVec;
+ if (!visited.count(op)) {
+ visited.insert(op);
+
+ if (isOutput(*op)) {
+ func::FuncOp func = op->getParentOfType<func::FuncOp>();
----------------
Sh0g0-1758 wrote:
But we need the enclosing function to get different footprint for output operations at different depth. If we measure the footprint from block, then two output operations at same depth from their enclosing block would have same footprint.
https://github.com/llvm/llvm-project/pull/162266
More information about the llvm-commits
mailing list