[Mlir-commits] [mlir] [MLIR][DRR] Fix inconsistent operand and arg index usage (PR #139816)
Xiaomin Liu
llvmlistbot at llvm.org
Mon May 19 00:08:04 PDT 2025
https://github.com/xl4624 updated https://github.com/llvm/llvm-project/pull/139816
>From 09316e0bd8cadfb3d3feef3ce932d1c38cd9043e Mon Sep 17 00:00:00 2001
From: Xiaomin Liu <xl4624 at nyu.edu>
Date: Tue, 13 May 2025 20:59:57 -0400
Subject: [PATCH 1/2] [MLIR][DRR] Fix inconsistent operand and arg index usage
---
mlir/test/lib/Dialect/Test/TestOps.td | 8 ++++++++
mlir/tools/mlir-tblgen/RewriterGen.cpp | 14 +++++++-------
2 files changed, 15 insertions(+), 7 deletions(-)
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 43a0bdaf86cf3..1d0849e479d37 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1872,6 +1872,11 @@ def TestEitherOpB : TEST_Op<"either_op_b"> {
let results = (outs I32:$output);
}
+def TestEitherOpC : TEST_Op<"either_op_c"> {
+ let arguments = (ins AnyI32Attr:$attr, AnyInteger:$arg0, AnyInteger:$arg1);
+ let results = (outs I32:$output);
+}
+
def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $x),
(TestEitherOpB $arg2, $x)>;
@@ -1883,6 +1888,9 @@ def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_),
$x),
(TestEitherOpB $arg2, $x)>;
+def : Pat<(TestEitherOpC ConstantAttr<I32Attr, "0">, (either $arg1, I32:$arg2)),
+ (TestEitherOpB $arg1, $arg2)>;
+
def TestEitherHelperOpA : TEST_Op<"either_helper_op_a"> {
let arguments = (ins I32:$arg0);
let results = (outs I32:$output);
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 58abcc2bee895..75721c89793b5 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -658,7 +658,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
if (isa<NamedTypeConstraint *>(opArg)) {
auto operandName =
formatv("{0}.getODSOperands({1})", castedName, nextOperand);
- emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
+ emitOperandMatch(tree, castedName, operandName.str(), nextOperand,
/*operandMatcher=*/tree.getArgAsLeaf(i),
/*argName=*/tree.getArgName(i), opArgIdx,
/*variadicSubIndex=*/std::nullopt);
@@ -680,7 +680,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
int argIndex,
std::optional<int> variadicSubIndex) {
Operator &op = tree.getDialectOp(opMap);
- auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));
+ NamedTypeConstraint operand = op.getOperand(operandIndex);
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
@@ -693,8 +693,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
// Only need to verify if the matcher's type is different from the one
// of op definition.
Constraint constraint = operandMatcher.getAsConstraint();
- if (operand->constraint != constraint) {
- if (operand->isVariableLength()) {
+ if (operand.constraint != constraint) {
+ if (operand.isVariableLength()) {
auto error = formatv(
"further constrain op {0}'s variadic operand #{1} unsupported now",
op.getOperationName(), argIndex);
@@ -706,7 +706,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
verifier, opName, self.str(),
formatv(
"\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
- operand - op.operand_begin(), op.getOperationName(),
+ operandIndex, op.getOperationName(),
escapeString(constraint.getSummary()))
.str());
}
@@ -715,7 +715,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
// Capture the value
// `$_` is a special symbol to ignore op argument matching.
if (!argName.empty() && argName != "_") {
- auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
+ auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex,
variadicSubIndex);
if (res == symbolInfoMap.end())
PrintFatalError(loc, formatv("symbol not found: {0}", argName));
@@ -821,7 +821,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
StringRef variadicTreeName = variadicArgTree.getSymbol();
if (!variadicTreeName.empty()) {
auto res =
- symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
+ symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, argIndex,
/*variadicSubIndex=*/std::nullopt);
if (res == symbolInfoMap.end())
PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));
>From b17a250a05cd278fe44230a08112aee3183ab94e Mon Sep 17 00:00:00 2001
From: Xiaomin Liu <xl4624 at nyu.edu>
Date: Mon, 19 May 2025 03:07:54 -0400
Subject: [PATCH 2/2] amend! [MLIR][DRR] Fix inconsistent operand and arg index
usage
[MLIR][DRR] Fix inconsistent operand and arg index usage
---
mlir/test/mlir-tblgen/pattern.mlir | 31 ++++++++++++++++++++++++------
1 file changed, 25 insertions(+), 6 deletions(-)
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 90905280c0796..27598fb63a6c8 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -609,8 +609,8 @@ func.func @redundantTest(%arg0: i32) -> i32 {
// Test either directive
//===----------------------------------------------------------------------===//
-// CHECK: @either_dag_leaf_only
-func.func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
+// CHECK-LABEL: @eitherDagLeafOnly
+func.func @eitherDagLeafOnly(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
// CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
%0 = "test.either_op_a"(%arg0, %arg1, %arg2) : (i32, i16, i8) -> i32
// CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
@@ -618,8 +618,8 @@ func.func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
return
}
-// CHECK: @either_dag_leaf_dag_node
-func.func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
+// CHECK-LABEL: @eitherDagLeafDagNode
+func.func @eitherDagLeafDagNode(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
%0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32
// CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
%1 = "test.either_op_a"(%0, %arg1, %arg2) : (i32, i16, i8) -> i32
@@ -628,8 +628,8 @@ func.func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> ()
return
}
-// CHECK: @either_dag_node_dag_node
-func.func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
+// CHECK-LABEL: @eitherDagNodeDagNode
+func.func @eitherDagNodeDagNode(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
%0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32
%1 = "test.either_op_b"(%arg1, %arg1) : (i16, i16) -> i32
// CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
@@ -639,10 +639,22 @@ func.func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> ()
return
}
+// CHECK-LABEL: @testEitherOpWithAttr
+func.func @testEitherOpWithAttr(%arg0 : i32, %arg1 : i16) -> () {
+ // CHECK: "test.either_op_b"(%arg1, %arg0) : (i16, i32) -> i32
+ %0 = "test.either_op_c"(%arg0, %arg1) {attr = 0 : i32} : (i32, i16) -> i32
+ // CHECK: "test.either_op_b"(%arg1, %arg0) : (i16, i32) -> i32
+ %1 = "test.either_op_c"(%arg1, %arg0) {attr = 0 : i32} : (i16, i32) -> i32
+ // CHECK: "test.either_op_c"(%arg0, %arg1) <{attr = 1 : i32}> : (i32, i16) -> i32
+ %2 = "test.either_op_c"(%arg0, %arg1) {attr = 1 : i32} : (i32, i16) -> i32
+ return
+}
+
//===----------------------------------------------------------------------===//
// Test that ops without type deduction can be created with type builders.
//===----------------------------------------------------------------------===//
+// CHECK-LABEL: @explicitReturnTypeTest
func.func @explicitReturnTypeTest(%arg0 : i64) -> i8 {
%0 = "test.source_op"(%arg0) {tag = 11 : i32} : (i64) -> i8
// CHECK: "test.op_x"(%arg0) : (i64) -> i32
@@ -650,6 +662,7 @@ func.func @explicitReturnTypeTest(%arg0 : i64) -> i8 {
return %0 : i8
}
+// CHECK-LABEL: @returnTypeBuilderTest
func.func @returnTypeBuilderTest(%arg0 : i1) -> i8 {
%0 = "test.source_op"(%arg0) {tag = 22 : i32} : (i1) -> i8
// CHECK: "test.op_x"(%arg0) : (i1) -> i1
@@ -657,6 +670,7 @@ func.func @returnTypeBuilderTest(%arg0 : i1) -> i8 {
return %0 : i8
}
+// CHECK-LABEL: @multipleReturnTypeBuildTest
func.func @multipleReturnTypeBuildTest(%arg0 : i1) -> i1 {
%0 = "test.source_op"(%arg0) {tag = 33 : i32} : (i1) -> i1
// CHECK: "test.one_to_two"(%arg0) : (i1) -> (i64, i32)
@@ -666,6 +680,7 @@ func.func @multipleReturnTypeBuildTest(%arg0 : i1) -> i1 {
return %0 : i1
}
+// CHECK-LABEL: @copyValueType
func.func @copyValueType(%arg0 : i8) -> i32 {
%0 = "test.source_op"(%arg0) {tag = 44 : i32} : (i8) -> i32
// CHECK: "test.op_x"(%arg0) : (i8) -> i8
@@ -673,6 +688,7 @@ func.func @copyValueType(%arg0 : i8) -> i32 {
return %0 : i32
}
+// CHECK-LABEL: @multipleReturnTypeDifferent
func.func @multipleReturnTypeDifferent(%arg0 : i1) -> i64 {
%0 = "test.source_op"(%arg0) {tag = 55 : i32} : (i1) -> i64
// CHECK: "test.one_to_two"(%arg0) : (i1) -> (i1, i64)
@@ -684,6 +700,7 @@ func.func @multipleReturnTypeDifferent(%arg0 : i1) -> i64 {
// Test that multiple trailing directives can be mixed in patterns.
//===----------------------------------------------------------------------===//
+// CHECK-LABEL: @returnTypeAndLocation
func.func @returnTypeAndLocation(%arg0 : i32) -> i1 {
%0 = "test.source_op"(%arg0) {tag = 66 : i32} : (i32) -> i1
// CHECK: "test.op_x"(%arg0) : (i32) -> i32 loc("loc1")
@@ -696,6 +713,7 @@ func.func @returnTypeAndLocation(%arg0 : i32) -> i1 {
// Test that patterns can create ConstantStrAttr
//===----------------------------------------------------------------------===//
+// CHECK-LABEL: @testConstantStrAttr
func.func @testConstantStrAttr() -> () {
// CHECK: test.has_str_value {value = "foo"}
test.no_str_value {value = "bar"}
@@ -706,6 +724,7 @@ func.func @testConstantStrAttr() -> () {
// Test that patterns with variadics propagate sizes
//===----------------------------------------------------------------------===//
+// CHECK-LABEL: @testVariadic
func.func @testVariadic(%arg_0: i32, %arg_1: i32, %brg: i64,
%crg_0: f32, %crg_1: f32, %crg_2: f32, %crg_3: f32) -> () {
// CHECK: "test.variadic_rewrite_dst_op"(%arg2, %arg3, %arg4, %arg5, %arg6, %arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 4, 2>}> : (i64, f32, f32, f32, f32, i32, i32) -> ()
More information about the Mlir-commits
mailing list