[Mlir-commits] [mlir] [mlir][ods] Verify access to operands in inferReturnTypes (PR #112574)
Nikolay Panchenko
llvmlistbot at llvm.org
Wed Oct 16 16:48:12 PDT 2024
https://github.com/npanchen updated https://github.com/llvm/llvm-project/pull/112574
>From e686f5aa3e1119d53ec77abef401be1a0364e739 Mon Sep 17 00:00:00 2001
From: Kolya Panchenko <npanchen at modular.com>
Date: Wed, 16 Oct 2024 11:51:27 -0400
Subject: [PATCH 1/3] [mlir][ods] Verfify access to operands in
inferReturnTypes
The patch adds graceful handling of incorrectly constructed MLIR
operation with less operands than expected.
---
mlir/test/mlir-tblgen/op-result.td | 7 +++++++
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 16 ++++++++++++++++
2 files changed, 23 insertions(+)
diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index 0ca570cf8cafba..51f8b0671a328d 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -130,6 +130,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint",
// CHECK-LABEL: LogicalResult OpL1::inferReturnTypes
// CHECK-NOT: }
+// CHECK: if (operands.size() <= 0)
+// CHECK-NEXT: return ::mlir::failure();
// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType();
// CHECK: inferredReturnTypes[0] = odsInferredType0;
@@ -141,6 +143,9 @@ def OpL2 : NS_Op<"op_with_all_types_constraint",
// CHECK-LABEL: LogicalResult OpL2::inferReturnTypes
// CHECK-NOT: }
+// CHECK: if (operands.size() <= 2)
+// CHECK-NEXT: return ::mlir::failure();
+// CHECK-NOT: if (operands.size() <= 0)
// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType();
// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType();
// CHECK: inferredReturnTypes[0] = odsInferredType0;
@@ -166,6 +171,8 @@ def OpL4 : NS_Op<"two_inference_edges", [
}
// CHECK-LABEL: LogicalResult OpL4::inferReturnTypes
+// CHECK: if (operands.size() <= 0)
+// CHECK-NEXT: return ::mlir::failure();
// CHECK: odsInferredType0 = fromInput(operands[0].getType())
// CHECK: odsInferredType1 = infer0(odsInferredType0)
// CHECK: odsInferredType2 = infer1(odsInferredType1)
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index ce2b6ed94c3949..c55a00cf08a7cc 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -3583,6 +3583,22 @@ void OpEmitter::genTypeInterfaceMethods() {
fctx.addSubst("_ctxt", "context");
body << " ::mlir::Builder odsBuilder(context);\n";
+ // Preprocessing stage to verify all accesses to operands are valid.
+ int maxAccessedIndex = -1;
+ for (int i = 0, e = op.getNumResults(); i != e; ++i) {
+ const InferredResultType &infer = op.getInferredResultType(i);
+ if (!infer.isArg())
+ continue;
+ auto arg = op.getArgToOperandOrAttribute(infer.getIndex());
+ if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
+ maxAccessedIndex =
+ std::max(maxAccessedIndex, arg.operandOrAttributeIndex());
+ }
+ if (maxAccessedIndex != -1) {
+ body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n";
+ body << " return ::mlir::failure();\n";
+ }
+
// Process the type inference graph in topological order, starting from types
// that are always fully-inferred: operands and results with constructible
// types. The type inference graph here will always be a DAG, so this gives
>From f6232ef4771daf5fb26a1efe170e58502af709bf Mon Sep 17 00:00:00 2001
From: Kolya Panchenko <npanchen at modular.com>
Date: Wed, 16 Oct 2024 18:28:35 -0400
Subject: [PATCH 2/3] [NFC] added {} for 2-line if-statement
---
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index c55a00cf08a7cc..d1767f09086711 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -3590,9 +3590,10 @@ void OpEmitter::genTypeInterfaceMethods() {
if (!infer.isArg())
continue;
auto arg = op.getArgToOperandOrAttribute(infer.getIndex());
- if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
+ if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
maxAccessedIndex =
std::max(maxAccessedIndex, arg.operandOrAttributeIndex());
+ }
}
if (maxAccessedIndex != -1) {
body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n";
>From fdfb7647ee2bb52db18b8ac31d5bb42d4c1e3b60 Mon Sep 17 00:00:00 2001
From: Kolya Panchenko <npanchen at modular.com>
Date: Wed, 16 Oct 2024 19:47:28 -0400
Subject: [PATCH 3/3] replaced auto with Operator::OperandOrAttribute
---
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index d1767f09086711..d466c4d47ee6f1 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -3589,7 +3589,8 @@ void OpEmitter::genTypeInterfaceMethods() {
const InferredResultType &infer = op.getInferredResultType(i);
if (!infer.isArg())
continue;
- auto arg = op.getArgToOperandOrAttribute(infer.getIndex());
+ Operator::OperandOrAttribute arg =
+ op.getArgToOperandOrAttribute(infer.getIndex());
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
maxAccessedIndex =
std::max(maxAccessedIndex, arg.operandOrAttributeIndex());
@@ -3616,7 +3617,8 @@ void OpEmitter::genTypeInterfaceMethods() {
if (infer.isArg()) {
// If this is an operand, just index into operand list to access the
// type.
- auto arg = op.getArgToOperandOrAttribute(infer.getIndex());
+ Operator::OperandOrAttribute arg =
+ op.getArgToOperandOrAttribute(infer.getIndex());
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
"].getType()")
More information about the Mlir-commits
mailing list