[Mlir-commits] [mlir] [mlir] Add normalize pass to MLIR (PR #186647)

lonely eagle llvmlistbot at llvm.org
Mon Mar 16 00:23:24 PDT 2026


https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/186647

>From 5a55524e2df48a309272394cbf28c24c8dc9e7c7 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Fri, 13 Mar 2026 02:26:50 +0000
Subject: [PATCH 1/2] complete basic normalize.

---
 mlir/include/mlir/Transforms/Passes.h  |   1 +
 mlir/include/mlir/Transforms/Passes.td |   8 ++
 mlir/lib/Transforms/CMakeLists.txt     |   1 +
 mlir/lib/Transforms/Normalize.cpp      | 112 +++++++++++++++++++++++++
 mlir/test/Transforms/normalize.mlir    |  84 +++++++++++++++++++
 5 files changed, 206 insertions(+)
 create mode 100644 mlir/lib/Transforms/Normalize.cpp
 create mode 100644 mlir/test/Transforms/normalize.mlir

diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 4804b023a8f79..e962a5659f710 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -41,6 +41,7 @@ class GreedyRewriteConfig;
 #define GEN_PASS_DECL_LOOPINVARIANTSUBSETHOISTINGPASS
 #define GEN_PASS_DECL_INLINERPASS
 #define GEN_PASS_DECL_MEM2REG
+#define GEN_PASS_DECL_NORMALIZEPASS
 #define GEN_PASS_DECL_PRINTIRPASS
 #define GEN_PASS_DECL_PRINTOPSTATSPASS
 #define GEN_PASS_DECL_REMOVEDEADVALUESPASS
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 20af90e56ee67..4dfad8dcaa703 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -630,4 +630,12 @@ def BubbleDownMemorySpaceCasts :
   }];
 }
 
+def NormalizePass : InterfacePass<"normalize", "FunctionOpInterface"> {
+  let summary = "Transforms IR into a normal form that's easier to diff.";
+  let description = [{
+    This pass attempts to relocate the defining ops of operands for any
+    side-effecting or terminator operation to their nearest dominating positions.
+  }];
+}
+
 #endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 8907724627386..e305b934c74cb 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_library(MLIRTransforms
   LocationSnapshot.cpp
   LoopInvariantCodeMotion.cpp
   Mem2Reg.cpp
+  Normalize.cpp
   OpStats.cpp
   PrintIR.cpp
   RemoveDeadValues.cpp
diff --git a/mlir/lib/Transforms/Normalize.cpp b/mlir/lib/Transforms/Normalize.cpp
new file mode 100644
index 0000000000000..1904d25d36674
--- /dev/null
+++ b/mlir/lib/Transforms/Normalize.cpp
@@ -0,0 +1,112 @@
+//===- Normalize.cpp - Transforms IR into a normal form ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/DebugLog.h"
+
+using namespace mlir;
+
+namespace mlir {
+#define GEN_PASS_DEF_NORMALIZEPASS
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "normalize"
+
+namespace {
+
+bool isOutput(Operation *op) {
+  if (!op)
+    return false;
+  return !isMemoryEffectFree(op) || op->hasTrait<OpTrait::IsTerminator>();
+}
+
+/// Returns a vector of output ops. An output is a op which
+/// has side-effects or is terminator.
+SmallVector<Operation *> collectOutputs(Operation *root) {
+  SmallVector<Operation *> outputs;
+  root->walk([&](Operation *op) {
+    if (isOutput(op))
+      outputs.push_back(op);
+  });
+  return outputs;
+}
+
+/// The function returns the operation that dominates all other operations in
+/// the given list.
+Operation *getDominateOp(SmallVectorImpl<Operation *> &ops) {
+  if (ops.empty())
+    return {};
+  Operation *curDomOp = ops.front();
+  DominanceInfo domInfo(curDomOp);
+  for (size_t i = 1, e = ops.size(); i < e; ++i) {
+    bool dominateA = domInfo.dominates(ops[i], curDomOp);
+    bool dominateB = domInfo.dominates(curDomOp, ops[i]);
+    if (dominateA) {
+      LDBG() << OpWithFlags(ops[i], OpPrintingFlags().skipRegions())
+             << "\ndominate\n"
+             << OpWithFlags(curDomOp, OpPrintingFlags().skipRegions());
+      curDomOp = ops[i];
+    }
+    if (!dominateA && !dominateB) {
+      LDBG() << OpWithFlags(ops[i], OpPrintingFlags().skipRegions())
+             << "\nand\n"
+             << OpWithFlags(curDomOp, OpPrintingFlags().skipRegions())
+             << "\ndo not dominate each other";
+      return {};
+    }
+  }
+  return curDomOp;
+}
+
+/// Move used to its nearest user and recursively perform the same process on
+/// the defining operations of its operands.
+void reorderOutput(IRRewriter &rewriter, Operation *used) {
+  if (!isPure(used))
+    return;
+  SmallVector<Operation *> users(used->getUsers());
+  if (Operation *domOp = getDominateOp(users)) {
+    rewriter.moveOpBefore(used, domOp);
+    for (Value operand : used->getOperands())
+      if (Operation *defineOp = operand.getDefiningOp())
+        reorderOutput(rewriter, defineOp);
+  }
+}
+
+/// Reorders ops by walking up the tree from each operand of an output op and
+/// reducing the def-use distance. This method assumes that output ops were
+/// collected top-down, otherwise the def-use chain may be broken. This method
+/// is a wrapper for recursive reorderOutput().
+void reorderOutputs(IRRewriter &rewriter,
+                    SmallVectorImpl<Operation *> &outputs) {
+  SmallPtrSet<Operation *, 16> visited;
+  for (Operation *output : outputs) {
+    for (Value operand : output->getOperands()) {
+      if (Operation *defineOp = operand.getDefiningOp();
+          defineOp && !visited.contains(defineOp)) {
+        reorderOutput(rewriter, defineOp);
+      }
+    }
+  }
+}
+
+struct NormalizePass : public impl::NormalizePassBase<NormalizePass> {
+  using impl::NormalizePassBase<NormalizePass>::NormalizePassBase;
+  void runOnOperation() override;
+};
+} // namespace
+
+void NormalizePass::runOnOperation() {
+  IRRewriter rewriter(&getContext());
+  SmallVector<Operation *> outputs = collectOutputs(getOperation());
+  reorderOutputs(rewriter, outputs);
+}
diff --git a/mlir/test/Transforms/normalize.mlir b/mlir/test/Transforms/normalize.mlir
new file mode 100644
index 0000000000000..241e28c3c385d
--- /dev/null
+++ b/mlir/test/Transforms/normalize.mlir
@@ -0,0 +1,84 @@
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(normalize))" -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @multiple_memref_store
+//  CHECK-SAME:   %[[ARG0:.*]]: index,
+//  CHECK-SAME:   %[[ARG1:.*]]: memref<?xf32>
+func.func @multiple_memref_store(%arg0: index, %arg1 : memref<?xf32>) {
+  %f0 = arith.constant 0.0 : f32
+  %f1 = arith.constant 1.0 : f32
+  %add = arith.addi %arg0, %arg0 : index
+  %sub = arith.subi %arg0, %arg0 : index
+  memref.store %f0, %arg1[%add] : memref<?xf32>
+  memref.store %f1, %arg1[%sub] : memref<?xf32>
+  return
+}
+
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG0]], %[[ARG0]] : index
+// CHECK-NEXT: memref.store %[[C0]], %[[ARG1]]{{\[}}%[[ADD]]] : memref<?xf32>
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[SUB:.*]] = arith.subi %[[ARG0]], %[[ARG0]] : index
+// CHECK-NEXT: memref.store %[[C1]], %[[ARG1]]{{\[}}%[[SUB]]] : memref<?xf32>
+
+// -----
+
+// CHECK-LABEL: func @normalize_return
+// CHECK-SAME:      %[[ARG0:.*]]: index,
+// CHECK-SAME:      %[[ARG1:.*]]: memref<?xf32>
+func.func @normalize_return(%arg0: index, %arg1 : memref<?xf32>) -> index {
+  %f0 = arith.constant 0.0 : f32
+  %add = arith.addi %arg0, %arg0 : index
+  %sub = arith.subi %arg0, %arg0 : index
+  memref.store %f0, %arg1[%add] : memref<?xf32>
+  return %sub : index
+}
+
+//      CHECK: memref.store
+// CHECK-NEXT: %[[SUB:.*]] = arith.subi %[[ARG0]], %[[ARG0]] : index
+// CHECK-NEXT: return %[[SUB]] : index
+
+// -----
+
+// CHECK-LABEL: func @cross_region
+//  CHECK-SAME:   %[[ARG0:.*]]: f32,
+//  CHECK-SAME:   %[[ARG1:.*]]: memref<10xf32>
+func.func @cross_region(%arg0: f32, %arg1 : memref<10xf32>) {
+  %add = arith.addf %arg0, %arg0 : f32
+  affine.for %i = 0 to 5 {
+    memref.store %add, %arg1[%i] : memref<10xf32>
+  }
+  %exp = math.log2 %add : f32
+  affine.for %i = 6 to 10 {
+    memref.store %exp, %arg1[%i] : memref<10xf32>
+  } 
+  return
+}
+
+//      CHECK: affine.for %[[IV:.*]] = 6 to 10 {
+// CHECK-NEXT:   %[[LOG:.*]] = math.log2
+// CHECK-NEXT:   memref.store %[[LOG]], %[[ARG1]]{{\[}}%[[IV]]] : memref<10xf32>
+// CHECK-NEXT: }
+
+// -----
+
+
+// CHECK-LABEL: func @side_effect_for_op
+//  CHECK-SAME:   %[[ARG0:.*]]: memref<?xf32>
+func.func @side_effect_for_op(%arg1 : memref<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %upper = memref.dim %arg1, %c0 : memref<?xf32>
+  %f1 = arith.constant 1.0 : f32
+  scf.for %i = %c0 to %upper step %c1 {
+    memref.store %f1, %arg1[%i] : memref<?xf32>
+  }
+  return
+}
+
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xf32>
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: scf.for %[[IV:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
+// CHECK-NEXT:   %[[F1:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT:   memref.store %[[F1]], %[[ARG0]]{{\[}}%[[IV]]] : memref<?xf32>
+// CHECK-NEXT: }

>From d3b9c7253adb4908ea8ce6a8d73440d2f37dabb0 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Mon, 16 Mar 2026 06:58:47 +0000
Subject: [PATCH 2/2] fix nit and add more comment.

---
 mlir/include/mlir/Transforms/Passes.td | 10 ++++++++--
 mlir/lib/Transforms/Normalize.cpp      | 20 ++++++++++---------
 mlir/test/Transforms/normalize.mlir    | 27 +++++++++++++-------------
 3 files changed, 33 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 4dfad8dcaa703..c75bc40614425 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -633,8 +633,14 @@ def BubbleDownMemorySpaceCasts :
 def NormalizePass : InterfacePass<"normalize", "FunctionOpInterface"> {
   let summary = "Transforms IR into a normal form that's easier to diff.";
   let description = [{
-    This pass attempts to relocate the defining ops of operands for any
-    side-effecting or terminator operation to their nearest dominating positions.
+    This pass aims to transform MLIR Modules into a normal form by reordering
+    operations while preserving the same semantics. It attempts to relocate
+    the defining ops of operands for any side-effecting or terminator operation
+    to their nearest dominating positions.
+
+    Note: The pass trying to increase syntactic equivalence of code to reduce
+    diff size while retaining semantic equivalence. It cannot replace the
+    canonicalization pass (a pass that aims for increasing semantic equivalence instead).
   }];
 }
 
diff --git a/mlir/lib/Transforms/Normalize.cpp b/mlir/lib/Transforms/Normalize.cpp
index 1904d25d36674..7e63c42b6d163 100644
--- a/mlir/lib/Transforms/Normalize.cpp
+++ b/mlir/lib/Transforms/Normalize.cpp
@@ -43,21 +43,22 @@ SmallVector<Operation *> collectOutputs(Operation *root) {
 
 /// The function returns the operation that dominates all other operations in
 /// the given list.
-Operation *getDominateOp(SmallVectorImpl<Operation *> &ops) {
+Operation *getDominateOp(const SmallVectorImpl<Operation *> &ops) {
   if (ops.empty())
     return {};
   Operation *curDomOp = ops.front();
   DominanceInfo domInfo(curDomOp);
   for (size_t i = 1, e = ops.size(); i < e; ++i) {
     bool dominateA = domInfo.dominates(ops[i], curDomOp);
-    bool dominateB = domInfo.dominates(curDomOp, ops[i]);
     if (dominateA) {
       LDBG() << OpWithFlags(ops[i], OpPrintingFlags().skipRegions())
              << "\ndominate\n"
              << OpWithFlags(curDomOp, OpPrintingFlags().skipRegions());
       curDomOp = ops[i];
+      continue;
     }
-    if (!dominateA && !dominateB) {
+    bool dominateB = domInfo.dominates(curDomOp, ops[i]);
+    if (!dominateB) {
       LDBG() << OpWithFlags(ops[i], OpPrintingFlags().skipRegions())
              << "\nand\n"
              << OpWithFlags(curDomOp, OpPrintingFlags().skipRegions())
@@ -70,13 +71,13 @@ Operation *getDominateOp(SmallVectorImpl<Operation *> &ops) {
 
 /// Move used to its nearest user and recursively perform the same process on
 /// the defining operations of its operands.
-void reorderOutput(IRRewriter &rewriter, Operation *used) {
-  if (!isPure(used))
+void reorderOutput(IRRewriter &rewriter, Operation *producer) {
+  if (!isPure(producer))
     return;
-  SmallVector<Operation *> users(used->getUsers());
+  SmallVector<Operation *> users(producer->getUsers());
   if (Operation *domOp = getDominateOp(users)) {
-    rewriter.moveOpBefore(used, domOp);
-    for (Value operand : used->getOperands())
+    rewriter.moveOpBefore(producer, domOp);
+    for (Value operand : producer->getOperands())
       if (Operation *defineOp = operand.getDefiningOp())
         reorderOutput(rewriter, defineOp);
   }
@@ -87,12 +88,13 @@ void reorderOutput(IRRewriter &rewriter, Operation *used) {
 /// collected top-down, otherwise the def-use chain may be broken. This method
 /// is a wrapper for recursive reorderOutput().
 void reorderOutputs(IRRewriter &rewriter,
-                    SmallVectorImpl<Operation *> &outputs) {
+                    const SmallVectorImpl<Operation *> &outputs) {
   SmallPtrSet<Operation *, 16> visited;
   for (Operation *output : outputs) {
     for (Value operand : output->getOperands()) {
       if (Operation *defineOp = operand.getDefiningOp();
           defineOp && !visited.contains(defineOp)) {
+        visited.insert(defineOp);
         reorderOutput(rewriter, defineOp);
       }
     }
diff --git a/mlir/test/Transforms/normalize.mlir b/mlir/test/Transforms/normalize.mlir
index 241e28c3c385d..28f0c05263278 100644
--- a/mlir/test/Transforms/normalize.mlir
+++ b/mlir/test/Transforms/normalize.mlir
@@ -22,20 +22,17 @@ func.func @multiple_memref_store(%arg0: index, %arg1 : memref<?xf32>) {
 
 // -----
 
-// CHECK-LABEL: func @normalize_return
-// CHECK-SAME:      %[[ARG0:.*]]: index,
-// CHECK-SAME:      %[[ARG1:.*]]: memref<?xf32>
-func.func @normalize_return(%arg0: index, %arg1 : memref<?xf32>) -> index {
-  %f0 = arith.constant 0.0 : f32
-  %add = arith.addi %arg0, %arg0 : index
-  %sub = arith.subi %arg0, %arg0 : index
-  memref.store %f0, %arg1[%add] : memref<?xf32>
-  return %sub : index
+// CHECK-LABEL: func @return_multiple_operands
+//  CHECK-SAME:   %[[ARG0:.*]]: index
+func.func @return_multiple_operands (%arg0: index) -> (index, index) {
+  %0 = arith.addi %arg0, %arg0 : index
+  %1 = arith.subi %arg0, %arg0 : index
+  return %1, %0 : index, index
 }
 
-//      CHECK: memref.store
 // CHECK-NEXT: %[[SUB:.*]] = arith.subi %[[ARG0]], %[[ARG0]] : index
-// CHECK-NEXT: return %[[SUB]] : index
+// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[ARG0]], %[[ARG0]] : index
+// CHECK-NEXT: return %[[SUB]], %[[ADD]] : index, index
 
 // -----
 
@@ -61,10 +58,14 @@ func.func @cross_region(%arg0: f32, %arg1 : memref<10xf32>) {
 
 // -----
 
+// This test verifies the reordering of scf.for ops.
+// The memref.store within the scf.for causes the loop to have side effects.
+// The lower bound of the scf.for remains in its original position
+// because the upper bound depends on it, but the step has been reordered.
 
-// CHECK-LABEL: func @side_effect_for_op
+// CHECK-LABEL: func @side_effect_loop_op
 //  CHECK-SAME:   %[[ARG0:.*]]: memref<?xf32>
-func.func @side_effect_for_op(%arg1 : memref<?xf32>) {
+func.func @side_effect_loop_op(%arg1 : memref<?xf32>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %upper = memref.dim %arg1, %c0 : memref<?xf32>



More information about the Mlir-commits mailing list