[Mlir-commits] [mlir] [mlir][emitc] Fix recurring operands in expression (PR #175535)
Gil Rapaport
llvmlistbot at llvm.org
Mon Jan 12 05:26:11 PST 2026
https://github.com/aniragil created https://github.com/llvm/llvm-project/pull/175535
The pretty-printing for `emitc.expression` breaks for expressions taking the same value as operand multiple times.
Passing the same value as operand more than once is redundant, and is therefore not the canonical form of `emitc.expression. However, since transformations affecting `emitc.expression` operands may cause this to happen, `emitc.expression` must retain its support for recurring operands.
This PR fixes this issue by shadowing the region arguments only when the operands are unique, printing and parsing an explicit basic block otherwise. In addition, a canonicalization pattern removing recurring operands is added.
Fixes issue https://github.com/llvm/llvm-project/issues/172952.
>From 4b09dfc75827795c47df46f7b81f2d2024c8edfb Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Mon, 12 Jan 2026 12:25:57 +0200
Subject: [PATCH] [mlir][emitc] Fix recurring operands in expression
The pretty-printing for `emitc.expression` breaks for expressions taking
the same value as operand multiple times.
Passing the same value as operand more than once is redundant, and is
therefore not the canonical form of `emitc.expression. However, since
transformations affecting `emitc.expression` operands may cause this
to happen, `emitc.expression` must retain its support for recurring
operands.
This PR fixes this issue by shadowing the region arguments only when the
operands are unique, printing and parsing an explicit basic block
otherwise. In addition, a canonicalization pattern removing recurring
operands is added.
Fixes issue https://github.com/llvm/llvm-project/issues/172952.
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 1 +
.../Dialect/EmitC/Transforms/Transforms.h | 4 ++
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 44 ++++++++++++----
.../Dialect/EmitC/Transforms/Transforms.cpp | 52 +++++++++++++++++++
mlir/test/Dialect/EmitC/form-expressions.mlir | 19 +++++++
mlir/test/Dialect/EmitC/invalid_ops.mlir | 13 +++++
mlir/test/Dialect/EmitC/ops.mlir | 24 ++++++++-
7 files changed, 146 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index c1820904f2665..b638130e24b24 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -566,6 +566,7 @@ def EmitC_ExpressionOp
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
+ let hasCanonicalizer = 1;
let extraClassDeclaration = [{
bool hasSideEffects() {
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index bdf6d0985e6db..67b17d3c0d573 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -25,6 +25,10 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
// Populate functions
//===----------------------------------------------------------------------===//
+/// Populates `patterns` with expression canonicalization patterns.
+void populateExpressionCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context);
+
/// Populates `patterns` with expression-related patterns.
void populateExpressionPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index b0566dd10f490..3d49ec0d3a78a 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitCInterfaces.h"
+#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -412,6 +413,11 @@ LogicalResult DereferenceOp::verify() {
// ExpressionOp
//===----------------------------------------------------------------------===//
+void ExpressionOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ populateExpressionCanonicalizationPatterns(results, context);
+}
+
ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand> operands;
if (parser.parseOperandList(operands))
@@ -435,27 +441,45 @@ ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
"expected single return type");
result.addTypes(fnType.getResults());
Region *body = result.addRegion();
+ DenseSet<Value> uniqueOperands(result.operands.begin(),
+ result.operands.end());
+ bool enableNameShadowing = uniqueOperands.size() == result.operands.size();
SmallVector<OpAsmParser::Argument> argsInfo;
- for (auto [unresolvedOperand, operandType] :
- llvm::zip(operands, fnType.getInputs())) {
- OpAsmParser::Argument argInfo;
- argInfo.ssaName = unresolvedOperand;
- argInfo.type = operandType;
- argsInfo.push_back(argInfo);
+ if (enableNameShadowing) {
+ for (auto [unresolvedOperand, operandType] :
+ llvm::zip(operands, fnType.getInputs())) {
+ OpAsmParser::Argument argInfo;
+ argInfo.ssaName = unresolvedOperand;
+ argInfo.type = operandType;
+ argsInfo.push_back(argInfo);
+ }
}
- if (parser.parseRegion(*body, argsInfo, /*enableNameShadowing=*/true))
+ SMLoc beforeRegionLoc = parser.getCurrentLocation();
+ if (parser.parseRegion(*body, argsInfo, enableNameShadowing))
return failure();
+ if (!enableNameShadowing) {
+ if (body->front().getArguments().size() < result.operands.size()) {
+ return parser.emitError(
+ beforeRegionLoc, "with recurring operands expected block arguments");
+ }
+ }
return success();
}
void emitc::ExpressionOp::print(OpAsmPrinter &p) {
p << ' ';
- p.printOperands(getDefs());
+ auto operands = getDefs();
+ p.printOperands(operands);
p << " : ";
p.printFunctionalType(getOperation());
- p.shadowRegionArgs(getRegion(), getDefs());
+ DenseSet<Value> uniqueOperands(operands.begin(), operands.end());
+ bool printEntryBlockArgs = true;
+ if (uniqueOperands.size() == operands.size()) {
+ p.shadowRegionArgs(getRegion(), getDefs());
+ printEntryBlockArgs = false;
+ }
p << ' ';
- p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+ p.printRegion(getRegion(), printEntryBlockArgs);
}
Operation *ExpressionOp::getRootOp() {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index f8469b8f0ed67..bfcb4a140ee9f 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -149,8 +149,60 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
}
};
+struct RemoveRecurringExpressionOperands
+ : public OpRewritePattern<ExpressionOp> {
+ using OpRewritePattern<ExpressionOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(ExpressionOp expressionOp,
+ PatternRewriter &rewriter) const override {
+ SetVector<Value> uniqueOperands;
+ DenseMap<Value, int> firstIndexOf;
+
+ // Collect duplicate operands and prepare to remove excessive copies.
+ for (auto [i, operand] : llvm::enumerate(expressionOp.getDefs())) {
+ if (uniqueOperands.contains(operand))
+ continue;
+ uniqueOperands.insert(operand);
+ firstIndexOf[operand] = i;
+ }
+
+ // If every operand is unique, bail out.
+ if (uniqueOperands.size() == expressionOp.getDefs().size())
+ return failure();
+
+ // Create a new expression with unique operands.
+ rewriter.setInsertionPointAfter(expressionOp);
+ auto uniqueExpression = emitc::ExpressionOp::create(
+ rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
+ uniqueOperands.getArrayRef(), expressionOp.getDoNotInline());
+ Block &uniqueExpressionBody = uniqueExpression.createBody();
+
+ // Map each original block arguments to the unique block argument taking
+ // the same operand.
+ IRMapping mapper;
+ Block *expressionBody = expressionOp.getBody();
+ for (auto [operand, arg] :
+ llvm::zip(expressionOp.getOperands(), expressionBody->getArguments()))
+ mapper.map(arg, uniqueExpressionBody.getArgument(firstIndexOf[operand]));
+
+ rewriter.setInsertionPointToStart(&uniqueExpressionBody);
+ for (Operation &opToClone : *expressionOp.getBody())
+ rewriter.clone(opToClone, mapper);
+
+ // Complete the rewrite.
+ rewriter.replaceOp(expressionOp, uniqueExpression);
+
+ return success();
+ }
+};
+
} // namespace
+void mlir::emitc::populateExpressionCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<RemoveRecurringExpressionOperands>(patterns.getContext());
+}
+
void mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) {
+ populateExpressionCanonicalizationPatterns(patterns, patterns.getContext());
patterns.add<FoldExpressionOp>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/EmitC/form-expressions.mlir b/mlir/test/Dialect/EmitC/form-expressions.mlir
index 7b6723989e260..58eac4381ccb7 100644
--- a/mlir/test/Dialect/EmitC/form-expressions.mlir
+++ b/mlir/test/Dialect/EmitC/form-expressions.mlir
@@ -20,6 +20,25 @@ func.func @single_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) ->
return %c : i1
}
+// CHECK-LABEL: func.func @expression_recurring_args(
+// CHECK-SAME: %[[ARG0:.*]]: i32,
+// CHECK-SAME: %[[ARG1:.*]]: i32) -> i1 {
+// CHECK: %[[EXPRESSION_0:.*]] = emitc.expression %[[ARG1]], %[[ARG0]] : (i32, i32) -> i1 {
+// CHECK: %[[VAL_0:.*]] = mul %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32
+// CHECK: %[[VAL_1:.*]] = sub %[[VAL_0]], %[[ARG0]] : (i32, i32) -> i32
+// CHECK: %[[VAL_2:.*]] = cmp lt, %[[VAL_1]], %[[ARG1]] : (i32, i32) -> i1
+// CHECK: yield %[[VAL_2]] : i1
+// CHECK: }
+// CHECK: return %[[EXPRESSION_0]] : i1
+// CHECK: }
+
+func.func @expression_recurring_args(%arg0: i32, %arg1: i32) -> i1 {
+ %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.sub %a, %arg0 : (i32, i32) -> i32
+ %c = emitc.cmp lt, %b, %arg1 :(i32, i32) -> i1
+ return %c : i1
+}
+
// CHECK-LABEL: func.func @multiple_expressions(
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> (i32, i32) {
// CHECK: %[[VAL_4:.*]] = emitc.expression %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : (i32, i32, i32) -> i32 {
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index d1601bed29ca9..0d878e90cdf0c 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -379,6 +379,19 @@ emitc.func @test_expression_op_outside_expression() {
// -----
+func.func @test_expression_recurring_operands(%arg0: i32, %arg1: i32) -> i32 {
+ // expected-error @+1 {{'emitc.expression' with recurring operands expected block arguments}}
+ %r = emitc.expression %arg0, %arg1, %arg0 : (i32, i32, i32) -> i32 {
+ %a = emitc.rem %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.add %a, %arg0 : (i32, i32) -> i32
+ %c = emitc.mul %b, %a : (i32, i32) -> i32
+ emitc.yield %c : i32
+ }
+ return %r : i32
+}
+
+// -----
+
// expected-error @+1 {{'emitc.func' op requires zero or exactly one result, but has 2}}
emitc.func @multiple_results(%0: i32) -> (i32, i32) {
emitc.return %0 : i32
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index b2c8b843ec14b..2f7544b5db096 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// RUN: mlir-opt %s -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -canonicalize | FileCheck -check-prefix=CANON %s
// CHECK: emitc.include <"test.h">
// CHECK: emitc.include "test.h"
@@ -213,6 +213,28 @@ func.func @test_expression_multiple_uses(%arg0: i32, %arg1: i32) -> i32 {
return %r : i32
}
+// CANON-LABEL: func.func @test_expression_recurring_operands(
+// CANON-SAME: %[[ARG0:.*]]: i32,
+// CANON-SAME: %[[ARG1:.*]]: i32) -> i32 {
+// CANON: %[[EXPRESSION_0:.*]] = emitc.expression %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32 {
+// CANON: %[[VAL_0:.*]] = rem %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32
+// CANON: %[[VAL_1:.*]] = add %[[VAL_0]], %[[ARG0]] : (i32, i32) -> i32
+// CANON: %[[VAL_2:.*]] = mul %[[VAL_1]], %[[VAL_0]] : (i32, i32) -> i32
+// CANON: yield %[[VAL_2]] : i32
+// CANON: }
+// CANON: return %[[EXPRESSION_0]] : i32
+// CANON: }
+func.func @test_expression_recurring_operands(%arg0: i32, %arg1: i32) -> i32 {
+ %r = emitc.expression %arg0, %arg1, %arg0 : (i32, i32, i32) -> i32 {
+ ^bb0(%x: i32, %y: i32, %z: i32):
+ %a = emitc.rem %x, %y : (i32, i32) -> i32
+ %b = emitc.add %a, %z : (i32, i32) -> i32
+ %c = emitc.mul %b, %a : (i32, i32) -> i32
+ emitc.yield %c : i32
+ }
+ return %r : i32
+}
+
func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
emitc.for %i0 = %arg0 to %arg1 step %arg2 {
%0 = emitc.call_opaque "func_const"(%i0) : (index) -> i32
More information about the Mlir-commits
mailing list