[Mlir-commits] [mlir] [mlir] Add the ability to define dialect-specific location attrs. (PR #105584)

Aman LaChapelle llvmlistbot at llvm.org
Wed Aug 21 14:03:58 PDT 2024


https://github.com/bzcheeseman updated https://github.com/llvm/llvm-project/pull/105584

>From f5ceac4c61c7d72d3ecd04d1e656250c29cc9c35 Mon Sep 17 00:00:00 2001
From: Aman LaChapelle <alachapelle at apple.com>
Date: Wed, 21 Aug 2024 13:49:44 -0700
Subject: [PATCH] [mlir] Add the ability to define dialect-specific location
 attrs.

This patch adds the capability to define dialect-specific location attrs. This is useful in particular for defining location structure that doesn't necessarily fit within the core MLIR location hierarchy, but doesn't make sense to push upstream (i.e. a custom use case).

This patch adds an AttributeTrait, `IsLocation`, which is tagged onto all the builtin location attrs, as well as the test location attribute. This is necessary because previously LocationAttr::classof only returned true if the attribute was one of the builtin location attributes, and well, the point of this patch is to allow dialects to define their own location attributes.

There was an alternate implementation I considered wherein LocationAttr becomes an AttrInterface, but that was discarded because there are likely to be *many* locations in a single program, and I was concerned that forcing every MLIR user to pay the cost of the additional lookup/dispatch was unacceptable. It also would have been a *much* more invasive change. It would have allowed for more flexibility in terms of pretty printing, but it's unclear how useful/necessary that flexibility would be given how much customizability there already is for attribute definitions.
---
 mlir/include/mlir/IR/Attributes.h             | 13 +++++++---
 .../mlir/IR/BuiltinLocationAttributes.td      |  3 ++-
 mlir/lib/AsmParser/LocationParser.cpp         | 26 +++++++++++++++++++
 mlir/lib/AsmParser/Parser.h                   |  3 +++
 mlir/lib/IR/AsmPrinter.cpp                    |  6 +++++
 mlir/lib/IR/Location.cpp                      |  3 +--
 mlir/test/IR/locations.mlir                   |  7 +++++
 mlir/test/IR/pretty-locations.mlir            |  3 +++
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    | 11 ++++++++
 9 files changed, 69 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 8a077865b51b5f..d347013295d5fc 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -322,12 +322,19 @@ class AttributeInterface
 // Core AttributeTrait
 //===----------------------------------------------------------------------===//
 
-/// This trait is used to determine if an attribute is mutable or not. It is
-/// attached on an attribute if the corresponding ImplType defines a `mutate`
-/// function with proper signature.
 namespace AttributeTrait {
+/// This trait is used to determine if an attribute is mutable or not. It is
+/// attached on an attribute if the corresponding ConcreteType defines a
+/// `mutate` function with proper signature.
 template <typename ConcreteType>
 using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
+
+/// This trait is used to determine if an attribute is a location or not. It is
+/// attached to an attribute by the user if they intend the attribute to be used
+/// as a location.
+template <typename ConcreteType>
+struct IsLocation : public AttributeTrait::TraitBase<ConcreteType, IsLocation> {
+};
 } // namespace AttributeTrait
 
 } // namespace mlir.
diff --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
index 5a72404dea15bb..3137a3089a0fc5 100644
--- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
@@ -18,7 +18,8 @@ include "mlir/IR/BuiltinDialect.td"
 
 // Base class for Builtin dialect location attributes.
 class Builtin_LocationAttr<string name, list<Trait> traits = []>
-    : AttrDef<Builtin_Dialect, name, traits, "::mlir::LocationAttr"> {
+    : AttrDef<Builtin_Dialect, name, traits # [NativeAttrTrait<"IsLocation">],
+              "::mlir::LocationAttr"> {
   let cppClassName = name;
   let mnemonic = ?;
 }
diff --git a/mlir/lib/AsmParser/LocationParser.cpp b/mlir/lib/AsmParser/LocationParser.cpp
index 1365da03c7c3d6..f66e67de1f8385 100644
--- a/mlir/lib/AsmParser/LocationParser.cpp
+++ b/mlir/lib/AsmParser/LocationParser.cpp
@@ -153,6 +153,29 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
   return success();
 }
 
+ParseResult Parser::parseDialectLocation(LocationAttr &loc) {
+  consumeToken(Token::bare_identifier);
+
+  if (parseToken(Token::less,
+                 "expected `<` to start dialect location attribute"))
+    return failure();
+
+  Attribute locAttr = parseAttribute(Type{});
+  // No attribute parsed, someone else has returned an error already.
+  if (!locAttr)
+    return failure();
+
+  loc = llvm::dyn_cast<LocationAttr>(locAttr);
+  if (!loc)
+    return emitError() << "expected a location attribute";
+
+  if (parseToken(Token::greater,
+                 "expected `>` to end dialect location attribute"))
+    return failure();
+
+  return success();
+}
+
 ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
   // Handle aliases.
   if (getToken().is(Token::hash_identifier)) {
@@ -187,5 +210,8 @@ ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
     return success();
   }
 
+  if (getToken().getSpelling() == "dialect")
+    return parseDialectLocation(loc);
+
   return emitWrongTokenError("expected location instance");
 }
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 4caab499e1a0e4..ee4c90b9e1caf1 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -303,6 +303,9 @@ class Parser {
   /// Parse a name or FileLineCol location instance.
   ParseResult parseNameOrFileLineColLocation(LocationAttr &loc);
 
+  /// Parse a dialect-specific location.
+  ParseResult parseDialectLocation(LocationAttr &loc);
+
   //===--------------------------------------------------------------------===//
   // Affine Parsing
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 02acc8c3f4659e..68e76e863d63a0 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2061,6 +2061,12 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
             [&](Location loc) { printLocationInternal(loc, pretty); },
             [&]() { os << ", "; });
         os << ']';
+      })
+      .Default([&](LocationAttr loc) {
+        // Assumes that this is a dialect-specific attribute.
+        os << "dialect<";
+        printAttribute(loc);
+        os << ">";
       });
 }
 
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index c548bbe4b6c860..fc6163884b021b 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -64,8 +64,7 @@ WalkResult LocationAttr::walk(function_ref<WalkResult(Location)> walkFn) {
 
 /// Methods for support type inquiry through isa, cast, and dyn_cast.
 bool LocationAttr::classof(Attribute attr) {
-  return llvm::isa<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
-                   UnknownLoc>(attr);
+  return attr.hasTrait<AttributeTrait::IsLocation>();
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir
index 8d7c7e4f13ed49..335afe00abfc49 100644
--- a/mlir/test/IR/locations.mlir
+++ b/mlir/test/IR/locations.mlir
@@ -89,3 +89,10 @@ func.func @optional_location_specifier() {
   test.attr_with_loc("foo" loc("foo_loc"))
   return
 }
+
+// CHECK-LABEL: @dialect_location
+// CHECK: test.attr_with_loc("dialectLoc" loc(dialect<#test.custom_location<"foo.mlir" * 32>>))
+func.func @dialect_location() {
+  test.attr_with_loc("dialectLoc" loc(dialect<#test.custom_location<"foo.mlir"*32>>))
+  return
+}
diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir
index e9337b5bef37b1..f9f6a365d1b3ae 100644
--- a/mlir/test/IR/pretty-locations.mlir
+++ b/mlir/test/IR/pretty-locations.mlir
@@ -24,6 +24,9 @@ func.func @inline_notation() -> i32 {
   affine.if #set0(%2) {
   } loc(fused<"myPass">["foo", "foo2"])
 
+  // CHECK: "foo.op"() : () -> () dialect<#test.custom_location<"foo.mlir" * 1234>>
+  "foo.op"() : () -> () loc(dialect<#test.custom_location<"foo.mlir" * 1234>>)
+
   // CHECK: return %0 : i32 [unknown]
   return %1 : i32 loc(unknown)
 }
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index b3b94bd0ffea31..5177075a34c8f5 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -377,4 +377,15 @@ def NestedPolynomialAttr2 : Test_Attr<"NestedPolynomialAttr2"> {
 }
 
 
+// Test custom location handling.
+def TestCustomLocationAttr
+    : Test_Attr<"TestCustomLocation", [NativeAttrTrait<"IsLocation">]> {
+  let mnemonic = "custom_location";
+  let parameters = (ins "mlir::StringAttr":$file, "unsigned":$line);
+
+  // Choose a silly separator token so we know it's hitting this code path
+  // and not another.
+  let assemblyFormat = "`<` $file `*` $line `>`";
+}
+
 #endif // TEST_ATTRDEFS



More information about the Mlir-commits mailing list