[Mlir-commits] [mlir] [mlir] Add the ability to define dialect-specific location attrs. (PR #105584)

Aman LaChapelle llvmlistbot at llvm.org
Wed Sep 18 10:01:54 PDT 2024


https://github.com/bzcheeseman updated https://github.com/llvm/llvm-project/pull/105584

>From f5ceac4c61c7d72d3ecd04d1e656250c29cc9c35 Mon Sep 17 00:00:00 2001
From: Aman LaChapelle <alachapelle at apple.com>
Date: Wed, 21 Aug 2024 13:49:44 -0700
Subject: [PATCH 1/5] [mlir] Add the ability to define dialect-specific
 location attrs.

This patch adds the capability to define dialect-specific location attrs. This is useful in particular for defining location structure that doesn't necessarily fit within the core MLIR location hierarchy, but doesn't make sense to push upstream (i.e. a custom use case).

This patch adds an AttributeTrait, `IsLocation`, which is tagged onto all the builtin location attrs, as well as the test location attribute. This is necessary because previously LocationAttr::classof only returned true if the attribute was one of the builtin location attributes, and well, the point of this patch is to allow dialects to define their own location attributes.

There was an alternate implementation I considered wherein LocationAttr becomes an AttrInterface, but that was discarded because there are likely to be *many* locations in a single program, and I was concerned that forcing every MLIR user to pay the cost of the additional lookup/dispatch was unacceptable. It also would have been a *much* more invasive change. It would have allowed for more flexibility in terms of pretty printing, but it's unclear how useful/necessary that flexibility would be given how much customizability there already is for attribute definitions.
---
 mlir/include/mlir/IR/Attributes.h             | 13 +++++++---
 .../mlir/IR/BuiltinLocationAttributes.td      |  3 ++-
 mlir/lib/AsmParser/LocationParser.cpp         | 26 +++++++++++++++++++
 mlir/lib/AsmParser/Parser.h                   |  3 +++
 mlir/lib/IR/AsmPrinter.cpp                    |  6 +++++
 mlir/lib/IR/Location.cpp                      |  3 +--
 mlir/test/IR/locations.mlir                   |  7 +++++
 mlir/test/IR/pretty-locations.mlir            |  3 +++
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    | 11 ++++++++
 9 files changed, 69 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 8a077865b51b5f..d347013295d5fc 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -322,12 +322,19 @@ class AttributeInterface
 // Core AttributeTrait
 //===----------------------------------------------------------------------===//
 
-/// This trait is used to determine if an attribute is mutable or not. It is
-/// attached on an attribute if the corresponding ImplType defines a `mutate`
-/// function with proper signature.
 namespace AttributeTrait {
+/// This trait is used to determine if an attribute is mutable or not. It is
+/// attached on an attribute if the corresponding ConcreteType defines a
+/// `mutate` function with proper signature.
 template <typename ConcreteType>
 using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
+
+/// This trait is used to determine if an attribute is a location or not. It is
+/// attached to an attribute by the user if they intend the attribute to be used
+/// as a location.
+template <typename ConcreteType>
+struct IsLocation : public AttributeTrait::TraitBase<ConcreteType, IsLocation> {
+};
 } // namespace AttributeTrait
 
 } // namespace mlir.
diff --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
index 5a72404dea15bb..3137a3089a0fc5 100644
--- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
@@ -18,7 +18,8 @@ include "mlir/IR/BuiltinDialect.td"
 
 // Base class for Builtin dialect location attributes.
 class Builtin_LocationAttr<string name, list<Trait> traits = []>
-    : AttrDef<Builtin_Dialect, name, traits, "::mlir::LocationAttr"> {
+    : AttrDef<Builtin_Dialect, name, traits # [NativeAttrTrait<"IsLocation">],
+              "::mlir::LocationAttr"> {
   let cppClassName = name;
   let mnemonic = ?;
 }
diff --git a/mlir/lib/AsmParser/LocationParser.cpp b/mlir/lib/AsmParser/LocationParser.cpp
index 1365da03c7c3d6..f66e67de1f8385 100644
--- a/mlir/lib/AsmParser/LocationParser.cpp
+++ b/mlir/lib/AsmParser/LocationParser.cpp
@@ -153,6 +153,29 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
   return success();
 }
 
+ParseResult Parser::parseDialectLocation(LocationAttr &loc) {
+  consumeToken(Token::bare_identifier);
+
+  if (parseToken(Token::less,
+                 "expected `<` to start dialect location attribute"))
+    return failure();
+
+  Attribute locAttr = parseAttribute(Type{});
+  // No attribute parsed, someone else has returned an error already.
+  if (!locAttr)
+    return failure();
+
+  loc = llvm::dyn_cast<LocationAttr>(locAttr);
+  if (!loc)
+    return emitError() << "expected a location attribute";
+
+  if (parseToken(Token::greater,
+                 "expected `>` to end dialect location attribute"))
+    return failure();
+
+  return success();
+}
+
 ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
   // Handle aliases.
   if (getToken().is(Token::hash_identifier)) {
@@ -187,5 +210,8 @@ ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
     return success();
   }
 
+  if (getToken().getSpelling() == "dialect")
+    return parseDialectLocation(loc);
+
   return emitWrongTokenError("expected location instance");
 }
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 4caab499e1a0e4..ee4c90b9e1caf1 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -303,6 +303,9 @@ class Parser {
   /// Parse a name or FileLineCol location instance.
   ParseResult parseNameOrFileLineColLocation(LocationAttr &loc);
 
+  /// Parse a dialect-specific location.
+  ParseResult parseDialectLocation(LocationAttr &loc);
+
   //===--------------------------------------------------------------------===//
   // Affine Parsing
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 02acc8c3f4659e..68e76e863d63a0 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2061,6 +2061,12 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
             [&](Location loc) { printLocationInternal(loc, pretty); },
             [&]() { os << ", "; });
         os << ']';
+      })
+      .Default([&](LocationAttr loc) {
+        // Assumes that this is a dialect-specific attribute.
+        os << "dialect<";
+        printAttribute(loc);
+        os << ">";
       });
 }
 
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index c548bbe4b6c860..fc6163884b021b 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -64,8 +64,7 @@ WalkResult LocationAttr::walk(function_ref<WalkResult(Location)> walkFn) {
 
 /// Methods for support type inquiry through isa, cast, and dyn_cast.
 bool LocationAttr::classof(Attribute attr) {
-  return llvm::isa<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
-                   UnknownLoc>(attr);
+  return attr.hasTrait<AttributeTrait::IsLocation>();
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir
index 8d7c7e4f13ed49..335afe00abfc49 100644
--- a/mlir/test/IR/locations.mlir
+++ b/mlir/test/IR/locations.mlir
@@ -89,3 +89,10 @@ func.func @optional_location_specifier() {
   test.attr_with_loc("foo" loc("foo_loc"))
   return
 }
+
+// CHECK-LABEL: @dialect_location
+// CHECK: test.attr_with_loc("dialectLoc" loc(dialect<#test.custom_location<"foo.mlir" * 32>>))
+func.func @dialect_location() {
+  test.attr_with_loc("dialectLoc" loc(dialect<#test.custom_location<"foo.mlir"*32>>))
+  return
+}
diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir
index e9337b5bef37b1..f9f6a365d1b3ae 100644
--- a/mlir/test/IR/pretty-locations.mlir
+++ b/mlir/test/IR/pretty-locations.mlir
@@ -24,6 +24,9 @@ func.func @inline_notation() -> i32 {
   affine.if #set0(%2) {
   } loc(fused<"myPass">["foo", "foo2"])
 
+  // CHECK: "foo.op"() : () -> () dialect<#test.custom_location<"foo.mlir" * 1234>>
+  "foo.op"() : () -> () loc(dialect<#test.custom_location<"foo.mlir" * 1234>>)
+
   // CHECK: return %0 : i32 [unknown]
   return %1 : i32 loc(unknown)
 }
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index b3b94bd0ffea31..5177075a34c8f5 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -377,4 +377,15 @@ def NestedPolynomialAttr2 : Test_Attr<"NestedPolynomialAttr2"> {
 }
 
 
+// Test custom location handling.
+def TestCustomLocationAttr
+    : Test_Attr<"TestCustomLocation", [NativeAttrTrait<"IsLocation">]> {
+  let mnemonic = "custom_location";
+  let parameters = (ins "mlir::StringAttr":$file, "unsigned":$line);
+
+  // Choose a silly separator token so we know it's hitting this code path
+  // and not another.
+  let assemblyFormat = "`<` $file `*` $line `>`";
+}
+
 #endif // TEST_ATTRDEFS

>From 3b5a1ba012a8cb951380c217eb36afb4ef5a913b Mon Sep 17 00:00:00 2001
From: Aman LaChapelle <alachapelle at apple.com>
Date: Tue, 17 Sep 2024 21:26:48 -0700
Subject: [PATCH 2/5] Take suggestions from code review, and add a sanity-check
 test to ensure LocationAttr::walk does what we expect it to do.

---
 mlir/lib/AsmParser/LocationParser.cpp | 26 ----------------------
 mlir/lib/AsmParser/Parser.cpp         | 25 +++++++++++----------
 mlir/lib/IR/AsmPrinter.cpp            |  5 ++---
 mlir/lib/IR/Location.cpp              | 25 +++------------------
 mlir/test/IR/invalid-locations.mlir   |  7 ------
 mlir/test/IR/locations.mlir           |  4 ++--
 mlir/test/IR/pretty-locations.mlir    |  4 ++--
 mlir/unittests/IR/CMakeLists.txt      |  1 +
 mlir/unittests/IR/LocationTest.cpp    | 32 +++++++++++++++++++++++++++
 9 files changed, 55 insertions(+), 74 deletions(-)
 create mode 100644 mlir/unittests/IR/LocationTest.cpp

diff --git a/mlir/lib/AsmParser/LocationParser.cpp b/mlir/lib/AsmParser/LocationParser.cpp
index f66e67de1f8385..1365da03c7c3d6 100644
--- a/mlir/lib/AsmParser/LocationParser.cpp
+++ b/mlir/lib/AsmParser/LocationParser.cpp
@@ -153,29 +153,6 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
   return success();
 }
 
-ParseResult Parser::parseDialectLocation(LocationAttr &loc) {
-  consumeToken(Token::bare_identifier);
-
-  if (parseToken(Token::less,
-                 "expected `<` to start dialect location attribute"))
-    return failure();
-
-  Attribute locAttr = parseAttribute(Type{});
-  // No attribute parsed, someone else has returned an error already.
-  if (!locAttr)
-    return failure();
-
-  loc = llvm::dyn_cast<LocationAttr>(locAttr);
-  if (!loc)
-    return emitError() << "expected a location attribute";
-
-  if (parseToken(Token::greater,
-                 "expected `>` to end dialect location attribute"))
-    return failure();
-
-  return success();
-}
-
 ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
   // Handle aliases.
   if (getToken().is(Token::hash_identifier)) {
@@ -210,8 +187,5 @@ ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
     return success();
   }
 
-  if (getToken().getSpelling() == "dialect")
-    return parseDialectLocation(loc);
-
   return emitWrongTokenError("expected location instance");
 }
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 2e4c4a36d46b9b..71c17c46314fb5 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -631,7 +631,8 @@ class OperationParser : public Parser {
 
   /// Parse a location alias, that is a sequence looking like: #loc42
   /// The alias may have already be defined or may be defined later, in which
-  /// case an OpaqueLoc is used a placeholder.
+  /// case an OpaqueLoc is used a placeholder. The caller must ensure that the
+  /// token is actually an alias, which means it must not contain a dot.
   ParseResult parseLocationAlias(LocationAttr &loc);
 
   /// This is the structure of a result specifier in the assembly syntax,
@@ -1917,9 +1918,10 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
 
     Token tok = parser.getToken();
 
-    // Check to see if we are parsing a location alias.
-    // Otherwise, we parse the location directly.
-    if (tok.is(Token::hash_identifier)) {
+    // Check to see if we are parsing a location alias. We are parsing a location
+    // alias if the token is a hash identifier *without* a dot in it - the dot
+    // signifies a dialect attribute. Otherwise, we parse the location directly.
+    if (tok.is(Token::hash_identifier) && !tok.getSpelling().contains('.')) {
       if (parser.parseLocationAlias(directLoc))
         return failure();
     } else if (parser.parseLocationInstance(directLoc)) {
@@ -2086,11 +2088,9 @@ ParseResult OperationParser::parseLocationAlias(LocationAttr &loc) {
   Token tok = getToken();
   consumeToken(Token::hash_identifier);
   StringRef identifier = tok.getSpelling().drop_front();
-  if (identifier.contains('.')) {
-    return emitError(tok.getLoc())
-           << "expected location, but found dialect attribute: '#" << identifier
-           << "'";
-  }
+  assert(!identifier.contains('.') &&
+         "we should not be getting hash identifiers with dots here");
+
   if (state.asmState)
     state.asmState->addAttrAliasUses(identifier, tok.getLocRange());
 
@@ -2120,10 +2120,11 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
     return failure();
   Token tok = getToken();
 
-  // Check to see if we are parsing a location alias.
-  // Otherwise, we parse the location directly.
+  // Check to see if we are parsing a location alias. We are parsing a location
+  // alias if the token is a hash identifier *without* a dot in it - the dot
+  // signifies a dialect attribute. Otherwise, we parse the location directly.
   LocationAttr directLoc;
-  if (tok.is(Token::hash_identifier)) {
+  if (tok.is(Token::hash_identifier) && !tok.getSpelling().contains('.')) {
     if (parseLocationAlias(directLoc))
       return failure();
   } else if (parseLocationInstance(directLoc)) {
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 68e76e863d63a0..1dd489efc3e908 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2063,10 +2063,9 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
         os << ']';
       })
       .Default([&](LocationAttr loc) {
-        // Assumes that this is a dialect-specific attribute.
-        os << "dialect<";
+        // Assumes that this is a dialect-specific attribute and prints it
+        // directly.
         printAttribute(loc);
-        os << ">";
       });
 }
 
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index fc6163884b021b..eda9a390c5518d 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -38,28 +38,9 @@ void BuiltinDialect::registerLocationAttributes() {
 //===----------------------------------------------------------------------===//
 
 WalkResult LocationAttr::walk(function_ref<WalkResult(Location)> walkFn) {
-  if (walkFn(*this).wasInterrupted())
-    return WalkResult::interrupt();
-
-  return TypeSwitch<LocationAttr, WalkResult>(*this)
-      .Case([&](CallSiteLoc callLoc) -> WalkResult {
-        if (callLoc.getCallee()->walk(walkFn).wasInterrupted())
-          return WalkResult::interrupt();
-        return callLoc.getCaller()->walk(walkFn);
-      })
-      .Case([&](FusedLoc fusedLoc) -> WalkResult {
-        for (Location subLoc : fusedLoc.getLocations())
-          if (subLoc->walk(walkFn).wasInterrupted())
-            return WalkResult::interrupt();
-        return WalkResult::advance();
-      })
-      .Case([&](NameLoc nameLoc) -> WalkResult {
-        return nameLoc.getChildLoc()->walk(walkFn);
-      })
-      .Case([&](OpaqueLoc opaqueLoc) -> WalkResult {
-        return opaqueLoc.getFallbackLocation()->walk(walkFn);
-      })
-      .Default(WalkResult::advance());
+  AttrTypeWalker walker;
+  walker.addWalk([&](LocationAttr loc) { return walkFn(loc); });
+  return walker.walk(*this);
 }
 
 /// Methods for support type inquiry through isa, cast, and dyn_cast.
diff --git a/mlir/test/IR/invalid-locations.mlir b/mlir/test/IR/invalid-locations.mlir
index bb8006485c5be5..78f3def9380ae8 100644
--- a/mlir/test/IR/invalid-locations.mlir
+++ b/mlir/test/IR/invalid-locations.mlir
@@ -94,13 +94,6 @@ func.func @location_fused_missing_r_square() {
 
 // -----
 
-func.func @location_invalid_alias() {
-  // expected-error at +1 {{expected location, but found dialect attribute: '#foo.loc'}}
-  return loc(#foo.loc)
-}
-
-// -----
-
 func.func @location_invalid_alias() {
   // expected-error at +1 {{operation location alias was never defined}}
   return loc(#invalid_alias)
diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir
index 335afe00abfc49..0c6426ebec8746 100644
--- a/mlir/test/IR/locations.mlir
+++ b/mlir/test/IR/locations.mlir
@@ -91,8 +91,8 @@ func.func @optional_location_specifier() {
 }
 
 // CHECK-LABEL: @dialect_location
-// CHECK: test.attr_with_loc("dialectLoc" loc(dialect<#test.custom_location<"foo.mlir" * 32>>))
+// CHECK: test.attr_with_loc("dialectLoc" loc(#test.custom_location<"foo.mlir" * 32>))
 func.func @dialect_location() {
-  test.attr_with_loc("dialectLoc" loc(dialect<#test.custom_location<"foo.mlir"*32>>))
+  test.attr_with_loc("dialectLoc" loc(#test.custom_location<"foo.mlir"*32>))
   return
 }
diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir
index f9f6a365d1b3ae..598bebeb83aebd 100644
--- a/mlir/test/IR/pretty-locations.mlir
+++ b/mlir/test/IR/pretty-locations.mlir
@@ -24,8 +24,8 @@ func.func @inline_notation() -> i32 {
   affine.if #set0(%2) {
   } loc(fused<"myPass">["foo", "foo2"])
 
-  // CHECK: "foo.op"() : () -> () dialect<#test.custom_location<"foo.mlir" * 1234>>
-  "foo.op"() : () -> () loc(dialect<#test.custom_location<"foo.mlir" * 1234>>)
+  // CHECK: "foo.op"() : () -> () #test.custom_location<"foo.mlir" * 1234>
+  "foo.op"() : () -> () loc(#test.custom_location<"foo.mlir" * 1234>)
 
   // CHECK: return %0 : i32 [unknown]
   return %1 : i32 loc(unknown)
diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 05cb36e1903163..547e536dd9cbbf 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_unittest(MLIRIRTests
   InterfaceTest.cpp
   IRMapping.cpp
   InterfaceAttachmentTest.cpp
+  LocationTest.cpp
   OperationSupportTest.cpp
   PatternMatchTest.cpp
   ShapedTypeTest.cpp
diff --git a/mlir/unittests/IR/LocationTest.cpp b/mlir/unittests/IR/LocationTest.cpp
new file mode 100644
index 00000000000000..376afa1aa8e881
--- /dev/null
+++ b/mlir/unittests/IR/LocationTest.cpp
@@ -0,0 +1,32 @@
+//===- LocationTest.cpp - unit tests for affine map API -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Builders.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+// Check that we only walk *locations* and not non-location attributes.
+TEST(LocationTest, Walk) {
+  MLIRContext ctx;
+  Builder builder(&ctx);
+  BoolAttr trueAttr = builder.getBoolAttr(true);
+
+  Location loc1 = FileLineColLoc::get(builder.getStringAttr("foo"), 1, 2);
+  Location loc2 = FileLineColLoc::get(builder.getStringAttr("foo"), 3, 4);
+  Location fused = builder.getFusedLoc({loc1, loc2}, trueAttr);
+
+  SmallVector<Attribute> visited;
+  fused->walk([&](Location l) {
+    visited.push_back(LocationAttr(l));
+    return WalkResult::advance();
+  });
+
+  EXPECT_EQ(llvm::ArrayRef(visited), ArrayRef<Attribute>({loc1, loc2, fused}));
+}

>From fa38aca3ff05f0cbc91cdc2d4ad9a41ea436f658 Mon Sep 17 00:00:00 2001
From: Aman LaChapelle <alachapelle at apple.com>
Date: Tue, 17 Sep 2024 21:27:51 -0700
Subject: [PATCH 3/5] Forgot to delete this

---
 mlir/lib/AsmParser/Parser.h | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index ee4c90b9e1caf1..4caab499e1a0e4 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -303,9 +303,6 @@ class Parser {
   /// Parse a name or FileLineCol location instance.
   ParseResult parseNameOrFileLineColLocation(LocationAttr &loc);
 
-  /// Parse a dialect-specific location.
-  ParseResult parseDialectLocation(LocationAttr &loc);
-
   //===--------------------------------------------------------------------===//
   // Affine Parsing
   //===--------------------------------------------------------------------===//

>From 2624e45181b5a67564f5f8d744129827f3c78146 Mon Sep 17 00:00:00 2001
From: Aman LaChapelle <alachapelle at apple.com>
Date: Wed, 18 Sep 2024 09:13:58 -0700
Subject: [PATCH 4/5] format

---
 mlir/lib/AsmParser/Parser.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 71c17c46314fb5..d241ba63c3561d 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1918,9 +1918,10 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
 
     Token tok = parser.getToken();
 
-    // Check to see if we are parsing a location alias. We are parsing a location
-    // alias if the token is a hash identifier *without* a dot in it - the dot
-    // signifies a dialect attribute. Otherwise, we parse the location directly.
+    // Check to see if we are parsing a location alias. We are parsing a
+    // location alias if the token is a hash identifier *without* a dot in it -
+    // the dot signifies a dialect attribute. Otherwise, we parse the location
+    // directly.
     if (tok.is(Token::hash_identifier) && !tok.getSpelling().contains('.')) {
       if (parser.parseLocationAlias(directLoc))
         return failure();

>From 6823673a68d0e171c0f833378e9275d2fbfad124 Mon Sep 17 00:00:00 2001
From: Aman LaChapelle <alachapelle at apple.com>
Date: Wed, 18 Sep 2024 10:01:28 -0700
Subject: [PATCH 5/5] next round of comments

---
 mlir/include/mlir/IR/AttrTypeBase.td          |  6 +++++
 .../mlir/IR/BuiltinLocationAttributes.td      |  3 +--
 mlir/include/mlir/IR/Location.h               |  5 +++-
 mlir/lib/AsmParser/Parser.cpp                 |  2 +-
 mlir/lib/IR/Location.cpp                      | 10 ++++++--
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    |  5 ++--
 mlir/unittests/IR/LocationTest.cpp            | 23 ++++++++++++++++++-
 7 files changed, 45 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index d176b36068f7a5..cbe4f0d67574b3 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -281,6 +281,12 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
   let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
 }
 
+// Provide a LocationAttrDef for dialects to provide their own locations
+// that subclass LocationAttr.
+class LocationAttrDef<Dialect dialect, string name, list<Trait> traits = []>
+    : AttrDef<dialect, name, traits # [NativeAttrTrait<"IsLocation">],
+              "::mlir::LocationAttr">;
+
 // Define a new type, named `name`, belonging to `dialect` that inherits from
 // the given C++ base class.
 class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
diff --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
index 3137a3089a0fc5..bbe566ce977775 100644
--- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
@@ -18,8 +18,7 @@ include "mlir/IR/BuiltinDialect.td"
 
 // Base class for Builtin dialect location attributes.
 class Builtin_LocationAttr<string name, list<Trait> traits = []>
-    : AttrDef<Builtin_Dialect, name, traits # [NativeAttrTrait<"IsLocation">],
-              "::mlir::LocationAttr"> {
+    : LocationAttrDef<Builtin_Dialect, name, traits> {
   let cppClassName = name;
   let mnemonic = ?;
 }
diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index 423b4d19b5b944..5eb1bfaf4afcdc 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -32,7 +32,10 @@ class LocationAttr : public Attribute {
 public:
   using Attribute::Attribute;
 
-  /// Walk all of the locations nested under, and including, the current.
+  /// Walk all of the locations nested directly under, and including, the
+  /// current. This means that if a location is nested under a non-location
+  /// attribute, it will *not* be walked by this method. This walk is performed
+  /// in pre-order to get this behavior.
   WalkResult walk(function_ref<WalkResult(Location)> walkFn);
 
   /// Return an instance of the given location type if one is nested under the
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index d241ba63c3561d..0570f6b8c6a6bd 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -2090,7 +2090,7 @@ ParseResult OperationParser::parseLocationAlias(LocationAttr &loc) {
   consumeToken(Token::hash_identifier);
   StringRef identifier = tok.getSpelling().drop_front();
   assert(!identifier.contains('.') &&
-         "we should not be getting hash identifiers with dots here");
+         "unexpected dialect attribute token, expected alias");
 
   if (state.asmState)
     state.asmState->addAttrAliasUses(identifier, tok.getLocRange());
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index eda9a390c5518d..dbd84912a8657d 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -39,8 +39,14 @@ void BuiltinDialect::registerLocationAttributes() {
 
 WalkResult LocationAttr::walk(function_ref<WalkResult(Location)> walkFn) {
   AttrTypeWalker walker;
-  walker.addWalk([&](LocationAttr loc) { return walkFn(loc); });
-  return walker.walk(*this);
+  // Walk locations, but skip any other attribute.
+  walker.addWalk([&](Attribute attr) {
+    if (auto loc = llvm::dyn_cast<LocationAttr>(attr))
+      return walkFn(loc);
+
+    return WalkResult::skip();
+  });
+  return walker.walk<WalkOrder::PreOrder>(*this);
 }
 
 /// Methods for support type inquiry through isa, cast, and dyn_cast.
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 5177075a34c8f5..907184b2e1ce4c 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -27,6 +27,8 @@ include "mlir/IR/OpAsmInterface.td"
 class Test_Attr<string name, list<Trait> traits = []>
     : AttrDef<Test_Dialect, name, traits>;
 
+class Test_LocAttr<string name> : LocationAttrDef<Test_Dialect, name, []>;
+
 def SimpleAttrA : Test_Attr<"SimpleA"> {
   let mnemonic = "smpla";
 }
@@ -378,8 +380,7 @@ def NestedPolynomialAttr2 : Test_Attr<"NestedPolynomialAttr2"> {
 
 
 // Test custom location handling.
-def TestCustomLocationAttr
-    : Test_Attr<"TestCustomLocation", [NativeAttrTrait<"IsLocation">]> {
+def TestCustomLocationAttr : Test_LocAttr<"TestCustomLocation"> {
   let mnemonic = "custom_location";
   let parameters = (ins "mlir::StringAttr":$file, "unsigned":$line);
 
diff --git a/mlir/unittests/IR/LocationTest.cpp b/mlir/unittests/IR/LocationTest.cpp
index 376afa1aa8e881..03374ee0b8467b 100644
--- a/mlir/unittests/IR/LocationTest.cpp
+++ b/mlir/unittests/IR/LocationTest.cpp
@@ -28,5 +28,26 @@ TEST(LocationTest, Walk) {
     return WalkResult::advance();
   });
 
-  EXPECT_EQ(llvm::ArrayRef(visited), ArrayRef<Attribute>({loc1, loc2, fused}));
+  EXPECT_EQ(llvm::ArrayRef(visited), ArrayRef<Attribute>({fused, loc1, loc2}));
+}
+
+// Check that we skip location attrs nested under a non-location attr.
+TEST(LocationTest, SkipNested) {
+  MLIRContext ctx;
+  Builder builder(&ctx);
+
+  Location loc1 = FileLineColLoc::get(builder.getStringAttr("foo"), 1, 2);
+  Location loc2 = FileLineColLoc::get(builder.getStringAttr("foo"), 3, 4);
+  Location loc3 = FileLineColLoc::get(builder.getStringAttr("bar"), 1, 2);
+  Location loc4 = FileLineColLoc::get(builder.getStringAttr("bar"), 3, 4);
+  ArrayAttr arr = builder.getArrayAttr({loc3, loc4});
+  Location fused = builder.getFusedLoc({loc1, loc2}, arr);
+
+  SmallVector<Attribute> visited;
+  fused->walk([&](Location l) {
+    visited.push_back(LocationAttr(l));
+    return WalkResult::advance();
+  });
+
+  EXPECT_EQ(llvm::ArrayRef(visited), ArrayRef<Attribute>({fused, loc1, loc2}));
 }



More information about the Mlir-commits mailing list