[Mlir-commits] [mlir] 59f59d1 - [mlir] Allow to override type/attr aliases from various hooks
Vladislav Vinogradov
llvmlistbot at llvm.org
Fri Aug 6 02:04:51 PDT 2021
Author: Vladislav Vinogradov
Date: 2021-08-06T12:05:31+03:00
New Revision: 59f59d1c621cf6844c41fd92ad32a897fc9d10bd
URL: https://github.com/llvm/llvm-project/commit/59f59d1c621cf6844c41fd92ad32a897fc9d10bd
DIFF: https://github.com/llvm/llvm-project/commit/59f59d1c621cf6844c41fd92ad32a897fc9d10bd.diff
LOG: [mlir] Allow to override type/attr aliases from various hooks
Use new return type for `OpAsmDialectInterface::getAlias`:
* `AliasResult::NoAlias` if an alias was not provided.
* `AliasResult::OverridableAlias` if an alias was provided, but it might be overriden by other hook.
* `AliasResult::FinalAlias` if an alias was provided and it should be used (no other hooks will be checked).
In that case `AsmPrinter` will use either the first alias with `FinalAlias` result or
the last alias with `OverridableAlias` result (it depends on dialect array order).
Used `OverridableAlias` result for `BuiltinOpAsmDialectInterface`.
Use case: provide more informative alias for built-in attributes like `AffineMapAttr`
instead of generic "map<N>".
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D107437
Added:
Modified:
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinDialect.cpp
mlir/test/IR/print-attr-type-aliases.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index a9c96473a3a1e..39108dd91c5cd 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -928,18 +928,29 @@ using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
class OpAsmDialectInterface
: public DialectInterface::Base<OpAsmDialectInterface> {
public:
+ /// Holds the result of `getAlias` hook call.
+ enum class AliasResult {
+ /// The object (type or attribute) is not supported by the hook
+ /// and an alias was not provided.
+ NoAlias,
+ /// An alias was provided, but it might be overriden by other hook.
+ OverridableAlias,
+ /// An alias was provided and it should be used
+ /// (no other hooks will be checked).
+ FinalAlias
+ };
+
OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
/// Hooks for getting an alias identifier alias for a given symbol, that is
/// not necessarily a part of this dialect. The identifier is used in place of
/// the symbol when printing textual IR. These aliases must not contain `.` or
- /// end with a numeric digit([0-9]+). Returns success if an alias was
- /// provided, failure otherwise.
- virtual LogicalResult getAlias(Attribute attr, raw_ostream &os) const {
- return failure();
+ /// end with a numeric digit([0-9]+).
+ virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
+ return AliasResult::NoAlias;
}
- virtual LogicalResult getAlias(Type type, raw_ostream &os) const {
- return failure();
+ virtual AliasResult getAlias(Type type, raw_ostream &os) const {
+ return AliasResult::NoAlias;
}
/// Get a special name to use when printing the given operation. See
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index d73ca948538ec..b689225925dbe 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -652,21 +652,28 @@ void AliasInitializer::visit(Type type) {
template <typename T>
LogicalResult AliasInitializer::generateAlias(
T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
- SmallString<16> tempBuffer;
+ SmallString<32> nameBuffer;
for (const auto &interface : interfaces) {
- if (failed(interface.getAlias(symbol, aliasOS)))
+ OpAsmDialectInterface::AliasResult result =
+ interface.getAlias(symbol, aliasOS);
+ if (result == OpAsmDialectInterface::AliasResult::NoAlias)
continue;
- StringRef name = aliasOS.str();
- assert(!name.empty() && "expected valid alias name");
- name = sanitizeIdentifier(name, tempBuffer, /*allowedPunctChars=*/"$_-",
- /*allowTrailingDigit=*/false);
- name = name.copy(aliasAllocator);
-
- aliasToSymbol[name].push_back(symbol);
- aliasBuffer.clear();
- return success();
+ nameBuffer = std::move(aliasBuffer);
+ assert(!nameBuffer.empty() && "expected valid alias name");
+ if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
+ break;
}
- return failure();
+
+ if (nameBuffer.empty())
+ return failure();
+
+ SmallString<16> tempBuffer;
+ StringRef name =
+ sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
+ /*allowTrailingDigit=*/false);
+ name = name.copy(aliasAllocator);
+ aliasToSymbol[name].push_back(symbol);
+ return success();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index 7e79c1b7c7c49..20514957b5fa6 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -33,30 +33,30 @@ namespace {
struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
- LogicalResult getAlias(Attribute attr, raw_ostream &os) const override {
+ AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
if (attr.isa<AffineMapAttr>()) {
os << "map";
- return success();
+ return AliasResult::OverridableAlias;
}
if (attr.isa<IntegerSetAttr>()) {
os << "set";
- return success();
+ return AliasResult::OverridableAlias;
}
if (attr.isa<LocationAttr>()) {
os << "loc";
- return success();
+ return AliasResult::OverridableAlias;
}
- return failure();
+ return AliasResult::NoAlias;
}
- LogicalResult getAlias(Type type, raw_ostream &os) const final {
+ AliasResult getAlias(Type type, raw_ostream &os) const final {
if (auto tupleType = type.dyn_cast<TupleType>()) {
if (tupleType.size() > 16) {
os << "tuple";
- return success();
+ return AliasResult::OverridableAlias;
}
}
- return failure();
+ return AliasResult::NoAlias;
}
};
} // end anonymous namespace.
diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir
index 8b4016d49ad56..5bb408d776800 100644
--- a/mlir/test/IR/print-attr-type-aliases.mlir
+++ b/mlir/test/IR/print-attr-type-aliases.mlir
@@ -18,6 +18,9 @@
// CHECK-DAG: !tuple = type tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>)
+// CHECK-DAG: !test_tuple = type tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
+"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)
+
// CHECK-DAG: #test_encoding = "alias_test:tensor_encoding"
// CHECK-DAG: tensor<32xf32, #test_encoding>
"test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 76a7f41fd186d..e56c2d1a92d0f 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -51,10 +51,10 @@ static_assert(OpTrait::hasSingleBlockImplicitTerminator<
struct TestOpAsmInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
- LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
+ AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
StringAttr strAttr = attr.dyn_cast<StringAttr>();
if (!strAttr)
- return failure();
+ return AliasResult::NoAlias;
// Check the contents of the string attribute to see what the test alias
// should be named.
@@ -70,10 +70,23 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
.Case("alias_test:tensor_encoding", StringRef("test_encoding"))
.Default(llvm::None);
if (!aliasName)
- return failure();
+ return AliasResult::NoAlias;
os << *aliasName;
- return success();
+ return AliasResult::FinalAlias;
+ }
+
+ AliasResult getAlias(Type type, raw_ostream &os) const final {
+ if (auto tupleType = type.dyn_cast<TupleType>()) {
+ if (tupleType.size() > 0 &&
+ llvm::all_of(tupleType.getTypes(), [](Type elemType) {
+ return elemType.isa<SimpleAType>();
+ })) {
+ os << "test_tuple";
+ return AliasResult::FinalAlias;
+ }
+ }
+ return AliasResult::NoAlias;
}
void getAsmResultNames(Operation *op,
More information about the Mlir-commits
mailing list