[Mlir-commits] [mlir] 1f8a33c - [mlir][bytecodegen] Add list helper methods.

Jacques Pienaar llvmlistbot at llvm.org
Sun Jun 4 15:52:44 PDT 2023


Author: Jacques Pienaar
Date: 2023-06-04T15:52:37-07:00
New Revision: 1f8a33c19c79fd4649a07eb70ea394c60a8ce316

URL: https://github.com/llvm/llvm-project/commit/1f8a33c19c79fd4649a07eb70ea394c60a8ce316
DIFF: https://github.com/llvm/llvm-project/commit/1f8a33c19c79fd4649a07eb70ea394c60a8ce316.diff

LOG: [mlir][bytecodegen] Add list helper methods.

Previously the SignedVarInt was incorrectly defined. Follow up work
needed for improving Array printing/parsing, but correcting the
definitions for now.

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinDialectBytecode.td
    mlir/include/mlir/IR/BytecodeBase.td
    mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 8bae2537f41c4..47d6c0df55485 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -134,7 +134,7 @@ def DenseResourceElementsAttr : DialectAttribute<(attr
 
 let cType = "RankedTensorType" in {
 def RankedTensorType : DialectType<(type
-  Array<SignedVarInt>:$shape,
+  Array<SignedVarIntList>:$shape,
   Type:$elementType
 )> {
   let printerPredicate = "!$_val.getEncoding()";
@@ -142,7 +142,7 @@ def RankedTensorType : DialectType<(type
 
 def RankedTensorTypeWithEncoding : DialectType<(type
   Attribute:$encoding,
-  Array<SignedVarInt>:$shape,
+  Array<SignedVarIntList>:$shape,
   Type:$elementType
 )> {
   let printerPredicate = "$_val.getEncoding()";
@@ -223,7 +223,7 @@ def MemRefLayout: WithType<"MemRefLayoutAttrInterface", Attribute>;
 
 let cType = "MemRefType" in {
 def MemRefType : DialectType<(type
-  Array<SignedVarInt>:$shape,
+  Array<SignedVarIntList>:$shape,
   Type:$elementType,
   MemRefLayout:$layout
 )> {
@@ -232,7 +232,7 @@ def MemRefType : DialectType<(type
 
 def MemRefTypeWithMemSpace : DialectType<(type
   Attribute:$memorySpace,
-  Array<SignedVarInt>:$shape,
+  Array<SignedVarIntList>:$shape,
   Type:$elementType,
   MemRefLayout:$layout
 )> {
@@ -272,7 +272,7 @@ def UnrankedTensorType : DialectType<(type
 
 let cType = "VectorType" in {
 def VectorType : DialectType<(type
-  Array<SignedVarInt>:$shape,
+  Array<SignedVarIntList>:$shape,
   Type:$elementType
 )> {
   let printerPredicate = "!$_val.getNumScalableDims()";
@@ -280,7 +280,7 @@ def VectorType : DialectType<(type
 
 def VectorTypeWithScalableDims : DialectType<(type
   VarInt:$numScalableDims,
-  Array<SignedVarInt>:$shape,
+  Array<SignedVarIntList>:$shape,
   Type:$elementType
 )> {
   let printerPredicate = "$_val.getNumScalableDims()";

diff  --git a/mlir/include/mlir/IR/BytecodeBase.td b/mlir/include/mlir/IR/BytecodeBase.td
index 3164bcad2d6c7..c7ec563b9f14d 100644
--- a/mlir/include/mlir/IR/BytecodeBase.td
+++ b/mlir/include/mlir/IR/BytecodeBase.td
@@ -73,6 +73,7 @@ class TypeKind :
   WithBuilder<"$_args",
   WithPrinter<"$_writer.writeType($_getter)">>>;
 def Type : TypeKind;
+
 def VarInt :
   WithParser <"succeeded($_reader.readVarInt($_var))",
   WithBuilder<"$_args",
@@ -82,14 +83,12 @@ def SignedVarInt :
   WithParser <"succeeded($_reader.readSignedVarInt($_var))",
   WithBuilder<"$_args",
   WithPrinter<"$_writer.writeSignedVarInt($_getter)",
-  WithGetter<"$_attrType",
-  WithType   <"int64_t">>>>>;
+  WithType   <"int64_t">>>>;
 def Blob :
   WithParser <"succeeded($_reader.readBlob($_var))",
   WithBuilder<"$_args",
   WithPrinter<"$_writer.writeOwnedBlob($_getter)",
   WithType   <"ArrayRef<char>">>>>;
-
 class KnownWidthAPInt<string s> :
   WithParser <"succeeded(readAPIntWithKnownWidth($_reader, " # s # ", $_var))",
   WithBuilder<"$_args",
@@ -119,6 +118,10 @@ class Array<Bytecode t> {
 
   string cBuilder = "$_args";
 }
+// - Array elements currently needs a 
diff erent bytecode type to accommodate
+//   for the list print/parsing.
+class List<Bytecode t> : WithGetter<"$_member", t>;
+def SignedVarIntList : List<SignedVarInt>;
 
 // Define dialect attribute or type.
 class DialectAttrOrType<dag d> {

diff  --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
index f4e3e4f1c1605..d78bd5f44f214 100644
--- a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
@@ -134,7 +134,7 @@ void Generator::emitParse(StringRef kind, Record &x) {
       llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
         return init->getAsUnquotedString();
       }));
-  StringRef builder = x.getValueAsString("cBuilder");
+  StringRef builder = x.getValueAsString("cBuilder").trim();
   emitParseHelper(kind, returnType, builder, members->getArgs(), argNames,
                   returnType + "()", os);
   os << "\n\n";
@@ -368,10 +368,11 @@ void Generator::emitPrintHelper(Record *memberRec, StringRef kind,
       }
     }
     std::string returnType = getCType(def);
+    std::string nestedName = kind.str();
     ios << "writer.writeList(" << getter << ", [&](" << returnType << " "
-        << kind << ") ";
+        << nestedName << ") ";
     auto lambdaScope = ios.scope("{\n", "});\n");
-    return emitPrintHelper(def, kind, kind, kind, ios);
+    return emitPrintHelper(def, kind, nestedName, nestedName, ios);
   }
   if (memberRec->isSubClassOf("CompositeBytecode")) {
     auto *members = memberRec->getValueAsDag("members");


        


More information about the Mlir-commits mailing list