[Mlir-commits] [mlir] 44610c0 - [MLIR][ODS] default-valued strings should be in quotes

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 14 20:00:46 PDT 2021


Author: Mogball
Date: 2021-10-15T03:00:41Z
New Revision: 44610c01aeed8402da982ab59c47f45e7b3bc93b

URL: https://github.com/llvm/llvm-project/commit/44610c01aeed8402da982ab59c47f45e7b3bc93b
DIFF: https://github.com/llvm/llvm-project/commit/44610c01aeed8402da982ab59c47f45e7b3bc93b.diff

LOG: [MLIR][ODS] default-valued strings should be in quotes

`DefaultValuedAttr<StrAttr, "">` and `ConstantAttr<StrAttr, "">`
result in bugs in which TableGen will not recognize that the attribute
has a default value, because `""` is an empty TableGen string.

Strings no longer have special treatment. Instead, string values must be
wrapped in quotes: "\"foo\"". Two helpers, `DefaultValuedStrAttr` and
`ConstantStrAttr` have been added to keep code clean.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D111855

Added: 
    mlir/test/mlir-tblgen/constant-str-attr-invalid.mlir

Modified: 
    mlir/include/mlir/IR/OpBase.td
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-attribute.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index d92c0a80a54fe..f555329d374d9 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -928,6 +928,11 @@ class OptionalAttr<Attr attr> : Attr<attr.predicate, attr.summary> {
   let baseAttr = attr;
 }
 
+// Default-valued string-based attribute. Wraps the default value in escaped
+// quotes.
+class DefaultValuedStrAttr<Attr attr, string val>
+    : DefaultValuedAttr<attr, "\"" # val # "\"">;
+
 //===----------------------------------------------------------------------===//
 // Primitive attribute kinds
 
@@ -1095,7 +1100,7 @@ def F64Attr : FloatAttrBase<F64, "64-bit float attribute">;
 
 // An attribute backed by a string type.
 class StringBasedAttr<Pred condition, string descr> : Attr<condition, descr> {
-  let constBuilderCall = "$_builder.getStringAttr(\"$0\")";
+  let constBuilderCall = "$_builder.getStringAttr($0)";
   let storageType = [{ ::mlir::StringAttr }];
   let returnType = [{ ::llvm::StringRef }];
   let valueType = NoneType;
@@ -1672,6 +1677,10 @@ def ConstBoolAttrFalse : ConstantAttr<BoolAttr, "false">;
 def ConstBoolAttrTrue : ConstantAttr<BoolAttr, "true">;
 def ConstUnitAttr : ConstantAttr<UnitAttr, "unit">;
 
+// Constant string-based attribute. Wraps the desired string in escaped quotes.
+class ConstantStrAttr<Attr attribute, string val>
+    : ConstantAttr<attribute, "\"" # val # "\"">;
+
 //===----------------------------------------------------------------------===//
 // Common attribute constraints
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 7968742b61f7d..0433c3075ca60 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2343,4 +2343,21 @@ def TestLinalgConvOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Test Ops with Default-Valued String Attributes
+//===----------------------------------------------------------------------===//
+
+def TestDefaultStrAttrNoValueOp : TEST_Op<"no_str_value"> {
+  let arguments = (ins DefaultValuedAttr<StrAttr, "">:$value);
+  let assemblyFormat = "attr-dict";
+}
+
+def TestDefaultStrAttrHasValueOp : TEST_Op<"has_str_value"> {
+  let arguments = (ins DefaultValuedStrAttr<StrAttr, "">:$value);
+  let assemblyFormat = "attr-dict";
+}
+
+def : Pat<(TestDefaultStrAttrNoValueOp $value),
+          (TestDefaultStrAttrHasValueOp ConstantStrAttr<StrAttr, "foo">)>;
+
 #endif // TEST_OPS

diff  --git a/mlir/test/mlir-tblgen/constant-str-attr-invalid.mlir b/mlir/test/mlir-tblgen/constant-str-attr-invalid.mlir
new file mode 100644
index 0000000000000..1d0e62d5dec09
--- /dev/null
+++ b/mlir/test/mlir-tblgen/constant-str-attr-invalid.mlir
@@ -0,0 +1,4 @@
+// RUN: mlir-opt -verify-diagnostics %s
+
+// Test DefaultValuedAttr<StrAttr, ""> is recognized as "no default value"
+test.no_str_value {} // expected-error {{'test.no_str_value' op requires attribute 'value'}}

diff  --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index 4ed715d5b8e6c..cf5150c067bb5 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -337,7 +337,7 @@ def DOp : NS_Op<"d_op", []> {
     SomeI32Enum:$enum_attr,
     DefaultValuedAttr<I32Attr, "42">:$dv_i32_attr,
     DefaultValuedAttr<F64Attr, "8.">:$dv_f64_attr,
-    DefaultValuedAttr<StrAttr, "abc">:$dv_str_attr,
+    DefaultValuedStrAttr<StrAttr, "abc">:$dv_str_attr,
     DefaultValuedAttr<BoolAttr, "true">:$dv_bool_attr,
     DefaultValuedAttr<SomeI32Enum, "::SomeI32Enum::case5">:$dv_enum_attr
   );
@@ -377,7 +377,7 @@ def EOp : NS_Op<"e_op", []> {
     F64Attr:$f64_attr,
     DefaultValuedAttr<F64Attr, "8.">:$dv_f64_attr,
     StrAttr:$str_attr,
-    DefaultValuedAttr<StrAttr, "abc">:$dv_str_attr,
+    DefaultValuedStrAttr<StrAttr, "abc">:$dv_str_attr,
     BoolAttr:$bool_attr,
     DefaultValuedAttr<BoolAttr, "true">:$dv_bool_attr,
     SomeI32Enum:$enum_attr,

diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 4be71a9eee1f7..5108af345e86d 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -361,3 +361,10 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
 
 // CHECK: test.format_infer_type
 %ignored_res7 = test.format_infer_type
+
+//===----------------------------------------------------------------------===//
+// Check DefaultValuedStrAttr
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.has_str_value
+test.has_str_value {}

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 4a05df4782c76..ab436fa33bd9b 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -583,3 +583,13 @@ func @returnTypeAndLocation(%arg0 : i32) -> i1 {
   // CHECK: "test.two_to_one"(%0, %1) : (i32, i32) -> i1
   return %0 : i1
 }
+
+//===----------------------------------------------------------------------===//
+// Test that patterns can create ConstantStrAttr
+//===----------------------------------------------------------------------===//
+
+func @testConstantStrAttr() -> () {
+  // CHECK: test.has_str_value {value = "foo"}
+  test.no_str_value {value = "bar"}
+  return
+}

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 8c78a3bac7714..8ee897cb921df 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1710,12 +1710,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
     std::string defaultValue;
     if (attrParamKind == AttrParamKind::UnwrappedValue &&
         i >= defaultValuedAttrStartIndex) {
-      bool isString = attr.getReturnType() == "::llvm::StringRef";
-      if (isString)
-        defaultValue.append("\"");
       defaultValue += attr.getDefaultValue();
-      if (isString)
-        defaultValue.append("\"");
     }
     paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
   }

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 823d551e0705a..9ce6eea4259ca 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -50,6 +50,13 @@ struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
 };
 } // end namespace llvm
 
+static std::string escapeString(StringRef value) {
+  std::string ret;
+  llvm::raw_string_ostream os(ret);
+  llvm::printEscapedString(value, os);
+  return os.str();
+}
+
 //===----------------------------------------------------------------------===//
 // PatternEmitter
 //===----------------------------------------------------------------------===//
@@ -189,7 +196,7 @@ class PatternEmitter {
 
   // Returns the C++ expression to construct a constant attribute of the given
   // `value` for the given attribute kind `attr`.
-  std::string handleConstantAttr(Attribute attr, StringRef value);
+  std::string handleConstantAttr(Attribute attr, const Twine &value);
 
   // Returns the C++ expression to build an argument from the given DAG `leaf`.
   // `patArgName` is used to bound the argument to the source pattern.
@@ -313,7 +320,7 @@ PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
 }
 
 std::string PatternEmitter::handleConstantAttr(Attribute attr,
-                                               StringRef value) {
+                                               const Twine &value) {
   if (!attr.isConstBuildable())
     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
                              " does not have the 'constBuilderCall' field");
@@ -492,7 +499,8 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
         formatv("\"operand {0} of native code call '{1}' failed to satisfy "
                 "constraint: "
                 "'{2}'\"",
-                i, tree.getNativeCodeTemplate(), constraint.getSummary()));
+                i, tree.getNativeCodeTemplate(),
+                escapeString(constraint.getSummary())));
   }
 
   LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
@@ -630,7 +638,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
           formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
                   "'{2}'\"",
                   operand - op.operand_begin(), op.getOperationName(),
-                  constraint.getSummary()));
+                  escapeString(constraint.getSummary())));
     }
   }
 
@@ -694,9 +702,9 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
         opName,
         tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
-                "{2}\"",
+                "'{2}'\"",
                 op.getOperationName(), namedAttr->name,
-                matcher.getAsConstraint().getSummary()));
+                escapeString(matcher.getAsConstraint().getSummary())));
   }
 
   // Capture the value
@@ -740,8 +748,8 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
                           symbolInfoMap.getValueAndRangeUse(entities.front()));
       emitMatchCheck(
           opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
-          formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
-                  entities.front(), constraint.getSummary()));
+          formatv("\"value entity '{0}' failed to satisfy constraint: '{1}'\"",
+                  entities.front(), escapeString(constraint.getSummary())));
 
     } else if (isa<AttrConstraint>(constraint)) {
       PrintFatalError(
@@ -765,9 +773,9 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
                      tgfmt(condition, &fmtCtx.withSelf(self), names[0],
                            names[1], names[2], names[3]),
                      formatv("\"entities '{0}' failed to satisfy constraint: "
-                             "{1}\"",
+                             "'{1}'\"",
                              llvm::join(entities, ", "),
-                             constraint.getSummary()));
+                             escapeString(constraint.getSummary())));
     }
   }
 
@@ -1103,7 +1111,7 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
   if (leaf.isEnumAttrCase()) {
     auto enumCase = leaf.getAsEnumAttrCase();
     if (enumCase.isStrCase())
-      return handleConstantAttr(enumCase, enumCase.getSymbol());
+      return handleConstantAttr(enumCase, "\"" + enumCase.getSymbol() + "\"");
     // This is an enum case backed by an IntegerAttr. We need to get its value
     // to build the constant.
     std::string val = std::to_string(enumCase.getValue());


        


More information about the Mlir-commits mailing list