[Mlir-commits] [mlir] [MLIR][DLTI] Introduce DLTIQueryInterface and impl for DLTI attrs (PR #104595)
Rolf Morel
llvmlistbot at llvm.org
Fri Aug 16 07:43:47 PDT 2024
https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/104595
>From ff58929868c75a43beccd1313ed42b9e05366d17 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 16 Aug 2024 05:59:11 -0700
Subject: [PATCH] [MLIR][DLTI] Introduce DLTIQueryInterface and impl for DLTI
attrs
This new interface is supposed to capture the core functionality
of DLTI: querying for values at keys. As such this new interface
unifies the ability to query DLTI attributes in a single method:
query(). All existing DLTI interfaces exposing their own query methods
now 1) now extend this new interface and 2) provide a default
implementation for `query()`.
As DLTIQueryInterface::query() returns an attribute, it naturally
enables recursive queries on nested DLTI attrs. A utility function,
`dlti::query()`, implements the logic for nested lookups.
A new `#dlti.map` attribute is introduced to capture the most generic
form of a finite DLTI-mapping. One of the benefits is that it allows
for more easily encoding hierachical information that is suitably
queryable, i.e. by means of nested attributes.
In line with the above, `transform.query.op` is modified so as to
take an arbitrary number of keys and to perform a nested lookup
using the above utility function.
---
mlir/include/mlir/Dialect/DLTI/DLTI.h | 11 ++
mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td | 90 +++++++---
mlir/include/mlir/Dialect/DLTI/DLTIBase.td | 9 +-
.../DLTI/TransformOps/DLTITransformOps.td | 40 +++--
.../Utils/DiagnosedSilenceableFailure.h | 7 +
.../mlir/Interfaces/DataLayoutInterfaces.h | 1 +
.../mlir/Interfaces/DataLayoutInterfaces.td | 60 ++++++-
mlir/lib/Dialect/DLTI/DLTI.cpp | 77 +++++++-
.../DLTI/TransformOps/DLTITransformOps.cpp | 35 +---
mlir/test/Dialect/DLTI/query.mlir | 167 +++++++++++++-----
.../Interfaces/DataLayoutInterfacesTest.cpp | 18 +-
11 files changed, 383 insertions(+), 132 deletions(-)
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTI.h b/mlir/include/mlir/Dialect/DLTI/DLTI.h
index 6e2623e084be52..235961aa97aa93 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTI.h
+++ b/mlir/include/mlir/Dialect/DLTI/DLTI.h
@@ -33,6 +33,17 @@ DataLayoutSpecInterface getDataLayoutSpec(Operation *op);
/// DataLayoutOpInterface, a method on ModuleOp, or an attribute implementing
/// the interface, on `op` and else on `op`'s ancestors in turn.
TargetSystemSpecInterface getTargetSystemSpec(Operation *op);
+
+/// Retrieve the first `DLTIQueryInterface`-implementing attribute that is
+/// attached to `op` or such an attr on as close as possible an ancestor. The
+/// op the attribute is attached to is returned as well.
+std::pair<DLTIQueryInterface, Operation *> getClosestQueryable(Operation *op);
+
+/// Perform a DLTI-query at `op`, recursively querying each key of `keys` on
+/// `DLTIQueryInterface`-implementing attrs, the first of which is obtained from
+/// `op`. When provided, (nested) lookup failure notes are attached to `diag`.
+FailureOr<Attribute> query(Operation *op, ArrayRef<StringAttr> keys,
+ InFlightDiagnostic *diag = nullptr);
} // namespace dlti
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
index 443e3128b4acb3..9c82e2cf111750 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_DLTI_DLTIATTRS_TD
include "mlir/Dialect/DLTI/DLTI.td"
+include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
class DLTIAttr<string name, list<Trait> traits = [],
@@ -20,13 +21,8 @@ class DLTIAttr<string name, list<Trait> traits = [],
// DataLayoutEntryAttr
//===----------------------------------------------------------------------===//
-def DataLayoutEntryTrait
- : NativeAttrTrait<"DataLayoutEntryInterface::Trait"> {
- let cppNamespace = "::mlir";
-}
-
def DLTI_DataLayoutEntryAttr :
- DLTIAttr<"DataLayoutEntry", [DataLayoutEntryTrait]> {
+ DLTIAttr<"DataLayoutEntry", [DataLayoutEntryInterface]> {
let summary = "An attribute to represent an entry of a data layout specification.";
let description = [{
A data layout entry attribute is a key-value pair where the key is a type or
@@ -53,13 +49,9 @@ def DLTI_DataLayoutEntryAttr :
//===----------------------------------------------------------------------===//
// DataLayoutSpecAttr
//===----------------------------------------------------------------------===//
-def DataLayoutSpecTrait
- : NativeAttrTrait<"DataLayoutSpecInterface::Trait"> {
- let cppNamespace = "::mlir";
-}
def DLTI_DataLayoutSpecAttr :
- DLTIAttr<"DataLayoutSpec", [DataLayoutSpecTrait]> {
+ DLTIAttr<"DataLayoutSpec", [DataLayoutSpecInterface]> {
let summary = "An attribute to represent a data layout specification.";
let description = [{
A data layout specification is a list of entries that specify (partial) data
@@ -78,7 +70,7 @@ def DLTI_DataLayoutSpecAttr :
/// same key as the newer entries if the entries are compatible. Returns null
/// if the specifications are not compatible.
DataLayoutSpecAttr combineWith(ArrayRef<DataLayoutSpecInterface> specs) const;
-
+
/// Returns the endiannes identifier.
StringAttr getEndiannessIdentifier(MLIRContext *context) const;
@@ -93,6 +85,54 @@ def DLTI_DataLayoutSpecAttr :
/// Returns the stack alignment identifier.
StringAttr getStackAlignmentIdentifier(MLIRContext *context) const;
+
+ /// Returns the attribute associated with the key.
+ FailureOr<Attribute> query(DataLayoutEntryKey key) {
+ return llvm::cast<mlir::DataLayoutSpecInterface>(*this).queryHelper(key);
+ }
+ }];
+}
+
+def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
+ let summary = "A mapping of DLTI-information by way of key-value pairs";
+ let description = [{
+ A data layout and target information map is a list of entries is
+ effectively a dictionary mapping DLTI-related keys to DLTI-related values.
+
+ Its main purpose is to facilate querying IR for arbitrary DLTI-related
+ key-value associations. Note that facility functions exist to perform
+ nested lookups on nested DLTI map attributes.
+
+ Consider the following shallow usage of a DLTI-map
+ ```
+ #dlti.map<#dlti.dl_entry<"CPU::cache::L1::size_in_bytes", 65536 : i32>>
+ ```
+ versus nested maps, which make it possible to obtain sub-dictionaries of
+ related information (with the following example making use of other
+ attributes that also implement the `DLTIQueryInterface`):
+ ```
+ #dlti.target_system_spec<"CPU":
+ #dlti.target_device_spec<#dlti.dl_entry<"cache",
+ #dlti.map<#dlti.dl_entry<"L1",
+ #dlti.map<#dlti.dl_entry<"size_in_bytes", 65536 : i32>>,
+ #dlti.dl_entry<"L1d",
+ #dlti.map<#dlti.dl_entry<"size_in_bytes", 32768 : i32>> >>>>
+ ```
+ }];
+ let parameters = (ins
+ ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
+ );
+ let mnemonic = "map";
+ let genVerifyDecl = 1;
+ let assemblyFormat = "`<` $entries `>`";
+ let extraClassDeclaration = [{
+ /// Returns the attribute associated with the key.
+ FailureOr<Attribute> query(DataLayoutEntryKey key) {
+ for (DataLayoutEntryInterface entry : getEntries())
+ if (entry.getKey() == key)
+ return entry.getValue();
+ return ::mlir::failure();
+ }
}];
}
@@ -100,13 +140,8 @@ def DLTI_DataLayoutSpecAttr :
// TargetSystemSpecAttr
//===----------------------------------------------------------------------===//
-def TargetSystemSpecTrait
- : NativeAttrTrait<"TargetSystemSpecInterface::Trait"> {
- let cppNamespace = "::mlir";
-}
-
def DLTI_TargetSystemSpecAttr :
- DLTIAttr<"TargetSystemSpec", [TargetSystemSpecTrait]> {
+ DLTIAttr<"TargetSystemSpec", [TargetSystemSpecInterface]> {
let summary = "An attribute to represent target system specification.";
let description = [{
A system specification describes the overall system containing
@@ -136,6 +171,11 @@ def DLTI_TargetSystemSpecAttr :
std::optional<TargetDeviceSpecInterface>
getDeviceSpecForDeviceID(
TargetSystemSpecInterface::DeviceID deviceID);
+
+ /// Returns the attribute associated with the key.
+ FailureOr<Attribute> query(DataLayoutEntryKey key) const {
+ return llvm::cast<mlir::TargetSystemSpecInterface>(*this).queryHelper(key);
+ }
}];
let extraClassDefinition = [{
std::optional<TargetDeviceSpecInterface>
@@ -154,13 +194,8 @@ def DLTI_TargetSystemSpecAttr :
// TargetDeviceSpecAttr
//===----------------------------------------------------------------------===//
-def TargetDeviceSpecTrait
- : NativeAttrTrait<"TargetDeviceSpecInterface::Trait"> {
- let cppNamespace = "::mlir";
-}
-
def DLTI_TargetDeviceSpecAttr :
- DLTIAttr<"TargetDeviceSpec", [TargetDeviceSpecTrait]> {
+ DLTIAttr<"TargetDeviceSpec", [TargetDeviceSpecInterface]> {
let summary = "An attribute to represent target device specification.";
let description = [{
Each device specification describes a single device and its
@@ -179,6 +214,13 @@ def DLTI_TargetDeviceSpecAttr :
let mnemonic = "target_device_spec";
let genVerifyDecl = 1;
let assemblyFormat = "`<` $entries `>`";
+
+ let extraClassDeclaration = [{
+ /// Returns the attribute associated with the key.
+ FailureOr<Attribute> query(DataLayoutEntryKey key) const {
+ return llvm::cast<mlir::TargetDeviceSpecInterface>(*this).queryHelper(key);
+ }
+ }];
}
#endif // MLIR_DIALECT_DLTI_DLTIATTRS_TD
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
index e26fbdb146645c..f84149c43e0fcd 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
@@ -23,14 +23,19 @@ def DLTI_Dialect : Dialect {
}];
let extraClassDeclaration = [{
- // Top level attribute name.
+ // Top-level attribute name for arbitrary description.
+ constexpr const static ::llvm::StringLiteral
+ kMapAttrName = "dlti.map";
+
+ // Top-level attribute name for data layout description.
constexpr const static ::llvm::StringLiteral
kDataLayoutAttrName = "dlti.dl_spec";
- // Top level attribute name for target system description
+ // Top-level attribute name for target system description.
constexpr const static ::llvm::StringLiteral
kTargetSystemDescAttrName = "dlti.target_system_spec";
+ // Top-level attribute name for target device description.
constexpr const static ::llvm::StringLiteral
kTargetDeviceDescAttrName = "dlti.target_device_spec";
diff --git a/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td b/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
index 69aacac986ad73..1b1bebfaab4e38 100644
--- a/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
+++ b/mlir/include/mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td
@@ -22,32 +22,40 @@ def QueryOp : Op<Transform_Dialect, "dlti.query", [
let summary = "Return attribute (as param) associated to key via DTLI";
let description = [{
This op queries data layout and target information associated to payload
- IR by way of the DLTI dialect. A lookup is performed for the given `key`
- at the `target` op, with the DLTI dialect determining which interfaces and
- attributes are consulted - first checking `target` and then its ancestors.
+ IR by way of the DLTI dialect.
- When only `key` is provided, the lookup occurs with respect to the data
- layout specification of DLTI. When `device` is provided, the lookup occurs
- with respect to DLTI's target device specifications associated to a DLTI
- system device specification.
+ A lookup is performed for the given `keys` at `target` op - or its closest
+ interface-implementing ancestor - by way of the `DLTIQueryInterface`, which
+ returns an attribute for a key. If more than one key is provided, the lookup
+ continues recursively, now on the returned attributes, with the condition
+ that these implement the above interface. For example if the payload IR is
+
+ ```
+ module attributes {#dlti.map = #dlti.map<#dlti.dl_entry<"A",
+ #dlti.map<#dlti.dl_entry<"B", 42: int>>>} {
+ func.func private @f()
+ }
+ ```
+ and we have that `%func` is a Tranform handle to op `@f`, then
+ `transform.dlti.query ["A", "B"] at %func` returns 42 as a param and
+ `transform.dlti.query ["A"] at %func` returns the `#dlti.map` attribute
+ containing just the key "B" and its value. Using `["B"]` or `["A","C"]` as
+ `keys` will yield an error.
#### Return modes
- When succesful, the result, `associated_attr`, associates one attribute as a
- param for each op in `target`'s payload.
+ When successful, the result, `associated_attr`, associates one attribute as
+ a param for each op in `target`'s payload.
- If the lookup fails - as DLTI specifications or entries with the right
- names are missing (i.e. the values of `device` and `key`) - a definite
- failure is returned.
+ If the lookup fails - as no DLTI attributes/interfaces are found or entries
+ with the right names are missing - a silenceable failure is returned.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- OptionalAttr<StrAttr>:$device,
- StrAttr:$key);
+ StrArrayAttr:$keys);
let results = (outs TransformParamTypeInterface:$associated_attr);
let assemblyFormat =
- "(`:``:` $device^ `:``:`)? $key `at` $target attr-dict `:`"
- "functional-type(operands, results)";
+ "$keys `at` $target attr-dict `:` functional-type(operands, results)";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
diff --git a/mlir/include/mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h b/mlir/include/mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h
index fcf422a0b6aa34..917c3826b24465 100644
--- a/mlir/include/mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h
+++ b/mlir/include/mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h
@@ -66,6 +66,13 @@ class [[nodiscard]] DiagnosedSilenceableFailure {
return DiagnosedSilenceableFailure(
std::forward<SmallVector<Diagnostic>>(diag));
}
+ static DiagnosedSilenceableFailure
+ silenceableFailure(InFlightDiagnostic &&diag) {
+ auto consumingFailure = DiagnosedSilenceableFailure(
+ std::forward<Diagnostic>(*diag.getUnderlyingDiagnostic()));
+ diag.abandon(); // consumingFailure takes responsibility for diag's message.
+ return consumingFailure;
+ }
/// Converts all kinds of failure into a LogicalResult failure, emitting the
/// diagnostic if necessary. Must not be called more than once.
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
index ab65f92820a6a8..848d2dee4a6309 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
@@ -23,6 +23,7 @@
namespace mlir {
class DataLayout;
class DataLayoutEntryInterface;
+class DLTIQueryInterface;
class TargetDeviceSpecInterface;
class TargetSystemSpecInterface;
using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index bc5080c9c6a558..d6e955be4291a3 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -20,6 +20,29 @@ include "mlir/IR/OpBase.td"
// Attribute interfaces
//===----------------------------------------------------------------------===//
+def DLTIQueryInterface : AttrInterface<"DLTIQueryInterface"> {
+ let cppNamespace = "::mlir";
+
+ let description = [{
+ Attribute interface exposing querying-mechanism for key-value associations.
+
+ The central feature of DLTI attributes is to allow looking up values at
+ keys. This interface represent the core functionality to do so - as such
+ most DLTI attributes should be implementing this interface.
+
+ Note that as the `query` method returns an attribute, this attribute can
+ be recursively queried when it also implements this interface.
+ }];
+ let methods = [
+ InterfaceMethod<
+ /*description=*/"Returns the attribute associated with the key.",
+ /*retTy=*/"::mlir::FailureOr<::mlir::Attribute>",
+ /*methodName=*/"query",
+ /*args=*/(ins "::mlir::DataLayoutEntryKey":$key)
+ >
+ ];
+}
+
def DataLayoutEntryInterface : AttrInterface<"DataLayoutEntryInterface"> {
let cppNamespace = "::mlir";
@@ -68,7 +91,7 @@ def DataLayoutEntryInterface : AttrInterface<"DataLayoutEntryInterface"> {
}];
}
-def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface"> {
+def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface", [DLTIQueryInterface]> {
let cppNamespace = "::mlir";
let description = [{
@@ -173,7 +196,7 @@ def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface"> {
/*defaultImplementation=*/[{
return ::mlir::detail::verifyDataLayoutSpec($_attr, loc);
}]
- >,
+ >
];
let extraClassDeclaration = [{
@@ -184,6 +207,15 @@ def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface"> {
return getSpecForType(TypeID::get<Ty>());
}
+ /// Helper for default implementation of `DLTIQueryInterface`'s `query`.
+ inline ::mlir::FailureOr<::mlir::Attribute>
+ queryHelper(::mlir::DataLayoutEntryKey key) const {
+ for (DataLayoutEntryInterface entry : getEntries())
+ if (entry.getKey() == key)
+ return entry.getValue();
+ return ::mlir::failure();
+ }
+
/// Populates the given maps with lists of entries grouped by the type or
/// identifier they are associated with. Users are not expected to call this
/// method directly.
@@ -194,7 +226,7 @@ def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface"> {
}];
}
-def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface"> {
+def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface", [DLTIQueryInterface]> {
let cppNamespace = "::mlir";
let description = [{
@@ -239,9 +271,20 @@ def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface"> {
/*defaultImplementation=*/[{ return ::mlir::success(); }]
>
];
+
+ let extraClassDeclaration = [{
+ /// Helper for default implementation of `DLTIQueryInterface`'s `query`.
+ ::mlir::FailureOr<::mlir::Attribute>
+ queryHelper(::mlir::DataLayoutEntryKey key) const {
+ if (auto strKey = llvm::dyn_cast<StringAttr>(key))
+ if (DataLayoutEntryInterface spec = getSpecForIdentifier(strKey))
+ return spec.getValue();
+ return ::mlir::failure();
+ }
+ }];
}
-def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface"> {
+def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface", [DLTIQueryInterface]> {
let cppNamespace = "::mlir";
let description = [{
@@ -287,6 +330,15 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface"> {
let extraClassDeclaration = [{
using DeviceID = StringAttr;
+
+ /// Helper for default implementation of `DLTIQueryInterface`'s `query`.
+ ::mlir::FailureOr<::mlir::Attribute>
+ queryHelper(::mlir::DataLayoutEntryKey key) const {
+ if (auto strKey = llvm::dyn_cast<::mlir::StringAttr>(key))
+ if (auto deviceSpec = getDeviceSpecForDeviceID(strKey))
+ return *deviceSpec;
+ return ::mlir::failure();
+ }
}];
}
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index c995f44c380e57..852c220c189a1d 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -109,12 +109,11 @@ void DataLayoutEntryAttr::print(AsmPrinter &os) const {
}
//===----------------------------------------------------------------------===//
-// DataLayoutSpecAttr
+// DLTIMapAttr
//===----------------------------------------------------------------------===//
-LogicalResult
-DataLayoutSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<DataLayoutEntryInterface> entries) {
+static LogicalResult verifyEntries(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<DataLayoutEntryInterface> entries) {
DenseSet<Type> types;
DenseSet<StringAttr> ids;
for (DataLayoutEntryInterface entry : entries) {
@@ -130,6 +129,21 @@ DataLayoutSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+LogicalResult MapAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<DataLayoutEntryInterface> entries) {
+ return verifyEntries(emitError, entries);
+}
+
+//===----------------------------------------------------------------------===//
+// DataLayoutSpecAttr
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+DataLayoutSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<DataLayoutEntryInterface> entries) {
+ return verifyEntries(emitError, entries);
+}
+
/// Given a list of old and a list of new entries, overwrites old entries with
/// new ones if they have matching keys, appends new entries to the old entry
/// list otherwise.
@@ -393,6 +407,49 @@ TargetSystemSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
// DLTIDialect
//===----------------------------------------------------------------------===//
+std::pair<DLTIQueryInterface, Operation *>
+dlti::getClosestQueryable(Operation *op) {
+ DLTIQueryInterface queryable = {};
+
+ // Search op and its ancestors for the first attached DLTIQueryInterface attr.
+ do {
+ for (NamedAttribute attr : op->getAttrs())
+ if ((queryable = llvm::dyn_cast<DLTIQueryInterface>(attr.getValue())))
+ break;
+ } while (!queryable && (op = op->getParentOp()));
+
+ return std::pair(queryable, op);
+}
+
+FailureOr<Attribute> dlti::query(Operation *op, ArrayRef<StringAttr> keys,
+ InFlightDiagnostic *diag) {
+ auto [queryable, queryOp] = dlti::getClosestQueryable(op);
+
+#define FAIL(message) \
+ (diag ? ((diag->attachNote(op->getLoc()) << "target op for DLTI query"), \
+ (diag->attachNote((queryOp ? queryOp : op)->getLoc()) << message)) \
+ : failure())
+
+ if (!queryable)
+ return FAIL("no DLTI-queryable attrs on target op or any of its ancestors");
+
+ Attribute currentAttr = queryable;
+ for (auto &&[idx, key] : llvm::enumerate(keys)) {
+ if (auto map = llvm::dyn_cast<DLTIQueryInterface>(currentAttr)) {
+ auto maybeAttr = map.query(key);
+ if (failed(maybeAttr))
+ return FAIL("key " << key << " has no DLTI-mapping per attr: " << map);
+ currentAttr = *maybeAttr;
+ } else {
+ return FAIL("got non-DLTI-queryable attribute upon looking up keys ["
+ << keys.take_front(idx) << "] at op");
+ }
+ }
+
+ return currentAttr;
+#undef FAIL
+}
+
DataLayoutSpecInterface dlti::getDataLayoutSpec(Operation *op) {
DataLayoutSpecInterface dlSpec = nullptr;
@@ -480,7 +537,9 @@ LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
if (isa<ModuleOp>(op))
return detail::verifyDataLayoutOp(op);
return success();
- } else if (attr.getName() == DLTIDialect::kTargetSystemDescAttrName) {
+ }
+
+ if (attr.getName() == DLTIDialect::kTargetSystemDescAttrName) {
if (!llvm::isa<TargetSystemSpecAttr>(attr.getValue())) {
return op->emitError()
<< "'" << DLTIDialect::kTargetSystemDescAttrName
@@ -489,6 +548,14 @@ LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
return success();
}
+ if (attr.getName() == DLTIDialect::kMapAttrName) {
+ if (!llvm::isa<MapAttr>(attr.getValue())) {
+ return op->emitError() << "'" << DLTIDialect::kMapAttrName
+ << "' is expected to be a #dlti.map attribute";
+ }
+ return success();
+ }
+
return op->emitError() << "attribute '" << attr.getName().getValue()
<< "' not supported by dialect";
}
diff --git a/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp b/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
index 032228b2f05ca2..ef22a79d287ac8 100644
--- a/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
+++ b/mlir/lib/Dialect/DLTI/TransformOps/DLTITransformOps.cpp
@@ -33,36 +33,17 @@ void transform::QueryOp::getEffects(
DiagnosedSilenceableFailure transform::QueryOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results, TransformState &state) {
- StringAttr deviceId = getDeviceAttr();
- StringAttr key = getKeyAttr();
-
- DataLayoutEntryInterface entry;
- if (deviceId) {
- TargetSystemSpecInterface sysSpec = dlti::getTargetSystemSpec(target);
- if (!sysSpec)
- return mlir::emitDefiniteFailure(target->getLoc())
- << "no target system spec associated to: " << target;
-
- if (auto targetSpec = sysSpec.getDeviceSpecForDeviceID(deviceId))
- entry = targetSpec->getSpecForIdentifier(key);
- else
- return mlir::emitDefiniteFailure(target->getLoc())
- << "no " << deviceId << " target device spec found";
- } else {
- DataLayoutSpecInterface dlSpec = dlti::getDataLayoutSpec(target);
- if (!dlSpec)
- return mlir::emitDefiniteFailure(target->getLoc())
- << "no data layout spec associated to: " << target;
-
- entry = dlSpec.getSpecForIdentifier(key);
- }
+ auto keys = SmallVector<StringAttr>(getKeys().getAsRange<StringAttr>());
+ InFlightDiagnostic inflight = emitOpError("failed to apply");
- if (!entry)
- return mlir::emitDefiniteFailure(target->getLoc())
- << "no DLTI entry for key: " << key;
+ // Explanatory notes get attached to `inflight` when the query fails.
+ FailureOr<Attribute> result = dlti::query(target, keys, &inflight);
- results.push_back(entry.getValue());
+ if (failed(result))
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(inflight));
+ inflight.abandon();
+ results.push_back(*result);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Dialect/DLTI/query.mlir b/mlir/test/Dialect/DLTI/query.mlir
index 2504958de09289..d11ababe922ff9 100644
--- a/mlir/test/Dialect/DLTI/query.mlir
+++ b/mlir/test/Dialect/DLTI/query.mlir
@@ -1,15 +1,32 @@
// RUN: mlir-opt -transform-interpreter -canonicalize -split-input-file -verify-diagnostics %s | FileCheck %s
+// expected-remark @below {{associated attr 42 : i32}}
+module attributes { test.dlti = #dlti.map<#dlti.dl_entry<"test.id", 42 : i32>>} {
+ func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+ %module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
+ %param = transform.dlti.query ["test.id"] at %module : (!transform.any_op) -> !transform.any_param
+ transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// expected-remark @below {{associated attr 42 : i32}}
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
- func.func private @f()
+ func.func private @f()
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
%module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
- %param = transform.dlti.query "test.id" at %module : (!transform.any_op) -> !transform.any_param
+ %param = transform.dlti.query ["test.id"] at %module : (!transform.any_op) -> !transform.any_param
transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
transform.yield
}
@@ -25,7 +42,7 @@ module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
- %param = transform.dlti.query "test.id" at %funcs : (!transform.any_op) -> !transform.any_param
+ %param = transform.dlti.query ["test.id"] at %funcs : (!transform.any_op) -> !transform.any_param
transform.debug.emit_param_as_remark %param, "associated attr" at %funcs : !transform.any_param, !transform.any_op
transform.yield
}
@@ -42,7 +59,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
%module = transform.get_parent_op %funcs : (!transform.any_op) -> !transform.any_op
- %param = transform.dlti.query "test.id" at %module : (!transform.any_op) -> !transform.any_param
+ %param = transform.dlti.query ["test.id"] at %module : (!transform.any_op) -> !transform.any_param
transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
transform.yield
}
@@ -65,7 +82,7 @@ module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg : (!transform.any_op) -> !transform.any_op
- %param = transform.dlti.query "test.id" at %matmul : (!transform.any_op) -> !transform.any_param
+ %param = transform.dlti.query ["test.id"] at %matmul : (!transform.any_op) -> !transform.any_param
transform.debug.emit_param_as_remark %param, "associated attr" at %matmul : !transform.any_param, !transform.any_op
transform.yield
}
@@ -88,7 +105,7 @@ module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg : (!transform.any_op) -> !transform.any_op
- %param = transform.dlti.query "test.id" at %matmul : (!transform.any_op) -> !transform.any_param
+ %param = transform.dlti.query ["test.id"] at %matmul : (!transform.any_op) -> !transform.any_param
transform.debug.emit_param_as_remark %param, "associated attr" at %matmul : !transform.any_param, !transform.any_op
transform.yield
}
@@ -97,7 +114,9 @@ module attributes {transform.with_named_sequence} {
// -----
// expected-remark @below {{associated attr 42 : i32}}
-module attributes { test.dlti = #dlti.target_system_spec<"CPU": #dlti.target_device_spec<#dlti.dl_entry<"test.id", 42 : i32>>>} {
+module attributes { test.dlti =
+ #dlti.target_system_spec<"CPU":
+ #dlti.target_device_spec<#dlti.dl_entry<"test.id", 42 : i32>>>} {
func.func private @f()
}
@@ -105,7 +124,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
%module = transform.get_parent_op %func : (!transform.any_op) -> !transform.any_op
- %param = transform.dlti.query ::"CPU"::"test.id" at %module : (!transform.any_op) -> !transform.any_param
+ %param = transform.dlti.query ["CPU","test.id"] at %module : (!transform.any_op) -> !transform.any_param
transform.debug.emit_param_as_remark %param, "associated attr" at %module : !transform.any_param, !transform.any_op
transform.yield
}
@@ -116,13 +135,13 @@ module attributes {transform.with_named_sequence} {
module attributes { test.dlti = #dlti.target_system_spec<"CPU": #dlti.target_device_spec<#dlti.dl_entry<"test.id", 42 : i32>>,
"GPU": #dlti.target_device_spec<#dlti.dl_entry<"test.id", 43 : i32>>>} {
// expected-remark @below {{associated attr 43 : i32}}
- func.func private @f()
+ func.func private @f()
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
- %param = transform.dlti.query ::"GPU"::"test.id" at %func : (!transform.any_op) -> !transform.any_param
+ %param = transform.dlti.query ["GPU","test.id"] at %func : (!transform.any_op) -> !transform.any_param
transform.debug.emit_param_as_remark %param, "associated attr" at %func : !transform.any_param, !transform.any_op
transform.yield
}
@@ -139,7 +158,7 @@ module attributes { test.dlti = #dlti.target_system_spec<"CPU": #dlti.target_dev
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
- %param = transform.dlti.query ::"CPU"::"test.id" at %func : (!transform.any_op) -> !transform.any_param
+ %param = transform.dlti.query ["CPU","test.id"] at %func : (!transform.any_op) -> !transform.any_param
transform.debug.emit_param_as_remark %param, "associated attr" at %func : !transform.any_param, !transform.any_op
transform.yield
}
@@ -147,8 +166,35 @@ module attributes {transform.with_named_sequence} {
// -----
-module attributes { test.dlti = #dlti.target_system_spec<"CPU": #dlti.target_device_spec<#dlti.dl_entry<"cache::L1::size_in_bytes", 65536 : i32>,
- #dlti.dl_entry<"cache::L1d::size_in_bytes", 32768 : i32>>> } {
+module attributes { test.dlti = #dlti.target_system_spec<
+ "CPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"cache::L1::size_in_bytes", 65536 : i32>,
+ #dlti.dl_entry<"cache::L1d::size_in_bytes", 32768 : i32>>> } {
+ // expected-remark @below {{L1::size_in_bytes 65536 : i32}}
+ // expected-remark @below {{L1d::size_in_bytes 32768 : i32}}
+ func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+ %l1_size = transform.dlti.query ["CPU","cache::L1::size_in_bytes"] at %func : (!transform.any_op) -> !transform.param<i32>
+ %l1d_size = transform.dlti.query ["CPU","cache::L1d::size_in_bytes"] at %func : (!transform.any_op) -> !transform.param<i32>
+ transform.debug.emit_param_as_remark %l1_size, "L1::size_in_bytes" at %func : !transform.param<i32>, !transform.any_op
+ transform.debug.emit_param_as_remark %l1d_size, "L1d::size_in_bytes" at %func : !transform.param<i32>, !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#l1_size = #dlti.map<#dlti.dl_entry<"size_in_bytes", 65536 : i32>>
+#l1d_size = #dlti.map<#dlti.dl_entry<"size_in_bytes", 32768 : i32>>
+module attributes { test.dlti =
+ #dlti.target_system_spec<"CPU":
+ #dlti.target_device_spec<#dlti.dl_entry<"cache",
+ #dlti.map<#dlti.dl_entry<"L1", #l1_size>,
+ #dlti.dl_entry<"L1d", #l1d_size> >>>> } {
// expected-remark @below {{L1::size_in_bytes 65536 : i32}}
// expected-remark @below {{L1d::size_in_bytes 32768 : i32}}
func.func private @f()
@@ -157,8 +203,8 @@ module attributes { test.dlti = #dlti.target_system_spec<"CPU": #dlti.target_dev
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
- %l1_size = transform.dlti.query ::"CPU"::"cache::L1::size_in_bytes" at %func : (!transform.any_op) -> !transform.param<i32>
- %l1d_size = transform.dlti.query ::"CPU"::"cache::L1d::size_in_bytes" at %func : (!transform.any_op) -> !transform.param<i32>
+ %l1_size = transform.dlti.query ["CPU","cache","L1","size_in_bytes"] at %func : (!transform.any_op) -> !transform.param<i32>
+ %l1d_size = transform.dlti.query ["CPU","cache","L1d","size_in_bytes"] at %func : (!transform.any_op) -> !transform.param<i32>
transform.debug.emit_param_as_remark %l1_size, "L1::size_in_bytes" at %func : !transform.param<i32>, !transform.any_op
transform.debug.emit_param_as_remark %l1d_size, "L1d::size_in_bytes" at %func : !transform.param<i32>, !transform.any_op
transform.yield
@@ -167,15 +213,16 @@ module attributes {transform.with_named_sequence} {
// -----
-module attributes { test.dlti = #dlti.target_system_spec<"CPU":
- #dlti.target_device_spec<#dlti.dl_entry<"inner_most_tile_size", 42 : i32>>>} {
+module attributes { test.dlti = #dlti.target_system_spec<
+ "CPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"inner_most_tile_size", 42 : i32>>>} {
// CHECK-LABEL: func @matmul_tensors
func.func @matmul_tensors(
%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
-> tensor<?x?xf32> {
// CHECK: scf.for {{.*}} to {{.*}} step {{.*}}42
// CHECK: tensor.extract_slice
- // CHECK: linalg.matmul
+ // CHECK: linalg.matmul
// CHECK: tensor.insert_slice
// CHECK: scf.yield
%0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
@@ -190,7 +237,7 @@ module attributes { test.dlti = #dlti.target_system_spec<"CPU":
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg : (!transform.any_op) -> !transform.any_op
- %tile_size = transform.dlti.query ::"CPU"::"inner_most_tile_size" at %matmul : (!transform.any_op) -> !transform.param<i32>
+ %tile_size = transform.dlti.query ["CPU","inner_most_tile_size"] at %matmul : (!transform.any_op) -> !transform.param<i32>
transform.structured.tile_using_for %matmul tile_sizes [%tile_size] : (!transform.any_op, !transform.param<i32>) -> (!transform.any_op, !transform.any_op)
transform.yield
}
@@ -198,67 +245,91 @@ module attributes {transform.with_named_sequence} {
// -----
-module attributes { test.dlti = #dlti.target_system_spec<"CPU": #dlti.target_device_spec<#dlti.dl_entry<"test.id", 42 : i32>>,
- "GPU": #dlti.target_device_spec<#dlti.dl_entry<"test.id", 43 : i32>>>} {
- // expected-error @below {{no "NPU" target device spec found}}
- func.func private @f()
+// expected-note @below {{key "NPU" has no DLTI-mapping per attr: #dlti.target_system_spec}}
+module attributes { test.dlti = #dlti.target_system_spec<
+ "CPU": #dlti.target_device_spec<#dlti.dl_entry<"test.id", 42 : i32>>,
+ "GPU": #dlti.target_device_spec<#dlti.dl_entry<"test.id", 43 : i32>>>} {
+ // expected-note @below {{target op for DLTI query}}
+ func.func private @f()
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
- %param = transform.dlti.query ::"NPU"::"test.id" at %func : (!transform.any_op) -> !transform.any_param
- transform.debug.emit_param_as_remark %param, "associated attr" at %func : !transform.any_param, !transform.any_op
+ // expected-error @below {{'transform.dlti.query' op failed to apply}}
+ %param = transform.dlti.query ["NPU","test.id"] at %func : (!transform.any_op) -> !transform.any_param
transform.yield
}
}
// -----
-module attributes { test.dlti = #dlti.target_system_spec<"CPU": #dlti.target_device_spec<#dlti.dl_entry<"test.id", 42 : i32>>,
- "GPU": #dlti.target_device_spec<#dlti.dl_entry<"test.id", 43 : i32>>>} {
- // expected-error @below {{no DLTI entry for key: "unspecified"}}
- func.func private @f()
+// expected-note @below {{key "unspecified" has no DLTI-mapping per attr: #dlti.target_device_spec}}
+module attributes { test.dlti = #dlti.target_system_spec<
+ "CPU": #dlti.target_device_spec<#dlti.dl_entry<"test.id", 42 : i32>>,
+ "GPU": #dlti.target_device_spec<#dlti.dl_entry<"test.id", 43 : i32>>>} {
+ // expected-note @below {{target op for DLTI query}}
+ func.func private @f()
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
- %param = transform.dlti.query ::"CPU"::"unspecified" at %func : (!transform.any_op) -> !transform.any_param
- transform.debug.emit_param_as_remark %param, "associated attr" at %func : !transform.any_param, !transform.any_op
+ // expected-error @below {{'transform.dlti.query' op failed to apply}}
+ %param = transform.dlti.query ["CPU","unspecified"] at %func : (!transform.any_op) -> !transform.any_param
transform.yield
}
}
// -----
-module attributes { test.dlti = #dlti.target_system_spec<"CPU": #dlti.target_device_spec<#dlti.dl_entry<"test.id", 42 : i32>>,
- "GPU": #dlti.target_device_spec<#dlti.dl_entry<"test.id", 43 : i32>>>} {
- // expected-error @below {{no data layout spec associated to: }}
- func.func private @f()
+// expected-note @below {{key "test.id" has no DLTI-mapping per attr: #dlti.target_system_spec}}
+module attributes { test.dlti = #dlti.target_system_spec<
+ "CPU": #dlti.target_device_spec<#dlti.dl_entry<"test.id", 42 : i32>>,
+ "GPU": #dlti.target_device_spec<#dlti.dl_entry<"test.id", 43 : i32>>>} {
+ // expected-note @below {{target op for DLTI query}}
+ func.func private @f()
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
- %param = transform.dlti.query "test.id" at %func : (!transform.any_op) -> !transform.any_param
- transform.debug.emit_param_as_remark %param, "associated attr" at %func : !transform.any_param, !transform.any_op
+ // expected-error @below {{'transform.dlti.query' op failed to apply}}
+ %param = transform.dlti.query ["test.id"] at %func : (!transform.any_op) -> !transform.any_param
transform.yield
}
}
// -----
+// expected-note @below {{key "CPU" has no DLTI-mapping per attr: #dlti.dl_spec}}
module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"test.id", 42 : i32>>} {
- // expected-error @below {{no target system spec associated to: }}
- func.func private @f()
+ // expected-note @below {{target op for DLTI query}}
+ func.func private @f()
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
- %param = transform.dlti.query ::"CPU"::"test.id" at %func : (!transform.any_op) -> !transform.any_param
- transform.debug.emit_param_as_remark %param, "associated attr" at %func : !transform.any_param, !transform.any_op
+ // expected-error @below {{'transform.dlti.query' op failed to apply}}
+ %param = transform.dlti.query ["CPU","test.id"] at %func : (!transform.any_op) -> !transform.any_param
+ transform.yield
+ }
+}
+
+// -----
+
+// expected-note @below {{got non-DLTI-queryable attribute upon looking up keys ["CPU"]}}
+module attributes { test.dlti = #dlti.dl_spec<#dlti.dl_entry<"CPU", 42 : i32>>} {
+ // expected-note @below {{target op for DLTI query}}
+ func.func private @f()
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg: !transform.any_op) {
+ %func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{'transform.dlti.query' op failed to apply}}
+ %param = transform.dlti.query ["CPU","test.id"] at %func : (!transform.any_op) -> !transform.any_param
transform.yield
}
}
@@ -266,15 +337,16 @@ module attributes {transform.with_named_sequence} {
// -----
module {
- // expected-error @below {{no target system spec associated to: }}
- func.func private @f()
+ // expected-note @below {{target op for DLTI query}}
+ // expected-note @below {{no DLTI-queryable attrs on target op or any of its ancestors}}
+ func.func private @f()
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%func = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
- %param = transform.dlti.query ::"CPU"::"test.id" at %func : (!transform.any_op) -> !transform.any_param
- transform.debug.emit_param_as_remark %param, "associated attr" at %func : !transform.any_param, !transform.any_op
+ // expected-error @below {{'transform.dlti.query' op failed to apply}}
+ %param = transform.dlti.query ["CPU","test.id"] at %func : (!transform.any_op) -> !transform.any_param
transform.yield
}
}
@@ -289,8 +361,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg: !transform.any_op) {
%funcs = transform.structured.match ops{["func.func"]} in %arg : (!transform.any_op) -> !transform.any_op
// expected-error @below {{expected the type of the parameter attribute ('i32') to match the parameter type ('i64')}}
- %param = transform.dlti.query "test.id" at %funcs : (!transform.any_op) -> !transform.param<i64>
- transform.debug.emit_param_as_remark %param, "associated attr" at %funcs : !transform.param<i64>, !transform.any_op
+ %param = transform.dlti.query ["test.id"] at %funcs : (!transform.any_op) -> !transform.param<i64>
transform.yield
}
-}
+}
\ No newline at end of file
diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
index d1227b045d4ed3..b667785c16f162 100644
--- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
@@ -58,9 +58,9 @@ class DataLayoutSpecStorage : public AttributeStorage {
/// Simple data layout spec containing a list of entries that always verifies
/// as valid.
struct CustomDataLayoutSpec
- : public Attribute::AttrBase<CustomDataLayoutSpec, Attribute,
- DataLayoutSpecStorage,
- DataLayoutSpecInterface::Trait> {
+ : public Attribute::AttrBase<
+ CustomDataLayoutSpec, Attribute, DataLayoutSpecStorage,
+ DLTIQueryInterface::Trait, DataLayoutSpecInterface::Trait> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CustomDataLayoutSpec)
using Base::Base;
@@ -92,6 +92,9 @@ struct CustomDataLayoutSpec
StringAttr getStackAlignmentIdentifier(MLIRContext *context) const {
return Builder(context).getStringAttr(kStackAlignmentKeyName);
}
+ FailureOr<Attribute> query(DataLayoutEntryKey key) const {
+ return llvm::cast<mlir::DataLayoutSpecInterface>(*this).queryHelper(key);
+ }
};
class TargetSystemSpecStorage : public AttributeStorage {
@@ -113,9 +116,9 @@ class TargetSystemSpecStorage : public AttributeStorage {
};
struct CustomTargetSystemSpec
- : public Attribute::AttrBase<CustomTargetSystemSpec, Attribute,
- TargetSystemSpecStorage,
- TargetSystemSpecInterface::Trait> {
+ : public Attribute::AttrBase<
+ CustomTargetSystemSpec, Attribute, TargetSystemSpecStorage,
+ DLTIQueryInterface::Trait, TargetSystemSpecInterface::Trait> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CustomDataLayoutSpec)
using Base::Base;
@@ -138,6 +141,9 @@ struct CustomTargetSystemSpec
}
return std::nullopt;
}
+ FailureOr<Attribute> query(DataLayoutEntryKey key) const {
+ return llvm::cast<mlir::TargetSystemSpecInterface>(*this).queryHelper(key);
+ }
};
/// A type subject to data layout that exits the program if it is queried more
More information about the Mlir-commits
mailing list