[Mlir-commits] [mlir] [mlir] Add normalize pass to MLIR (PR #186647)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 15 01:31:17 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: lonely eagle (linuxlonelyeagle)

<details>
<summary>Changes</summary>

Add normalize pass to MLIR, It is similar to llvm-canon(https://llvm.org/devmtg/2019-10/slides/Paszkowski-LLVMCanon.pdf) and llvm's normalize pass(https://llvm.org/docs/Passes.html#normalize-transforms-ir-into-a-normal-form-that-s-easier-to-diff). However, it does not implement name renaming, as MLIR has its own canonical naming conventions.

---
Full diff: https://github.com/llvm/llvm-project/pull/186647.diff


5 Files Affected:

- (modified) mlir/include/mlir/Transforms/Passes.h (+1) 
- (modified) mlir/include/mlir/Transforms/Passes.td (+8) 
- (modified) mlir/lib/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Transforms/Normalize.cpp (+112) 
- (added) mlir/test/Transforms/normalize.mlir (+84) 


``````````diff
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 4804b023a8f79..e962a5659f710 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -41,6 +41,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_LOOPINVARIANTSUBSETHOISTINGPASS
 #define GEN_PASS_DECL_INLINERPASS
 #define GEN_PASS_DECL_MEM2REG
+#define GEN_PASS_DECL_NORMALIZEPASS
 #define GEN_PASS_DECL_PRINTIRPASS
 #define GEN_PASS_DECL_PRINTOPSTATSPASS
 #define GEN_PASS_DECL_REMOVEDEADVALUESPASS
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 20af90e56ee67..4dfad8dcaa703 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -630,4 +630,12 @@ def BubbleDownMemorySpaceCasts :
   }];
 }
 
+def NormalizePass : InterfacePass<"normalize", "FunctionOpInterface"> {
+  let summary = "Transforms IR into a normal form that's easier to diff.";
+  let description = [{
+    This pass attempts to relocate the defining ops of operands for any
+    side-effecting or terminator operation to their nearest dominating positions.
+  }];
+}
+
 #endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 8907724627386..e305b934c74cb 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_library(MLIRTransforms
   LocationSnapshot.cpp
   LoopInvariantCodeMotion.cpp
   Mem2Reg.cpp
+  Normalize.cpp
   OpStats.cpp
   PrintIR.cpp
   RemoveDeadValues.cpp
diff --git a/mlir/lib/Transforms/Normalize.cpp b/mlir/lib/Transforms/Normalize.cpp
new file mode 100644
index 0000000000000..1904d25d36674
--- /dev/null
+++ b/mlir/lib/Transforms/Normalize.cpp
@@ -0,0 +1,112 @@
+//===- Normalize.cpp - Transforms IR into a normal form ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/DebugLog.h"
+
+using namespace mlir;
+
+namespace mlir {
+#define GEN_PASS_DEF_NORMALIZEPASS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "normalize"
+
+namespace {
+
+bool isOutput(Operation *op) {
+  if (!op)
+    return false;
+  return !isMemoryEffectFree(op) || op->hasTrait<OpTrait::IsTerminator>();
+}
+
+/// Returns a vector of output ops. An output is a op which
+/// has side-effects or is terminator.
+SmallVector<Operation *> collectOutputs(Operation *root) {
+  SmallVector<Operation *> outputs;
+  root->walk([&](Operation *op) {
+    if (isOutput(op))
+      outputs.push_back(op);
+  });
+  return outputs;
+}
+
+/// The function returns the operation that dominates all other operations in
+/// the given list.
+Operation *getDominateOp(SmallVectorImpl<Operation *> &ops) {
+  if (ops.empty())
+    return {};
+  Operation *curDomOp = ops.front();
+  DominanceInfo domInfo(curDomOp);
+  for (size_t i = 1, e = ops.size(); i < e; ++i) {
+    bool dominateA = domInfo.dominates(ops[i], curDomOp);
+    bool dominateB = domInfo.dominates(curDomOp, ops[i]);
+    if (dominateA) {
+      LDBG() << OpWithFlags(ops[i], OpPrintingFlags().skipRegions())
+             << "\ndominate\n"
+             << OpWithFlags(curDomOp, OpPrintingFlags().skipRegions());
+      curDomOp = ops[i];
+    }
+    if (!dominateA && !dominateB) {
+      LDBG() << OpWithFlags(ops[i], OpPrintingFlags().skipRegions())
+             << "\nand\n"
+             << OpWithFlags(curDomOp, OpPrintingFlags().skipRegions())
+             << "\ndo not dominate each other";
+      return {};
+    }
+  }
+  return curDomOp;
+}
+
+/// Move used to its nearest user and recursively perform the same process on
+/// the defining operations of its operands.
+void reorderOutput(IRRewriter &rewriter, Operation *used) {
+  if (!isPure(used))
+    return;
+  SmallVector<Operation *> users(used->getUsers());
+  if (Operation *domOp = getDominateOp(users)) {
+    rewriter.moveOpBefore(used, domOp);
+    for (Value operand : used->getOperands())
+      if (Operation *defineOp = operand.getDefiningOp())
+        reorderOutput(rewriter, defineOp);
+  }
+}
+
+/// Reorders ops by walking up the tree from each operand of an output op and
+/// reducing the def-use distance. This method assumes that output ops were
+/// collected top-down, otherwise the def-use chain may be broken. This method
+/// is a wrapper for recursive reorderOutput().
+void reorderOutputs(IRRewriter &rewriter,
+                    SmallVectorImpl<Operation *> &outputs) {
+  SmallPtrSet<Operation *, 16> visited;
+  for (Operation *output : outputs) {
+    for (Value operand : output->getOperands()) {
+      if (Operation *defineOp = operand.getDefiningOp();
+          defineOp && !visited.contains(defineOp)) {
+        reorderOutput(rewriter, defineOp);
+      }
+    }
+  }
+}
+
+struct NormalizePass : public impl::NormalizePassBase<NormalizePass> {
+  using impl::NormalizePassBase<NormalizePass>::NormalizePassBase;
+  void runOnOperation() override;
+};
+} // namespace
+
+void NormalizePass::runOnOperation() {
+  IRRewriter rewriter(&getContext());
+  SmallVector<Operation *> outputs = collectOutputs(getOperation());
+  reorderOutputs(rewriter, outputs);
+}
diff --git a/mlir/test/Transforms/normalize.mlir b/mlir/test/Transforms/normalize.mlir
new file mode 100644
index 0000000000000..241e28c3c385d
--- /dev/null
+++ b/mlir/test/Transforms/normalize.mlir
@@ -0,0 +1,84 @@
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(normalize))" -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @multiple_memref_store
+//  CHECK-SAME:   %[[ARG0:.*]]: index,
+//  CHECK-SAME:   %[[ARG1:.*]]: memref<?xf32>
+func.func @multiple_memref_store(%arg0: index, %arg1 : memref<?xf32>) {
+  %f0 = arith.constant 0.0 : f32
+  %f1 = arith.constant 1.0 : f32
+  %add = arith.addi %arg0, %arg0 : index
+  %sub = arith.subi %arg0, %arg0 : index
+  memref.store %f0, %arg1[%add] : memref<?xf32>
+  memref.store %f1, %arg1[%sub] : memref<?xf32>
+  return
+}
+
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG0]], %[[ARG0]] : index
+// CHECK-NEXT: memref.store %[[C0]], %[[ARG1]]{{\[}}%[[ADD]]] : memref<?xf32>
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[SUB:.*]] = arith.subi %[[ARG0]], %[[ARG0]] : index
+// CHECK-NEXT: memref.store %[[C1]], %[[ARG1]]{{\[}}%[[SUB]]] : memref<?xf32>
+
+// -----
+
+// CHECK-LABEL: func @normalize_return
+// CHECK-SAME:      %[[ARG0:.*]]: index,
+// CHECK-SAME:      %[[ARG1:.*]]: memref<?xf32>
+func.func @normalize_return(%arg0: index, %arg1 : memref<?xf32>) -> index {
+  %f0 = arith.constant 0.0 : f32
+  %add = arith.addi %arg0, %arg0 : index
+  %sub = arith.subi %arg0, %arg0 : index
+  memref.store %f0, %arg1[%add] : memref<?xf32>
+  return %sub : index
+}
+
+//      CHECK: memref.store
+// CHECK-NEXT: %[[SUB:.*]] = arith.subi %[[ARG0]], %[[ARG0]] : index
+// CHECK-NEXT: return %[[SUB]] : index
+
+// -----
+
+// CHECK-LABEL: func @cross_region
+//  CHECK-SAME:   %[[ARG0:.*]]: f32,
+//  CHECK-SAME:   %[[ARG1:.*]]: memref<10xf32>
+func.func @cross_region(%arg0: f32, %arg1 : memref<10xf32>) {
+  %add = arith.addf %arg0, %arg0 : f32
+  affine.for %i = 0 to 5 {
+    memref.store %add, %arg1[%i] : memref<10xf32>
+  }
+  %exp = math.log2 %add : f32
+  affine.for %i = 6 to 10 {
+    memref.store %exp, %arg1[%i] : memref<10xf32>
+  } 
+  return
+}
+
+//      CHECK: affine.for %[[IV:.*]] = 6 to 10 {
+// CHECK-NEXT:   %[[LOG:.*]] = math.log2
+// CHECK-NEXT:   memref.store %[[LOG]], %[[ARG1]]{{\[}}%[[IV]]] : memref<10xf32>
+// CHECK-NEXT: }
+
+// -----
+
+
+// CHECK-LABEL: func @side_effect_for_op
+//  CHECK-SAME:   %[[ARG0:.*]]: memref<?xf32>
+func.func @side_effect_for_op(%arg1 : memref<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %upper = memref.dim %arg1, %c0 : memref<?xf32>
+  %f1 = arith.constant 1.0 : f32
+  scf.for %i = %c0 to %upper step %c1 {
+    memref.store %f1, %arg1[%i] : memref<?xf32>
+  }
+  return
+}
+
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xf32>
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: scf.for %[[IV:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
+// CHECK-NEXT:   %[[F1:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT:   memref.store %[[F1]], %[[ARG0]]{{\[}}%[[IV]]] : memref<?xf32>
+// CHECK-NEXT: }

``````````

</details>


https://github.com/llvm/llvm-project/pull/186647


More information about the Mlir-commits mailing list