[Mlir-commits] [mlir] 1f8a33c - [mlir][bytecodegen] Add list helper methods.
Jacques Pienaar
llvmlistbot at llvm.org
Sun Jun 4 15:52:44 PDT 2023
Author: Jacques Pienaar
Date: 2023-06-04T15:52:37-07:00
New Revision: 1f8a33c19c79fd4649a07eb70ea394c60a8ce316
URL: https://github.com/llvm/llvm-project/commit/1f8a33c19c79fd4649a07eb70ea394c60a8ce316
DIFF: https://github.com/llvm/llvm-project/commit/1f8a33c19c79fd4649a07eb70ea394c60a8ce316.diff
LOG: [mlir][bytecodegen] Add list helper methods.
Previously the SignedVarInt was incorrectly defined. Follow up work
needed for improving Array printing/parsing, but correcting the
definitions for now.
Added:
Modified:
mlir/include/mlir/IR/BuiltinDialectBytecode.td
mlir/include/mlir/IR/BytecodeBase.td
mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 8bae2537f41c4..47d6c0df55485 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -134,7 +134,7 @@ def DenseResourceElementsAttr : DialectAttribute<(attr
let cType = "RankedTensorType" in {
def RankedTensorType : DialectType<(type
- Array<SignedVarInt>:$shape,
+ Array<SignedVarIntList>:$shape,
Type:$elementType
)> {
let printerPredicate = "!$_val.getEncoding()";
@@ -142,7 +142,7 @@ def RankedTensorType : DialectType<(type
def RankedTensorTypeWithEncoding : DialectType<(type
Attribute:$encoding,
- Array<SignedVarInt>:$shape,
+ Array<SignedVarIntList>:$shape,
Type:$elementType
)> {
let printerPredicate = "$_val.getEncoding()";
@@ -223,7 +223,7 @@ def MemRefLayout: WithType<"MemRefLayoutAttrInterface", Attribute>;
let cType = "MemRefType" in {
def MemRefType : DialectType<(type
- Array<SignedVarInt>:$shape,
+ Array<SignedVarIntList>:$shape,
Type:$elementType,
MemRefLayout:$layout
)> {
@@ -232,7 +232,7 @@ def MemRefType : DialectType<(type
def MemRefTypeWithMemSpace : DialectType<(type
Attribute:$memorySpace,
- Array<SignedVarInt>:$shape,
+ Array<SignedVarIntList>:$shape,
Type:$elementType,
MemRefLayout:$layout
)> {
@@ -272,7 +272,7 @@ def UnrankedTensorType : DialectType<(type
let cType = "VectorType" in {
def VectorType : DialectType<(type
- Array<SignedVarInt>:$shape,
+ Array<SignedVarIntList>:$shape,
Type:$elementType
)> {
let printerPredicate = "!$_val.getNumScalableDims()";
@@ -280,7 +280,7 @@ def VectorType : DialectType<(type
def VectorTypeWithScalableDims : DialectType<(type
VarInt:$numScalableDims,
- Array<SignedVarInt>:$shape,
+ Array<SignedVarIntList>:$shape,
Type:$elementType
)> {
let printerPredicate = "$_val.getNumScalableDims()";
diff --git a/mlir/include/mlir/IR/BytecodeBase.td b/mlir/include/mlir/IR/BytecodeBase.td
index 3164bcad2d6c7..c7ec563b9f14d 100644
--- a/mlir/include/mlir/IR/BytecodeBase.td
+++ b/mlir/include/mlir/IR/BytecodeBase.td
@@ -73,6 +73,7 @@ class TypeKind :
WithBuilder<"$_args",
WithPrinter<"$_writer.writeType($_getter)">>>;
def Type : TypeKind;
+
def VarInt :
WithParser <"succeeded($_reader.readVarInt($_var))",
WithBuilder<"$_args",
@@ -82,14 +83,12 @@ def SignedVarInt :
WithParser <"succeeded($_reader.readSignedVarInt($_var))",
WithBuilder<"$_args",
WithPrinter<"$_writer.writeSignedVarInt($_getter)",
- WithGetter<"$_attrType",
- WithType <"int64_t">>>>>;
+ WithType <"int64_t">>>>;
def Blob :
WithParser <"succeeded($_reader.readBlob($_var))",
WithBuilder<"$_args",
WithPrinter<"$_writer.writeOwnedBlob($_getter)",
WithType <"ArrayRef<char>">>>>;
-
class KnownWidthAPInt<string s> :
WithParser <"succeeded(readAPIntWithKnownWidth($_reader, " # s # ", $_var))",
WithBuilder<"$_args",
@@ -119,6 +118,10 @@ class Array<Bytecode t> {
string cBuilder = "$_args";
}
+// - Array elements currently needs a
diff erent bytecode type to accommodate
+// for the list print/parsing.
+class List<Bytecode t> : WithGetter<"$_member", t>;
+def SignedVarIntList : List<SignedVarInt>;
// Define dialect attribute or type.
class DialectAttrOrType<dag d> {
diff --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
index f4e3e4f1c1605..d78bd5f44f214 100644
--- a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
@@ -134,7 +134,7 @@ void Generator::emitParse(StringRef kind, Record &x) {
llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
return init->getAsUnquotedString();
}));
- StringRef builder = x.getValueAsString("cBuilder");
+ StringRef builder = x.getValueAsString("cBuilder").trim();
emitParseHelper(kind, returnType, builder, members->getArgs(), argNames,
returnType + "()", os);
os << "\n\n";
@@ -368,10 +368,11 @@ void Generator::emitPrintHelper(Record *memberRec, StringRef kind,
}
}
std::string returnType = getCType(def);
+ std::string nestedName = kind.str();
ios << "writer.writeList(" << getter << ", [&](" << returnType << " "
- << kind << ") ";
+ << nestedName << ") ";
auto lambdaScope = ios.scope("{\n", "});\n");
- return emitPrintHelper(def, kind, kind, kind, ios);
+ return emitPrintHelper(def, kind, nestedName, nestedName, ios);
}
if (memberRec->isSubClassOf("CompositeBytecode")) {
auto *members = memberRec->getValueAsDag("members");
More information about the Mlir-commits
mailing list