[Mlir-commits] [mlir] b3fc0fa - [mlir][PDLL] Don't use the result of `Constraint::getDefName()` when uniquing

River Riddle llvmlistbot at llvm.org
Tue Apr 26 18:34:11 PDT 2022


Author: River Riddle
Date: 2022-04-26T18:33:16-07:00
New Revision: b3fc0fa84a09540d8fc7214899021acbf2fd6ff8

URL: https://github.com/llvm/llvm-project/commit/b3fc0fa84a09540d8fc7214899021acbf2fd6ff8
DIFF: https://github.com/llvm/llvm-project/commit/b3fc0fa84a09540d8fc7214899021acbf2fd6ff8.diff

LOG: [mlir][PDLL] Don't use the result of `Constraint::getDefName()` when uniquing

In the case of anonymous defs this may return the name of the base def class,
which can lead to two different defs with the same name (which hits an assert).
This commit adds a new `getUniqueDefName` method that returns a unique name
for the constraint.

Differential Revision: https://reviews.llvm.org/D124074

Added: 
    mlir/lib/Tools/PDLL/ODS/Constraint.cpp

Modified: 
    mlir/include/mlir/TableGen/Constraint.h
    mlir/include/mlir/Tools/PDLL/ODS/Constraint.h
    mlir/lib/TableGen/Constraint.cpp
    mlir/lib/Tools/PDLL/ODS/CMakeLists.txt
    mlir/lib/Tools/PDLL/ODS/Context.cpp
    mlir/lib/Tools/PDLL/Parser/Parser.cpp
    mlir/test/mlir-pdll/Parser/include/ops.td
    mlir/test/mlir-pdll/Parser/include_td.pdll

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h
index b24b9b7459ee4..0f6c2b58faaff 100644
--- a/mlir/include/mlir/TableGen/Constraint.h
+++ b/mlir/include/mlir/TableGen/Constraint.h
@@ -59,6 +59,12 @@ class Constraint {
   /// `Optional<>`/`Variadic<>` type constraints).
   StringRef getDefName() const;
 
+  /// Returns a unique name for the TablGen def of this constraint. This is
+  /// generally just the name of the def, but in some cases where the current
+  /// def is anonymous, the name of the base def is attached (to provide more
+  /// context on the def).
+  std::string getUniqueDefName() const;
+
   Kind getKind() const { return kind; }
 
 protected:
@@ -66,6 +72,9 @@ class Constraint {
   const llvm::Record *def;
 
 private:
+  /// Return the name of the base def if there is one, or None otherwise.
+  Optional<StringRef> getBaseDefName() const;
+
   // What kind of constraint this is.
   Kind kind;
 };

diff  --git a/mlir/include/mlir/Tools/PDLL/ODS/Constraint.h b/mlir/include/mlir/Tools/PDLL/ODS/Constraint.h
index 270330966de41..da73e00e2dd4b 100644
--- a/mlir/include/mlir/Tools/PDLL/ODS/Constraint.h
+++ b/mlir/include/mlir/Tools/PDLL/ODS/Constraint.h
@@ -31,9 +31,15 @@ namespace ods {
 /// This class represents a generic ODS constraint.
 class Constraint {
 public:
-  /// Return the name of this constraint.
+  /// Return the unique name of this constraint.
   StringRef getName() const { return name; }
 
+  /// Return the demangled name of this constraint. This tries to strip out bits
+  /// of the name that are purely for uniquing, and show the underlying name. As
+  /// such, this name does guarantee uniqueness and should only be used for
+  /// logging or other lossy friendly "pretty" output.
+  StringRef getDemangledName() const;
+
   /// Return the summary of this constraint.
   StringRef getSummary() const { return summary; }
 

diff  --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp
index 249c22eebbfb6..0c5e034a8ee6b 100644
--- a/mlir/lib/TableGen/Constraint.cpp
+++ b/mlir/lib/TableGen/Constraint.cpp
@@ -58,25 +58,48 @@ StringRef Constraint::getSummary() const {
 }
 
 StringRef Constraint::getDefName() const {
+  if (Optional<StringRef> baseDefName = getBaseDefName())
+    return *baseDefName;
+  return def->getName();
+}
+
+std::string Constraint::getUniqueDefName() const {
+  std::string defName = def->getName().str();
+
+  // Non-anonymous classes already have a unique name from the def.
+  if (!def->isAnonymous())
+    return defName;
+
+  // Otherwise, this is an anonymous class. In these cases we still use the def
+  // name, but we also try attach the name of the base def when present to make
+  // the name more obvious.
+  if (Optional<StringRef> baseDefName = getBaseDefName())
+    return (*baseDefName + "(" + defName + ")").str();
+  return defName;
+}
+
+Optional<StringRef> Constraint::getBaseDefName() const {
   // Functor used to check a base def in the case where the current def is
   // anonymous.
-  auto checkBaseDefFn = [&](StringRef baseName) {
-    if (const auto *init = dyn_cast<llvm::DefInit>(def->getValueInit(baseName)))
-      return Constraint(init->getDef(), kind).getDefName();
-    return def->getName();
+  auto checkBaseDefFn = [&](StringRef baseName) -> Optional<StringRef> {
+    if (const auto *defValue = def->getValue(baseName)) {
+      if (const auto *defInit = dyn_cast<llvm::DefInit>(defValue->getValue()))
+        return Constraint(defInit->getDef(), kind).getDefName();
+    }
+    return llvm::None;
   };
 
   switch (kind) {
   case CK_Attr:
     if (def->isAnonymous())
       return checkBaseDefFn("baseAttr");
-    return def->getName();
+    return llvm::None;
   case CK_Type:
     if (def->isAnonymous())
       return checkBaseDefFn("baseType");
-    return def->getName();
+    return llvm::None;
   default:
-    return def->getName();
+    return llvm::None;
   }
 }
 

diff  --git a/mlir/lib/Tools/PDLL/ODS/CMakeLists.txt b/mlir/lib/Tools/PDLL/ODS/CMakeLists.txt
index 3abbaab33ab3f..6a9abc903418a 100644
--- a/mlir/lib/Tools/PDLL/ODS/CMakeLists.txt
+++ b/mlir/lib/Tools/PDLL/ODS/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_library(MLIRPDLLODS
+  Constraint.cpp
   Context.cpp
   Dialect.cpp
   Operation.cpp

diff  --git a/mlir/lib/Tools/PDLL/ODS/Constraint.cpp b/mlir/lib/Tools/PDLL/ODS/Constraint.cpp
new file mode 100644
index 0000000000000..c9471f984e414
--- /dev/null
+++ b/mlir/lib/Tools/PDLL/ODS/Constraint.cpp
@@ -0,0 +1,26 @@
+//===- Constraint.cpp -----------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Tools/PDLL/ODS/Constraint.h"
+
+using namespace mlir;
+using namespace mlir::pdll::ods;
+
+//===----------------------------------------------------------------------===//
+// Constraint
+//===----------------------------------------------------------------------===//
+
+StringRef Constraint::getDemangledName() const {
+  StringRef demangledName = name;
+
+  // Drop the "anonymous" suffix if present.
+  size_t anonymousSuffix = demangledName.find("(anonymous_");
+  if (anonymousSuffix != StringRef::npos)
+    demangledName = demangledName.take_front(anonymousSuffix);
+  return demangledName;
+}

diff  --git a/mlir/lib/Tools/PDLL/ODS/Context.cpp b/mlir/lib/Tools/PDLL/ODS/Context.cpp
index 7684da5e05fae..8186af9fe1b26 100644
--- a/mlir/lib/Tools/PDLL/ODS/Context.cpp
+++ b/mlir/lib/Tools/PDLL/ODS/Context.cpp
@@ -120,7 +120,7 @@ void Context::print(raw_ostream &os) const {
 
           auto kind = attr.isOptional() ? VariableLengthKind::Optional
                                         : VariableLengthKind::Single;
-          printVariableLengthCst(attr.getConstraint().getName(), kind);
+          printVariableLengthCst(attr.getConstraint().getDemangledName(), kind);
         });
         os << " }\n";
       }
@@ -132,7 +132,7 @@ void Context::print(raw_ostream &os) const {
         llvm::interleaveComma(
             operands, os, [&](const OperandOrResult &operand) {
               os << operand.getName() << " : ";
-              printVariableLengthCst(operand.getConstraint().getName(),
+              printVariableLengthCst(operand.getConstraint().getDemangledName(),
                                      operand.getVariableLengthKind());
             });
         os << " }\n";
@@ -144,7 +144,7 @@ void Context::print(raw_ostream &os) const {
         printer.startLine() << "Results { ";
         llvm::interleaveComma(results, os, [&](const OperandOrResult &result) {
           os << result.getName() << " : ";
-          printVariableLengthCst(result.getConstraint().getName(),
+          printVariableLengthCst(result.getConstraint().getDemangledName(),
                                  result.getVariableLengthKind());
         });
         os << " }\n";
@@ -155,7 +155,8 @@ void Context::print(raw_ostream &os) const {
     printer.objectEnd();
   }
   for (const AttributeConstraint *cst : sortMapByName(attributeConstraints)) {
-    printer.startLine() << "AttributeConstraint `" << cst->getName() << "` {\n";
+    printer.startLine() << "AttributeConstraint `" << cst->getDemangledName()
+                        << "` {\n";
     printer.indent();
 
     printer.startLine() << "Summary: " << cst->getSummary() << "\n";
@@ -163,7 +164,8 @@ void Context::print(raw_ostream &os) const {
     printer.objectEnd();
   }
   for (const TypeConstraint *cst : sortMapByName(typeConstraints)) {
-    printer.startLine() << "TypeConstraint `" << cst->getName() << "` {\n";
+    printer.startLine() << "TypeConstraint `" << cst->getDemangledName()
+                        << "` {\n";
     printer.indent();
 
     printer.startLine() << "Summary: " << cst->getSummary() << "\n";

diff  --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 1bb23da08b4ad..108634bad82b4 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -774,7 +774,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
   ods::Context &odsContext = ctx.getODSContext();
   auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst)
       -> const ods::TypeConstraint & {
-    return odsContext.insertTypeConstraint(cst.constraint.getDefName(),
+    return odsContext.insertTypeConstraint(cst.constraint.getUniqueDefName(),
                                            cst.constraint.getSummary(),
                                            cst.constraint.getCPPClassName());
   };
@@ -800,7 +800,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
     for (const tblgen::NamedAttribute &attr : op.getAttributes()) {
       odsOp->appendAttribute(
           attr.name, attr.attr.isOptional(),
-          odsContext.insertAttributeConstraint(attr.attr.getAttrDefName(),
+          odsContext.insertAttributeConstraint(attr.attr.getUniqueDefName(),
                                                attr.attr.getSummary(),
                                                attr.attr.getStorageType()));
     }
@@ -891,8 +891,8 @@ Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint,
   std::string codeBlock =
       tblgen::tgfmt(constraint.getConditionTemplate(), &fmtContext);
 
-  return createODSNativePDLLConstraintDecl<ConstraintT>(constraint.getDefName(),
-                                                        codeBlock, loc, type);
+  return createODSNativePDLLConstraintDecl<ConstraintT>(
+      constraint.getUniqueDefName(), codeBlock, loc, type);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-pdll/Parser/include/ops.td b/mlir/test/mlir-pdll/Parser/include/ops.td
index 1727d1a5444bc..ecbd44762c84c 100644
--- a/mlir/test/mlir-pdll/Parser/include/ops.td
+++ b/mlir/test/mlir-pdll/Parser/include/ops.td
@@ -7,7 +7,11 @@ def Test_Dialect : Dialect {
 def OpAllEmpty : Op<Test_Dialect, "all_empty">;
 
 def OpAllSingle : Op<Test_Dialect, "all_single"> {
-  let arguments = (ins I64:$operand, I64Attr:$attr);
+  let arguments = (ins
+    I64:$operand,
+    I64Attr:$attr,
+    Confined<I64Attr, [IntNonNegative]>:$nonNegativeAttr
+  );
   let results = (outs I64:$result);
 }
 

diff  --git a/mlir/test/mlir-pdll/Parser/include_td.pdll b/mlir/test/mlir-pdll/Parser/include_td.pdll
index c55ed1d0f154b..a7a02e9323e36 100644
--- a/mlir/test/mlir-pdll/Parser/include_td.pdll
+++ b/mlir/test/mlir-pdll/Parser/include_td.pdll
@@ -12,7 +12,7 @@
 // CHECK-NEXT: }
 
 // CHECK:      Operation `test.all_single` {
-// CHECK-NEXT:   Attributes { attr : I64Attr }
+// CHECK-NEXT:   Attributes { attr : I64Attr, nonNegativeAttr : I64Attr }
 // CHECK-NEXT:   Operands { operand : I64 }
 // CHECK-NEXT:   Results { result : I64 }
 // CHECK-NEXT: }


        


More information about the Mlir-commits mailing list