[Mlir-commits] [mlir] 759a7b5 - [mlir] Add the ability to define dialect-specific location attrs. (#105584)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 3 10:25:48 PDT 2024


Author: Aman LaChapelle
Date: 2024-10-03T10:25:44-07:00
New Revision: 759a7b5933654b67b9b7089d9aef2ca287cc38fa

URL: https://github.com/llvm/llvm-project/commit/759a7b5933654b67b9b7089d9aef2ca287cc38fa
DIFF: https://github.com/llvm/llvm-project/commit/759a7b5933654b67b9b7089d9aef2ca287cc38fa.diff

LOG: [mlir] Add the ability to define dialect-specific location attrs. (#105584)

This patch adds the capability to define dialect-specific location
attrs. This is useful in particular for defining location structure that
doesn't necessarily fit within the core MLIR location hierarchy, but
doesn't make sense to push upstream (i.e. a custom use case).

This patch adds an AttributeTrait, `IsLocation`, which is tagged onto
all the builtin location attrs, as well as the test location attribute.
This is necessary because previously LocationAttr::classof only returned
true if the attribute was one of the builtin location attributes, and
well, the point of this patch is to allow dialects to define their own
location attributes.

There was an alternate implementation I considered wherein LocationAttr
becomes an AttrInterface, but that was discarded because there are
likely to be *many* locations in a single program, and I was concerned
that forcing every MLIR user to pay the cost of the additional
lookup/dispatch was unacceptable. It also would have been a *much* more
invasive change. It would have allowed for more flexibility in terms of
pretty printing, but it's unclear how useful/necessary that flexibility
would be given how much customizability there already is for attribute
definitions.

Added: 
    mlir/unittests/IR/LocationTest.cpp

Modified: 
    mlir/include/mlir/IR/AttrTypeBase.td
    mlir/include/mlir/IR/Attributes.h
    mlir/include/mlir/IR/BuiltinLocationAttributes.td
    mlir/include/mlir/IR/Location.h
    mlir/lib/AsmParser/Parser.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Location.cpp
    mlir/test/IR/invalid-locations.mlir
    mlir/test/IR/locations.mlir
    mlir/test/IR/pretty-locations.mlir
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/unittests/IR/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index d176b36068f7a5..cbe4f0d67574b3 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -281,6 +281,12 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
   let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
 }
 
+// Provide a LocationAttrDef for dialects to provide their own locations
+// that subclass LocationAttr.
+class LocationAttrDef<Dialect dialect, string name, list<Trait> traits = []>
+    : AttrDef<dialect, name, traits # [NativeAttrTrait<"IsLocation">],
+              "::mlir::LocationAttr">;
+
 // Define a new type, named `name`, belonging to `dialect` that inherits from
 // the given C++ base class.
 class TypeDef<Dialect dialect, string name, list<Trait> traits = [],

diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 8a077865b51b5f..d347013295d5fc 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -322,12 +322,19 @@ class AttributeInterface
 // Core AttributeTrait
 //===----------------------------------------------------------------------===//
 
-/// This trait is used to determine if an attribute is mutable or not. It is
-/// attached on an attribute if the corresponding ImplType defines a `mutate`
-/// function with proper signature.
 namespace AttributeTrait {
+/// This trait is used to determine if an attribute is mutable or not. It is
+/// attached on an attribute if the corresponding ConcreteType defines a
+/// `mutate` function with proper signature.
 template <typename ConcreteType>
 using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
+
+/// This trait is used to determine if an attribute is a location or not. It is
+/// attached to an attribute by the user if they intend the attribute to be used
+/// as a location.
+template <typename ConcreteType>
+struct IsLocation : public AttributeTrait::TraitBase<ConcreteType, IsLocation> {
+};
 } // namespace AttributeTrait
 
 } // namespace mlir.

diff  --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
index 5a72404dea15bb..bbe566ce977775 100644
--- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
@@ -18,7 +18,7 @@ include "mlir/IR/BuiltinDialect.td"
 
 // Base class for Builtin dialect location attributes.
 class Builtin_LocationAttr<string name, list<Trait> traits = []>
-    : AttrDef<Builtin_Dialect, name, traits, "::mlir::LocationAttr"> {
+    : LocationAttrDef<Builtin_Dialect, name, traits> {
   let cppClassName = name;
   let mnemonic = ?;
 }

diff  --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index 423b4d19b5b944..5eb1bfaf4afcdc 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -32,7 +32,10 @@ class LocationAttr : public Attribute {
 public:
   using Attribute::Attribute;
 
-  /// Walk all of the locations nested under, and including, the current.
+  /// Walk all of the locations nested directly under, and including, the
+  /// current. This means that if a location is nested under a non-location
+  /// attribute, it will *not* be walked by this method. This walk is performed
+  /// in pre-order to get this behavior.
   WalkResult walk(function_ref<WalkResult(Location)> walkFn);
 
   /// Return an instance of the given location type if one is nested under the

diff  --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 83eec3244009d8..8f19487d80fa39 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -631,7 +631,8 @@ class OperationParser : public Parser {
 
   /// Parse a location alias, that is a sequence looking like: #loc42
   /// The alias may have already be defined or may be defined later, in which
-  /// case an OpaqueLoc is used a placeholder.
+  /// case an OpaqueLoc is used a placeholder. The caller must ensure that the
+  /// token is actually an alias, which means it must not contain a dot.
   ParseResult parseLocationAlias(LocationAttr &loc);
 
   /// This is the structure of a result specifier in the assembly syntax,
@@ -1917,9 +1918,11 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
 
     Token tok = parser.getToken();
 
-    // Check to see if we are parsing a location alias.
-    // Otherwise, we parse the location directly.
-    if (tok.is(Token::hash_identifier)) {
+    // Check to see if we are parsing a location alias. We are parsing a
+    // location alias if the token is a hash identifier *without* a dot in it -
+    // the dot signifies a dialect attribute. Otherwise, we parse the location
+    // directly.
+    if (tok.is(Token::hash_identifier) && !tok.getSpelling().contains('.')) {
       if (parser.parseLocationAlias(directLoc))
         return failure();
     } else if (parser.parseLocationInstance(directLoc)) {
@@ -2086,11 +2089,9 @@ ParseResult OperationParser::parseLocationAlias(LocationAttr &loc) {
   Token tok = getToken();
   consumeToken(Token::hash_identifier);
   StringRef identifier = tok.getSpelling().drop_front();
-  if (identifier.contains('.')) {
-    return emitError(tok.getLoc())
-           << "expected location, but found dialect attribute: '#" << identifier
-           << "'";
-  }
+  assert(!identifier.contains('.') &&
+         "unexpected dialect attribute token, expected alias");
+
   if (state.asmState)
     state.asmState->addAttrAliasUses(identifier, tok.getLocRange());
 
@@ -2120,10 +2121,11 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
     return failure();
   Token tok = getToken();
 
-  // Check to see if we are parsing a location alias.
-  // Otherwise, we parse the location directly.
+  // Check to see if we are parsing a location alias. We are parsing a location
+  // alias if the token is a hash identifier *without* a dot in it - the dot
+  // signifies a dialect attribute. Otherwise, we parse the location directly.
   LocationAttr directLoc;
-  if (tok.is(Token::hash_identifier)) {
+  if (tok.is(Token::hash_identifier) && !tok.getSpelling().contains('.')) {
     if (parseLocationAlias(directLoc))
       return failure();
   } else if (parseLocationInstance(directLoc)) {

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 7f95f5ace8c00f..5e9ea7bc088bfe 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2064,6 +2064,11 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
             [&](Location loc) { printLocationInternal(loc, pretty); },
             [&]() { os << ", "; });
         os << ']';
+      })
+      .Default([&](LocationAttr loc) {
+        // Assumes that this is a dialect-specific attribute and prints it
+        // directly.
+        printAttribute(loc);
       });
 }
 

diff  --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index c548bbe4b6c860..dbd84912a8657d 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -38,34 +38,20 @@ void BuiltinDialect::registerLocationAttributes() {
 //===----------------------------------------------------------------------===//
 
 WalkResult LocationAttr::walk(function_ref<WalkResult(Location)> walkFn) {
-  if (walkFn(*this).wasInterrupted())
-    return WalkResult::interrupt();
-
-  return TypeSwitch<LocationAttr, WalkResult>(*this)
-      .Case([&](CallSiteLoc callLoc) -> WalkResult {
-        if (callLoc.getCallee()->walk(walkFn).wasInterrupted())
-          return WalkResult::interrupt();
-        return callLoc.getCaller()->walk(walkFn);
-      })
-      .Case([&](FusedLoc fusedLoc) -> WalkResult {
-        for (Location subLoc : fusedLoc.getLocations())
-          if (subLoc->walk(walkFn).wasInterrupted())
-            return WalkResult::interrupt();
-        return WalkResult::advance();
-      })
-      .Case([&](NameLoc nameLoc) -> WalkResult {
-        return nameLoc.getChildLoc()->walk(walkFn);
-      })
-      .Case([&](OpaqueLoc opaqueLoc) -> WalkResult {
-        return opaqueLoc.getFallbackLocation()->walk(walkFn);
-      })
-      .Default(WalkResult::advance());
+  AttrTypeWalker walker;
+  // Walk locations, but skip any other attribute.
+  walker.addWalk([&](Attribute attr) {
+    if (auto loc = llvm::dyn_cast<LocationAttr>(attr))
+      return walkFn(loc);
+
+    return WalkResult::skip();
+  });
+  return walker.walk<WalkOrder::PreOrder>(*this);
 }
 
 /// Methods for support type inquiry through isa, cast, and dyn_cast.
 bool LocationAttr::classof(Attribute attr) {
-  return llvm::isa<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
-                   UnknownLoc>(attr);
+  return attr.hasTrait<AttributeTrait::IsLocation>();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/IR/invalid-locations.mlir b/mlir/test/IR/invalid-locations.mlir
index bb8006485c5be5..78f3def9380ae8 100644
--- a/mlir/test/IR/invalid-locations.mlir
+++ b/mlir/test/IR/invalid-locations.mlir
@@ -94,13 +94,6 @@ func.func @location_fused_missing_r_square() {
 
 // -----
 
-func.func @location_invalid_alias() {
-  // expected-error at +1 {{expected location, but found dialect attribute: '#foo.loc'}}
-  return loc(#foo.loc)
-}
-
-// -----
-
 func.func @location_invalid_alias() {
   // expected-error at +1 {{operation location alias was never defined}}
   return loc(#invalid_alias)

diff  --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir
index 8d7c7e4f13ed49..0c6426ebec8746 100644
--- a/mlir/test/IR/locations.mlir
+++ b/mlir/test/IR/locations.mlir
@@ -89,3 +89,10 @@ func.func @optional_location_specifier() {
   test.attr_with_loc("foo" loc("foo_loc"))
   return
 }
+
+// CHECK-LABEL: @dialect_location
+// CHECK: test.attr_with_loc("dialectLoc" loc(#test.custom_location<"foo.mlir" * 32>))
+func.func @dialect_location() {
+  test.attr_with_loc("dialectLoc" loc(#test.custom_location<"foo.mlir"*32>))
+  return
+}

diff  --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir
index e9337b5bef37b1..598bebeb83aebd 100644
--- a/mlir/test/IR/pretty-locations.mlir
+++ b/mlir/test/IR/pretty-locations.mlir
@@ -24,6 +24,9 @@ func.func @inline_notation() -> i32 {
   affine.if #set0(%2) {
   } loc(fused<"myPass">["foo", "foo2"])
 
+  // CHECK: "foo.op"() : () -> () #test.custom_location<"foo.mlir" * 1234>
+  "foo.op"() : () -> () loc(#test.custom_location<"foo.mlir" * 1234>)
+
   // CHECK: return %0 : i32 [unknown]
   return %1 : i32 loc(unknown)
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index b3b94bd0ffea31..907184b2e1ce4c 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -27,6 +27,8 @@ include "mlir/IR/OpAsmInterface.td"
 class Test_Attr<string name, list<Trait> traits = []>
     : AttrDef<Test_Dialect, name, traits>;
 
+class Test_LocAttr<string name> : LocationAttrDef<Test_Dialect, name, []>;
+
 def SimpleAttrA : Test_Attr<"SimpleA"> {
   let mnemonic = "smpla";
 }
@@ -377,4 +379,14 @@ def NestedPolynomialAttr2 : Test_Attr<"NestedPolynomialAttr2"> {
 }
 
 
+// Test custom location handling.
+def TestCustomLocationAttr : Test_LocAttr<"TestCustomLocation"> {
+  let mnemonic = "custom_location";
+  let parameters = (ins "mlir::StringAttr":$file, "unsigned":$line);
+
+  // Choose a silly separator token so we know it's hitting this code path
+  // and not another.
+  let assemblyFormat = "`<` $file `*` $line `>`";
+}
+
 #endif // TEST_ATTRDEFS

diff  --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 05cb36e1903163..547e536dd9cbbf 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_unittest(MLIRIRTests
   InterfaceTest.cpp
   IRMapping.cpp
   InterfaceAttachmentTest.cpp
+  LocationTest.cpp
   OperationSupportTest.cpp
   PatternMatchTest.cpp
   ShapedTypeTest.cpp

diff  --git a/mlir/unittests/IR/LocationTest.cpp b/mlir/unittests/IR/LocationTest.cpp
new file mode 100644
index 00000000000000..03374ee0b8467b
--- /dev/null
+++ b/mlir/unittests/IR/LocationTest.cpp
@@ -0,0 +1,53 @@
+//===- LocationTest.cpp - unit tests for affine map API -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Builders.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+// Check that we only walk *locations* and not non-location attributes.
+TEST(LocationTest, Walk) {
+  MLIRContext ctx;
+  Builder builder(&ctx);
+  BoolAttr trueAttr = builder.getBoolAttr(true);
+
+  Location loc1 = FileLineColLoc::get(builder.getStringAttr("foo"), 1, 2);
+  Location loc2 = FileLineColLoc::get(builder.getStringAttr("foo"), 3, 4);
+  Location fused = builder.getFusedLoc({loc1, loc2}, trueAttr);
+
+  SmallVector<Attribute> visited;
+  fused->walk([&](Location l) {
+    visited.push_back(LocationAttr(l));
+    return WalkResult::advance();
+  });
+
+  EXPECT_EQ(llvm::ArrayRef(visited), ArrayRef<Attribute>({fused, loc1, loc2}));
+}
+
+// Check that we skip location attrs nested under a non-location attr.
+TEST(LocationTest, SkipNested) {
+  MLIRContext ctx;
+  Builder builder(&ctx);
+
+  Location loc1 = FileLineColLoc::get(builder.getStringAttr("foo"), 1, 2);
+  Location loc2 = FileLineColLoc::get(builder.getStringAttr("foo"), 3, 4);
+  Location loc3 = FileLineColLoc::get(builder.getStringAttr("bar"), 1, 2);
+  Location loc4 = FileLineColLoc::get(builder.getStringAttr("bar"), 3, 4);
+  ArrayAttr arr = builder.getArrayAttr({loc3, loc4});
+  Location fused = builder.getFusedLoc({loc1, loc2}, arr);
+
+  SmallVector<Attribute> visited;
+  fused->walk([&](Location l) {
+    visited.push_back(LocationAttr(l));
+    return WalkResult::advance();
+  });
+
+  EXPECT_EQ(llvm::ArrayRef(visited), ArrayRef<Attribute>({fused, loc1, loc2}));
+}


        


More information about the Mlir-commits mailing list