[Mlir-commits] [mlir] db9df43 - [mlir-tblgen] Avoid ODS verifier duplication

Chia-hung Duan llvmlistbot at llvm.org
Sun Jul 4 19:13:44 PDT 2021


Author: Chia-hung Duan
Date: 2021-07-05T10:09:41+08:00
New Revision: db9df434fae905fba6e02becdb266ae5c143540c

URL: https://github.com/llvm/llvm-project/commit/db9df434fae905fba6e02becdb266ae5c143540c
DIFF: https://github.com/llvm/llvm-project/commit/db9df434fae905fba6e02becdb266ae5c143540c.diff

LOG: [mlir-tblgen] Avoid ODS verifier duplication

Different constraints may share the same predicate, in this case, we
will generate duplicate ODS verification function.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D104369

Added: 
    

Modified: 
    mlir/include/mlir/TableGen/Predicate.h
    mlir/test/mlir-tblgen/predicate.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/Predicate.h b/mlir/include/mlir/TableGen/Predicate.h
index 7caea7c561f8b..e23c751e52d09 100644
--- a/mlir/include/mlir/TableGen/Predicate.h
+++ b/mlir/include/mlir/TableGen/Predicate.h
@@ -14,6 +14,7 @@
 #define MLIR_TABLEGEN_PREDICATE_H_
 
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/Hashing.h"
 
 #include <string>
 #include <vector>
@@ -59,6 +60,8 @@ class Pred {
   ArrayRef<llvm::SMLoc> getLoc() const;
 
 protected:
+  friend llvm::DenseMapInfo<Pred>;
+
   // The TableGen definition of this predicate.
   const llvm::Record *def;
 };
@@ -116,4 +119,18 @@ class ConcatPred : public CombinedPred {
 } // end namespace tblgen
 } // end namespace mlir
 
+namespace llvm {
+template <>
+struct DenseMapInfo<mlir::tblgen::Pred> {
+  static mlir::tblgen::Pred getEmptyKey() { return mlir::tblgen::Pred(); }
+  static mlir::tblgen::Pred getTombstoneKey() { return mlir::tblgen::Pred(); }
+  static unsigned getHashValue(mlir::tblgen::Pred pred) {
+    return llvm::hash_value(pred.def);
+  }
+  static bool isEqual(mlir::tblgen::Pred lhs, mlir::tblgen::Pred rhs) {
+    return lhs == rhs;
+  }
+};
+} // end namespace llvm
+
 #endif // MLIR_TABLEGEN_PREDICATE_H_

diff  --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index 386c61319b79e..f8c1b1d90b3a3 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -13,19 +13,24 @@ def I32OrF32 : Type<CPred<"$_self.isInteger(32) || $_self.isF32()">,
 
 def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
   let arguments = (ins I32OrF32:$x);
+  let results = (outs Variadic<I32OrF32>:$y);
 }
 
 // CHECK: static ::mlir::LogicalResult [[$INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
-// CHECK:  if (!((type.isInteger(32) || type.isF32()))) {
-// CHECK:    return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
+// CHECK-NEXT:  if (!((type.isInteger(32) || type.isF32()))) {
+// CHECK-NEXT:    return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
+
+// Check there is no verifier with same predicate generated.
+// CHECK-NOT:  if (!((type.isInteger(32) || type.isF32()))) {
+// CHECK-NOT:    return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
 
 // CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
-// CHECK:  if (!(((type.isa<::mlir::TensorType>())) && ((true)))) {
-// CHECK:    return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type;
+// CHECK-NEXT:  if (!(((type.isa<::mlir::TensorType>())) && ((true)))) {
+// CHECK-NEXT:    return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type;
 
 // CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
-// CHECK:  if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) {
-// CHECK:    return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
+// CHECK-NEXT:  if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) {
+// CHECK-NEXT:    return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
 
 // CHECK-LABEL: OpA::verify
 // CHECK: auto valueGroup0 = getODSOperands(0);

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 664fc2de788ea..2bc9ea465f5e0 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -216,19 +216,50 @@ void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
         typeConstraints.insert(result.constraint.getAsOpaquePointer());
   }
 
+  // Record the mapping from predicate to constraint. If two constraints has the
+  // same predicate and constraint summary, they can share the same verification
+  // function.
+  llvm::DenseMap<Pred, const void *> predToConstraint;
   FmtContext fctx;
   for (auto it : llvm::enumerate(typeConstraints)) {
+    std::string name;
+    Constraint constraint = Constraint::getFromOpaquePointer(it.value());
+    Pred pred = constraint.getPredicate();
+    auto iter = predToConstraint.find(pred);
+    if (iter != predToConstraint.end()) {
+      do {
+        Constraint built = Constraint::getFromOpaquePointer(iter->second);
+        // We may have the 
diff erent constraints but have the same predicate,
+        // for example, ConstraintA and Variadic<ConstraintA>, note that
+        // Variadic<> doesn't introduce new predicate. In this case, we can
+        // share the same predicate function if they also have consistent
+        // summary, otherwise we may report the wrong message while verification
+        // fails.
+        if (constraint.getSummary() == built.getSummary()) {
+          name = getTypeConstraintFn(built).str();
+          break;
+        }
+        ++iter;
+      } while (iter != predToConstraint.end() && iter->first == pred);
+    }
+
+    if (!name.empty()) {
+      localTypeConstraints.try_emplace(it.value(), name);
+      continue;
+    }
+
     // Generate an obscure and unique name for this type constraint.
-    std::string name = (Twine("__mlir_ods_local_type_constraint_") +
-                        uniqueOutputLabel + Twine(it.index()))
-                           .str();
+    name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
+            Twine(it.index()))
+               .str();
+    predToConstraint.insert(
+        std::make_pair(constraint.getPredicate(), it.value()));
     localTypeConstraints.try_emplace(it.value(), name);
 
     // Only generate the methods if we are generating definitions.
     if (emitDecl)
       continue;
 
-    Constraint constraint = Constraint::getFromOpaquePointer(it.value());
     os << "static ::mlir::LogicalResult " << name
        << "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
           "valueKind, unsigned valueGroupStartIndex) {\n";


        


More information about the Mlir-commits mailing list