[Mlir-commits] [mlir] 076d3e2 - [mlir][ods] Verify access to operands in inferReturnTypes (#112574)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 22 21:20:56 PDT 2024


Author: Nikolay Panchenko
Date: 2024-10-23T00:20:53-04:00
New Revision: 076d3e232681d50aca96eaeabebd17e68ff6f7e7

URL: https://github.com/llvm/llvm-project/commit/076d3e232681d50aca96eaeabebd17e68ff6f7e7
DIFF: https://github.com/llvm/llvm-project/commit/076d3e232681d50aca96eaeabebd17e68ff6f7e7.diff

LOG: [mlir][ods] Verify access to operands in inferReturnTypes (#112574)

The patch adds graceful handling of incorrectly constructed MLIR
operation with less operands than expected.

Added: 
    

Modified: 
    mlir/test/mlir-tblgen/op-result.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index 0ca570cf8cafba..51f8b0671a328d 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -130,6 +130,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint",
 
 // CHECK-LABEL: LogicalResult OpL1::inferReturnTypes
 // CHECK-NOT: }
+// CHECK: if (operands.size() <= 0)
+// CHECK-NEXT: return ::mlir::failure();
 // CHECK: ::mlir::Type odsInferredType0 = operands[0].getType();
 // CHECK: inferredReturnTypes[0] = odsInferredType0;
 
@@ -141,6 +143,9 @@ def OpL2 : NS_Op<"op_with_all_types_constraint",
 
 // CHECK-LABEL: LogicalResult OpL2::inferReturnTypes
 // CHECK-NOT: }
+// CHECK: if (operands.size() <= 2)
+// CHECK-NEXT: return ::mlir::failure();
+// CHECK-NOT: if (operands.size() <= 0)
 // CHECK: ::mlir::Type odsInferredType0 = operands[2].getType();
 // CHECK: ::mlir::Type odsInferredType1 = operands[0].getType();
 // CHECK: inferredReturnTypes[0] = odsInferredType0;
@@ -166,6 +171,8 @@ def OpL4 : NS_Op<"two_inference_edges", [
 }
 
 // CHECK-LABEL: LogicalResult OpL4::inferReturnTypes
+// CHECK: if (operands.size() <= 0)
+// CHECK-NEXT: return ::mlir::failure();
 // CHECK: odsInferredType0 = fromInput(operands[0].getType())
 // CHECK: odsInferredType1 = infer0(odsInferredType0)
 // CHECK: odsInferredType2 = infer1(odsInferredType1)

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index dea6fb209863ce..9badb7aa163a60 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -3584,6 +3584,24 @@ void OpEmitter::genTypeInterfaceMethods() {
   fctx.addSubst("_ctxt", "context");
   body << "  ::mlir::Builder odsBuilder(context);\n";
 
+  // Preprocessing stage to verify all accesses to operands are valid.
+  int maxAccessedIndex = -1;
+  for (int i = 0, e = op.getNumResults(); i != e; ++i) {
+    const InferredResultType &infer = op.getInferredResultType(i);
+    if (!infer.isArg())
+      continue;
+    Operator::OperandOrAttribute arg =
+        op.getArgToOperandOrAttribute(infer.getIndex());
+    if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
+      maxAccessedIndex =
+          std::max(maxAccessedIndex, arg.operandOrAttributeIndex());
+    }
+  }
+  if (maxAccessedIndex != -1) {
+    body << "  if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n";
+    body << "    return ::mlir::failure();\n";
+  }
+
   // Process the type inference graph in topological order, starting from types
   // that are always fully-inferred: operands and results with constructible
   // types. The type inference graph here will always be a DAG, so this gives
@@ -3600,7 +3618,8 @@ void OpEmitter::genTypeInterfaceMethods() {
       if (infer.isArg()) {
         // If this is an operand, just index into operand list to access the
         // type.
-        auto arg = op.getArgToOperandOrAttribute(infer.getIndex());
+        Operator::OperandOrAttribute arg =
+            op.getArgToOperandOrAttribute(infer.getIndex());
         if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
           typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
                      "].getType()")


        


More information about the Mlir-commits mailing list