[Mlir-commits] [mlir] [mlir][emitc] Fix recurring operands in expression (PR #178382)

Gil Rapaport llvmlistbot at llvm.org
Wed Jan 28 01:11:37 PST 2026


https://github.com/aniragil created https://github.com/llvm/llvm-project/pull/178382

Relanding #175535 which got reverted for failing the buildbot.
New canonicalization pattern moved to dialect code.

>From bd9d77af92574e9cfc70e80a60dae062fabda394 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Tue, 27 Jan 2026 10:12:08 +0200
Subject: [PATCH] [mlir][emitc] Fix recurring operands in expression

Relanding #175535 which got reverted for failing the buildbot.
New canonicalization pattern moved to dialect code.
---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   |  1 +
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           | 95 +++++++++++++++++--
 .../Dialect/EmitC/Transforms/Transforms.cpp   |  1 +
 mlir/test/Dialect/EmitC/form-expressions.mlir | 19 ++++
 mlir/test/Dialect/EmitC/invalid_ops.mlir      | 13 +++
 mlir/test/Dialect/EmitC/ops.mlir              | 24 ++++-
 6 files changed, 142 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index caed3233f62e9..7d9bb8907eb8b 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -566,6 +566,7 @@ def EmitC_ExpressionOp
 
   let hasVerifier = 1;
   let hasCustomAssemblyFormat = 1;
+  let hasCanonicalizer = 1;
 
   let extraClassDeclaration = [{
     bool hasSideEffects() {
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 1d4b748b2a88a..64a475c2e42ab 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -8,10 +8,12 @@
 
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/EmitC/IR/EmitCInterfaces.h"
+#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
 #include "mlir/Support/LLVM.h"
@@ -412,6 +414,61 @@ LogicalResult DereferenceOp::verify() {
 // ExpressionOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+struct RemoveRecurringExpressionOperands
+    : public OpRewritePattern<ExpressionOp> {
+  using OpRewritePattern<ExpressionOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(ExpressionOp expressionOp,
+                                PatternRewriter &rewriter) const override {
+    SetVector<Value> uniqueOperands;
+    DenseMap<Value, int> firstIndexOf;
+
+    // Collect duplicate operands and prepare to remove excessive copies.
+    for (auto [i, operand] : llvm::enumerate(expressionOp.getDefs())) {
+      if (uniqueOperands.contains(operand))
+        continue;
+      uniqueOperands.insert(operand);
+      firstIndexOf[operand] = i;
+    }
+
+    // If every operand is unique, bail out.
+    if (uniqueOperands.size() == expressionOp.getDefs().size())
+      return failure();
+
+    // Create a new expression with unique operands.
+    rewriter.setInsertionPointAfter(expressionOp);
+    auto uniqueExpression = emitc::ExpressionOp::create(
+        rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
+        uniqueOperands.getArrayRef(), expressionOp.getDoNotInline());
+    Block &uniqueExpressionBody = uniqueExpression.createBody();
+
+    // Map each original block arguments to the unique block argument taking
+    // the same operand.
+    IRMapping mapper;
+    Block *expressionBody = expressionOp.getBody();
+    for (auto [operand, arg] :
+         llvm::zip(expressionOp.getOperands(), expressionBody->getArguments()))
+      mapper.map(arg, uniqueExpressionBody.getArgument(firstIndexOf[operand]));
+
+    rewriter.setInsertionPointToStart(&uniqueExpressionBody);
+    for (Operation &opToClone : *expressionOp.getBody())
+      rewriter.clone(opToClone, mapper);
+
+    // Complete the rewrite.
+    rewriter.replaceOp(expressionOp, uniqueExpression);
+
+    return success();
+  }
+};
+
+} // namespace
+
+void ExpressionOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                               MLIRContext *context) {
+  results.add<RemoveRecurringExpressionOperands>(context);
+}
+
 ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand> operands;
   if (parser.parseOperandList(operands))
@@ -435,27 +492,45 @@ ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
                             "expected single return type");
   result.addTypes(fnType.getResults());
   Region *body = result.addRegion();
+  DenseSet<Value> uniqueOperands(result.operands.begin(),
+                                 result.operands.end());
+  bool enableNameShadowing = uniqueOperands.size() == result.operands.size();
   SmallVector<OpAsmParser::Argument> argsInfo;
-  for (auto [unresolvedOperand, operandType] :
-       llvm::zip(operands, fnType.getInputs())) {
-    OpAsmParser::Argument argInfo;
-    argInfo.ssaName = unresolvedOperand;
-    argInfo.type = operandType;
-    argsInfo.push_back(argInfo);
+  if (enableNameShadowing) {
+    for (auto [unresolvedOperand, operandType] :
+         llvm::zip(operands, fnType.getInputs())) {
+      OpAsmParser::Argument argInfo;
+      argInfo.ssaName = unresolvedOperand;
+      argInfo.type = operandType;
+      argsInfo.push_back(argInfo);
+    }
   }
-  if (parser.parseRegion(*body, argsInfo, /*enableNameShadowing=*/true))
+  SMLoc beforeRegionLoc = parser.getCurrentLocation();
+  if (parser.parseRegion(*body, argsInfo, enableNameShadowing))
     return failure();
+  if (!enableNameShadowing) {
+    if (body->front().getArguments().size() < result.operands.size()) {
+      return parser.emitError(
+          beforeRegionLoc, "with recurring operands expected block arguments");
+    }
+  }
   return success();
 }
 
 void emitc::ExpressionOp::print(OpAsmPrinter &p) {
   p << ' ';
-  p.printOperands(getDefs());
+  auto operands = getDefs();
+  p.printOperands(operands);
   p << " : ";
   p.printFunctionalType(getOperation());
-  p.shadowRegionArgs(getRegion(), getDefs());
+  DenseSet<Value> uniqueOperands(operands.begin(), operands.end());
+  bool printEntryBlockArgs = true;
+  if (uniqueOperands.size() == operands.size()) {
+    p.shadowRegionArgs(getRegion(), getDefs());
+    printEntryBlockArgs = false;
+  }
   p << ' ';
-  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+  p.printRegion(getRegion(), printEntryBlockArgs);
 }
 
 Operation *ExpressionOp::getRootOp() {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index f8469b8f0ed67..4bcdb285d6a16 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -152,5 +152,6 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
 } // namespace
 
 void mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) {
+  ExpressionOp::getCanonicalizationPatterns(patterns, patterns.getContext());
   patterns.add<FoldExpressionOp>(patterns.getContext());
 }
diff --git a/mlir/test/Dialect/EmitC/form-expressions.mlir b/mlir/test/Dialect/EmitC/form-expressions.mlir
index 7b6723989e260..58eac4381ccb7 100644
--- a/mlir/test/Dialect/EmitC/form-expressions.mlir
+++ b/mlir/test/Dialect/EmitC/form-expressions.mlir
@@ -20,6 +20,25 @@ func.func @single_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) ->
   return %c : i1
 }
 
+// CHECK-LABEL:   func.func @expression_recurring_args(
+// CHECK-SAME:      %[[ARG0:.*]]: i32,
+// CHECK-SAME:      %[[ARG1:.*]]: i32) -> i1 {
+// CHECK:           %[[EXPRESSION_0:.*]] = emitc.expression %[[ARG1]], %[[ARG0]] : (i32, i32) -> i1 {
+// CHECK:             %[[VAL_0:.*]] = mul %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32
+// CHECK:             %[[VAL_1:.*]] = sub %[[VAL_0]], %[[ARG0]] : (i32, i32) -> i32
+// CHECK:             %[[VAL_2:.*]] = cmp lt, %[[VAL_1]], %[[ARG1]] : (i32, i32) -> i1
+// CHECK:             yield %[[VAL_2]] : i1
+// CHECK:           }
+// CHECK:           return %[[EXPRESSION_0]] : i1
+// CHECK:         }
+
+func.func @expression_recurring_args(%arg0: i32, %arg1: i32) -> i1 {
+  %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+  %b = emitc.sub %a, %arg0 : (i32, i32) -> i32
+  %c = emitc.cmp lt, %b, %arg1 :(i32, i32) -> i1
+  return %c : i1
+}
+
 // CHECK-LABEL: func.func @multiple_expressions(
 // CHECK-SAME:      %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> (i32, i32) {
 // CHECK:         %[[VAL_4:.*]] = emitc.expression %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : (i32, i32, i32) -> i32 {
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index d1601bed29ca9..0d878e90cdf0c 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -379,6 +379,19 @@ emitc.func @test_expression_op_outside_expression() {
 
 // -----
 
+func.func @test_expression_recurring_operands(%arg0: i32, %arg1: i32) -> i32 {
+  // expected-error @+1 {{'emitc.expression' with recurring operands expected block arguments}}
+  %r = emitc.expression %arg0, %arg1, %arg0 : (i32, i32, i32) -> i32 {
+    %a = emitc.rem %arg0, %arg1 : (i32, i32) -> i32
+    %b = emitc.add %a, %arg0 : (i32, i32) -> i32
+    %c = emitc.mul %b, %a : (i32, i32) -> i32
+    emitc.yield %c : i32
+  }
+  return %r : i32
+}
+
+// -----
+
 // expected-error @+1 {{'emitc.func' op requires zero or exactly one result, but has 2}}
 emitc.func @multiple_results(%0: i32) -> (i32, i32) {
   emitc.return %0 : i32
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index b2c8b843ec14b..2f7544b5db096 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// RUN: mlir-opt %s -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -canonicalize | FileCheck -check-prefix=CANON %s
 
 // CHECK: emitc.include <"test.h">
 // CHECK: emitc.include "test.h"
@@ -213,6 +213,28 @@ func.func @test_expression_multiple_uses(%arg0: i32, %arg1: i32) -> i32 {
   return %r : i32
 }
 
+// CANON-LABEL:   func.func @test_expression_recurring_operands(
+// CANON-SAME:      %[[ARG0:.*]]: i32,
+// CANON-SAME:      %[[ARG1:.*]]: i32) -> i32 {
+// CANON:           %[[EXPRESSION_0:.*]] = emitc.expression %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32 {
+// CANON:             %[[VAL_0:.*]] = rem %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32
+// CANON:             %[[VAL_1:.*]] = add %[[VAL_0]], %[[ARG0]] : (i32, i32) -> i32
+// CANON:             %[[VAL_2:.*]] = mul %[[VAL_1]], %[[VAL_0]] : (i32, i32) -> i32
+// CANON:             yield %[[VAL_2]] : i32
+// CANON:           }
+// CANON:           return %[[EXPRESSION_0]] : i32
+// CANON:         }
+func.func @test_expression_recurring_operands(%arg0: i32, %arg1: i32) -> i32 {
+  %r = emitc.expression %arg0, %arg1, %arg0 : (i32, i32, i32) -> i32 {
+  ^bb0(%x: i32, %y: i32, %z: i32):
+    %a = emitc.rem %x, %y : (i32, i32) -> i32
+    %b = emitc.add %a, %z : (i32, i32) -> i32
+    %c = emitc.mul %b, %a : (i32, i32) -> i32
+    emitc.yield %c : i32
+  }
+  return %r : i32
+}
+
 func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
   emitc.for %i0 = %arg0 to %arg1 step %arg2 {
     %0 = emitc.call_opaque "func_const"(%i0) : (index) -> i32



More information about the Mlir-commits mailing list