[Mlir-commits] [mlir] db9df43 - [mlir-tblgen] Avoid ODS verifier duplication
Chia-hung Duan
llvmlistbot at llvm.org
Sun Jul 4 19:13:44 PDT 2021
Author: Chia-hung Duan
Date: 2021-07-05T10:09:41+08:00
New Revision: db9df434fae905fba6e02becdb266ae5c143540c
URL: https://github.com/llvm/llvm-project/commit/db9df434fae905fba6e02becdb266ae5c143540c
DIFF: https://github.com/llvm/llvm-project/commit/db9df434fae905fba6e02becdb266ae5c143540c.diff
LOG: [mlir-tblgen] Avoid ODS verifier duplication
Different constraints may share the same predicate, in this case, we
will generate duplicate ODS verification function.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D104369
Added:
Modified:
mlir/include/mlir/TableGen/Predicate.h
mlir/test/mlir-tblgen/predicate.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/TableGen/Predicate.h b/mlir/include/mlir/TableGen/Predicate.h
index 7caea7c561f8b..e23c751e52d09 100644
--- a/mlir/include/mlir/TableGen/Predicate.h
+++ b/mlir/include/mlir/TableGen/Predicate.h
@@ -14,6 +14,7 @@
#define MLIR_TABLEGEN_PREDICATE_H_
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/Hashing.h"
#include <string>
#include <vector>
@@ -59,6 +60,8 @@ class Pred {
ArrayRef<llvm::SMLoc> getLoc() const;
protected:
+ friend llvm::DenseMapInfo<Pred>;
+
// The TableGen definition of this predicate.
const llvm::Record *def;
};
@@ -116,4 +119,18 @@ class ConcatPred : public CombinedPred {
} // end namespace tblgen
} // end namespace mlir
+namespace llvm {
+template <>
+struct DenseMapInfo<mlir::tblgen::Pred> {
+ static mlir::tblgen::Pred getEmptyKey() { return mlir::tblgen::Pred(); }
+ static mlir::tblgen::Pred getTombstoneKey() { return mlir::tblgen::Pred(); }
+ static unsigned getHashValue(mlir::tblgen::Pred pred) {
+ return llvm::hash_value(pred.def);
+ }
+ static bool isEqual(mlir::tblgen::Pred lhs, mlir::tblgen::Pred rhs) {
+ return lhs == rhs;
+ }
+};
+} // end namespace llvm
+
#endif // MLIR_TABLEGEN_PREDICATE_H_
diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index 386c61319b79e..f8c1b1d90b3a3 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -13,19 +13,24 @@ def I32OrF32 : Type<CPred<"$_self.isInteger(32) || $_self.isF32()">,
def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
let arguments = (ins I32OrF32:$x);
+ let results = (outs Variadic<I32OrF32>:$y);
}
// CHECK: static ::mlir::LogicalResult [[$INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
-// CHECK: if (!((type.isInteger(32) || type.isF32()))) {
-// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
+// CHECK-NEXT: if (!((type.isInteger(32) || type.isF32()))) {
+// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
+
+// Check there is no verifier with same predicate generated.
+// CHECK-NOT: if (!((type.isInteger(32) || type.isF32()))) {
+// CHECK-NOT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
// CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
-// CHECK: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) {
-// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type;
+// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) {
+// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type;
// CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
-// CHECK: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) {
-// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
+// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) {
+// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
// CHECK-LABEL: OpA::verify
// CHECK: auto valueGroup0 = getODSOperands(0);
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 664fc2de788ea..2bc9ea465f5e0 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -216,19 +216,50 @@ void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
typeConstraints.insert(result.constraint.getAsOpaquePointer());
}
+ // Record the mapping from predicate to constraint. If two constraints has the
+ // same predicate and constraint summary, they can share the same verification
+ // function.
+ llvm::DenseMap<Pred, const void *> predToConstraint;
FmtContext fctx;
for (auto it : llvm::enumerate(typeConstraints)) {
+ std::string name;
+ Constraint constraint = Constraint::getFromOpaquePointer(it.value());
+ Pred pred = constraint.getPredicate();
+ auto iter = predToConstraint.find(pred);
+ if (iter != predToConstraint.end()) {
+ do {
+ Constraint built = Constraint::getFromOpaquePointer(iter->second);
+ // We may have the
diff erent constraints but have the same predicate,
+ // for example, ConstraintA and Variadic<ConstraintA>, note that
+ // Variadic<> doesn't introduce new predicate. In this case, we can
+ // share the same predicate function if they also have consistent
+ // summary, otherwise we may report the wrong message while verification
+ // fails.
+ if (constraint.getSummary() == built.getSummary()) {
+ name = getTypeConstraintFn(built).str();
+ break;
+ }
+ ++iter;
+ } while (iter != predToConstraint.end() && iter->first == pred);
+ }
+
+ if (!name.empty()) {
+ localTypeConstraints.try_emplace(it.value(), name);
+ continue;
+ }
+
// Generate an obscure and unique name for this type constraint.
- std::string name = (Twine("__mlir_ods_local_type_constraint_") +
- uniqueOutputLabel + Twine(it.index()))
- .str();
+ name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
+ Twine(it.index()))
+ .str();
+ predToConstraint.insert(
+ std::make_pair(constraint.getPredicate(), it.value()));
localTypeConstraints.try_emplace(it.value(), name);
// Only generate the methods if we are generating definitions.
if (emitDecl)
continue;
- Constraint constraint = Constraint::getFromOpaquePointer(it.value());
os << "static ::mlir::LogicalResult " << name
<< "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
"valueKind, unsigned valueGroupStartIndex) {\n";
More information about the Mlir-commits
mailing list