[llvm] [mlir] [mlir] Add Normalize pass (PR #162266)
Shourya Goel via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 21 05:57:17 PDT 2025
================
@@ -0,0 +1,436 @@
+//===- Normalize.cpp - Conversion from MLIR to its canonical form ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/Normalize/Normalize.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <iomanip>
+#include <sstream>
+
+namespace mlir {
+#define GEN_PASS_DEF_NORMALIZE
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "normalize"
+
+namespace {
+struct NormalizePass : public impl::NormalizeBase<NormalizePass> {
+ NormalizePass() = default;
+
+ void runOnOperation() override;
+
+private:
+ const uint64_t MagicHashConstant = 0x6acaa36bef8325c5ULL;
+ void
+ collectOutputOperations(Block &block,
+ SmallVector<Operation *, 16> &Output) const noexcept;
+ bool isOutput(mlir::Operation &op) const noexcept;
+
+ void reorderOperations(const SmallVector<mlir::Operation *, 16> &Outputs);
+ void
+ reorderOperation(mlir::Operation *used, mlir::Operation *user,
+ llvm::SmallPtrSet<const mlir::Operation *, 32> &visited);
+
+ void renameOperations(const SmallVector<Operation *, 16> &Outputs);
+ void RenameOperation(mlir::Operation *op,
+ SmallPtrSet<const mlir::Operation *, 32> &visited);
+
+ bool isInitialOperation(mlir::Operation *const op) const noexcept;
+ void nameAsInitialOperation(
+ mlir::Operation *op,
+ llvm::SmallPtrSet<const mlir::Operation *, 32> &visited);
+ void nameAsRegularOperation(
+ mlir::Operation *op,
+ llvm::SmallPtrSet<const mlir::Operation *, 32> &visited);
+ bool hasOnlyImmediateOperands(mlir::Operation *const op) const noexcept;
+ llvm::SetVector<int> getOutputFootprint(
+ mlir::Operation *op,
+ llvm::SmallPtrSet<const mlir::Operation *, 32> &visited) const;
+ void foldOperation(mlir::Operation *op);
+ void reorderOperationOperandsByName(mlir::Operation *op);
+ mlir::OpPrintingFlags flags{};
+};
+} // namespace
+
+void NormalizePass::runOnOperation() {
+ flags.printNameLocAsPrefix(true);
+
+ ModuleOp module = getOperation();
+
+ for (auto &op : module.getOps()) {
+ SmallVector<Operation *, 16> Outputs;
+
+ for (auto ®ion : op.getRegions())
+ for (auto &block : region)
+ collectOutputOperations(block, Outputs);
+
+ reorderOperations(Outputs);
+ renameOperations(Outputs);
+ }
+}
+
+void NormalizePass::renameOperations(
+ const SmallVector<Operation *, 16> &Outputs) {
+ llvm::SmallPtrSet<const mlir::Operation *, 32> visited;
+
+ for (auto *op : Outputs)
+ RenameOperation(op, visited);
+}
+
+void NormalizePass::RenameOperation(
+ Operation *op, SmallPtrSet<const mlir::Operation *, 32> &visited) {
+ if (!visited.count(op)) {
+ visited.insert(op);
+
+ if (isInitialOperation(op)) {
+ nameAsInitialOperation(op, visited);
+ } else {
+ nameAsRegularOperation(op, visited);
+ }
+ foldOperation(op);
+ reorderOperationOperandsByName(op);
+ }
+}
+
+bool NormalizePass::isInitialOperation(
+ mlir::Operation *const op) const noexcept {
+ return !op->use_empty() and hasOnlyImmediateOperands(op);
+}
+
+bool NormalizePass::hasOnlyImmediateOperands(
+ mlir::Operation *const op) const noexcept {
+ for (mlir::Value operand : op->getOperands())
+ if (mlir::Operation *defOp = operand.getDefiningOp())
+ if (!(defOp->hasTrait<OpTrait::ConstantLike>()))
+ return false;
+ return true;
+}
+
+std::string inline to_string(uint64_t const hash) noexcept {
+ std::ostringstream oss;
+ oss << std::hex << std::setw(5) << std::setfill('0') << hash;
+ std::string tmp = oss.str();
+ return tmp.size() > 5 ? tmp.substr(tmp.size() - 5, 5) : tmp;
+}
+
+uint64_t inline strHash(std::string_view data) noexcept {
+ const static uint64_t FNV_OFFSET = 0xcbf29ce484222325ULL;
+ const static uint64_t FNV_PRIME = 0x100000001b3ULL;
+ uint64_t hash = FNV_OFFSET;
+ for (const auto &c : data) {
+ hash ^= static_cast<uint64_t>(c);
+ hash *= FNV_PRIME;
+ }
+ return hash;
+}
+
+std::string inline split(std::string_view str, const char &delimiter,
+ int indx = 0) noexcept {
+ std::stringstream ss{std::string{str}};
+ std::string item;
+ int cnt = 0;
+ while (std::getline(ss, item, delimiter)) {
+ if (cnt == indx) {
+ std::replace(item.begin(), item.end(), ':', '_');
+ return item;
+ } else {
+ cnt++;
+ }
+ }
+ return nullptr;
+}
+
+void NormalizePass::nameAsInitialOperation(
+ mlir::Operation *op,
+ llvm::SmallPtrSet<const mlir::Operation *, 32> &visited) {
+
+ for (mlir::Value operand : op->getOperands())
+ if (mlir::Operation *defOp = operand.getDefiningOp())
+ RenameOperation(defOp, visited);
+
+ uint64_t Hash = MagicHashConstant;
+
+ uint64_t opcodeHash = strHash(op->getName().getStringRef().str());
+ Hash = llvm::hashing::detail::hash_16_bytes(Hash, opcodeHash);
+
+ SmallPtrSet<const mlir::Operation *, 32> Visited;
+ SetVector<int> OutputFootprint = getOutputFootprint(op, Visited);
+
+ for (const int &Output : OutputFootprint)
+ Hash = llvm::hashing::detail::hash_16_bytes(Hash, Output);
+
+ std::string Name{""};
+ Name.append("vl" + std::to_string(Hash).substr(0, 5));
+
+ if (auto call = mlir::dyn_cast<mlir::func::CallOp>(op)) {
+ llvm::StringRef callee = call.getCallee();
+ Name.append(callee.str());
+ }
+
+ if (op->getNumOperands() == 0) {
+ Name.append("$");
+ if (auto call = mlir::dyn_cast<mlir::func::CallOp>(op)) {
+ Name.append("void");
+ } else {
+ std::string TextRepresentation;
+ mlir::AsmState state(op, flags);
+ llvm::raw_string_ostream Stream(TextRepresentation);
+ op->print(Stream, state);
+ std::string hash = to_string(strHash(split(Stream.str(), '=', 1)));
+ Name.append(hash);
+ }
+ Name.append("$");
+ }
+
+ mlir::OpBuilder b(op->getContext());
+ mlir::StringAttr sat = b.getStringAttr(Name);
+ mlir::Location newLoc = mlir::NameLoc::get(sat, op->getLoc());
+ op->setLoc(newLoc);
+}
+
+void NormalizePass::nameAsRegularOperation(
+ mlir::Operation *op,
+ llvm::SmallPtrSet<const mlir::Operation *, 32> &visited) {
+
+ for (mlir::Value operand : op->getOperands())
+ if (mlir::Operation *defOp = operand.getDefiningOp())
+ RenameOperation(defOp, visited);
+
+ uint64_t Hash = MagicHashConstant;
+
+ uint64_t opcodeHash = strHash(op->getName().getStringRef().str());
+ Hash = llvm::hashing::detail::hash_16_bytes(Hash, opcodeHash);
+
+ SmallVector<uint64_t, 4> OperandsOpcodes;
+
+ for (mlir::Value operand : op->getOperands())
+ if (mlir::Operation *defOp = operand.getDefiningOp())
+ OperandsOpcodes.push_back(strHash(defOp->getName().getStringRef().str()));
+
+ if (op->hasTrait<OpTrait::IsCommutative>())
+ llvm::sort(OperandsOpcodes.begin(), OperandsOpcodes.end());
+
+ for (const uint64_t Code : OperandsOpcodes)
+ Hash = llvm::hashing::detail::hash_16_bytes(Hash, Code);
+
+ SmallString<512> Name;
+ Name.append("op" + std::to_string(Hash).substr(0, 5));
+
+ if (auto call = mlir::dyn_cast<mlir::func::CallOp>(op)) {
+ llvm::StringRef callee = call.getCallee();
+ Name.append(callee.str());
+ }
+
+ mlir::OpBuilder b(op->getContext());
+ mlir::StringAttr sat = b.getStringAttr(Name);
+ mlir::Location newLoc = mlir::NameLoc::get(sat, op->getLoc());
+ op->setLoc(newLoc);
+}
+
+bool inline starts_with(std::string_view base,
+ std::string_view check) noexcept {
+ return base.size() >= check.size() &&
+ std::equal(check.begin(), check.end(), base.begin());
+}
+
+void NormalizePass::foldOperation(mlir::Operation *op) {
+ if (isOutput(*op) || op->getNumOperands() == 0)
+ return;
+
+ std::string TextRepresentation;
+ mlir::AsmState state(op, flags);
+ llvm::raw_string_ostream Stream(TextRepresentation);
+ op->print(Stream, state);
+
+ auto opName = split(Stream.str(), '=', 0);
+ if (!starts_with(opName, "%op") && !starts_with(opName, "%vl"))
+ return;
+
+ SmallVector<std::string, 4> Operands;
+
+ for (mlir::Value operand : op->getOperands()) {
+ if (mlir::Operation *defOp = operand.getDefiningOp()) {
+ std::string TextRepresentation;
+ mlir::AsmState state(defOp, flags);
+ llvm::raw_string_ostream Stream(TextRepresentation);
+ defOp->print(Stream, state);
+ auto name = split(Stream.str(), '=', 0);
+
+ bool hasNormalName =
+ (starts_with(name, "%op") || starts_with(name, "%vl"));
+
+ if (hasNormalName) {
+ Operands.push_back(name.substr(1, 7));
+ } else {
+ Operands.push_back(name);
+ }
+ } else if (auto ba = dyn_cast<mlir::BlockArgument>(operand)) {
+ mlir::Block *ownerBlock = ba.getOwner();
+ unsigned argIndex = ba.getArgNumber();
+ if (auto func = dyn_cast<mlir::func::FuncOp>(ownerBlock->getParentOp())) {
+ if (&func.front() == ownerBlock) {
+ Operands.push_back(std::string("funcArg" + std::to_string(argIndex)));
+ } else {
+ Operands.push_back(
+ std::string("blockArg" + std::to_string(argIndex)));
+ }
+ } else {
+ Operands.push_back(std::string("blockArg" + std::to_string(argIndex)));
+ }
+ }
+ }
+
+ if (op->hasTrait<OpTrait::IsCommutative>())
+ llvm::sort(Operands.begin(), Operands.end());
+
+ SmallString<512> Name;
+ Name.append(opName.substr(1, 7));
+
+ Name.append("$");
+ for (unsigned long i = 0; i < Operands.size(); ++i) {
+ Name.append(Operands[i]);
+
+ if (i < Operands.size() - 1)
+ Name.append("-");
+ }
+ Name.append("$");
+
+ mlir::OpBuilder b(op->getContext());
+ mlir::StringAttr sat = b.getStringAttr(Name);
+ mlir::Location newLoc = mlir::NameLoc::get(sat, op->getLoc());
+ op->setLoc(newLoc);
+}
+
+void NormalizePass::reorderOperationOperandsByName(mlir::Operation *op) {
+ if (op->getNumOperands() == 0)
+ return;
+
+ SmallVector<std::pair<std::string, mlir::Value>, 4> Operands;
+
+ for (mlir::Value operand : op->getOperands()) {
+ std::string TextRepresentation;
+ llvm::raw_string_ostream Stream(TextRepresentation);
+ operand.printAsOperand(Stream, flags);
+ Operands.push_back({Stream.str(), operand});
+ }
+
+ if (op->hasTrait<OpTrait::IsCommutative>()) {
+ llvm::sort(
+ Operands.begin(), Operands.end(), [](const auto &a, const auto &b) {
+ return llvm::StringRef(a.first).compare_insensitive(b.first) < 0;
+ });
+ }
+
+ for (size_t i = 0; i < Operands.size(); i++) {
+ op->setOperand(i, Operands[i].second);
+ }
+}
+
+void NormalizePass::reorderOperations(
+ const SmallVector<Operation *, 16> &Outputs) {
+ llvm::SmallPtrSet<const mlir::Operation *, 32> visited;
+ for (auto *const op : Outputs)
+ for (mlir::Value operand : op->getOperands())
+ if (mlir::Operation *defOp = operand.getDefiningOp())
+ reorderOperation(defOp, op, visited);
+}
+
+void NormalizePass::reorderOperation(
+ mlir::Operation *used, mlir::Operation *user,
+ llvm::SmallPtrSet<const mlir::Operation *, 32> &visited) {
+ if (!visited.count(used)) {
+ visited.insert(used);
+
+ mlir::Block *usedBlock = used->getBlock();
+ mlir::Block *userBlock = user->getBlock();
+
+ if (usedBlock == userBlock)
+ used->moveBefore(user);
+ else
+ used->moveBefore(&usedBlock->back());
+
+ for (mlir::Value operand : used->getOperands())
+ if (mlir::Operation *defOp = operand.getDefiningOp())
+ reorderOperation(defOp, used, visited);
+ }
+}
+
+void NormalizePass::collectOutputOperations(
+ Block &block, SmallVector<Operation *, 16> &Outputs) const noexcept {
+ for (auto &innerOp : block)
+ if (isOutput(innerOp))
+ Outputs.emplace_back(&innerOp);
+}
+
+bool NormalizePass::isOutput(Operation &op) const noexcept {
+ if (op.hasTrait<OpTrait::IsTerminator>())
+ return true;
+
+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(&op)) {
----------------
Sh0g0-1758 wrote:
Can you check the updated description once and let me know if its still not clear ?
https://github.com/llvm/llvm-project/pull/162266
More information about the llvm-commits
mailing list