[Mlir-commits] [mlir] [mlir] Add the ability to define dialect-specific location attrs. (PR #105584)
Aman LaChapelle
llvmlistbot at llvm.org
Tue Sep 17 21:27:22 PDT 2024
https://github.com/bzcheeseman updated https://github.com/llvm/llvm-project/pull/105584
>From f5ceac4c61c7d72d3ecd04d1e656250c29cc9c35 Mon Sep 17 00:00:00 2001
From: Aman LaChapelle <alachapelle at apple.com>
Date: Wed, 21 Aug 2024 13:49:44 -0700
Subject: [PATCH 1/2] [mlir] Add the ability to define dialect-specific
location attrs.
This patch adds the capability to define dialect-specific location attrs. This is useful in particular for defining location structure that doesn't necessarily fit within the core MLIR location hierarchy, but doesn't make sense to push upstream (i.e. a custom use case).
This patch adds an AttributeTrait, `IsLocation`, which is tagged onto all the builtin location attrs, as well as the test location attribute. This is necessary because previously LocationAttr::classof only returned true if the attribute was one of the builtin location attributes, and well, the point of this patch is to allow dialects to define their own location attributes.
There was an alternate implementation I considered wherein LocationAttr becomes an AttrInterface, but that was discarded because there are likely to be *many* locations in a single program, and I was concerned that forcing every MLIR user to pay the cost of the additional lookup/dispatch was unacceptable. It also would have been a *much* more invasive change. It would have allowed for more flexibility in terms of pretty printing, but it's unclear how useful/necessary that flexibility would be given how much customizability there already is for attribute definitions.
---
mlir/include/mlir/IR/Attributes.h | 13 +++++++---
.../mlir/IR/BuiltinLocationAttributes.td | 3 ++-
mlir/lib/AsmParser/LocationParser.cpp | 26 +++++++++++++++++++
mlir/lib/AsmParser/Parser.h | 3 +++
mlir/lib/IR/AsmPrinter.cpp | 6 +++++
mlir/lib/IR/Location.cpp | 3 +--
mlir/test/IR/locations.mlir | 7 +++++
mlir/test/IR/pretty-locations.mlir | 3 +++
mlir/test/lib/Dialect/Test/TestAttrDefs.td | 11 ++++++++
9 files changed, 69 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 8a077865b51b5f..d347013295d5fc 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -322,12 +322,19 @@ class AttributeInterface
// Core AttributeTrait
//===----------------------------------------------------------------------===//
-/// This trait is used to determine if an attribute is mutable or not. It is
-/// attached on an attribute if the corresponding ImplType defines a `mutate`
-/// function with proper signature.
namespace AttributeTrait {
+/// This trait is used to determine if an attribute is mutable or not. It is
+/// attached on an attribute if the corresponding ConcreteType defines a
+/// `mutate` function with proper signature.
template <typename ConcreteType>
using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
+
+/// This trait is used to determine if an attribute is a location or not. It is
+/// attached to an attribute by the user if they intend the attribute to be used
+/// as a location.
+template <typename ConcreteType>
+struct IsLocation : public AttributeTrait::TraitBase<ConcreteType, IsLocation> {
+};
} // namespace AttributeTrait
} // namespace mlir.
diff --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
index 5a72404dea15bb..3137a3089a0fc5 100644
--- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
@@ -18,7 +18,8 @@ include "mlir/IR/BuiltinDialect.td"
// Base class for Builtin dialect location attributes.
class Builtin_LocationAttr<string name, list<Trait> traits = []>
- : AttrDef<Builtin_Dialect, name, traits, "::mlir::LocationAttr"> {
+ : AttrDef<Builtin_Dialect, name, traits # [NativeAttrTrait<"IsLocation">],
+ "::mlir::LocationAttr"> {
let cppClassName = name;
let mnemonic = ?;
}
diff --git a/mlir/lib/AsmParser/LocationParser.cpp b/mlir/lib/AsmParser/LocationParser.cpp
index 1365da03c7c3d6..f66e67de1f8385 100644
--- a/mlir/lib/AsmParser/LocationParser.cpp
+++ b/mlir/lib/AsmParser/LocationParser.cpp
@@ -153,6 +153,29 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
return success();
}
+ParseResult Parser::parseDialectLocation(LocationAttr &loc) {
+ consumeToken(Token::bare_identifier);
+
+ if (parseToken(Token::less,
+ "expected `<` to start dialect location attribute"))
+ return failure();
+
+ Attribute locAttr = parseAttribute(Type{});
+ // No attribute parsed, someone else has returned an error already.
+ if (!locAttr)
+ return failure();
+
+ loc = llvm::dyn_cast<LocationAttr>(locAttr);
+ if (!loc)
+ return emitError() << "expected a location attribute";
+
+ if (parseToken(Token::greater,
+ "expected `>` to end dialect location attribute"))
+ return failure();
+
+ return success();
+}
+
ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
// Handle aliases.
if (getToken().is(Token::hash_identifier)) {
@@ -187,5 +210,8 @@ ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
return success();
}
+ if (getToken().getSpelling() == "dialect")
+ return parseDialectLocation(loc);
+
return emitWrongTokenError("expected location instance");
}
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 4caab499e1a0e4..ee4c90b9e1caf1 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -303,6 +303,9 @@ class Parser {
/// Parse a name or FileLineCol location instance.
ParseResult parseNameOrFileLineColLocation(LocationAttr &loc);
+ /// Parse a dialect-specific location.
+ ParseResult parseDialectLocation(LocationAttr &loc);
+
//===--------------------------------------------------------------------===//
// Affine Parsing
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 02acc8c3f4659e..68e76e863d63a0 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2061,6 +2061,12 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
[&](Location loc) { printLocationInternal(loc, pretty); },
[&]() { os << ", "; });
os << ']';
+ })
+ .Default([&](LocationAttr loc) {
+ // Assumes that this is a dialect-specific attribute.
+ os << "dialect<";
+ printAttribute(loc);
+ os << ">";
});
}
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index c548bbe4b6c860..fc6163884b021b 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -64,8 +64,7 @@ WalkResult LocationAttr::walk(function_ref<WalkResult(Location)> walkFn) {
/// Methods for support type inquiry through isa, cast, and dyn_cast.
bool LocationAttr::classof(Attribute attr) {
- return llvm::isa<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
- UnknownLoc>(attr);
+ return attr.hasTrait<AttributeTrait::IsLocation>();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir
index 8d7c7e4f13ed49..335afe00abfc49 100644
--- a/mlir/test/IR/locations.mlir
+++ b/mlir/test/IR/locations.mlir
@@ -89,3 +89,10 @@ func.func @optional_location_specifier() {
test.attr_with_loc("foo" loc("foo_loc"))
return
}
+
+// CHECK-LABEL: @dialect_location
+// CHECK: test.attr_with_loc("dialectLoc" loc(dialect<#test.custom_location<"foo.mlir" * 32>>))
+func.func @dialect_location() {
+ test.attr_with_loc("dialectLoc" loc(dialect<#test.custom_location<"foo.mlir"*32>>))
+ return
+}
diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir
index e9337b5bef37b1..f9f6a365d1b3ae 100644
--- a/mlir/test/IR/pretty-locations.mlir
+++ b/mlir/test/IR/pretty-locations.mlir
@@ -24,6 +24,9 @@ func.func @inline_notation() -> i32 {
affine.if #set0(%2) {
} loc(fused<"myPass">["foo", "foo2"])
+ // CHECK: "foo.op"() : () -> () dialect<#test.custom_location<"foo.mlir" * 1234>>
+ "foo.op"() : () -> () loc(dialect<#test.custom_location<"foo.mlir" * 1234>>)
+
// CHECK: return %0 : i32 [unknown]
return %1 : i32 loc(unknown)
}
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index b3b94bd0ffea31..5177075a34c8f5 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -377,4 +377,15 @@ def NestedPolynomialAttr2 : Test_Attr<"NestedPolynomialAttr2"> {
}
+// Test custom location handling.
+def TestCustomLocationAttr
+ : Test_Attr<"TestCustomLocation", [NativeAttrTrait<"IsLocation">]> {
+ let mnemonic = "custom_location";
+ let parameters = (ins "mlir::StringAttr":$file, "unsigned":$line);
+
+ // Choose a silly separator token so we know it's hitting this code path
+ // and not another.
+ let assemblyFormat = "`<` $file `*` $line `>`";
+}
+
#endif // TEST_ATTRDEFS
>From 3b5a1ba012a8cb951380c217eb36afb4ef5a913b Mon Sep 17 00:00:00 2001
From: Aman LaChapelle <alachapelle at apple.com>
Date: Tue, 17 Sep 2024 21:26:48 -0700
Subject: [PATCH 2/2] Take suggestions from code review, and add a sanity-check
test to ensure LocationAttr::walk does what we expect it to do.
---
mlir/lib/AsmParser/LocationParser.cpp | 26 ----------------------
mlir/lib/AsmParser/Parser.cpp | 25 +++++++++++----------
mlir/lib/IR/AsmPrinter.cpp | 5 ++---
mlir/lib/IR/Location.cpp | 25 +++------------------
mlir/test/IR/invalid-locations.mlir | 7 ------
mlir/test/IR/locations.mlir | 4 ++--
mlir/test/IR/pretty-locations.mlir | 4 ++--
mlir/unittests/IR/CMakeLists.txt | 1 +
mlir/unittests/IR/LocationTest.cpp | 32 +++++++++++++++++++++++++++
9 files changed, 55 insertions(+), 74 deletions(-)
create mode 100644 mlir/unittests/IR/LocationTest.cpp
diff --git a/mlir/lib/AsmParser/LocationParser.cpp b/mlir/lib/AsmParser/LocationParser.cpp
index f66e67de1f8385..1365da03c7c3d6 100644
--- a/mlir/lib/AsmParser/LocationParser.cpp
+++ b/mlir/lib/AsmParser/LocationParser.cpp
@@ -153,29 +153,6 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
return success();
}
-ParseResult Parser::parseDialectLocation(LocationAttr &loc) {
- consumeToken(Token::bare_identifier);
-
- if (parseToken(Token::less,
- "expected `<` to start dialect location attribute"))
- return failure();
-
- Attribute locAttr = parseAttribute(Type{});
- // No attribute parsed, someone else has returned an error already.
- if (!locAttr)
- return failure();
-
- loc = llvm::dyn_cast<LocationAttr>(locAttr);
- if (!loc)
- return emitError() << "expected a location attribute";
-
- if (parseToken(Token::greater,
- "expected `>` to end dialect location attribute"))
- return failure();
-
- return success();
-}
-
ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
// Handle aliases.
if (getToken().is(Token::hash_identifier)) {
@@ -210,8 +187,5 @@ ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
return success();
}
- if (getToken().getSpelling() == "dialect")
- return parseDialectLocation(loc);
-
return emitWrongTokenError("expected location instance");
}
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 2e4c4a36d46b9b..71c17c46314fb5 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -631,7 +631,8 @@ class OperationParser : public Parser {
/// Parse a location alias, that is a sequence looking like: #loc42
/// The alias may have already be defined or may be defined later, in which
- /// case an OpaqueLoc is used a placeholder.
+ /// case an OpaqueLoc is used a placeholder. The caller must ensure that the
+ /// token is actually an alias, which means it must not contain a dot.
ParseResult parseLocationAlias(LocationAttr &loc);
/// This is the structure of a result specifier in the assembly syntax,
@@ -1917,9 +1918,10 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
Token tok = parser.getToken();
- // Check to see if we are parsing a location alias.
- // Otherwise, we parse the location directly.
- if (tok.is(Token::hash_identifier)) {
+ // Check to see if we are parsing a location alias. We are parsing a location
+ // alias if the token is a hash identifier *without* a dot in it - the dot
+ // signifies a dialect attribute. Otherwise, we parse the location directly.
+ if (tok.is(Token::hash_identifier) && !tok.getSpelling().contains('.')) {
if (parser.parseLocationAlias(directLoc))
return failure();
} else if (parser.parseLocationInstance(directLoc)) {
@@ -2086,11 +2088,9 @@ ParseResult OperationParser::parseLocationAlias(LocationAttr &loc) {
Token tok = getToken();
consumeToken(Token::hash_identifier);
StringRef identifier = tok.getSpelling().drop_front();
- if (identifier.contains('.')) {
- return emitError(tok.getLoc())
- << "expected location, but found dialect attribute: '#" << identifier
- << "'";
- }
+ assert(!identifier.contains('.') &&
+ "we should not be getting hash identifiers with dots here");
+
if (state.asmState)
state.asmState->addAttrAliasUses(identifier, tok.getLocRange());
@@ -2120,10 +2120,11 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
return failure();
Token tok = getToken();
- // Check to see if we are parsing a location alias.
- // Otherwise, we parse the location directly.
+ // Check to see if we are parsing a location alias. We are parsing a location
+ // alias if the token is a hash identifier *without* a dot in it - the dot
+ // signifies a dialect attribute. Otherwise, we parse the location directly.
LocationAttr directLoc;
- if (tok.is(Token::hash_identifier)) {
+ if (tok.is(Token::hash_identifier) && !tok.getSpelling().contains('.')) {
if (parseLocationAlias(directLoc))
return failure();
} else if (parseLocationInstance(directLoc)) {
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 68e76e863d63a0..1dd489efc3e908 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2063,10 +2063,9 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
os << ']';
})
.Default([&](LocationAttr loc) {
- // Assumes that this is a dialect-specific attribute.
- os << "dialect<";
+ // Assumes that this is a dialect-specific attribute and prints it
+ // directly.
printAttribute(loc);
- os << ">";
});
}
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index fc6163884b021b..eda9a390c5518d 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -38,28 +38,9 @@ void BuiltinDialect::registerLocationAttributes() {
//===----------------------------------------------------------------------===//
WalkResult LocationAttr::walk(function_ref<WalkResult(Location)> walkFn) {
- if (walkFn(*this).wasInterrupted())
- return WalkResult::interrupt();
-
- return TypeSwitch<LocationAttr, WalkResult>(*this)
- .Case([&](CallSiteLoc callLoc) -> WalkResult {
- if (callLoc.getCallee()->walk(walkFn).wasInterrupted())
- return WalkResult::interrupt();
- return callLoc.getCaller()->walk(walkFn);
- })
- .Case([&](FusedLoc fusedLoc) -> WalkResult {
- for (Location subLoc : fusedLoc.getLocations())
- if (subLoc->walk(walkFn).wasInterrupted())
- return WalkResult::interrupt();
- return WalkResult::advance();
- })
- .Case([&](NameLoc nameLoc) -> WalkResult {
- return nameLoc.getChildLoc()->walk(walkFn);
- })
- .Case([&](OpaqueLoc opaqueLoc) -> WalkResult {
- return opaqueLoc.getFallbackLocation()->walk(walkFn);
- })
- .Default(WalkResult::advance());
+ AttrTypeWalker walker;
+ walker.addWalk([&](LocationAttr loc) { return walkFn(loc); });
+ return walker.walk(*this);
}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
diff --git a/mlir/test/IR/invalid-locations.mlir b/mlir/test/IR/invalid-locations.mlir
index bb8006485c5be5..78f3def9380ae8 100644
--- a/mlir/test/IR/invalid-locations.mlir
+++ b/mlir/test/IR/invalid-locations.mlir
@@ -94,13 +94,6 @@ func.func @location_fused_missing_r_square() {
// -----
-func.func @location_invalid_alias() {
- // expected-error at +1 {{expected location, but found dialect attribute: '#foo.loc'}}
- return loc(#foo.loc)
-}
-
-// -----
-
func.func @location_invalid_alias() {
// expected-error at +1 {{operation location alias was never defined}}
return loc(#invalid_alias)
diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir
index 335afe00abfc49..0c6426ebec8746 100644
--- a/mlir/test/IR/locations.mlir
+++ b/mlir/test/IR/locations.mlir
@@ -91,8 +91,8 @@ func.func @optional_location_specifier() {
}
// CHECK-LABEL: @dialect_location
-// CHECK: test.attr_with_loc("dialectLoc" loc(dialect<#test.custom_location<"foo.mlir" * 32>>))
+// CHECK: test.attr_with_loc("dialectLoc" loc(#test.custom_location<"foo.mlir" * 32>))
func.func @dialect_location() {
- test.attr_with_loc("dialectLoc" loc(dialect<#test.custom_location<"foo.mlir"*32>>))
+ test.attr_with_loc("dialectLoc" loc(#test.custom_location<"foo.mlir"*32>))
return
}
diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir
index f9f6a365d1b3ae..598bebeb83aebd 100644
--- a/mlir/test/IR/pretty-locations.mlir
+++ b/mlir/test/IR/pretty-locations.mlir
@@ -24,8 +24,8 @@ func.func @inline_notation() -> i32 {
affine.if #set0(%2) {
} loc(fused<"myPass">["foo", "foo2"])
- // CHECK: "foo.op"() : () -> () dialect<#test.custom_location<"foo.mlir" * 1234>>
- "foo.op"() : () -> () loc(dialect<#test.custom_location<"foo.mlir" * 1234>>)
+ // CHECK: "foo.op"() : () -> () #test.custom_location<"foo.mlir" * 1234>
+ "foo.op"() : () -> () loc(#test.custom_location<"foo.mlir" * 1234>)
// CHECK: return %0 : i32 [unknown]
return %1 : i32 loc(unknown)
diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 05cb36e1903163..547e536dd9cbbf 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_unittest(MLIRIRTests
InterfaceTest.cpp
IRMapping.cpp
InterfaceAttachmentTest.cpp
+ LocationTest.cpp
OperationSupportTest.cpp
PatternMatchTest.cpp
ShapedTypeTest.cpp
diff --git a/mlir/unittests/IR/LocationTest.cpp b/mlir/unittests/IR/LocationTest.cpp
new file mode 100644
index 00000000000000..376afa1aa8e881
--- /dev/null
+++ b/mlir/unittests/IR/LocationTest.cpp
@@ -0,0 +1,32 @@
+//===- LocationTest.cpp - unit tests for affine map API -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Builders.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+// Check that we only walk *locations* and not non-location attributes.
+TEST(LocationTest, Walk) {
+ MLIRContext ctx;
+ Builder builder(&ctx);
+ BoolAttr trueAttr = builder.getBoolAttr(true);
+
+ Location loc1 = FileLineColLoc::get(builder.getStringAttr("foo"), 1, 2);
+ Location loc2 = FileLineColLoc::get(builder.getStringAttr("foo"), 3, 4);
+ Location fused = builder.getFusedLoc({loc1, loc2}, trueAttr);
+
+ SmallVector<Attribute> visited;
+ fused->walk([&](Location l) {
+ visited.push_back(LocationAttr(l));
+ return WalkResult::advance();
+ });
+
+ EXPECT_EQ(llvm::ArrayRef(visited), ArrayRef<Attribute>({loc1, loc2, fused}));
+}
More information about the Mlir-commits
mailing list