[Mlir-commits] [mlir] eb64fb6 - [mlir][emitc] Fix recurring operands in expression (#178382)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 28 05:00:28 PST 2026
Author: Gil Rapaport
Date: 2026-01-28T15:00:23+02:00
New Revision: eb64fb6d1208d7a97039f4b042b9b55afa7da0e6
URL: https://github.com/llvm/llvm-project/commit/eb64fb6d1208d7a97039f4b042b9b55afa7da0e6
DIFF: https://github.com/llvm/llvm-project/commit/eb64fb6d1208d7a97039f4b042b9b55afa7da0e6.diff
LOG: [mlir][emitc] Fix recurring operands in expression (#178382)
Relanding #175535 which got reverted for failing the buildbot.
New canonicalization pattern moved to dialect code.
Added:
Modified:
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
mlir/lib/Dialect/EmitC/IR/EmitC.cpp
mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
mlir/test/Dialect/EmitC/form-expressions.mlir
mlir/test/Dialect/EmitC/invalid_ops.mlir
mlir/test/Dialect/EmitC/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index caed3233f62e9..7d9bb8907eb8b 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -566,6 +566,7 @@ def EmitC_ExpressionOp
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
+ let hasCanonicalizer = 1;
let extraClassDeclaration = [{
bool hasSideEffects() {
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 1d4b748b2a88a..64a475c2e42ab 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -8,10 +8,12 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitCInterfaces.h"
+#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Support/LLVM.h"
@@ -412,6 +414,61 @@ LogicalResult DereferenceOp::verify() {
// ExpressionOp
//===----------------------------------------------------------------------===//
+namespace {
+
+struct RemoveRecurringExpressionOperands
+ : public OpRewritePattern<ExpressionOp> {
+ using OpRewritePattern<ExpressionOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(ExpressionOp expressionOp,
+ PatternRewriter &rewriter) const override {
+ SetVector<Value> uniqueOperands;
+ DenseMap<Value, int> firstIndexOf;
+
+ // Collect duplicate operands and prepare to remove excessive copies.
+ for (auto [i, operand] : llvm::enumerate(expressionOp.getDefs())) {
+ if (uniqueOperands.contains(operand))
+ continue;
+ uniqueOperands.insert(operand);
+ firstIndexOf[operand] = i;
+ }
+
+ // If every operand is unique, bail out.
+ if (uniqueOperands.size() == expressionOp.getDefs().size())
+ return failure();
+
+ // Create a new expression with unique operands.
+ rewriter.setInsertionPointAfter(expressionOp);
+ auto uniqueExpression = emitc::ExpressionOp::create(
+ rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
+ uniqueOperands.getArrayRef(), expressionOp.getDoNotInline());
+ Block &uniqueExpressionBody = uniqueExpression.createBody();
+
+ // Map each original block arguments to the unique block argument taking
+ // the same operand.
+ IRMapping mapper;
+ Block *expressionBody = expressionOp.getBody();
+ for (auto [operand, arg] :
+ llvm::zip(expressionOp.getOperands(), expressionBody->getArguments()))
+ mapper.map(arg, uniqueExpressionBody.getArgument(firstIndexOf[operand]));
+
+ rewriter.setInsertionPointToStart(&uniqueExpressionBody);
+ for (Operation &opToClone : *expressionOp.getBody())
+ rewriter.clone(opToClone, mapper);
+
+ // Complete the rewrite.
+ rewriter.replaceOp(expressionOp, uniqueExpression);
+
+ return success();
+ }
+};
+
+} // namespace
+
+void ExpressionOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<RemoveRecurringExpressionOperands>(context);
+}
+
ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand> operands;
if (parser.parseOperandList(operands))
@@ -435,27 +492,45 @@ ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
"expected single return type");
result.addTypes(fnType.getResults());
Region *body = result.addRegion();
+ DenseSet<Value> uniqueOperands(result.operands.begin(),
+ result.operands.end());
+ bool enableNameShadowing = uniqueOperands.size() == result.operands.size();
SmallVector<OpAsmParser::Argument> argsInfo;
- for (auto [unresolvedOperand, operandType] :
- llvm::zip(operands, fnType.getInputs())) {
- OpAsmParser::Argument argInfo;
- argInfo.ssaName = unresolvedOperand;
- argInfo.type = operandType;
- argsInfo.push_back(argInfo);
+ if (enableNameShadowing) {
+ for (auto [unresolvedOperand, operandType] :
+ llvm::zip(operands, fnType.getInputs())) {
+ OpAsmParser::Argument argInfo;
+ argInfo.ssaName = unresolvedOperand;
+ argInfo.type = operandType;
+ argsInfo.push_back(argInfo);
+ }
}
- if (parser.parseRegion(*body, argsInfo, /*enableNameShadowing=*/true))
+ SMLoc beforeRegionLoc = parser.getCurrentLocation();
+ if (parser.parseRegion(*body, argsInfo, enableNameShadowing))
return failure();
+ if (!enableNameShadowing) {
+ if (body->front().getArguments().size() < result.operands.size()) {
+ return parser.emitError(
+ beforeRegionLoc, "with recurring operands expected block arguments");
+ }
+ }
return success();
}
void emitc::ExpressionOp::print(OpAsmPrinter &p) {
p << ' ';
- p.printOperands(getDefs());
+ auto operands = getDefs();
+ p.printOperands(operands);
p << " : ";
p.printFunctionalType(getOperation());
- p.shadowRegionArgs(getRegion(), getDefs());
+ DenseSet<Value> uniqueOperands(operands.begin(), operands.end());
+ bool printEntryBlockArgs = true;
+ if (uniqueOperands.size() == operands.size()) {
+ p.shadowRegionArgs(getRegion(), getDefs());
+ printEntryBlockArgs = false;
+ }
p << ' ';
- p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+ p.printRegion(getRegion(), printEntryBlockArgs);
}
Operation *ExpressionOp::getRootOp() {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index f8469b8f0ed67..4bcdb285d6a16 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -152,5 +152,6 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
} // namespace
void mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) {
+ ExpressionOp::getCanonicalizationPatterns(patterns, patterns.getContext());
patterns.add<FoldExpressionOp>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/EmitC/form-expressions.mlir b/mlir/test/Dialect/EmitC/form-expressions.mlir
index 7b6723989e260..58eac4381ccb7 100644
--- a/mlir/test/Dialect/EmitC/form-expressions.mlir
+++ b/mlir/test/Dialect/EmitC/form-expressions.mlir
@@ -20,6 +20,25 @@ func.func @single_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) ->
return %c : i1
}
+// CHECK-LABEL: func.func @expression_recurring_args(
+// CHECK-SAME: %[[ARG0:.*]]: i32,
+// CHECK-SAME: %[[ARG1:.*]]: i32) -> i1 {
+// CHECK: %[[EXPRESSION_0:.*]] = emitc.expression %[[ARG1]], %[[ARG0]] : (i32, i32) -> i1 {
+// CHECK: %[[VAL_0:.*]] = mul %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32
+// CHECK: %[[VAL_1:.*]] = sub %[[VAL_0]], %[[ARG0]] : (i32, i32) -> i32
+// CHECK: %[[VAL_2:.*]] = cmp lt, %[[VAL_1]], %[[ARG1]] : (i32, i32) -> i1
+// CHECK: yield %[[VAL_2]] : i1
+// CHECK: }
+// CHECK: return %[[EXPRESSION_0]] : i1
+// CHECK: }
+
+func.func @expression_recurring_args(%arg0: i32, %arg1: i32) -> i1 {
+ %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.sub %a, %arg0 : (i32, i32) -> i32
+ %c = emitc.cmp lt, %b, %arg1 :(i32, i32) -> i1
+ return %c : i1
+}
+
// CHECK-LABEL: func.func @multiple_expressions(
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> (i32, i32) {
// CHECK: %[[VAL_4:.*]] = emitc.expression %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : (i32, i32, i32) -> i32 {
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index d1601bed29ca9..0d878e90cdf0c 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -379,6 +379,19 @@ emitc.func @test_expression_op_outside_expression() {
// -----
+func.func @test_expression_recurring_operands(%arg0: i32, %arg1: i32) -> i32 {
+ // expected-error @+1 {{'emitc.expression' with recurring operands expected block arguments}}
+ %r = emitc.expression %arg0, %arg1, %arg0 : (i32, i32, i32) -> i32 {
+ %a = emitc.rem %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.add %a, %arg0 : (i32, i32) -> i32
+ %c = emitc.mul %b, %a : (i32, i32) -> i32
+ emitc.yield %c : i32
+ }
+ return %r : i32
+}
+
+// -----
+
// expected-error @+1 {{'emitc.func' op requires zero or exactly one result, but has 2}}
emitc.func @multiple_results(%0: i32) -> (i32, i32) {
emitc.return %0 : i32
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index b2c8b843ec14b..2f7544b5db096 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// RUN: mlir-opt %s -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -canonicalize | FileCheck -check-prefix=CANON %s
// CHECK: emitc.include <"test.h">
// CHECK: emitc.include "test.h"
@@ -213,6 +213,28 @@ func.func @test_expression_multiple_uses(%arg0: i32, %arg1: i32) -> i32 {
return %r : i32
}
+// CANON-LABEL: func.func @test_expression_recurring_operands(
+// CANON-SAME: %[[ARG0:.*]]: i32,
+// CANON-SAME: %[[ARG1:.*]]: i32) -> i32 {
+// CANON: %[[EXPRESSION_0:.*]] = emitc.expression %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32 {
+// CANON: %[[VAL_0:.*]] = rem %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32
+// CANON: %[[VAL_1:.*]] = add %[[VAL_0]], %[[ARG0]] : (i32, i32) -> i32
+// CANON: %[[VAL_2:.*]] = mul %[[VAL_1]], %[[VAL_0]] : (i32, i32) -> i32
+// CANON: yield %[[VAL_2]] : i32
+// CANON: }
+// CANON: return %[[EXPRESSION_0]] : i32
+// CANON: }
+func.func @test_expression_recurring_operands(%arg0: i32, %arg1: i32) -> i32 {
+ %r = emitc.expression %arg0, %arg1, %arg0 : (i32, i32, i32) -> i32 {
+ ^bb0(%x: i32, %y: i32, %z: i32):
+ %a = emitc.rem %x, %y : (i32, i32) -> i32
+ %b = emitc.add %a, %z : (i32, i32) -> i32
+ %c = emitc.mul %b, %a : (i32, i32) -> i32
+ emitc.yield %c : i32
+ }
+ return %r : i32
+}
+
func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
emitc.for %i0 = %arg0 to %arg1 step %arg2 {
%0 = emitc.call_opaque "func_const"(%i0) : (index) -> i32
More information about the Mlir-commits
mailing list