[Mlir-commits] [mlir] bae8e1f - [MLIR][DRR] Fix inconsistent operand and arg index usage (#139816)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 23 03:03:26 PDT 2025


Author: Xiaomin Liu
Date: 2025-05-23T12:03:23+02:00
New Revision: bae8e1f99e0b245ec31912e29b4c80e823f635c6

URL: https://github.com/llvm/llvm-project/commit/bae8e1f99e0b245ec31912e29b4c80e823f635c6
DIFF: https://github.com/llvm/llvm-project/commit/bae8e1f99e0b245ec31912e29b4c80e823f635c6.diff

LOG: [MLIR][DRR] Fix inconsistent operand and arg index usage (#139816)

Background issue: #139813

In
[emitEitherOperandMatch()](https://github.com/llvm/llvm-project/blob/e62fc14a5d214f801758b35bdcad0c8efc65e8b8/mlir/tools/mlir-tblgen/RewriterGen.cpp#L774)
we check if `op.getArg(argIndex)` is a `NamedTypeConstraint`:

```cpp
} else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
      emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
                       operandIndex,
                       /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
                       /*argName=*/eitherArgTree.getArgName(i), argIndex,
                       /*variadicSubIndex=*/std::nullopt);
      ++operandIndex;
}
```

but in `emitOperandMatch()` we cast on `op.getArg(operandIndex)`, which
is incorrect if the operation has attributes or other non-operand
arguments before its operands.

Added: 
    

Modified: 
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index cdc0f393b4761..3161d2d37e090 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1873,6 +1873,11 @@ def TestEitherOpB : TEST_Op<"either_op_b"> {
   let results = (outs I32:$output);
 }
 
+def TestEitherOpC : TEST_Op<"either_op_c"> {
+  let arguments = (ins AnyI32Attr:$attr, AnyInteger:$arg0, AnyInteger:$arg1);
+  let results = (outs I32:$output);
+}
+
 def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $x),
           (TestEitherOpB $arg2, $x)>;
 
@@ -1884,6 +1889,9 @@ def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_),
                           $x),
           (TestEitherOpB $arg2, $x)>;
 
+def : Pat<(TestEitherOpC ConstantAttr<I32Attr, "0">, (either $arg1, I32:$arg2)),
+          (TestEitherOpB $arg1, $arg2)>;
+
 def TestEitherHelperOpA : TEST_Op<"either_helper_op_a"> {
   let arguments = (ins I32:$arg0);
   let results = (outs I32:$output);

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 90905280c0796..27598fb63a6c8 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -609,8 +609,8 @@ func.func @redundantTest(%arg0: i32) -> i32 {
 // Test either directive
 //===----------------------------------------------------------------------===//
 
-// CHECK: @either_dag_leaf_only
-func.func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
+// CHECK-LABEL: @eitherDagLeafOnly
+func.func @eitherDagLeafOnly(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
   // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
   %0 = "test.either_op_a"(%arg0, %arg1, %arg2) : (i32, i16, i8) -> i32
   // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
@@ -618,8 +618,8 @@ func.func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
   return
 }
 
-// CHECK: @either_dag_leaf_dag_node
-func.func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
+// CHECK-LABEL: @eitherDagLeafDagNode
+func.func @eitherDagLeafDagNode(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
   %0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32
   // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
   %1 = "test.either_op_a"(%0, %arg1, %arg2) : (i32, i16, i8) -> i32
@@ -628,8 +628,8 @@ func.func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> ()
   return
 }
 
-// CHECK: @either_dag_node_dag_node
-func.func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
+// CHECK-LABEL: @eitherDagNodeDagNode
+func.func @eitherDagNodeDagNode(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
   %0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32
   %1 = "test.either_op_b"(%arg1, %arg1) : (i16, i16) -> i32
   // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32
@@ -639,10 +639,22 @@ func.func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> ()
   return
 }
 
+// CHECK-LABEL: @testEitherOpWithAttr
+func.func @testEitherOpWithAttr(%arg0 : i32, %arg1 : i16) -> () {
+  // CHECK: "test.either_op_b"(%arg1, %arg0) : (i16, i32) -> i32
+  %0 = "test.either_op_c"(%arg0, %arg1) {attr = 0 : i32} : (i32, i16) -> i32
+  // CHECK: "test.either_op_b"(%arg1, %arg0) : (i16, i32) -> i32
+  %1 = "test.either_op_c"(%arg1, %arg0) {attr = 0 : i32} : (i16, i32) -> i32
+  // CHECK: "test.either_op_c"(%arg0, %arg1) <{attr = 1 : i32}> : (i32, i16) -> i32
+  %2 = "test.either_op_c"(%arg0, %arg1) {attr = 1 : i32} : (i32, i16) -> i32
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // Test that ops without type deduction can be created with type builders.
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: @explicitReturnTypeTest
 func.func @explicitReturnTypeTest(%arg0 : i64) -> i8 {
   %0 = "test.source_op"(%arg0) {tag = 11 : i32} : (i64) -> i8
   // CHECK: "test.op_x"(%arg0) : (i64) -> i32
@@ -650,6 +662,7 @@ func.func @explicitReturnTypeTest(%arg0 : i64) -> i8 {
   return %0 : i8
 }
 
+// CHECK-LABEL: @returnTypeBuilderTest
 func.func @returnTypeBuilderTest(%arg0 : i1) -> i8 {
   %0 = "test.source_op"(%arg0) {tag = 22 : i32} : (i1) -> i8
   // CHECK: "test.op_x"(%arg0) : (i1) -> i1
@@ -657,6 +670,7 @@ func.func @returnTypeBuilderTest(%arg0 : i1) -> i8 {
   return %0 : i8
 }
 
+// CHECK-LABEL: @multipleReturnTypeBuildTest
 func.func @multipleReturnTypeBuildTest(%arg0 : i1) -> i1 {
   %0 = "test.source_op"(%arg0) {tag = 33 : i32} : (i1) -> i1
   // CHECK: "test.one_to_two"(%arg0) : (i1) -> (i64, i32)
@@ -666,6 +680,7 @@ func.func @multipleReturnTypeBuildTest(%arg0 : i1) -> i1 {
   return %0 : i1
 }
 
+// CHECK-LABEL: @copyValueType
 func.func @copyValueType(%arg0 : i8) -> i32 {
   %0 = "test.source_op"(%arg0) {tag = 44 : i32} : (i8) -> i32
   // CHECK: "test.op_x"(%arg0) : (i8) -> i8
@@ -673,6 +688,7 @@ func.func @copyValueType(%arg0 : i8) -> i32 {
   return %0 : i32
 }
 
+// CHECK-LABEL: @multipleReturnTypeDifferent
 func.func @multipleReturnTypeDifferent(%arg0 : i1) -> i64 {
   %0 = "test.source_op"(%arg0) {tag = 55 : i32} : (i1) -> i64
   // CHECK: "test.one_to_two"(%arg0) : (i1) -> (i1, i64)
@@ -684,6 +700,7 @@ func.func @multipleReturnTypeDifferent(%arg0 : i1) -> i64 {
 // Test that multiple trailing directives can be mixed in patterns.
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: @returnTypeAndLocation
 func.func @returnTypeAndLocation(%arg0 : i32) -> i1 {
   %0 = "test.source_op"(%arg0) {tag = 66 : i32} : (i32) -> i1
   // CHECK: "test.op_x"(%arg0) : (i32) -> i32 loc("loc1")
@@ -696,6 +713,7 @@ func.func @returnTypeAndLocation(%arg0 : i32) -> i1 {
 // Test that patterns can create ConstantStrAttr
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: @testConstantStrAttr
 func.func @testConstantStrAttr() -> () {
   // CHECK: test.has_str_value {value = "foo"}
   test.no_str_value {value = "bar"}
@@ -706,6 +724,7 @@ func.func @testConstantStrAttr() -> () {
 // Test that patterns with variadics propagate sizes
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: @testVariadic
 func.func @testVariadic(%arg_0: i32, %arg_1: i32, %brg: i64,
     %crg_0: f32, %crg_1: f32, %crg_2: f32, %crg_3: f32) -> () {
   // CHECK: "test.variadic_rewrite_dst_op"(%arg2, %arg3, %arg4, %arg5, %arg6, %arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 4, 2>}> : (i64, f32, f32, f32, f32, i32, i32) -> ()

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 58abcc2bee895..75721c89793b5 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -658,7 +658,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
     if (isa<NamedTypeConstraint *>(opArg)) {
       auto operandName =
           formatv("{0}.getODSOperands({1})", castedName, nextOperand);
-      emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
+      emitOperandMatch(tree, castedName, operandName.str(), nextOperand,
                        /*operandMatcher=*/tree.getArgAsLeaf(i),
                        /*argName=*/tree.getArgName(i), opArgIdx,
                        /*variadicSubIndex=*/std::nullopt);
@@ -680,7 +680,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
                                       int argIndex,
                                       std::optional<int> variadicSubIndex) {
   Operator &op = tree.getDialectOp(opMap);
-  auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));
+  NamedTypeConstraint operand = op.getOperand(operandIndex);
 
   // If a constraint is specified, we need to generate C++ statements to
   // check the constraint.
@@ -693,8 +693,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
     // Only need to verify if the matcher's type is 
diff erent from the one
     // of op definition.
     Constraint constraint = operandMatcher.getAsConstraint();
-    if (operand->constraint != constraint) {
-      if (operand->isVariableLength()) {
+    if (operand.constraint != constraint) {
+      if (operand.isVariableLength()) {
         auto error = formatv(
             "further constrain op {0}'s variadic operand #{1} unsupported now",
             op.getOperationName(), argIndex);
@@ -706,7 +706,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
           verifier, opName, self.str(),
           formatv(
               "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
-              operand - op.operand_begin(), op.getOperationName(),
+              operandIndex, op.getOperationName(),
               escapeString(constraint.getSummary()))
               .str());
     }
@@ -715,7 +715,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
   // Capture the value
   // `$_` is a special symbol to ignore op argument matching.
   if (!argName.empty() && argName != "_") {
-    auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
+    auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex,
                                              variadicSubIndex);
     if (res == symbolInfoMap.end())
       PrintFatalError(loc, formatv("symbol not found: {0}", argName));
@@ -821,7 +821,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
   StringRef variadicTreeName = variadicArgTree.getSymbol();
   if (!variadicTreeName.empty()) {
     auto res =
-        symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
+        symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, argIndex,
                                       /*variadicSubIndex=*/std::nullopt);
     if (res == symbolInfoMap.end())
       PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));


        


More information about the Mlir-commits mailing list