[Mlir-commits] [mlir] Reimplementing target description concept using DLTI attribute (PR #92138)
Niranjan Hasabnis
llvmlistbot at llvm.org
Mon Jun 17 10:35:43 PDT 2024
https://github.com/nhasabni updated https://github.com/llvm/llvm-project/pull/92138
>From e8e087ce2727181bdbfa6f2d2f27f58e5c81abc6 Mon Sep 17 00:00:00 2001
From: Niranjan Hasabnis <niranjan.hasabnis at intel.com>
Date: Tue, 14 May 2024 07:53:49 -0700
Subject: [PATCH 1/7] Reimplementing target description concept using DLTI
attribute
---
mlir/include/mlir/Dialect/DLTI/DLTI.h | 146 +++++++
mlir/include/mlir/Dialect/DLTI/DLTIBase.td | 44 +++
mlir/include/mlir/Dialect/DLTI/Traits.h | 7 +
mlir/include/mlir/IR/BuiltinOps.td | 1 +
.../mlir/Interfaces/DataLayoutInterfaces.h | 65 ++++
.../mlir/Interfaces/DataLayoutInterfaces.td | 228 +++++++++++
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 +-
mlir/lib/Dialect/DLTI/DLTI.cpp | 367 +++++++++++++++++-
mlir/lib/Dialect/DLTI/Traits.cpp | 6 +
.../Linalg/Transforms/BlockPackMatmul.cpp | 22 +-
mlir/lib/IR/BuiltinDialect.cpp | 11 +
mlir/lib/Interfaces/DataLayoutInterfaces.cpp | 185 ++++++++-
mlir/lib/Transforms/Canonicalizer.cpp | 51 +++
.../Interfaces/DataLayoutInterfacesTest.cpp | 10 +
14 files changed, 1148 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTI.h b/mlir/include/mlir/Dialect/DLTI/DLTI.h
index 5ac7c11e6ffee..f78e8bdc5eb98 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTI.h
+++ b/mlir/include/mlir/Dialect/DLTI/DLTI.h
@@ -21,6 +21,8 @@ namespace mlir {
namespace impl {
class DataLayoutEntryStorage;
class DataLayoutSpecStorage;
+class TargetSystemDescSpecAttrStorage;
+class TargetDeviceDescSpecAttrStorage;
} // namespace impl
//===----------------------------------------------------------------------===//
@@ -124,6 +126,150 @@ class DataLayoutSpecAttr
static constexpr StringLiteral name = "builtin.data_layout_spec";
};
+//===----------------------------------------------------------------------===//
+// TargetSystemDescSpecAttr
+//===----------------------------------------------------------------------===//
+
+/// A system description attribute is a list of device descriptors, each
+/// having a unique device ID
+class TargetSystemDescSpecAttr
+ : public Attribute::AttrBase<TargetSystemDescSpecAttr, Attribute,
+ impl::TargetSystemDescSpecAttrStorage,
+ TargetSystemDescSpecInterface::Trait> {
+public:
+ using Base::Base;
+
+ /// The keyword used for this attribute in custom syntax.
+ constexpr const static StringLiteral kAttrKeyword = "tsd_spec";
+
+ /// Returns a system descriptor attribute from the given system descriptor
+ static TargetSystemDescSpecAttr
+ get(MLIRContext *context, ArrayRef<TargetDeviceDescSpecInterface> entries);
+
+ /// Returns the list of entries.
+ TargetDeviceDescSpecListRef getEntries() const;
+
+ /// Return the device descriptor that matches the given device ID
+ TargetDeviceDescSpecInterface getDeviceDescForDeviceID(uint32_t deviceID);
+
+ /// Returns the specification containing the given list of keys. If the list
+ /// contains duplicate keys or is otherwise invalid, reports errors using the
+ /// given callback and returns null.
+ static TargetSystemDescSpecAttr
+ getChecked(function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
+ ArrayRef<TargetDeviceDescSpecInterface> entries);
+
+ /// Checks that the given list of entries does not contain duplicate keys.
+ static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<TargetDeviceDescSpecInterface> entries);
+
+ /// Parses an instance of this attribute.
+ static TargetSystemDescSpecAttr parse(AsmParser &parser);
+
+ /// Prints this attribute.
+ void print(AsmPrinter &os) const;
+
+ static constexpr StringLiteral name = "builtin.target_system_description";
+};
+
+//===----------------------------------------------------------------------===//
+// TargetDeviceDescSpecAttr
+//===----------------------------------------------------------------------===//
+
+class TargetDeviceDescSpecAttr
+ : public Attribute::AttrBase<TargetDeviceDescSpecAttr, Attribute,
+ impl::TargetDeviceDescSpecAttrStorage,
+ TargetDeviceDescSpecInterface::Trait> {
+public:
+ using Base::Base;
+
+ /// The keyword used for this attribute in custom syntax.
+ constexpr const static StringLiteral kAttrKeyword = "tdd_spec";
+
+ /// Returns a system descriptor attribute from the given system descriptor
+ static TargetDeviceDescSpecAttr
+ get(MLIRContext *context, ArrayRef<DataLayoutEntryInterface> entries);
+
+ /// Returns the specification containing the given list of keys. If the list
+ /// contains duplicate keys or is otherwise invalid, reports errors using the
+ /// given callback and returns null.
+ static TargetDeviceDescSpecAttr
+ getChecked(function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
+ ArrayRef<DataLayoutEntryInterface> entries);
+
+ /// Checks that the given list of entries does not contain duplicate keys.
+ static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<DataLayoutEntryInterface> entries);
+
+ /// Returns the list of entries.
+ DataLayoutEntryListRef getEntries() const;
+
+ /// Parses an instance of this attribute.
+ static TargetDeviceDescSpecAttr parse(AsmParser &parser);
+
+ /// Prints this attribute.
+ void print(AsmPrinter &os) const;
+
+ /// Returns the device ID identifier.
+ StringAttr getDeviceIDIdentifier(MLIRContext *context);
+
+ /// Returns the device type identifier.
+ StringAttr getDeviceTypeIdentifier(MLIRContext *context);
+
+ /// Returns max vector op width identifier.
+ StringAttr getMaxVectorOpWidthIdentifier(MLIRContext *context);
+
+ /// Returns canonicalizer max iterations identifier.
+ StringAttr getCanonicalizerMaxIterationsIdentifier(MLIRContext *context);
+
+ /// Returns canonicalizer max num rewrites identifier.
+ StringAttr getCanonicalizerMaxNumRewritesIdentifier(MLIRContext *context);
+
+ /// Returns L1 cache size identifier
+ StringAttr getL1CacheSizeInBytesIdentifier(MLIRContext *context);
+
+ /// Returns the interface spec for device ID
+ /// Since we verify that the spec contains device ID the function
+ /// will return a valid spec.
+ DataLayoutEntryInterface getSpecForDeviceID(MLIRContext *context);
+
+ /// Returns the interface spec for device type
+ /// Since we verify that the spec contains device type the function
+ /// will return a valid spec.
+ DataLayoutEntryInterface getSpecForDeviceType(MLIRContext *context);
+
+ /// Returns the interface spec for max vector op width
+ /// Since max vector op width is an optional property, this function will
+ /// return a valid spec if the property is defined, otherwise it
+ /// will return an empty spec.
+ DataLayoutEntryInterface getSpecForMaxVectorOpWidth(MLIRContext *context);
+
+ /// Returns the interface spec for L1 cache size
+ /// Since L1 cache size is an optional property, this function will
+ /// return a valid spec if the property is defined, otherwise it
+ /// will return an empty spec.
+ DataLayoutEntryInterface getSpecForL1CacheSizeInBytes(MLIRContext *context);
+
+ /// Returns the interface spec for canonicalizer max iterations.
+ /// Since this is an optional property, this function will
+ /// return a valid spec if the property is defined, otherwise it
+ /// will return an empty spec.
+ DataLayoutEntryInterface
+ getSpecForCanonicalizerMaxIterations(MLIRContext *context);
+
+ /// Returns the interface spec for canonicalizer max num rewrites.
+ /// Since this is an optional property, this function will
+ /// return a valid spec if the property is defined, otherwise it
+ /// will return an empty spec.
+ DataLayoutEntryInterface
+ getSpecForCanonicalizerMaxNumRewrites(MLIRContext *context);
+
+ /// Return the value of device ID
+ uint32_t getDeviceID(MLIRContext *context);
+
+ static constexpr StringLiteral name = "builtin.target_device_description";
+};
+
} // namespace mlir
#include "mlir/Dialect/DLTI/DLTIDialect.h.inc"
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
index 3572a99fad874..c9a054b3c1e51 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
@@ -27,6 +27,13 @@ def DLTI_Dialect : Dialect {
constexpr const static ::llvm::StringLiteral
kDataLayoutAttrName = "dlti.dl_spec";
+ // Top level attribute name for target system description
+ constexpr const static ::llvm::StringLiteral
+ kTargetSystemDescAttrName = "dlti.tsd_spec";
+
+ constexpr const static ::llvm::StringLiteral
+ kTargetDeviceDescAttrName = "dlti.tdd_spec";
+
// Constants used in entries.
constexpr const static ::llvm::StringLiteral
kDataLayoutEndiannessKey = "dlti.endianness";
@@ -48,6 +55,25 @@ def DLTI_Dialect : Dialect {
constexpr const static ::llvm::StringLiteral
kDataLayoutStackAlignmentKey = "dlti.stack_alignment";
+
+ // Constants used in target description part of DLTI
+ constexpr const static ::llvm::StringLiteral
+ kTargetDeviceIDKey = "dlti.device_id";
+
+ constexpr const static ::llvm::StringLiteral
+ kTargetDeviceTypeKey = "dlti.device_type";
+
+ constexpr const static ::llvm::StringLiteral
+ kTargetDeviceMaxVectorOpWidthKey = "dlti.max_vector_op_width";
+
+ constexpr const static ::llvm::StringLiteral
+ kTargetDeviceCanonicalizerMaxIterationsKey = "dlti.canonicalizer_max_iterations";
+
+ constexpr const static ::llvm::StringLiteral
+ kTargetDeviceCanonicalizerMaxNumRewritesKey = "dlti.canonicalizer_max_num_rewrites";
+
+ constexpr const static ::llvm::StringLiteral
+ kTargetDeviceL1CacheSizeInBytesKey = "dlti.L1_cache_size_in_bytes";
}];
let useDefaultAttributePrinterParser = 1;
@@ -71,6 +97,24 @@ def DLTI_DataLayoutSpecAttr : DialectAttr<
let convertFromStorage = "$_self";
}
+def DLTI_TargetSystemDescSpecAttr : DialectAttr<
+ DLTI_Dialect,
+ CPred<"::llvm::isa<::mlir::TargetSystemDescSpecAttr>($_self)">,
+ "Target system description part of DLTI"> {
+ let storageType = "::mlir::TargetSystemDescSpecAttr";
+ let returnType = "::mlir::TargetSystemDescSpecAttr";
+ let convertFromStorage = "$_self";
+}
+
+def DLTI_TargetDeviceDescSpecAttr : DialectAttr<
+ DLTI_Dialect,
+ CPred<"::llvm::isa<::mlir::TargetDeviceDescSpecAttr>($_self)">,
+ "Target device description part of DLTI"> {
+ let storageType = "::mlir::TargetDeviceDescSpecAttr";
+ let returnType = "::mlir::TargetDeviceDescSpecAttr";
+ let convertFromStorage = "$_self";
+}
+
def HasDefaultDLTIDataLayout : NativeOpTrait<"HasDefaultDLTIDataLayout"> {
let cppNamespace = "::mlir";
}
diff --git a/mlir/include/mlir/Dialect/DLTI/Traits.h b/mlir/include/mlir/Dialect/DLTI/Traits.h
index 5d86195305a95..44083d54c4cad 100644
--- a/mlir/include/mlir/Dialect/DLTI/Traits.h
+++ b/mlir/include/mlir/Dialect/DLTI/Traits.h
@@ -18,6 +18,7 @@ class DataLayoutSpecAttr;
namespace impl {
LogicalResult verifyHasDefaultDLTIDataLayoutTrait(Operation *op);
DataLayoutSpecInterface getDataLayoutSpec(Operation *op);
+TargetSystemDescSpecInterface getTargetSystemDescSpec(Operation *op);
} // namespace impl
/// Trait to be used by operations willing to use the implementation of the
@@ -37,6 +38,12 @@ class HasDefaultDLTIDataLayout
DataLayoutSpecInterface getDataLayoutSpec() {
return impl::getDataLayoutSpec(this->getOperation());
}
+
+ /// Returns the target system description specification as provided by DLTI
+ /// dialect
+ TargetSystemDescSpecInterface getTargetSystemDescSpec() {
+ return impl::getTargetSystemDescSpec(this->getOperation());
+ }
};
} // namespace mlir
diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td
index eda24615c71ea..bdb4ce3ddfe20 100644
--- a/mlir/include/mlir/IR/BuiltinOps.td
+++ b/mlir/include/mlir/IR/BuiltinOps.td
@@ -78,6 +78,7 @@ def ModuleOp : Builtin_Op<"module", [
//===------------------------------------------------------------------===//
DataLayoutSpecInterface getDataLayoutSpec();
+ TargetSystemDescSpecInterface getTargetSystemDescSpec();
//===------------------------------------------------------------------===//
// OpAsmOpInterface Methods
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
index 76bf33e89a716..1584a13247dff 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
@@ -23,11 +23,17 @@
namespace mlir {
class DataLayout;
class DataLayoutEntryInterface;
+class TargetDeviceDescSpecInterface;
+class TargetSystemDescSpecInterface;
using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
// Using explicit SmallVector size because we cannot infer the size from the
// forward declaration, and we need the typedef in the actual declaration.
using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
+// using TargetDeviceDescSpecList =
+// llvm::SmallVector<TargetDeviceDescSpecInterface, 4>;
+using TargetDeviceDescSpecListRef =
+ llvm::ArrayRef<TargetDeviceDescSpecInterface>;
class DataLayoutOpInterface;
class DataLayoutSpecInterface;
class ModuleOp;
@@ -84,6 +90,24 @@ Attribute getDefaultGlobalMemorySpace(DataLayoutEntryInterface entry);
/// DataLayoutInterface if specified, otherwise returns the default.
uint64_t getDefaultStackAlignment(DataLayoutEntryInterface entry);
+/// return max vector op width from the specified DataLayoutEntry. If the
+/// property is missing from the entry, then return std::nullopt.
+std::optional<uint32_t> getMaxVectorOpWidth(DataLayoutEntryInterface entry);
+
+/// return L1 cache size in bytes from the specified DataLayoutEntry. If the
+/// property is missing from the entry, then return std::nullopt.
+std::optional<uint32_t> getL1CacheSizeInBytes(DataLayoutEntryInterface entry);
+
+/// return canonicalizer max iterations from the specified DataLayoutEntry.
+/// If the property is missing from the entry, then return std::nullopt.
+std::optional<int64_t>
+getCanonicalizerMaxIterations(DataLayoutEntryInterface entry);
+
+/// returncanonicalizer max num rewrites from the specified DataLayoutEntry.
+/// If the property is missing from the entry, then return std::nullopt.
+std::optional<int64_t>
+getCanonicalizerMaxNumRewrites(DataLayoutEntryInterface entry);
+
/// Given a list of data layout entries, returns a new list containing the
/// entries with keys having the given type ID, i.e. belonging to the same type
/// class.
@@ -95,6 +119,11 @@ DataLayoutEntryList filterEntriesForType(DataLayoutEntryListRef entries,
DataLayoutEntryInterface
filterEntryForIdentifier(DataLayoutEntryListRef entries, StringAttr id);
+/// Given a list of target device entries, returns the entry that has the given
+/// identifier as key, if such an entry exists in the list.
+TargetDeviceDescSpecInterface
+filterEntryForIdentifier(TargetDeviceDescSpecListRef entries, StringAttr id);
+
/// Verifies that the operation implementing the data layout interface, or a
/// module operation, is valid. This calls the verifier of the spec attribute
/// and checks if the layout is compatible with specs attached to the enclosing
@@ -106,6 +135,12 @@ LogicalResult verifyDataLayoutOp(Operation *op);
/// and dialect interfaces for type and identifier keys respectively.
LogicalResult verifyDataLayoutSpec(DataLayoutSpecInterface spec, Location loc);
+/// Verifies that a target system desc spec is valid. This dispatches to
+/// individual entry verifiers, and then to the verifiers implemented by the
+/// relevant dialect interfaces for identifier keys.
+LogicalResult verifyTargetSystemDescSpec(TargetSystemDescSpecInterface spec,
+ Location loc);
+
/// Divides the known min value of the numerator by the denominator and rounds
/// the result up to the next integer. Preserves the scalable flag.
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator);
@@ -137,6 +172,13 @@ class DataLayoutDialectInterface
return success();
}
+ /// Checks whether the given data layout entry is valid and reports any errors
+ /// at the provided location. Derived classes should override this.
+ virtual LogicalResult verifyEntry(TargetDeviceDescSpecInterface entry,
+ Location loc) const {
+ return success();
+ }
+
/// Default implementation of entry combination that combines identical
/// entries and returns null otherwise.
static DataLayoutEntryInterface
@@ -214,10 +256,33 @@ class DataLayout {
/// unspecified.
uint64_t getStackAlignment() const;
+ /// Returns for max vector op width if the property is defined for the given
+ /// device ID, otherwise return std::nullopt.
+ std::optional<uint32_t>
+ getMaxVectorOpWidth(TargetDeviceDescSpecInterface::DeviceID) const;
+
+ /// Returns for L1 cache size if the property is defined for the given
+ /// device ID, otherwise return std::nullopt.
+ std::optional<uint32_t>
+ getL1CacheSizeInBytes(TargetDeviceDescSpecInterface::DeviceID) const;
+
+ /// Returns for canonicalizer max iterations if the property is defined for
+ /// the given device ID, otherwise return std::nullopt.
+ std::optional<int64_t> getCanonicalizerMaxIterations(
+ TargetDeviceDescSpecInterface::DeviceID) const;
+
+ /// Returns for canonicalizer max rewrites if the property is defined for
+ /// the given device ID, otherwise return std::nullopt.
+ std::optional<int64_t> getCanonicalizerMaxNumRewrites(
+ TargetDeviceDescSpecInterface::DeviceID) const;
+
private:
/// Combined layout spec at the given scope.
const DataLayoutSpecInterface originalLayout;
+ /// Combined target system desc spec at the given scope.
+ const TargetSystemDescSpecInterface originalTargetSystemDesc;
+
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// List of enclosing layout specs.
SmallVector<DataLayoutSpecInterface, 2> layoutStack;
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index 9edc885b9c5a9..75e609dde8fcf 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -194,6 +194,182 @@ def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface"> {
}];
}
+def TargetDeviceDescSpecInterface : AttrInterface<"TargetDeviceDescSpecInterface"> {
+ let cppNamespace = "::mlir";
+
+ let description = [{
+ Attribute interface describing a target device description specification.
+
+ A target device description specification is a list of device properties (key)
+ and their values for a specific device. The device is identified using "device_id"
+ (as a key and ui32 value) and "device_type" key which must have a string value.
+ Both "device_id" and "device_type" are mandatory keys. As an example, L1 cache
+ size could be a device property, and its value would be a device specific size.
+
+ A target device description specification is attached to a module as a module level
+ attribute.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*description=*/"Returns the list of layout entries.",
+ /*retTy=*/"::mlir::DataLayoutEntryListRef",
+ /*methodName=*/"getEntries",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the entry related to the given identifier, if "
+ "present.",
+ /*retTy=*/"::mlir::DataLayoutEntryInterface",
+ /*methodName=*/"getSpecForIdentifier",
+ /*args=*/(ins "::mlir::StringAttr":$identifier),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::filterEntryForIdentifier($_attr.getEntries(),
+ identifier);
+ }]
+ >,
+ InterfaceMethod<
+ /*description=*/"Checks that the entry is well-formed, reports errors "
+ "at the provided location.",
+ /*retTy=*/"::mlir::LogicalResult",
+ /*methodName=*/"verifyEntry",
+ /*args=*/(ins "::mlir::Location":$loc),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{ return ::mlir::success(); }]
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the device ID identifier.",
+ /*retTy=*/"::mlir::StringAttr",
+ /*methodName=*/"getDeviceIDIdentifier",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the device type identifier.",
+ /*retTy=*/"::mlir::StringAttr",
+ /*methodName=*/"getDeviceTypeIdentifier",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the L1 cache size identifier.",
+ /*retTy=*/"::mlir::StringAttr",
+ /*methodName=*/"getMaxVectorOpWidthIdentifier",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns canonicalizer max iterations identifier.",
+ /*retTy=*/"::mlir::StringAttr",
+ /*methodName=*/"getCanonicalizerMaxIterationsIdentifier",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns canonicalizer max num rewrites identifier.",
+ /*retTy=*/"::mlir::StringAttr",
+ /*methodName=*/"getCanonicalizerMaxNumRewritesIdentifier",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the entry related to Device ID. The function"
+ "will crash if the entry is missing.",
+ /*retTy=*/"::mlir::DataLayoutEntryInterface",
+ /*methodName=*/"getSpecForDeviceID",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the entry related to the given identifier. "
+ "The function will crash if the entry is missing.",
+ /*retTy=*/"::mlir::DataLayoutEntryInterface",
+ /*methodName=*/"getSpecForDeviceType",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the entry related to the given identifier, if "
+ "present. Otherwise, return empty spec.",
+ /*retTy=*/"::mlir::DataLayoutEntryInterface",
+ /*methodName=*/"getSpecForMaxVectorOpWidth",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the entry related to the given identifier, if "
+ "present. Otherwise, return empty spec.",
+ /*retTy=*/"::mlir::DataLayoutEntryInterface",
+ /*methodName=*/"getSpecForL1CacheSizeInBytes",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the entry related to the given identifier, if "
+ "present. Otherwise, return empty spec.",
+ /*retTy=*/"::mlir::DataLayoutEntryInterface",
+ /*methodName=*/"getSpecForCanonicalizerMaxIterations",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the entry related to the given identifier, if "
+ "present. Otherwise, return empty spec.",
+ /*retTy=*/"::mlir::DataLayoutEntryInterface",
+ /*methodName=*/"getSpecForCanonicalizerMaxNumRewrites",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the entry related to the given identifier, if "
+ "present.",
+ /*retTy=*/"uint32_t",
+ /*methodName=*/"getDeviceID",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ ];
+
+ let extraClassDeclaration = [{
+ using DeviceID = uint32_t;
+ }];
+}
+
+def TargetSystemDescSpecInterface : AttrInterface<"TargetSystemDescSpecInterface"> {
+ let cppNamespace = "::mlir";
+
+ let description = [{
+ Attribute interface describing a target system description specification.
+
+ A target system description specification is a list of target device
+ specifications, with one device specification for a device in the system. As
+ such, a target system description specification allows specifying a heterogenous
+ system, with devices of different types (e.g., CPU, GPU, etc.)
+
+ The only requirement on a valid target system description specification is that
+ the "device_id" in every target device description specification needs to be
+ unique. This is because, ultimately, this "device_id" is used by the user to
+ query a value of a device property.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*description=*/"Returns the list of layout entries.",
+ /*retTy=*/"::mlir::TargetDeviceDescSpecListRef",
+ /*methodName=*/"getEntries",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the device description spec for given device "
+ "ID",
+ /*retTy=*/"::mlir::TargetDeviceDescSpecInterface",
+ /*methodName=*/"getDeviceDescForDeviceID",
+ /*args=*/(ins "int":$deviceID)
+ >,
+ InterfaceMethod<
+ /*description=*/"Verifies the validity of the specification and "
+ "reports "
+ "any errors at the given location.",
+ /*retTy=*/"::mlir::LogicalResult",
+ /*methodName=*/"verifySpec",
+ /*args=*/(ins "::mlir::Location":$loc),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::verifyTargetSystemDescSpec($_attr, loc);
+ }]
+ >
+ ];
+}
+
//===----------------------------------------------------------------------===//
// Operation interface
//===----------------------------------------------------------------------===//
@@ -227,6 +403,14 @@ def DataLayoutOpInterface : OpInterface<"DataLayoutOpInterface"> {
/*methodName=*/"getDataLayoutSpec",
/*args=*/(ins)
>,
+ InterfaceMethod<
+ /*description=*/"Returns the target system desc specification for this "
+ "op, or "
+ "null if it does not exist.",
+ /*retTy=*/"::mlir::TargetSystemDescSpecInterface",
+ /*methodName=*/"getTargetSystemDescSpec",
+ /*args=*/(ins)
+ >,
StaticInterfaceMethod<
/*description=*/"Returns the size of the given type computed using the "
"relevant entries. The data layout object can be used "
@@ -362,6 +546,50 @@ def DataLayoutOpInterface : OpInterface<"DataLayoutOpInterface"> {
return ::mlir::detail::getDefaultStackAlignment(entry);
}]
>,
+ StaticInterfaceMethod<
+ /*description=*/"Returns the max vector op width, if the property is "
+ "defined. Otherwise, it returns std::nullopt.",
+ /*retTy=*/"std::optional<uint32_t>",
+ /*methodName=*/"getMaxVectorOpWidth",
+ /*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::getMaxVectorOpWidth(entry);
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*description=*/"Returns the L1 cache size in bytes, if the property is "
+ "defined. Otherwise, it returns std::nullopt.",
+ /*retTy=*/"std::optional<uint32_t>",
+ /*methodName=*/"getL1CacheSizeInBytes",
+ /*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::getL1CacheSizeInBytes(entry);
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*description=*/"Returns the canonicalizer max iterations, if the "
+ "property is defined. Otherwise, it returns std::nullopt.",
+ /*retTy=*/"std::optional<int64_t>",
+ /*methodName=*/"getCanonicalizerMaxIterations",
+ /*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::getCanonicalizerMaxIterations(entry);
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*description=*/"Returns the canonicalizer max num rewrites, if the "
+ "property is defined. Otherwise, it returns std::nullopt.",
+ /*retTy=*/"std::optional<int64_t>",
+ /*methodName=*/"getCanonicalizerMaxNumRewrites",
+ /*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::getCanonicalizerMaxNumRewrites(entry);
+ }]
+ >,
];
let verify = [{ return ::mlir::detail::verifyDataLayoutOp($_op); }];
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 033e66c6118f3..8ff2b2104b998 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
@@ -29,6 +30,8 @@ namespace mlir {
using namespace mlir;
using namespace mlir::amdgpu;
+#define DEBUG_TYPE "amd-gpu-to-rocdl"
+
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
Type llvmI32 = rewriter.getI32Type();
@@ -49,7 +52,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
: ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
Chipset chipset;
- static constexpr uint32_t maxVectorOpWidth = 128;
LogicalResult
matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
@@ -111,6 +113,15 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
uint32_t elemBits = dataVector.getElementTypeBitWidth();
uint32_t totalBits = elemBits * dataVector.getNumElements();
+ uint32_t maxVectorOpWidth = 128; // default value
+ if (std::optional<uint32_t> v =
+ DataLayout(gpuOp->template getParentOfType<mlir::ModuleOp>())
+ .getMaxVectorOpWidth(1 /* gpu ID*/)) {
+ maxVectorOpWidth = *v;
+ }
+ LLVM_DEBUG(llvm::dbgs() << "[CostModel] GPU MaxVectorWidth:"
+ << maxVectorOpWidth << "\n");
+
if (totalBits > maxVectorOpWidth)
return gpuOp.emitOpError(
"Total width of loads or stores must be no more than " +
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 98a8865ef4da3..a8518469c7824 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -10,14 +10,21 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+
using namespace mlir;
#include "mlir/Dialect/DLTI/DLTIDialect.cpp.inc"
+#define DEBUG_TYPE "dlti"
+
//===----------------------------------------------------------------------===//
// DataLayoutEntryAttr
//===----------------------------------------------------------------------===//
@@ -337,6 +344,320 @@ void DataLayoutSpecAttr::print(AsmPrinter &os) const {
os << ">";
}
+//===----------------------------------------------------------------------===//
+// TargetDeviceDescSpecAttr
+//===----------------------------------------------------------------------===//
+constexpr const StringLiteral mlir::TargetDeviceDescSpecAttr::kAttrKeyword;
+
+constexpr const StringLiteral mlir::DLTIDialect::kTargetDeviceIDKey;
+constexpr const StringLiteral mlir::DLTIDialect::kTargetDeviceTypeKey;
+constexpr const StringLiteral
+ mlir::DLTIDialect::kTargetDeviceMaxVectorOpWidthKey;
+constexpr const StringLiteral
+ mlir::DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey;
+constexpr const StringLiteral
+ mlir::DLTIDialect::kTargetDeviceCanonicalizerMaxIterationsKey;
+constexpr const StringLiteral
+ mlir::DLTIDialect::kTargetDeviceCanonicalizerMaxNumRewritesKey;
+
+namespace mlir {
+namespace impl {
+class TargetDeviceDescSpecAttrStorage : public AttributeStorage {
+public:
+ using KeyTy = ArrayRef<DataLayoutEntryInterface>;
+
+ TargetDeviceDescSpecAttrStorage(KeyTy entries) : entries(entries) {}
+
+ bool operator==(const KeyTy &key) const { return key == entries; }
+
+ static TargetDeviceDescSpecAttrStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ return new (allocator.allocate<TargetDeviceDescSpecAttrStorage>())
+ TargetDeviceDescSpecAttrStorage(allocator.copyInto(key));
+ }
+
+ ArrayRef<DataLayoutEntryInterface> entries;
+};
+} // namespace impl
+} // namespace mlir
+
+TargetDeviceDescSpecAttr
+TargetDeviceDescSpecAttr::get(MLIRContext *ctx,
+ ArrayRef<DataLayoutEntryInterface> entries) {
+ return Base::get(ctx, entries);
+}
+
+DataLayoutEntryListRef TargetDeviceDescSpecAttr::getEntries() const {
+ return getImpl()->entries;
+}
+
+TargetDeviceDescSpecAttr TargetDeviceDescSpecAttr::getChecked(
+ function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
+ ArrayRef<DataLayoutEntryInterface> entries) {
+ return Base::getChecked(emitError, context, entries);
+}
+
+LogicalResult
+TargetDeviceDescSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<DataLayoutEntryInterface> entries) {
+ // Entries in tdd_spec can only have StringAttr as key. It does not support
+ // type as a key. Hence not reusing DataLayoutEntryInterface::verify.
+ bool targetDeviceIDKeyPresentAndValid = false;
+ bool targetDeviceTypeKeyPresentAndValid = false;
+
+ DenseSet<StringAttr> ids;
+ for (DataLayoutEntryInterface entry : entries) {
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+ return emitError()
+ << "dlti.tdd_spec attribute does not allow type as a key: "
+ << type;
+ } else {
+ auto id = entry.getKey().get<StringAttr>();
+ if (!ids.insert(id).second)
+ return emitError() << "repeated layout entry key: " << id.getValue();
+ }
+
+ // check that Device ID and Device Type are present.
+ StringRef entryName = entry.getKey().get<StringAttr>().strref();
+ if (entryName == DLTIDialect::kTargetDeviceIDKey) {
+ // Also check the type of the value.
+ IntegerAttr value =
+ llvm::dyn_cast_if_present<IntegerAttr>(entry.getValue());
+ if (value && value.getType().isUnsignedInteger(32)) {
+ targetDeviceIDKeyPresentAndValid = true;
+ }
+ } else if (entryName == DLTIDialect::kTargetDeviceTypeKey) {
+ // Also check the type of the value.
+ if (auto value = llvm::dyn_cast<StringAttr>(entry.getValue())) {
+ targetDeviceTypeKeyPresentAndValid = true;
+ }
+ }
+ }
+
+ // check that both DeviceID and DeviceType are present
+ // and are of correct type.
+ if (!targetDeviceIDKeyPresentAndValid) {
+ return emitError() << "tdd_spec requires key: "
+ << DLTIDialect::kTargetDeviceIDKey
+ << " and its value of ui32 type";
+ }
+ if (!targetDeviceTypeKeyPresentAndValid) {
+ return emitError() << "tdd_spec requires key: "
+ << DLTIDialect::kTargetDeviceTypeKey
+ << " and its value of string type";
+ }
+
+ return success();
+}
+
+/// Parses an attribute with syntax
+/// tdd_spec_attr ::= `#target.` `tdd_spec` `<` dl-entry-attr-list? `>`
+/// dl-entry-attr-list ::= dl-entry-attr
+/// | dl-entry-attr `,` dl-entry-attr-list
+TargetDeviceDescSpecAttr TargetDeviceDescSpecAttr::parse(AsmParser &parser) {
+ if (failed(parser.parseLess()))
+ return {};
+
+ // Empty spec.
+ if (succeeded(parser.parseOptionalGreater()))
+ return get(parser.getContext(), {});
+
+ SmallVector<DataLayoutEntryInterface> entries;
+ if (parser.parseCommaSeparatedList(
+ [&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
+ parser.parseGreater())
+ return {};
+
+ return getChecked([&] { return parser.emitError(parser.getNameLoc()); },
+ parser.getContext(), entries);
+}
+
+void TargetDeviceDescSpecAttr::print(AsmPrinter &os) const {
+ os << TargetDeviceDescSpecAttr::kAttrKeyword << "<";
+ llvm::interleaveComma(getEntries(), os);
+ os << ">";
+}
+
+// ---------------------------------------------------------------------------//
+// Support for specific keys
+// ---------------------------------------------------------------------------//
+
+StringAttr
+TargetDeviceDescSpecAttr::getDeviceIDIdentifier(MLIRContext *context) {
+ return Builder(context).getStringAttr(DLTIDialect::kTargetDeviceIDKey);
+}
+
+StringAttr
+TargetDeviceDescSpecAttr::getDeviceTypeIdentifier(MLIRContext *context) {
+ return Builder(context).getStringAttr(DLTIDialect::kTargetDeviceTypeKey);
+}
+
+StringAttr
+TargetDeviceDescSpecAttr::getMaxVectorOpWidthIdentifier(MLIRContext *context) {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kTargetDeviceMaxVectorOpWidthKey);
+}
+
+StringAttr TargetDeviceDescSpecAttr::getL1CacheSizeInBytesIdentifier(
+ MLIRContext *context) {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey);
+}
+
+StringAttr TargetDeviceDescSpecAttr::getCanonicalizerMaxIterationsIdentifier(
+ MLIRContext *context) {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kTargetDeviceCanonicalizerMaxIterationsKey);
+}
+
+StringAttr TargetDeviceDescSpecAttr::getCanonicalizerMaxNumRewritesIdentifier(
+ MLIRContext *context) {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kTargetDeviceCanonicalizerMaxNumRewritesKey);
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForDeviceID(MLIRContext *context) {
+ return getSpecForIdentifier(getDeviceIDIdentifier(context));
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForDeviceType(MLIRContext *context) {
+ return getSpecForIdentifier(getDeviceTypeIdentifier(context));
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForMaxVectorOpWidth(MLIRContext *context) {
+ return getSpecForIdentifier(getMaxVectorOpWidthIdentifier(context));
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForL1CacheSizeInBytes(MLIRContext *context) {
+ return getSpecForIdentifier(getL1CacheSizeInBytesIdentifier(context));
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForCanonicalizerMaxIterations(
+ MLIRContext *context) {
+ return getSpecForIdentifier(getCanonicalizerMaxIterationsIdentifier(context));
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForCanonicalizerMaxNumRewrites(
+ MLIRContext *context) {
+ return getSpecForIdentifier(
+ getCanonicalizerMaxNumRewritesIdentifier(context));
+}
+
+uint32_t TargetDeviceDescSpecAttr::getDeviceID(MLIRContext *context) {
+ DataLayoutEntryInterface entry = getSpecForDeviceID(context);
+ return llvm::cast<IntegerAttr>(entry.getValue()).getValue().getZExtValue();
+}
+
+//===----------------------------------------------------------------------===//
+// TargetSystemDescSpecAttr
+//===----------------------------------------------------------------------===//
+
+constexpr const StringLiteral mlir::TargetSystemDescSpecAttr::kAttrKeyword;
+
+namespace mlir {
+namespace impl {
+class TargetSystemDescSpecAttrStorage : public AttributeStorage {
+public:
+ using KeyTy = ArrayRef<TargetDeviceDescSpecInterface>;
+
+ TargetSystemDescSpecAttrStorage(KeyTy entries) : entries(entries) {}
+
+ bool operator==(const KeyTy &key) const { return key == entries; }
+
+ static TargetSystemDescSpecAttrStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ return new (allocator.allocate<TargetSystemDescSpecAttrStorage>())
+ TargetSystemDescSpecAttrStorage(allocator.copyInto(key));
+ }
+
+ // This could be a map of DeviceID to DeviceDesc for faster lookup.
+ ArrayRef<TargetDeviceDescSpecInterface> entries;
+};
+} // namespace impl
+} // namespace mlir
+
+TargetSystemDescSpecAttr
+TargetSystemDescSpecAttr::get(MLIRContext *context,
+ ArrayRef<TargetDeviceDescSpecInterface> entries) {
+ return Base::get(context, entries);
+}
+
+TargetDeviceDescSpecListRef TargetSystemDescSpecAttr::getEntries() const {
+ return getImpl()->entries;
+}
+
+TargetDeviceDescSpecInterface
+TargetSystemDescSpecAttr::getDeviceDescForDeviceID(
+ TargetDeviceDescSpecInterface::DeviceID DeviceID) {
+ for (TargetDeviceDescSpecInterface entry : getEntries()) {
+ if (entry.getDeviceID(getContext()) == DeviceID)
+ return entry;
+ }
+ return TargetDeviceDescSpecInterface();
+}
+
+TargetSystemDescSpecAttr TargetSystemDescSpecAttr::getChecked(
+ function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
+ ArrayRef<TargetDeviceDescSpecInterface> entries) {
+ return Base::getChecked(emitError, context, entries);
+}
+
+LogicalResult TargetSystemDescSpecAttr::verify(
+ function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<TargetDeviceDescSpecInterface> entries) {
+ DenseSet<uint32_t> device_ids;
+
+ for (TargetDeviceDescSpecInterface tdd_spec : entries) {
+ // First verify that a target device desc spec is valid.
+ if (failed(
+ TargetDeviceDescSpecAttr::verify(emitError, tdd_spec.getEntries())))
+ return failure();
+
+ // Check that device IDs are unique across all entries.
+ MLIRContext *context = tdd_spec.getContext();
+ uint32_t device_id = tdd_spec.getDeviceID(context);
+ if (!device_ids.insert(device_id).second) {
+ return emitError() << "repeated Device ID in dlti.tsd_spec: "
+ << device_id;
+ }
+ }
+ return success();
+}
+
+/// Parses an attribute with syntax
+/// attr ::= `#target.` `tsd_spec` `<` tdd-spec-attr-list? `>`
+/// tdd-spec-attr-list ::= tdd_spec
+/// | tdd_spec `,` tdd_spec_attr_list
+TargetSystemDescSpecAttr TargetSystemDescSpecAttr::parse(AsmParser &parser) {
+ if (failed(parser.parseLess()))
+ return {};
+
+ // Empty spec.
+ if (succeeded(parser.parseOptionalGreater()))
+ return get(parser.getContext(), {});
+
+ SmallVector<TargetDeviceDescSpecInterface> entries;
+ if (parser.parseCommaSeparatedList(
+ [&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
+ parser.parseGreater())
+ return {};
+
+ return getChecked([&] { return parser.emitError(parser.getNameLoc()); },
+ parser.getContext(), entries);
+}
+
+void TargetSystemDescSpecAttr::print(AsmPrinter &os) const {
+ os << TargetSystemDescSpecAttr::kAttrKeyword << "<";
+ llvm::interleaveComma(getEntries(), os);
+ os << ">";
+}
+
//===----------------------------------------------------------------------===//
// DLTIDialect
//===----------------------------------------------------------------------===//
@@ -375,9 +696,36 @@ class TargetDataLayoutInterface : public DataLayoutDialectInterface {
};
} // namespace
+namespace {
+class SystemDescSpecInterface : public DataLayoutDialectInterface {
+public:
+ using DataLayoutDialectInterface::DataLayoutDialectInterface;
+
+ LogicalResult verifyEntry(TargetDeviceDescSpecInterface entry,
+ Location loc) const final {
+
+ for (DataLayoutEntryInterface dl_entry : entry.getEntries()) {
+ StringRef entryName = dl_entry.getKey().get<StringAttr>().strref();
+ // Check that the key name is known to us. Although, we may allow keys
+ // unknown to us.
+ if (entryName != DLTIDialect::kTargetDeviceIDKey &&
+ entryName != DLTIDialect::kTargetDeviceTypeKey &&
+ entryName != DLTIDialect::kTargetDeviceMaxVectorOpWidthKey &&
+ entryName != DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey &&
+ entryName !=
+ DLTIDialect::kTargetDeviceCanonicalizerMaxIterationsKey &&
+ entryName != DLTIDialect::kTargetDeviceCanonicalizerMaxNumRewritesKey)
+ return emitError(loc) << "unknown target desc key name: " << entryName;
+ }
+ return success();
+ }
+};
+} // namespace
+
void DLTIDialect::initialize() {
- addAttributes<DataLayoutEntryAttr, DataLayoutSpecAttr>();
- addInterfaces<TargetDataLayoutInterface>();
+ addAttributes<DataLayoutEntryAttr, DataLayoutSpecAttr,
+ TargetSystemDescSpecAttr, TargetDeviceDescSpecAttr>();
+ addInterfaces<TargetDataLayoutInterface, SystemDescSpecInterface>();
}
Attribute DLTIDialect::parseAttribute(DialectAsmParser &parser,
@@ -390,6 +738,10 @@ Attribute DLTIDialect::parseAttribute(DialectAsmParser &parser,
return DataLayoutEntryAttr::parse(parser);
if (attrKind == DataLayoutSpecAttr::kAttrKeyword)
return DataLayoutSpecAttr::parse(parser);
+ if (attrKind == TargetSystemDescSpecAttr::kAttrKeyword)
+ return TargetSystemDescSpecAttr::parse(parser);
+ if (attrKind == TargetDeviceDescSpecAttr::kAttrKeyword)
+ return TargetDeviceDescSpecAttr::parse(parser);
parser.emitError(parser.getNameLoc(), "unknown attrribute type: ")
<< attrKind;
@@ -398,8 +750,8 @@ Attribute DLTIDialect::parseAttribute(DialectAsmParser &parser,
void DLTIDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
llvm::TypeSwitch<Attribute>(attr)
- .Case<DataLayoutEntryAttr, DataLayoutSpecAttr>(
- [&](auto a) { a.print(os); })
+ .Case<DataLayoutEntryAttr, DataLayoutSpecAttr, TargetSystemDescSpecAttr,
+ TargetDeviceDescSpecAttr>([&](auto a) { a.print(os); })
.Default([](Attribute) { llvm_unreachable("unknown attribute kind"); });
}
@@ -413,6 +765,13 @@ LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
if (isa<ModuleOp>(op))
return detail::verifyDataLayoutOp(op);
return success();
+ } else if (attr.getName() == DLTIDialect::kTargetSystemDescAttrName) {
+ if (!llvm::isa<TargetSystemDescSpecAttr>(attr.getValue())) {
+ return op->emitError()
+ << "'" << DLTIDialect::kTargetSystemDescAttrName
+ << "' is expected to be a #dlti.tsd_spec attribute";
+ }
+ return success();
}
return op->emitError() << "attribute '" << attr.getName().getValue()
diff --git a/mlir/lib/Dialect/DLTI/Traits.cpp b/mlir/lib/Dialect/DLTI/Traits.cpp
index 85acbee46defd..ead656774a27c 100644
--- a/mlir/lib/Dialect/DLTI/Traits.cpp
+++ b/mlir/lib/Dialect/DLTI/Traits.cpp
@@ -27,3 +27,9 @@ DataLayoutSpecInterface mlir::impl::getDataLayoutSpec(Operation *op) {
return op->getAttrOfType<DataLayoutSpecAttr>(
DLTIDialect::kDataLayoutAttrName);
}
+
+TargetSystemDescSpecInterface
+mlir::impl::getTargetSystemDescSpec(Operation *op) {
+ return op->getAttrOfType<TargetSystemDescSpecAttr>(
+ DLTIDialect::kTargetSystemDescAttrName);
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 91d4efa3372b7..c08224f7af54e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -18,6 +18,8 @@
#include <optional>
+#define DEBUG_TYPE "block-pack-matmul"
+
namespace mlir {
#define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
#include "mlir/Dialect/Linalg/Passes.h.inc"
@@ -134,6 +136,24 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
return packTransposedMatmul;
}
+static SmallVector<int64_t> getDefaultBlockFactors(linalg::LinalgOp linalgOp) {
+ // get L1 cache size first.
+ uint32_t L1_cache_size = 4096; // default value
+ uint32_t cpuID = 0;
+ ModuleOp moduleOp = linalgOp->getParentOfType<ModuleOp>();
+ if (std::optional<int64_t> v =
+ DataLayout(moduleOp).getL1CacheSizeInBytes(cpuID)) {
+ L1_cache_size = *v;
+ }
+
+ // block_size = sqrt(L1_cache_size) rounded down to nearest power of 2.
+ int64_t block_size =
+ std::pow(2, std::floor(std::log2(std::sqrt(L1_cache_size))));
+ // we use same block size for all dims.
+ LLVM_DEBUG(llvm::dbgs() << "block_size:" << block_size << "\n");
+ return {block_size, block_size, block_size};
+}
+
/// Pack a matmul operation into blocked 4D layout.
FailureOr<PackResult>
linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
@@ -146,7 +166,7 @@ linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
if (options->blockFactors.size() != 3)
- return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
+ options->blockFactors = getDefaultBlockFactors(linalgOp);
SmallVector<OpFoldResult> mnkTiles =
getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index dcb1119fe5207..1d57e0bdef187 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -155,6 +155,17 @@ DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() {
return {};
}
+TargetSystemDescSpecInterface ModuleOp::getTargetSystemDescSpec() {
+ // Take the first and only (if present) attribute that implements the
+ // interface. This needs a linear search, but is called only once per data
+ // layout object construction that is used for repeated queries.
+ for (NamedAttribute attr : getOperation()->getAttrs())
+ if (auto spec =
+ llvm::dyn_cast<TargetSystemDescSpecInterface>(attr.getValue()))
+ return spec;
+ return {};
+}
+
LogicalResult ModuleOp::verify() {
// Check that none of the attributes are non-dialect attributes, except for
// the symbol related attributes.
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 15cfb3dbaf745..857ccf03ed8c0 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -293,6 +293,50 @@ mlir::detail::getDefaultStackAlignment(DataLayoutEntryInterface entry) {
return value.getValue().getZExtValue();
}
+// Returns the max vector op width if specified in the given entry. If the entry
+// is empty (meaning the spec is missing), returns std::nullopt.
+std::optional<uint32_t>
+mlir::detail::getMaxVectorOpWidth(DataLayoutEntryInterface entry) {
+ if (entry == DataLayoutEntryInterface())
+ return std::nullopt;
+
+ auto value = cast<IntegerAttr>(entry.getValue());
+ return value.getValue().getZExtValue();
+}
+
+// Returns the L1 cache size if specified in the given entry. If the entry
+// is empty (meaning the spec is missing), returns std::nullopt.
+std::optional<uint32_t>
+mlir::detail::getL1CacheSizeInBytes(DataLayoutEntryInterface entry) {
+ if (entry == DataLayoutEntryInterface())
+ return std::nullopt;
+
+ auto value = cast<IntegerAttr>(entry.getValue());
+ return value.getValue().getZExtValue();
+}
+
+// Returns the canonicalizer max iterations if specified in the given entry.
+// If the entry is empty (meaning the spec is missing), returns std::nullopt.
+std::optional<int64_t>
+mlir::detail::getCanonicalizerMaxIterations(DataLayoutEntryInterface entry) {
+ if (entry == DataLayoutEntryInterface())
+ return std::nullopt;
+
+ auto value = cast<IntegerAttr>(entry.getValue());
+ return value.getValue().getSExtValue();
+}
+
+// Returns the canonicalizer max num rewrites if specified in the given entry.
+// If the entry is empty (meaning the spec is missing), returns std::nullopt.
+std::optional<int64_t>
+mlir::detail::getCanonicalizerMaxNumRewrites(DataLayoutEntryInterface entry) {
+ if (entry == DataLayoutEntryInterface())
+ return std::nullopt;
+
+ auto value = cast<IntegerAttr>(entry.getValue());
+ return value.getValue().getSExtValue();
+}
+
DataLayoutEntryList
mlir::detail::filterEntriesForType(DataLayoutEntryListRef entries,
TypeID typeID) {
@@ -324,6 +368,20 @@ static DataLayoutSpecInterface getSpec(Operation *operation) {
});
}
+static TargetSystemDescSpecInterface
+getTargetSystemDescSpec(Operation *operation) {
+ if (operation) {
+ ModuleOp moduleOp;
+ if (isa<ModuleOp>(operation)) {
+ moduleOp = llvm::dyn_cast<ModuleOp>(operation);
+ } else {
+ moduleOp = operation->getParentOfType<ModuleOp>();
+ }
+ return moduleOp.getTargetSystemDescSpec();
+ } else
+ return TargetSystemDescSpecInterface();
+}
+
/// Populates `opsWithLayout` with the list of proper ancestors of `leaf` that
/// are either modules or implement the `DataLayoutOpInterface`.
static void
@@ -435,7 +493,8 @@ mlir::DataLayout::DataLayout() : DataLayout(ModuleOp()) {}
mlir::DataLayout::DataLayout(DataLayoutOpInterface op)
: originalLayout(getCombinedDataLayout(op)), scope(op),
allocaMemorySpace(std::nullopt), programMemorySpace(std::nullopt),
- globalMemorySpace(std::nullopt), stackAlignment(std::nullopt) {
+ globalMemorySpace(std::nullopt), stackAlignment(std::nullopt),
+ originalTargetSystemDesc(getTargetSystemDescSpec(op)) {
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
checkMissingLayout(originalLayout, op);
collectParentLayouts(op, layoutStack);
@@ -445,7 +504,8 @@ mlir::DataLayout::DataLayout(DataLayoutOpInterface op)
mlir::DataLayout::DataLayout(ModuleOp op)
: originalLayout(getCombinedDataLayout(op)), scope(op),
allocaMemorySpace(std::nullopt), programMemorySpace(std::nullopt),
- globalMemorySpace(std::nullopt), stackAlignment(std::nullopt) {
+ globalMemorySpace(std::nullopt), stackAlignment(std::nullopt),
+ originalTargetSystemDesc(getTargetSystemDescSpec(op)) {
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
checkMissingLayout(originalLayout, op);
collectParentLayouts(op, layoutStack);
@@ -640,6 +700,78 @@ uint64_t mlir::DataLayout::getStackAlignment() const {
return *stackAlignment;
}
+std::optional<uint32_t> mlir::DataLayout::getMaxVectorOpWidth(
+ TargetDeviceDescSpecInterface::DeviceID deviceID) const {
+ checkValid();
+ DataLayoutEntryInterface entry;
+ if (originalTargetSystemDesc)
+ entry =
+ originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID)
+ .getSpecForMaxVectorOpWidth(originalTargetSystemDesc.getContext());
+ // Currently I am not caching the results because we do not return
+ // default values of these properties. Instead if the property is
+ // missing, we return std::nullopt so that the users can resort to
+ // the default value however they want.
+ if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+ return iface.getMaxVectorOpWidth(entry);
+ else
+ return detail::getMaxVectorOpWidth(entry);
+}
+
+std::optional<uint32_t> mlir::DataLayout::getL1CacheSizeInBytes(
+ TargetDeviceDescSpecInterface::DeviceID deviceID) const {
+ checkValid();
+ DataLayoutEntryInterface entry;
+ if (originalTargetSystemDesc)
+ entry = originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID)
+ .getSpecForL1CacheSizeInBytes(
+ originalTargetSystemDesc.getContext());
+ // Currently I am not caching the results because we do not return
+ // default values of these properties. Instead if the property is
+ // missing, we return std::nullopt so that the users can resort to
+ // the default value however they want.
+ if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+ return iface.getL1CacheSizeInBytes(entry);
+ else
+ return detail::getL1CacheSizeInBytes(entry);
+}
+
+std::optional<int64_t> mlir::DataLayout::getCanonicalizerMaxIterations(
+ TargetDeviceDescSpecInterface::DeviceID deviceID) const {
+ checkValid();
+ DataLayoutEntryInterface entry;
+ if (originalTargetSystemDesc)
+ entry = originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID)
+ .getSpecForCanonicalizerMaxIterations(
+ originalTargetSystemDesc.getContext());
+ // Currently I am not caching the results because we do not return
+ // default values of these properties. Instead if the property is
+ // missing, we return std::nullopt so that the users can resort to
+ // the default value however they want.
+ if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+ return iface.getCanonicalizerMaxIterations(entry);
+ else
+ return detail::getCanonicalizerMaxIterations(entry);
+}
+
+std::optional<int64_t> mlir::DataLayout::getCanonicalizerMaxNumRewrites(
+ TargetDeviceDescSpecInterface::DeviceID deviceID) const {
+ checkValid();
+ DataLayoutEntryInterface entry;
+ if (originalTargetSystemDesc)
+ entry = originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID)
+ .getSpecForCanonicalizerMaxNumRewrites(
+ originalTargetSystemDesc.getContext());
+ // Currently I am not caching the results because we do not return
+ // default values of these properties. Instead if the property is
+ // missing, we return std::nullopt so that the users can resort to
+ // the default value however they want.
+ if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+ return iface.getCanonicalizerMaxNumRewrites(entry);
+ else
+ return detail::getCanonicalizerMaxNumRewrites(entry);
+}
+
//===----------------------------------------------------------------------===//
// DataLayoutSpecInterface
//===----------------------------------------------------------------------===//
@@ -744,6 +876,55 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
return success();
}
+LogicalResult
+mlir::detail::verifyTargetSystemDescSpec(TargetSystemDescSpecInterface spec,
+ Location loc) {
+ DenseMap<StringAttr, DataLayoutEntryInterface> device_desc_keys;
+ DenseSet<uint32_t> device_ids;
+ for (TargetDeviceDescSpecInterface tdd_spec : spec.getEntries()) {
+ // First, verify individual target device desc specs.
+ if (failed(tdd_spec.verifyEntry(loc)))
+ return failure();
+
+ // Check that device IDs are unique across all entries.
+ MLIRContext *context = tdd_spec.getContext();
+ uint32_t device_id = tdd_spec.getDeviceID(context);
+ if (!device_ids.insert(device_id).second) {
+ return failure();
+ }
+
+ // collect all the keys used by all the tdd_specs.
+ for (DataLayoutEntryInterface entry : tdd_spec.getEntries()) {
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+ // tdd_spec does not support Type as a key.
+ return failure();
+ } else
+ device_desc_keys[entry.getKey().get<StringAttr>()] = entry;
+ }
+ }
+
+ for (const auto &kvp : device_desc_keys) {
+ StringAttr identifier = kvp.second.getKey().get<StringAttr>();
+ Dialect *dialect = identifier.getReferencedDialect();
+
+ // Ignore attributes that belong to an unknown dialect, the dialect may
+ // actually implement the relevant interface but we don't know about that.
+ if (!dialect)
+ continue;
+
+ const auto *iface = dyn_cast<DataLayoutDialectInterface>(dialect);
+ if (!iface) {
+ return emitError(loc)
+ << "the '" << dialect->getNamespace()
+ << "' dialect does not support identifier data layout entries";
+ }
+ if (failed(iface->verifyEntry(kvp.second, loc)))
+ return failure();
+ }
+
+ return success();
+}
+
#include "mlir/Interfaces/DataLayoutAttrInterface.cpp.inc"
#include "mlir/Interfaces/DataLayoutOpInterface.cpp.inc"
#include "mlir/Interfaces/DataLayoutTypeInterface.cpp.inc"
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index d50019bd6aee5..2948804b8f92a 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -23,6 +23,8 @@ namespace mlir {
using namespace mlir;
+#define DEBUG_TYPE "canonicalizer"
+
namespace {
/// Canonicalize operations in nested regions.
struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
@@ -48,6 +50,20 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
config.maxIterations = maxIterations;
config.maxNumRewrites = maxNumRewrites;
+ LLVM_DEBUG(llvm::dbgs()
+ << "[CostModel] Canonicalizer MaxIterations (default):"
+ << config.maxIterations << "\n");
+ LLVM_DEBUG(llvm::dbgs()
+ << "[CostModel] Canonicalizer MaxNumRewrites (default):"
+ << config.maxNumRewrites << "\n");
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "[CostModel] Canonicalizer MaxIterations (default):"
+ << config.maxIterations << "\n");
+ LLVM_DEBUG(llvm::dbgs()
+ << "[CostModel] Canonicalizer MaxNumRewrites (default):"
+ << config.maxNumRewrites << "\n");
+
RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
@@ -59,6 +75,41 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
return success();
}
void runOnOperation() override {
+ Operation *op = getOperation();
+ uint32_t cpuID = 0;
+
+ if (isa<ModuleOp>(op)) {
+ if (std::optional<int64_t> v =
+ DataLayout(llvm::dyn_cast<ModuleOp>(*op))
+ .getCanonicalizerMaxIterations(cpuID)) {
+ config.maxIterations = *v;
+ }
+ } else {
+ ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
+ if (std::optional<int64_t> v =
+ DataLayout(moduleOp).getCanonicalizerMaxIterations(cpuID)) {
+ config.maxIterations = *v;
+ }
+ }
+
+ if (isa<ModuleOp>(op)) {
+ if (std::optional<int64_t> v =
+ DataLayout(llvm::dyn_cast<ModuleOp>(*op))
+ .getCanonicalizerMaxNumRewrites(cpuID)) {
+ config.maxNumRewrites = *v;
+ }
+ } else {
+ ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
+ if (std::optional<int64_t> v =
+ DataLayout(moduleOp).getCanonicalizerMaxNumRewrites(cpuID)) {
+ config.maxNumRewrites = *v;
+ }
+ }
+ LLVM_DEBUG(llvm::dbgs() << "[CostModel] Canonicalizer MaxIterations (new):"
+ << config.maxIterations << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "[CostModel] Canonicalizer MaxNumRewrites (new):"
+ << config.maxNumRewrites << "\n");
+
LogicalResult converged =
applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
// Canonicalization is best-effort. Non-convergence is not a pass failure.
diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
index 5f484294268ab..38f0ed0ed8da3 100644
--- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
@@ -197,6 +197,11 @@ struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
}
+ TargetSystemDescSpecInterface getTargetSystemDescSpec() {
+ return getOperation()->getAttrOfType<TargetSystemDescSpecInterface>(
+ kAttrName);
+ }
+
static llvm::TypeSize getTypeSizeInBits(Type type,
const DataLayout &dataLayout,
DataLayoutEntryListRef params) {
@@ -244,6 +249,11 @@ struct OpWith7BitByte
return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
}
+ TargetSystemDescSpecInterface getTargetSystemDescSpec() {
+ return getOperation()->getAttrOfType<TargetSystemDescSpecInterface>(
+ kAttrName);
+ }
+
// Bytes are assumed to be 7-bit here.
static llvm::TypeSize getTypeSize(Type type, const DataLayout &dataLayout,
DataLayoutEntryListRef params) {
>From ccd436f80e6bcfacda9887b0226671811cd29de7 Mon Sep 17 00:00:00 2001
From: Niranjan Hasabnis <niranjan.hasabnis at intel.com>
Date: Fri, 17 May 2024 04:44:49 -0700
Subject: [PATCH 2/7] Addressing review comments
1. Use ODS framework for all of DLTI attrs
2. Removing need of MLIRContext in APIs
3. Removing canonicalizer heuristics from this PR
---
mlir/include/mlir/Dialect/DLTI/CMakeLists.txt | 6 +
mlir/include/mlir/Dialect/DLTI/DLTI.h | 257 +------------
mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td | 303 +++++++++++++++
mlir/include/mlir/Dialect/DLTI/DLTIBase.td | 46 +--
.../mlir/Interfaces/DataLayoutInterfaces.h | 26 +-
.../mlir/Interfaces/DataLayoutInterfaces.td | 70 +---
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 7 +-
mlir/lib/Dialect/DLTI/DLTI.cpp | 360 +++---------------
mlir/lib/Interfaces/DataLayoutInterfaces.cpp | 78 +---
mlir/lib/Transforms/Canonicalizer.cpp | 49 ---
10 files changed, 388 insertions(+), 814 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
diff --git a/mlir/include/mlir/Dialect/DLTI/CMakeLists.txt b/mlir/include/mlir/Dialect/DLTI/CMakeLists.txt
index e0b18b12cda34..44a814f1c8e82 100644
--- a/mlir/include/mlir/Dialect/DLTI/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/DLTI/CMakeLists.txt
@@ -1,2 +1,8 @@
add_mlir_dialect(DLTI dlti)
add_mlir_doc(DLTI DLTIDialect Dialects/ -gen-dialect-doc)
+
+set(LLVM_TARGET_DEFINITIONS DLTIAttrs.td)
+mlir_tablegen(DLTIAttrs.h.inc -gen-attrdef-decls -attrdefs-dialect=dlti)
+mlir_tablegen(DLTIAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=dlti)
+add_public_tablegen_target(MLIRDLTIAttrsIncGen)
+add_dependencies(mlir-headers MLIRDLTIAttrsIncGen)
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTI.h b/mlir/include/mlir/Dialect/DLTI/DLTI.h
index f78e8bdc5eb98..f50a654f3885d 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTI.h
+++ b/mlir/include/mlir/Dialect/DLTI/DLTI.h
@@ -18,260 +18,13 @@
#include "mlir/Interfaces/DataLayoutInterfaces.h"
namespace mlir {
-namespace impl {
-class DataLayoutEntryStorage;
-class DataLayoutSpecStorage;
-class TargetSystemDescSpecAttrStorage;
-class TargetDeviceDescSpecAttrStorage;
-} // namespace impl
-
-//===----------------------------------------------------------------------===//
-// DataLayoutEntryAttr
-//===----------------------------------------------------------------------===//
-
-/// A data layout entry attribute is a key-value pair where the key is a type or
-/// an identifier and the value is another attribute. These entries form a data
-/// layout specification.
-class DataLayoutEntryAttr
- : public Attribute::AttrBase<DataLayoutEntryAttr, Attribute,
- impl::DataLayoutEntryStorage,
- DataLayoutEntryInterface::Trait> {
-public:
- using Base::Base;
-
- /// The keyword used for this attribute in custom syntax.
- constexpr const static llvm::StringLiteral kAttrKeyword = "dl_entry";
-
- /// Returns the entry with the given key and value.
- static DataLayoutEntryAttr get(StringAttr key, Attribute value);
- static DataLayoutEntryAttr get(Type key, Attribute value);
-
- /// Returns the key of this entry.
- DataLayoutEntryKey getKey() const;
-
- /// Returns the value of this entry.
- Attribute getValue() const;
-
- /// Parses an instance of this attribute.
- static DataLayoutEntryAttr parse(AsmParser &parser);
-
- /// Prints this attribute.
- void print(AsmPrinter &os) const;
-
- static constexpr StringLiteral name = "builtin.data_layout_entry";
-};
-
-//===----------------------------------------------------------------------===//
-// DataLayoutSpecAttr
-//===----------------------------------------------------------------------===//
-
-/// A data layout specification is a list of entries that specify (partial) data
-/// layout information. It is expected to be attached to operations that serve
-/// as scopes for data layout requests.
-class DataLayoutSpecAttr
- : public Attribute::AttrBase<DataLayoutSpecAttr, Attribute,
- impl::DataLayoutSpecStorage,
- DataLayoutSpecInterface::Trait> {
-public:
- using Base::Base;
-
- /// The keyword used for this attribute in custom syntax.
- constexpr const static StringLiteral kAttrKeyword = "dl_spec";
-
- /// Returns the specification containing the given list of keys.
- static DataLayoutSpecAttr get(MLIRContext *ctx,
- ArrayRef<DataLayoutEntryInterface> entries);
-
- /// Returns the specification containing the given list of keys. If the list
- /// contains duplicate keys or is otherwise invalid, reports errors using the
- /// given callback and returns null.
- static DataLayoutSpecAttr
- getChecked(function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
- ArrayRef<DataLayoutEntryInterface> entries);
-
- /// Checks that the given list of entries does not contain duplicate keys.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<DataLayoutEntryInterface> entries);
-
- /// Combines this specification with `specs`, enclosing specifications listed
- /// from outermost to innermost. This overwrites the older entries with the
- /// same key as the newer entries if the entries are compatible. Returns null
- /// if the specifications are not compatible.
- DataLayoutSpecAttr combineWith(ArrayRef<DataLayoutSpecInterface> specs) const;
-
- /// Returns the list of entries.
- DataLayoutEntryListRef getEntries() const;
-
- /// Returns the endiannes identifier.
- StringAttr getEndiannessIdentifier(MLIRContext *context) const;
-
- /// Returns the alloca memory space identifier.
- StringAttr getAllocaMemorySpaceIdentifier(MLIRContext *context) const;
-
- /// Returns the program memory space identifier.
- StringAttr getProgramMemorySpaceIdentifier(MLIRContext *context) const;
-
- /// Returns the global memory space identifier.
- StringAttr getGlobalMemorySpaceIdentifier(MLIRContext *context) const;
-
- /// Returns the stack alignment identifier.
- StringAttr getStackAlignmentIdentifier(MLIRContext *context) const;
-
- /// Parses an instance of this attribute.
- static DataLayoutSpecAttr parse(AsmParser &parser);
-
- /// Prints this attribute.
- void print(AsmPrinter &os) const;
-
- static constexpr StringLiteral name = "builtin.data_layout_spec";
-};
-
-//===----------------------------------------------------------------------===//
-// TargetSystemDescSpecAttr
-//===----------------------------------------------------------------------===//
-
-/// A system description attribute is a list of device descriptors, each
-/// having a unique device ID
-class TargetSystemDescSpecAttr
- : public Attribute::AttrBase<TargetSystemDescSpecAttr, Attribute,
- impl::TargetSystemDescSpecAttrStorage,
- TargetSystemDescSpecInterface::Trait> {
-public:
- using Base::Base;
-
- /// The keyword used for this attribute in custom syntax.
- constexpr const static StringLiteral kAttrKeyword = "tsd_spec";
-
- /// Returns a system descriptor attribute from the given system descriptor
- static TargetSystemDescSpecAttr
- get(MLIRContext *context, ArrayRef<TargetDeviceDescSpecInterface> entries);
-
- /// Returns the list of entries.
- TargetDeviceDescSpecListRef getEntries() const;
-
- /// Return the device descriptor that matches the given device ID
- TargetDeviceDescSpecInterface getDeviceDescForDeviceID(uint32_t deviceID);
-
- /// Returns the specification containing the given list of keys. If the list
- /// contains duplicate keys or is otherwise invalid, reports errors using the
- /// given callback and returns null.
- static TargetSystemDescSpecAttr
- getChecked(function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
- ArrayRef<TargetDeviceDescSpecInterface> entries);
-
- /// Checks that the given list of entries does not contain duplicate keys.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<TargetDeviceDescSpecInterface> entries);
-
- /// Parses an instance of this attribute.
- static TargetSystemDescSpecAttr parse(AsmParser &parser);
-
- /// Prints this attribute.
- void print(AsmPrinter &os) const;
-
- static constexpr StringLiteral name = "builtin.target_system_description";
-};
-
-//===----------------------------------------------------------------------===//
-// TargetDeviceDescSpecAttr
-//===----------------------------------------------------------------------===//
-
-class TargetDeviceDescSpecAttr
- : public Attribute::AttrBase<TargetDeviceDescSpecAttr, Attribute,
- impl::TargetDeviceDescSpecAttrStorage,
- TargetDeviceDescSpecInterface::Trait> {
-public:
- using Base::Base;
-
- /// The keyword used for this attribute in custom syntax.
- constexpr const static StringLiteral kAttrKeyword = "tdd_spec";
-
- /// Returns a system descriptor attribute from the given system descriptor
- static TargetDeviceDescSpecAttr
- get(MLIRContext *context, ArrayRef<DataLayoutEntryInterface> entries);
-
- /// Returns the specification containing the given list of keys. If the list
- /// contains duplicate keys or is otherwise invalid, reports errors using the
- /// given callback and returns null.
- static TargetDeviceDescSpecAttr
- getChecked(function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
- ArrayRef<DataLayoutEntryInterface> entries);
-
- /// Checks that the given list of entries does not contain duplicate keys.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<DataLayoutEntryInterface> entries);
-
- /// Returns the list of entries.
- DataLayoutEntryListRef getEntries() const;
-
- /// Parses an instance of this attribute.
- static TargetDeviceDescSpecAttr parse(AsmParser &parser);
-
- /// Prints this attribute.
- void print(AsmPrinter &os) const;
-
- /// Returns the device ID identifier.
- StringAttr getDeviceIDIdentifier(MLIRContext *context);
-
- /// Returns the device type identifier.
- StringAttr getDeviceTypeIdentifier(MLIRContext *context);
-
- /// Returns max vector op width identifier.
- StringAttr getMaxVectorOpWidthIdentifier(MLIRContext *context);
-
- /// Returns canonicalizer max iterations identifier.
- StringAttr getCanonicalizerMaxIterationsIdentifier(MLIRContext *context);
-
- /// Returns canonicalizer max num rewrites identifier.
- StringAttr getCanonicalizerMaxNumRewritesIdentifier(MLIRContext *context);
-
- /// Returns L1 cache size identifier
- StringAttr getL1CacheSizeInBytesIdentifier(MLIRContext *context);
-
- /// Returns the interface spec for device ID
- /// Since we verify that the spec contains device ID the function
- /// will return a valid spec.
- DataLayoutEntryInterface getSpecForDeviceID(MLIRContext *context);
-
- /// Returns the interface spec for device type
- /// Since we verify that the spec contains device type the function
- /// will return a valid spec.
- DataLayoutEntryInterface getSpecForDeviceType(MLIRContext *context);
-
- /// Returns the interface spec for max vector op width
- /// Since max vector op width is an optional property, this function will
- /// return a valid spec if the property is defined, otherwise it
- /// will return an empty spec.
- DataLayoutEntryInterface getSpecForMaxVectorOpWidth(MLIRContext *context);
-
- /// Returns the interface spec for L1 cache size
- /// Since L1 cache size is an optional property, this function will
- /// return a valid spec if the property is defined, otherwise it
- /// will return an empty spec.
- DataLayoutEntryInterface getSpecForL1CacheSizeInBytes(MLIRContext *context);
-
- /// Returns the interface spec for canonicalizer max iterations.
- /// Since this is an optional property, this function will
- /// return a valid spec if the property is defined, otherwise it
- /// will return an empty spec.
- DataLayoutEntryInterface
- getSpecForCanonicalizerMaxIterations(MLIRContext *context);
-
- /// Returns the interface spec for canonicalizer max num rewrites.
- /// Since this is an optional property, this function will
- /// return a valid spec if the property is defined, otherwise it
- /// will return an empty spec.
- DataLayoutEntryInterface
- getSpecForCanonicalizerMaxNumRewrites(MLIRContext *context);
-
- /// Return the value of device ID
- uint32_t getDeviceID(MLIRContext *context);
-
- static constexpr StringLiteral name = "builtin.target_device_description";
-};
-
+namespace detail {
+class DataLayoutEntryAttrStorage;
+} // namespace detail
} // namespace mlir
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/DLTI/DLTIAttrs.h.inc"
#include "mlir/Dialect/DLTI/DLTIDialect.h.inc"
#endif // MLIR_DIALECT_DLTI_DLTI_H
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
new file mode 100644
index 0000000000000..b2b3b8cd31d20
--- /dev/null
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
@@ -0,0 +1,303 @@
+//===- DLTIAttrs.td - DLTI dialect attributes definition --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_DLTI_DLTIATTRS_TD
+#define MLIR_DIALECT_DLTI_DLTIATTRS_TD
+
+include "mlir/Dialect/DLTI/DLTI.td"
+include "mlir/IR/AttrTypeBase.td"
+
+class DLTIAttr<string name, list<Trait> traits = [],
+ string baseCppClass = "::mlir::Attribute">
+ : AttrDef<DLTI_Dialect, name, traits, baseCppClass> { }
+
+//===----------------------------------------------------------------------===//
+// DataLayoutEntryAttr
+//===----------------------------------------------------------------------===//
+
+def DataLayoutEntryTrait
+ : NativeAttrTrait<"DataLayoutEntryInterface::Trait"> {
+ let cppNamespace = "::mlir";
+}
+
+def DLTI_DataLayoutEntryAttr :
+ DLTIAttr<"DataLayoutEntry", [DataLayoutEntryTrait]> {
+ let summary = [{
+ An attribute to represent an entry of a data layout specification
+ }];
+ let description = [{
+ A data layout entry attribute is a key-value pair where the key is a type or
+ an identifier and the value is another attribute. These entries form a data
+ layout specification.
+ }];
+ let parameters = (ins
+ "DataLayoutEntryKey":$key, "Attribute":$value
+ );
+ // We do not generate storage class because llvm::PointerUnion
+ // does not work with hash_key method.
+ let genStorageClass = 0;
+ let mnemonic = "dl_entry";
+ let genVerifyDecl = 0;
+ let hasCustomAssemblyFormat = 1;
+ let extraClassDeclaration = [{
+ /// Returns the entry with the given key and value.
+ static DataLayoutEntryAttr get(StringAttr key, Attribute value);
+ static DataLayoutEntryAttr get(MLIRContext *context, Type key, Attribute value);
+ static DataLayoutEntryAttr get(Type key, Attribute value);
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// DataLayoutSpecAttr
+//===----------------------------------------------------------------------===//
+def DataLayoutSpecTrait
+ : NativeAttrTrait<"DataLayoutSpecInterface::Trait"> {
+ let cppNamespace = "::mlir";
+}
+
+def DLTI_DataLayoutSpecAttr :
+ DLTIAttr<"DataLayoutSpec", [DataLayoutSpecTrait]> {
+ let summary = [{An attribute to represent a data layout specification}];
+ let description = [{
+ A data layout specification is a list of entries that specify (partial) data
+ layout information. It is expected to be attached to operations that serve
+ as scopes for data layout requests.
+ }];
+ let parameters = (ins
+ ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
+ );
+ let mnemonic = "dl_spec";
+ let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
+ let extraClassDeclaration = [{
+ /// Combines this specification with `specs`, enclosing specifications listed
+ /// from outermost to innermost. This overwrites the older entries with the
+ /// same key as the newer entries if the entries are compatible. Returns null
+ /// if the specifications are not compatible.
+ DataLayoutSpecAttr combineWith(ArrayRef<DataLayoutSpecInterface> specs) const;
+
+ /// Returns the endiannes identifier.
+ StringAttr getEndiannessIdentifier(MLIRContext *context) const;
+
+ /// Returns the alloca memory space identifier.
+ StringAttr getAllocaMemorySpaceIdentifier(MLIRContext *context) const;
+
+ /// Returns the program memory space identifier.
+ StringAttr getProgramMemorySpaceIdentifier(MLIRContext *context) const;
+
+ /// Returns the global memory space identifier.
+ StringAttr getGlobalMemorySpaceIdentifier(MLIRContext *context) const;
+
+ /// Returns the stack alignment identifier.
+ StringAttr getStackAlignmentIdentifier(MLIRContext *context) const;
+ }];
+ let extraClassDefinition = [{
+ StringAttr
+ $cppClass::getEndiannessIdentifier(MLIRContext *context) const {
+ return Builder(context).getStringAttr(DLTIDialect::kDataLayoutEndiannessKey);
+ }
+
+ StringAttr
+ $cppClass::getAllocaMemorySpaceIdentifier(MLIRContext *context) const {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kDataLayoutAllocaMemorySpaceKey);
+ }
+
+ StringAttr $cppClass::getProgramMemorySpaceIdentifier(
+ MLIRContext *context) const {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kDataLayoutProgramMemorySpaceKey);
+ }
+
+ StringAttr
+ $cppClass::getGlobalMemorySpaceIdentifier(MLIRContext *context) const {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kDataLayoutGlobalMemorySpaceKey);
+ }
+
+ StringAttr
+ $cppClass::getStackAlignmentIdentifier(MLIRContext *context) const {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kDataLayoutStackAlignmentKey);
+ }
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// TargetSystemDescSpecAttr
+//===----------------------------------------------------------------------===//
+
+def TargetSystemDescSpecTrait
+ : NativeAttrTrait<"TargetSystemDescSpecInterface::Trait"> {
+ let cppNamespace = "::mlir";
+}
+
+def DLTI_TargetSystemDescSpecAttr :
+ DLTIAttr<"TargetSystemDescSpec", [TargetSystemDescSpecTrait]> {
+ let summary = [{An attribute to represent target system description}];
+ let description = [{
+ A system description specification describes the overall system
+ containing multiple devices, with each device having a unique ID
+ and its corresponding TargetDeviceDescSpec object.
+
+ Example:
+ dlti.target_system_desc_spec =
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "CPU">>,
+ #dlti.target_device_desc_spec <
+ #dlti.dl_entry<"dlti.device_id", 1: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "GPU">,
+ #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
+ #dlti.target_device_desc_spec <
+ #dlti.dl_entry<"dlti.device_id", 2: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "XPU">>>
+ }];
+ let parameters = (ins
+ ArrayRefParameter<"TargetDeviceDescSpecInterface", "">:$entries
+ );
+ let mnemonic = "target_system_desc_spec";
+ let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
+ let extraClassDeclaration = [{
+ /// Return the device descriptor that matches the given device ID
+ std::optional<TargetDeviceDescSpecInterface>
+ getDeviceDescForDeviceID(uint32_t deviceID);
+ }];
+ let extraClassDefinition = [{
+ std::optional<TargetDeviceDescSpecInterface>
+ $cppClass::getDeviceDescForDeviceID(
+ TargetDeviceDescSpecInterface::DeviceID DeviceID) {
+ for (TargetDeviceDescSpecInterface entry : getEntries()) {
+ if (entry.getDeviceID() == DeviceID)
+ return entry;
+ }
+ return std::nullopt;
+ }
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// TargetDeviceDescSpecAttr
+//===----------------------------------------------------------------------===//
+
+def TargetDeviceDescSpecTrait
+ : NativeAttrTrait<"TargetDeviceDescSpecInterface::Trait"> {
+ let cppNamespace = "::mlir";
+}
+
+def DLTI_TargetDeviceDescSpecAttr :
+ DLTIAttr<"TargetDeviceDescSpec", [TargetDeviceDescSpecTrait]> {
+ let summary = [{An attribute to represent target device description}];
+ let description = [{
+ Each device description specification describes a single device and
+ its hardware properties. Each device description must have a device_id
+ and a device_type. In addition, the description can contain any number
+ of optional hardware properties (e.g., max_vector_op_width below).
+
+ Example:
+ #dlti.target_device_desc_spec <
+ #dlti.dl_entry<"dlti.device_id", 1: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "GPU">,
+ #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>
+ }];
+ let parameters = (ins
+ ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
+ );
+ let mnemonic = "target_device_desc_spec";
+ let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
+ let extraClassDeclaration = [{
+ /// Returns the device ID identifier.
+ StringAttr getDeviceIDIdentifier();
+
+ /// Returns the device type identifier.
+ StringAttr getDeviceTypeIdentifier();
+
+ /// Returns max vector op width identifier.
+ StringAttr getMaxVectorOpWidthIdentifier();
+
+ /// Returns L1 cache size identifier
+ StringAttr getL1CacheSizeInBytesIdentifier();
+
+ /// Returns the interface spec for device ID
+ /// Since we verify that the spec contains device ID the function
+ /// will return a valid spec.
+ DataLayoutEntryInterface getSpecForDeviceID();
+
+ /// Returns the interface spec for device type
+ /// Since we verify that the spec contains device type the function
+ /// will return a valid spec.
+ DataLayoutEntryInterface getSpecForDeviceType();
+
+ /// Returns the interface spec for max vector op width
+ /// Since max vector op width is an optional property, this function will
+ /// return a valid spec if the property is defined, otherwise it
+ /// will return an empty spec.
+ DataLayoutEntryInterface getSpecForMaxVectorOpWidth();
+
+ /// Returns the interface spec for L1 cache size
+ /// Since L1 cache size is an optional property, this function will
+ /// return a valid spec if the property is defined, otherwise it
+ /// will return an empty spec.
+ DataLayoutEntryInterface getSpecForL1CacheSizeInBytes();
+
+ /// Return the value of device ID
+ uint32_t getDeviceID();
+ }];
+
+ let extraClassDefinition = [{
+ StringAttr
+ $cppClass::getDeviceIDIdentifier() {
+ return Builder(getContext()).getStringAttr(DLTIDialect::kTargetDeviceIDKey);
+ }
+
+ StringAttr
+ $cppClass::getDeviceTypeIdentifier() {
+ return Builder(getContext()).getStringAttr(DLTIDialect::kTargetDeviceTypeKey);
+ }
+
+ StringAttr
+ $cppClass::getMaxVectorOpWidthIdentifier() {
+ return Builder(getContext()).getStringAttr(
+ DLTIDialect::kTargetDeviceMaxVectorOpWidthKey);
+ }
+
+ StringAttr $cppClass::getL1CacheSizeInBytesIdentifier() {
+ return Builder(getContext()).getStringAttr(
+ DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey);
+ }
+
+ DataLayoutEntryInterface
+ $cppClass::getSpecForDeviceID() {
+ return getSpecForIdentifier(getDeviceIDIdentifier());
+ }
+
+ DataLayoutEntryInterface
+ $cppClass::getSpecForDeviceType() {
+ return getSpecForIdentifier(getDeviceTypeIdentifier());
+ }
+
+ DataLayoutEntryInterface
+ $cppClass::getSpecForMaxVectorOpWidth() {
+ return getSpecForIdentifier(getMaxVectorOpWidthIdentifier());
+ }
+
+ DataLayoutEntryInterface
+ $cppClass::getSpecForL1CacheSizeInBytes() {
+ return getSpecForIdentifier(getL1CacheSizeInBytesIdentifier());
+ }
+
+ uint32_t $cppClass::getDeviceID() {
+ DataLayoutEntryInterface entry = getSpecForDeviceID();
+ return llvm::cast<IntegerAttr>(entry.getValue()).getValue().getZExtValue();
+ }
+ }];
+}
+
+#endif // MLIR_DIALECT_DLTI_DLTIATTRS_TD
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
index c9a054b3c1e51..61a3d2553d20e 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
@@ -29,10 +29,10 @@ def DLTI_Dialect : Dialect {
// Top level attribute name for target system description
constexpr const static ::llvm::StringLiteral
- kTargetSystemDescAttrName = "dlti.tsd_spec";
+ kTargetSystemDescAttrName = "dlti.target_system_desc_spec";
constexpr const static ::llvm::StringLiteral
- kTargetDeviceDescAttrName = "dlti.tdd_spec";
+ kTargetDeviceDescAttrName = "dlti.target_device_desc_spec";
// Constants used in entries.
constexpr const static ::llvm::StringLiteral
@@ -66,12 +66,6 @@ def DLTI_Dialect : Dialect {
constexpr const static ::llvm::StringLiteral
kTargetDeviceMaxVectorOpWidthKey = "dlti.max_vector_op_width";
- constexpr const static ::llvm::StringLiteral
- kTargetDeviceCanonicalizerMaxIterationsKey = "dlti.canonicalizer_max_iterations";
-
- constexpr const static ::llvm::StringLiteral
- kTargetDeviceCanonicalizerMaxNumRewritesKey = "dlti.canonicalizer_max_num_rewrites";
-
constexpr const static ::llvm::StringLiteral
kTargetDeviceL1CacheSizeInBytesKey = "dlti.L1_cache_size_in_bytes";
}];
@@ -79,42 +73,6 @@ def DLTI_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
}
-def DLTI_DataLayoutEntryAttr : DialectAttr<
- DLTI_Dialect,
- CPred<"::llvm::isa<::mlir::DataLayoutEntryAttr>($_self)">,
- "Target data layout entry"> {
- let storageType = "::mlir::DataLayoutEntryAttr";
- let returnType = "::mlir::DataLayoutEntryAttr";
- let convertFromStorage = "$_self";
-}
-
-def DLTI_DataLayoutSpecAttr : DialectAttr<
- DLTI_Dialect,
- CPred<"::llvm::isa<::mlir::DataLayoutSpecAttr>($_self)">,
- "Target data layout specification"> {
- let storageType = "::mlir::DataLayoutSpecAttr";
- let returnType = "::mlir::DataLayoutSpecAttr";
- let convertFromStorage = "$_self";
-}
-
-def DLTI_TargetSystemDescSpecAttr : DialectAttr<
- DLTI_Dialect,
- CPred<"::llvm::isa<::mlir::TargetSystemDescSpecAttr>($_self)">,
- "Target system description part of DLTI"> {
- let storageType = "::mlir::TargetSystemDescSpecAttr";
- let returnType = "::mlir::TargetSystemDescSpecAttr";
- let convertFromStorage = "$_self";
-}
-
-def DLTI_TargetDeviceDescSpecAttr : DialectAttr<
- DLTI_Dialect,
- CPred<"::llvm::isa<::mlir::TargetDeviceDescSpecAttr>($_self)">,
- "Target device description part of DLTI"> {
- let storageType = "::mlir::TargetDeviceDescSpecAttr";
- let returnType = "::mlir::TargetDeviceDescSpecAttr";
- let convertFromStorage = "$_self";
-}
-
def HasDefaultDLTIDataLayout : NativeOpTrait<"HasDefaultDLTIDataLayout"> {
let cppNamespace = "::mlir";
}
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
index 1584a13247dff..625ac2e9367dc 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
@@ -30,8 +30,6 @@ using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
// forward declaration, and we need the typedef in the actual declaration.
using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
-// using TargetDeviceDescSpecList =
-// llvm::SmallVector<TargetDeviceDescSpecInterface, 4>;
using TargetDeviceDescSpecListRef =
llvm::ArrayRef<TargetDeviceDescSpecInterface>;
class DataLayoutOpInterface;
@@ -90,24 +88,14 @@ Attribute getDefaultGlobalMemorySpace(DataLayoutEntryInterface entry);
/// DataLayoutInterface if specified, otherwise returns the default.
uint64_t getDefaultStackAlignment(DataLayoutEntryInterface entry);
-/// return max vector op width from the specified DataLayoutEntry. If the
+/// Return max vector op width from the specified DataLayoutEntry. If the
/// property is missing from the entry, then return std::nullopt.
std::optional<uint32_t> getMaxVectorOpWidth(DataLayoutEntryInterface entry);
-/// return L1 cache size in bytes from the specified DataLayoutEntry. If the
+/// Return L1 cache size in bytes from the specified DataLayoutEntry. If the
/// property is missing from the entry, then return std::nullopt.
std::optional<uint32_t> getL1CacheSizeInBytes(DataLayoutEntryInterface entry);
-/// return canonicalizer max iterations from the specified DataLayoutEntry.
-/// If the property is missing from the entry, then return std::nullopt.
-std::optional<int64_t>
-getCanonicalizerMaxIterations(DataLayoutEntryInterface entry);
-
-/// returncanonicalizer max num rewrites from the specified DataLayoutEntry.
-/// If the property is missing from the entry, then return std::nullopt.
-std::optional<int64_t>
-getCanonicalizerMaxNumRewrites(DataLayoutEntryInterface entry);
-
/// Given a list of data layout entries, returns a new list containing the
/// entries with keys having the given type ID, i.e. belonging to the same type
/// class.
@@ -266,16 +254,6 @@ class DataLayout {
std::optional<uint32_t>
getL1CacheSizeInBytes(TargetDeviceDescSpecInterface::DeviceID) const;
- /// Returns for canonicalizer max iterations if the property is defined for
- /// the given device ID, otherwise return std::nullopt.
- std::optional<int64_t> getCanonicalizerMaxIterations(
- TargetDeviceDescSpecInterface::DeviceID) const;
-
- /// Returns for canonicalizer max rewrites if the property is defined for
- /// the given device ID, otherwise return std::nullopt.
- std::optional<int64_t> getCanonicalizerMaxNumRewrites(
- TargetDeviceDescSpecInterface::DeviceID) const;
-
private:
/// Combined layout spec at the given scope.
const DataLayoutSpecInterface originalLayout;
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index 75e609dde8fcf..6547b83080804 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -242,80 +242,54 @@ def TargetDeviceDescSpecInterface : AttrInterface<"TargetDeviceDescSpecInterface
/*description=*/"Returns the device ID identifier.",
/*retTy=*/"::mlir::StringAttr",
/*methodName=*/"getDeviceIDIdentifier",
- /*args=*/(ins "::mlir::MLIRContext *":$context)
+ /*args=*/(ins)
>,
InterfaceMethod<
/*description=*/"Returns the device type identifier.",
/*retTy=*/"::mlir::StringAttr",
/*methodName=*/"getDeviceTypeIdentifier",
- /*args=*/(ins "::mlir::MLIRContext *":$context)
+ /*args=*/(ins)
>,
InterfaceMethod<
/*description=*/"Returns the L1 cache size identifier.",
/*retTy=*/"::mlir::StringAttr",
/*methodName=*/"getMaxVectorOpWidthIdentifier",
- /*args=*/(ins "::mlir::MLIRContext *":$context)
- >,
- InterfaceMethod<
- /*description=*/"Returns canonicalizer max iterations identifier.",
- /*retTy=*/"::mlir::StringAttr",
- /*methodName=*/"getCanonicalizerMaxIterationsIdentifier",
- /*args=*/(ins "::mlir::MLIRContext *":$context)
- >,
- InterfaceMethod<
- /*description=*/"Returns canonicalizer max num rewrites identifier.",
- /*retTy=*/"::mlir::StringAttr",
- /*methodName=*/"getCanonicalizerMaxNumRewritesIdentifier",
- /*args=*/(ins "::mlir::MLIRContext *":$context)
+ /*args=*/(ins)
>,
InterfaceMethod<
/*description=*/"Returns the entry related to Device ID. The function"
"will crash if the entry is missing.",
/*retTy=*/"::mlir::DataLayoutEntryInterface",
/*methodName=*/"getSpecForDeviceID",
- /*args=*/(ins "::mlir::MLIRContext *":$context)
+ /*args=*/(ins)
>,
InterfaceMethod<
/*description=*/"Returns the entry related to the given identifier. "
"The function will crash if the entry is missing.",
/*retTy=*/"::mlir::DataLayoutEntryInterface",
/*methodName=*/"getSpecForDeviceType",
- /*args=*/(ins "::mlir::MLIRContext *":$context)
+ /*args=*/(ins)
>,
InterfaceMethod<
/*description=*/"Returns the entry related to the given identifier, if "
"present. Otherwise, return empty spec.",
/*retTy=*/"::mlir::DataLayoutEntryInterface",
/*methodName=*/"getSpecForMaxVectorOpWidth",
- /*args=*/(ins "::mlir::MLIRContext *":$context)
+ /*args=*/(ins)
>,
InterfaceMethod<
/*description=*/"Returns the entry related to the given identifier, if "
"present. Otherwise, return empty spec.",
/*retTy=*/"::mlir::DataLayoutEntryInterface",
/*methodName=*/"getSpecForL1CacheSizeInBytes",
- /*args=*/(ins "::mlir::MLIRContext *":$context)
- >,
- InterfaceMethod<
- /*description=*/"Returns the entry related to the given identifier, if "
- "present. Otherwise, return empty spec.",
- /*retTy=*/"::mlir::DataLayoutEntryInterface",
- /*methodName=*/"getSpecForCanonicalizerMaxIterations",
- /*args=*/(ins "::mlir::MLIRContext *":$context)
- >,
- InterfaceMethod<
- /*description=*/"Returns the entry related to the given identifier, if "
- "present. Otherwise, return empty spec.",
- /*retTy=*/"::mlir::DataLayoutEntryInterface",
- /*methodName=*/"getSpecForCanonicalizerMaxNumRewrites",
- /*args=*/(ins "::mlir::MLIRContext *":$context)
+ /*args=*/(ins)
>,
InterfaceMethod<
/*description=*/"Returns the entry related to the given identifier, if "
"present.",
/*retTy=*/"uint32_t",
/*methodName=*/"getDeviceID",
- /*args=*/(ins "::mlir::MLIRContext *":$context)
+ /*args=*/(ins)
>,
];
@@ -344,14 +318,14 @@ def TargetSystemDescSpecInterface : AttrInterface<"TargetSystemDescSpecInterface
let methods = [
InterfaceMethod<
/*description=*/"Returns the list of layout entries.",
- /*retTy=*/"::mlir::TargetDeviceDescSpecListRef",
+ /*retTy=*/"llvm::ArrayRef<::mlir::TargetDeviceDescSpecInterface>",
/*methodName=*/"getEntries",
/*args=*/(ins)
>,
InterfaceMethod<
/*description=*/"Returns the device description spec for given device "
"ID",
- /*retTy=*/"::mlir::TargetDeviceDescSpecInterface",
+ /*retTy=*/"std::optional<::mlir::TargetDeviceDescSpecInterface>",
/*methodName=*/"getDeviceDescForDeviceID",
/*args=*/(ins "int":$deviceID)
>,
@@ -567,29 +541,7 @@ def DataLayoutOpInterface : OpInterface<"DataLayoutOpInterface"> {
/*defaultImplementation=*/[{
return ::mlir::detail::getL1CacheSizeInBytes(entry);
}]
- >,
- StaticInterfaceMethod<
- /*description=*/"Returns the canonicalizer max iterations, if the "
- "property is defined. Otherwise, it returns std::nullopt.",
- /*retTy=*/"std::optional<int64_t>",
- /*methodName=*/"getCanonicalizerMaxIterations",
- /*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return ::mlir::detail::getCanonicalizerMaxIterations(entry);
- }]
- >,
- StaticInterfaceMethod<
- /*description=*/"Returns the canonicalizer max num rewrites, if the "
- "property is defined. Otherwise, it returns std::nullopt.",
- /*retTy=*/"std::optional<int64_t>",
- /*methodName=*/"getCanonicalizerMaxNumRewrites",
- /*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return ::mlir::detail::getCanonicalizerMaxNumRewrites(entry);
- }]
- >,
+ >
];
let verify = [{ return ::mlir::detail::verifyDataLayoutOp($_op); }];
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 8ff2b2104b998..f3e1f83b56550 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -114,9 +114,10 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
uint32_t elemBits = dataVector.getElementTypeBitWidth();
uint32_t totalBits = elemBits * dataVector.getNumElements();
uint32_t maxVectorOpWidth = 128; // default value
- if (std::optional<uint32_t> v =
- DataLayout(gpuOp->template getParentOfType<mlir::ModuleOp>())
- .getMaxVectorOpWidth(1 /* gpu ID*/)) {
+ ModuleOp moduleOp = gpuOp->template getParentOfType<mlir::ModuleOp>();
+ std::optional<uint32_t> v = std::nullopt;
+ if (moduleOp &&
+ (v = DataLayout(moduleOp).getMaxVectorOpWidth(1 /* gpu ID*/))) {
maxVectorOpWidth = *v;
}
LLVM_DEBUG(llvm::dbgs() << "[CostModel] GPU MaxVectorWidth:"
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index a8518469c7824..9d2865e8152ef 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -28,22 +28,19 @@ using namespace mlir;
//===----------------------------------------------------------------------===//
// DataLayoutEntryAttr
//===----------------------------------------------------------------------===//
-//
-constexpr const StringLiteral mlir::DataLayoutEntryAttr::kAttrKeyword;
-
namespace mlir {
-namespace impl {
-class DataLayoutEntryStorage : public AttributeStorage {
+namespace detail {
+class DataLayoutEntryAttrStorage : public AttributeStorage {
public:
using KeyTy = std::pair<DataLayoutEntryKey, Attribute>;
- DataLayoutEntryStorage(DataLayoutEntryKey entryKey, Attribute value)
+ DataLayoutEntryAttrStorage(DataLayoutEntryKey entryKey, Attribute value)
: entryKey(entryKey), value(value) {}
- static DataLayoutEntryStorage *construct(AttributeStorageAllocator &allocator,
- const KeyTy &key) {
- return new (allocator.allocate<DataLayoutEntryStorage>())
- DataLayoutEntryStorage(key.first, key.second);
+ static DataLayoutEntryAttrStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ return new (allocator.allocate<DataLayoutEntryAttrStorage>())
+ DataLayoutEntryAttrStorage(key.first, key.second);
}
bool operator==(const KeyTy &other) const {
@@ -53,7 +50,7 @@ class DataLayoutEntryStorage : public AttributeStorage {
DataLayoutEntryKey entryKey;
Attribute value;
};
-} // namespace impl
+} // namespace detail
} // namespace mlir
DataLayoutEntryAttr DataLayoutEntryAttr::get(StringAttr key, Attribute value) {
@@ -72,7 +69,7 @@ Attribute DataLayoutEntryAttr::getValue() const { return getImpl()->value; }
/// Parses an attribute with syntax:
/// attr ::= `#target.` `dl_entry` `<` (type | quoted-string) `,` attr `>`
-DataLayoutEntryAttr DataLayoutEntryAttr::parse(AsmParser &parser) {
+Attribute DataLayoutEntryAttr::parse(AsmParser &parser, Type ty) {
if (failed(parser.parseLess()))
return {};
@@ -100,7 +97,7 @@ DataLayoutEntryAttr DataLayoutEntryAttr::parse(AsmParser &parser) {
}
void DataLayoutEntryAttr::print(AsmPrinter &os) const {
- os << DataLayoutEntryAttr::kAttrKeyword << "<";
+ os << "<";
if (auto type = llvm::dyn_cast_if_present<Type>(getKey()))
os << type;
else
@@ -111,51 +108,6 @@ void DataLayoutEntryAttr::print(AsmPrinter &os) const {
//===----------------------------------------------------------------------===//
// DataLayoutSpecAttr
//===----------------------------------------------------------------------===//
-//
-constexpr const StringLiteral mlir::DataLayoutSpecAttr::kAttrKeyword;
-constexpr const StringLiteral
- mlir::DLTIDialect::kDataLayoutAllocaMemorySpaceKey;
-constexpr const StringLiteral
- mlir::DLTIDialect::kDataLayoutProgramMemorySpaceKey;
-constexpr const StringLiteral
- mlir::DLTIDialect::kDataLayoutGlobalMemorySpaceKey;
-
-constexpr const StringLiteral mlir::DLTIDialect::kDataLayoutStackAlignmentKey;
-
-namespace mlir {
-namespace impl {
-class DataLayoutSpecStorage : public AttributeStorage {
-public:
- using KeyTy = ArrayRef<DataLayoutEntryInterface>;
-
- DataLayoutSpecStorage(ArrayRef<DataLayoutEntryInterface> entries)
- : entries(entries) {}
-
- bool operator==(const KeyTy &key) const { return key == entries; }
-
- static DataLayoutSpecStorage *construct(AttributeStorageAllocator &allocator,
- const KeyTy &key) {
- return new (allocator.allocate<DataLayoutSpecStorage>())
- DataLayoutSpecStorage(allocator.copyInto(key));
- }
-
- ArrayRef<DataLayoutEntryInterface> entries;
-};
-} // namespace impl
-} // namespace mlir
-
-DataLayoutSpecAttr
-DataLayoutSpecAttr::get(MLIRContext *ctx,
- ArrayRef<DataLayoutEntryInterface> entries) {
- return Base::get(ctx, entries);
-}
-
-DataLayoutSpecAttr
-DataLayoutSpecAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
- MLIRContext *context,
- ArrayRef<DataLayoutEntryInterface> entries) {
- return Base::getChecked(emitError, context, entries);
-}
LogicalResult
DataLayoutSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
@@ -284,43 +236,11 @@ DataLayoutSpecAttr::combineWith(ArrayRef<DataLayoutSpecInterface> specs) const {
return DataLayoutSpecAttr::get(getContext(), entries);
}
-DataLayoutEntryListRef DataLayoutSpecAttr::getEntries() const {
- return getImpl()->entries;
-}
-
-StringAttr
-DataLayoutSpecAttr::getEndiannessIdentifier(MLIRContext *context) const {
- return Builder(context).getStringAttr(DLTIDialect::kDataLayoutEndiannessKey);
-}
-
-StringAttr
-DataLayoutSpecAttr::getAllocaMemorySpaceIdentifier(MLIRContext *context) const {
- return Builder(context).getStringAttr(
- DLTIDialect::kDataLayoutAllocaMemorySpaceKey);
-}
-
-StringAttr DataLayoutSpecAttr::getProgramMemorySpaceIdentifier(
- MLIRContext *context) const {
- return Builder(context).getStringAttr(
- DLTIDialect::kDataLayoutProgramMemorySpaceKey);
-}
-
-StringAttr
-DataLayoutSpecAttr::getGlobalMemorySpaceIdentifier(MLIRContext *context) const {
- return Builder(context).getStringAttr(
- DLTIDialect::kDataLayoutGlobalMemorySpaceKey);
-}
-StringAttr
-DataLayoutSpecAttr::getStackAlignmentIdentifier(MLIRContext *context) const {
- return Builder(context).getStringAttr(
- DLTIDialect::kDataLayoutStackAlignmentKey);
-}
-
/// Parses an attribute with syntax
/// attr ::= `#target.` `dl_spec` `<` attr-list? `>`
/// attr-list ::= attr
/// | attr `,` attr-list
-DataLayoutSpecAttr DataLayoutSpecAttr::parse(AsmParser &parser) {
+Attribute DataLayoutSpecAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
@@ -339,7 +259,7 @@ DataLayoutSpecAttr DataLayoutSpecAttr::parse(AsmParser &parser) {
}
void DataLayoutSpecAttr::print(AsmPrinter &os) const {
- os << DataLayoutSpecAttr::kAttrKeyword << "<";
+ os << "<";
llvm::interleaveComma(getEntries(), os);
os << ">";
}
@@ -347,55 +267,6 @@ void DataLayoutSpecAttr::print(AsmPrinter &os) const {
//===----------------------------------------------------------------------===//
// TargetDeviceDescSpecAttr
//===----------------------------------------------------------------------===//
-constexpr const StringLiteral mlir::TargetDeviceDescSpecAttr::kAttrKeyword;
-
-constexpr const StringLiteral mlir::DLTIDialect::kTargetDeviceIDKey;
-constexpr const StringLiteral mlir::DLTIDialect::kTargetDeviceTypeKey;
-constexpr const StringLiteral
- mlir::DLTIDialect::kTargetDeviceMaxVectorOpWidthKey;
-constexpr const StringLiteral
- mlir::DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey;
-constexpr const StringLiteral
- mlir::DLTIDialect::kTargetDeviceCanonicalizerMaxIterationsKey;
-constexpr const StringLiteral
- mlir::DLTIDialect::kTargetDeviceCanonicalizerMaxNumRewritesKey;
-
-namespace mlir {
-namespace impl {
-class TargetDeviceDescSpecAttrStorage : public AttributeStorage {
-public:
- using KeyTy = ArrayRef<DataLayoutEntryInterface>;
-
- TargetDeviceDescSpecAttrStorage(KeyTy entries) : entries(entries) {}
-
- bool operator==(const KeyTy &key) const { return key == entries; }
-
- static TargetDeviceDescSpecAttrStorage *
- construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
- return new (allocator.allocate<TargetDeviceDescSpecAttrStorage>())
- TargetDeviceDescSpecAttrStorage(allocator.copyInto(key));
- }
-
- ArrayRef<DataLayoutEntryInterface> entries;
-};
-} // namespace impl
-} // namespace mlir
-
-TargetDeviceDescSpecAttr
-TargetDeviceDescSpecAttr::get(MLIRContext *ctx,
- ArrayRef<DataLayoutEntryInterface> entries) {
- return Base::get(ctx, entries);
-}
-
-DataLayoutEntryListRef TargetDeviceDescSpecAttr::getEntries() const {
- return getImpl()->entries;
-}
-
-TargetDeviceDescSpecAttr TargetDeviceDescSpecAttr::getChecked(
- function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
- ArrayRef<DataLayoutEntryInterface> entries) {
- return Base::getChecked(emitError, context, entries);
-}
LogicalResult
TargetDeviceDescSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
@@ -409,7 +280,7 @@ TargetDeviceDescSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
for (DataLayoutEntryInterface entry : entries) {
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
return emitError()
- << "dlti.tdd_spec attribute does not allow type as a key: "
+ << "dlti.target_device_desc_spec does not allow type as a key: "
<< type;
} else {
auto id = entry.getKey().get<StringAttr>();
@@ -431,18 +302,22 @@ TargetDeviceDescSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
if (auto value = llvm::dyn_cast<StringAttr>(entry.getValue())) {
targetDeviceTypeKeyPresentAndValid = true;
}
+ } else if (entryName != DLTIDialect::kTargetDeviceMaxVectorOpWidthKey &&
+ entryName != DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey) {
+ return emitError() << "unknown target device desc key name: "
+ << entryName;
}
}
// check that both DeviceID and DeviceType are present
// and are of correct type.
if (!targetDeviceIDKeyPresentAndValid) {
- return emitError() << "tdd_spec requires key: "
+ return emitError() << "target_device_desc_spec requires key: "
<< DLTIDialect::kTargetDeviceIDKey
<< " and its value of ui32 type";
}
if (!targetDeviceTypeKeyPresentAndValid) {
- return emitError() << "tdd_spec requires key: "
+ return emitError() << "target_device_desc_spec requires key: "
<< DLTIDialect::kTargetDeviceTypeKey
<< " and its value of string type";
}
@@ -451,10 +326,11 @@ TargetDeviceDescSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
}
/// Parses an attribute with syntax
-/// tdd_spec_attr ::= `#target.` `tdd_spec` `<` dl-entry-attr-list? `>`
+/// target_device_desc_spec_attr ::=
+/// `#target.` `target_device_desc_spec` `<` dl-entry-attr-list? `>`
/// dl-entry-attr-list ::= dl-entry-attr
/// | dl-entry-attr `,` dl-entry-attr-list
-TargetDeviceDescSpecAttr TargetDeviceDescSpecAttr::parse(AsmParser &parser) {
+Attribute TargetDeviceDescSpecAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
@@ -473,141 +349,15 @@ TargetDeviceDescSpecAttr TargetDeviceDescSpecAttr::parse(AsmParser &parser) {
}
void TargetDeviceDescSpecAttr::print(AsmPrinter &os) const {
- os << TargetDeviceDescSpecAttr::kAttrKeyword << "<";
+ os << "<";
llvm::interleaveComma(getEntries(), os);
os << ">";
}
-// ---------------------------------------------------------------------------//
-// Support for specific keys
-// ---------------------------------------------------------------------------//
-
-StringAttr
-TargetDeviceDescSpecAttr::getDeviceIDIdentifier(MLIRContext *context) {
- return Builder(context).getStringAttr(DLTIDialect::kTargetDeviceIDKey);
-}
-
-StringAttr
-TargetDeviceDescSpecAttr::getDeviceTypeIdentifier(MLIRContext *context) {
- return Builder(context).getStringAttr(DLTIDialect::kTargetDeviceTypeKey);
-}
-
-StringAttr
-TargetDeviceDescSpecAttr::getMaxVectorOpWidthIdentifier(MLIRContext *context) {
- return Builder(context).getStringAttr(
- DLTIDialect::kTargetDeviceMaxVectorOpWidthKey);
-}
-
-StringAttr TargetDeviceDescSpecAttr::getL1CacheSizeInBytesIdentifier(
- MLIRContext *context) {
- return Builder(context).getStringAttr(
- DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey);
-}
-
-StringAttr TargetDeviceDescSpecAttr::getCanonicalizerMaxIterationsIdentifier(
- MLIRContext *context) {
- return Builder(context).getStringAttr(
- DLTIDialect::kTargetDeviceCanonicalizerMaxIterationsKey);
-}
-
-StringAttr TargetDeviceDescSpecAttr::getCanonicalizerMaxNumRewritesIdentifier(
- MLIRContext *context) {
- return Builder(context).getStringAttr(
- DLTIDialect::kTargetDeviceCanonicalizerMaxNumRewritesKey);
-}
-
-DataLayoutEntryInterface
-TargetDeviceDescSpecAttr::getSpecForDeviceID(MLIRContext *context) {
- return getSpecForIdentifier(getDeviceIDIdentifier(context));
-}
-
-DataLayoutEntryInterface
-TargetDeviceDescSpecAttr::getSpecForDeviceType(MLIRContext *context) {
- return getSpecForIdentifier(getDeviceTypeIdentifier(context));
-}
-
-DataLayoutEntryInterface
-TargetDeviceDescSpecAttr::getSpecForMaxVectorOpWidth(MLIRContext *context) {
- return getSpecForIdentifier(getMaxVectorOpWidthIdentifier(context));
-}
-
-DataLayoutEntryInterface
-TargetDeviceDescSpecAttr::getSpecForL1CacheSizeInBytes(MLIRContext *context) {
- return getSpecForIdentifier(getL1CacheSizeInBytesIdentifier(context));
-}
-
-DataLayoutEntryInterface
-TargetDeviceDescSpecAttr::getSpecForCanonicalizerMaxIterations(
- MLIRContext *context) {
- return getSpecForIdentifier(getCanonicalizerMaxIterationsIdentifier(context));
-}
-
-DataLayoutEntryInterface
-TargetDeviceDescSpecAttr::getSpecForCanonicalizerMaxNumRewrites(
- MLIRContext *context) {
- return getSpecForIdentifier(
- getCanonicalizerMaxNumRewritesIdentifier(context));
-}
-
-uint32_t TargetDeviceDescSpecAttr::getDeviceID(MLIRContext *context) {
- DataLayoutEntryInterface entry = getSpecForDeviceID(context);
- return llvm::cast<IntegerAttr>(entry.getValue()).getValue().getZExtValue();
-}
-
//===----------------------------------------------------------------------===//
// TargetSystemDescSpecAttr
//===----------------------------------------------------------------------===//
-constexpr const StringLiteral mlir::TargetSystemDescSpecAttr::kAttrKeyword;
-
-namespace mlir {
-namespace impl {
-class TargetSystemDescSpecAttrStorage : public AttributeStorage {
-public:
- using KeyTy = ArrayRef<TargetDeviceDescSpecInterface>;
-
- TargetSystemDescSpecAttrStorage(KeyTy entries) : entries(entries) {}
-
- bool operator==(const KeyTy &key) const { return key == entries; }
-
- static TargetSystemDescSpecAttrStorage *
- construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
- return new (allocator.allocate<TargetSystemDescSpecAttrStorage>())
- TargetSystemDescSpecAttrStorage(allocator.copyInto(key));
- }
-
- // This could be a map of DeviceID to DeviceDesc for faster lookup.
- ArrayRef<TargetDeviceDescSpecInterface> entries;
-};
-} // namespace impl
-} // namespace mlir
-
-TargetSystemDescSpecAttr
-TargetSystemDescSpecAttr::get(MLIRContext *context,
- ArrayRef<TargetDeviceDescSpecInterface> entries) {
- return Base::get(context, entries);
-}
-
-TargetDeviceDescSpecListRef TargetSystemDescSpecAttr::getEntries() const {
- return getImpl()->entries;
-}
-
-TargetDeviceDescSpecInterface
-TargetSystemDescSpecAttr::getDeviceDescForDeviceID(
- TargetDeviceDescSpecInterface::DeviceID DeviceID) {
- for (TargetDeviceDescSpecInterface entry : getEntries()) {
- if (entry.getDeviceID(getContext()) == DeviceID)
- return entry;
- }
- return TargetDeviceDescSpecInterface();
-}
-
-TargetSystemDescSpecAttr TargetSystemDescSpecAttr::getChecked(
- function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
- ArrayRef<TargetDeviceDescSpecInterface> entries) {
- return Base::getChecked(emitError, context, entries);
-}
-
LogicalResult TargetSystemDescSpecAttr::verify(
function_ref<InFlightDiagnostic()> emitError,
ArrayRef<TargetDeviceDescSpecInterface> entries) {
@@ -620,21 +370,24 @@ LogicalResult TargetSystemDescSpecAttr::verify(
return failure();
// Check that device IDs are unique across all entries.
- MLIRContext *context = tdd_spec.getContext();
- uint32_t device_id = tdd_spec.getDeviceID(context);
+ uint32_t device_id = tdd_spec.getDeviceID();
if (!device_ids.insert(device_id).second) {
- return emitError() << "repeated Device ID in dlti.tsd_spec: "
- << device_id;
+ return emitError()
+ << "repeated Device ID in dlti.target_system_desc_spec: "
+ << device_id;
}
}
return success();
}
/// Parses an attribute with syntax
-/// attr ::= `#target.` `tsd_spec` `<` tdd-spec-attr-list? `>`
-/// tdd-spec-attr-list ::= tdd_spec
-/// | tdd_spec `,` tdd_spec_attr_list
-TargetSystemDescSpecAttr TargetSystemDescSpecAttr::parse(AsmParser &parser) {
+/// attr ::=
+/// `#target.` `target_system_desc_spec` `<`
+/// target-device-desc-spec-attr-list? `>`
+/// target-device-desc-spec-attr-list ::= target_device_desc_spec
+/// | target_device_desc_spec `,`
+/// target-device-desc-spec-attr-list
+Attribute TargetSystemDescSpecAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
@@ -653,11 +406,14 @@ TargetSystemDescSpecAttr TargetSystemDescSpecAttr::parse(AsmParser &parser) {
}
void TargetSystemDescSpecAttr::print(AsmPrinter &os) const {
- os << TargetSystemDescSpecAttr::kAttrKeyword << "<";
+ os << "<";
llvm::interleaveComma(getEntries(), os);
os << ">";
}
+#define GET_ATTRDEF_CLASSES
+#include <mlir/Dialect/DLTI/DLTIAttrs.cpp.inc>
+
//===----------------------------------------------------------------------===//
// DLTIDialect
//===----------------------------------------------------------------------===//
@@ -711,10 +467,7 @@ class SystemDescSpecInterface : public DataLayoutDialectInterface {
if (entryName != DLTIDialect::kTargetDeviceIDKey &&
entryName != DLTIDialect::kTargetDeviceTypeKey &&
entryName != DLTIDialect::kTargetDeviceMaxVectorOpWidthKey &&
- entryName != DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey &&
- entryName !=
- DLTIDialect::kTargetDeviceCanonicalizerMaxIterationsKey &&
- entryName != DLTIDialect::kTargetDeviceCanonicalizerMaxNumRewritesKey)
+ entryName != DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey)
return emitError(loc) << "unknown target desc key name: " << entryName;
}
return success();
@@ -723,38 +476,13 @@ class SystemDescSpecInterface : public DataLayoutDialectInterface {
} // namespace
void DLTIDialect::initialize() {
- addAttributes<DataLayoutEntryAttr, DataLayoutSpecAttr,
- TargetSystemDescSpecAttr, TargetDeviceDescSpecAttr>();
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include <mlir/Dialect/DLTI/DLTIAttrs.cpp.inc>
+ >();
addInterfaces<TargetDataLayoutInterface, SystemDescSpecInterface>();
}
-Attribute DLTIDialect::parseAttribute(DialectAsmParser &parser,
- Type type) const {
- StringRef attrKind;
- if (parser.parseKeyword(&attrKind))
- return {};
-
- if (attrKind == DataLayoutEntryAttr::kAttrKeyword)
- return DataLayoutEntryAttr::parse(parser);
- if (attrKind == DataLayoutSpecAttr::kAttrKeyword)
- return DataLayoutSpecAttr::parse(parser);
- if (attrKind == TargetSystemDescSpecAttr::kAttrKeyword)
- return TargetSystemDescSpecAttr::parse(parser);
- if (attrKind == TargetDeviceDescSpecAttr::kAttrKeyword)
- return TargetDeviceDescSpecAttr::parse(parser);
-
- parser.emitError(parser.getNameLoc(), "unknown attrribute type: ")
- << attrKind;
- return {};
-}
-
-void DLTIDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
- llvm::TypeSwitch<Attribute>(attr)
- .Case<DataLayoutEntryAttr, DataLayoutSpecAttr, TargetSystemDescSpecAttr,
- TargetDeviceDescSpecAttr>([&](auto a) { a.print(os); })
- .Default([](Attribute) { llvm_unreachable("unknown attribute kind"); });
-}
-
LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
if (attr.getName() == DLTIDialect::kDataLayoutAttrName) {
@@ -769,7 +497,7 @@ LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
if (!llvm::isa<TargetSystemDescSpecAttr>(attr.getValue())) {
return op->emitError()
<< "'" << DLTIDialect::kTargetSystemDescAttrName
- << "' is expected to be a #dlti.tsd_spec attribute";
+ << "' is expected to be a #dlti.target_system_desc_spec attribute";
}
return success();
}
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 857ccf03ed8c0..19a662b3793ba 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -315,28 +315,6 @@ mlir::detail::getL1CacheSizeInBytes(DataLayoutEntryInterface entry) {
return value.getValue().getZExtValue();
}
-// Returns the canonicalizer max iterations if specified in the given entry.
-// If the entry is empty (meaning the spec is missing), returns std::nullopt.
-std::optional<int64_t>
-mlir::detail::getCanonicalizerMaxIterations(DataLayoutEntryInterface entry) {
- if (entry == DataLayoutEntryInterface())
- return std::nullopt;
-
- auto value = cast<IntegerAttr>(entry.getValue());
- return value.getValue().getSExtValue();
-}
-
-// Returns the canonicalizer max num rewrites if specified in the given entry.
-// If the entry is empty (meaning the spec is missing), returns std::nullopt.
-std::optional<int64_t>
-mlir::detail::getCanonicalizerMaxNumRewrites(DataLayoutEntryInterface entry) {
- if (entry == DataLayoutEntryInterface())
- return std::nullopt;
-
- auto value = cast<IntegerAttr>(entry.getValue());
- return value.getValue().getSExtValue();
-}
-
DataLayoutEntryList
mlir::detail::filterEntriesForType(DataLayoutEntryListRef entries,
TypeID typeID) {
@@ -704,10 +682,11 @@ std::optional<uint32_t> mlir::DataLayout::getMaxVectorOpWidth(
TargetDeviceDescSpecInterface::DeviceID deviceID) const {
checkValid();
DataLayoutEntryInterface entry;
- if (originalTargetSystemDesc)
- entry =
- originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID)
- .getSpecForMaxVectorOpWidth(originalTargetSystemDesc.getContext());
+ if (originalTargetSystemDesc) {
+ if (auto device =
+ originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID))
+ entry = device->getSpecForMaxVectorOpWidth();
+ }
// Currently I am not caching the results because we do not return
// default values of these properties. Instead if the property is
// missing, we return std::nullopt so that the users can resort to
@@ -722,10 +701,11 @@ std::optional<uint32_t> mlir::DataLayout::getL1CacheSizeInBytes(
TargetDeviceDescSpecInterface::DeviceID deviceID) const {
checkValid();
DataLayoutEntryInterface entry;
- if (originalTargetSystemDesc)
- entry = originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID)
- .getSpecForL1CacheSizeInBytes(
- originalTargetSystemDesc.getContext());
+ if (originalTargetSystemDesc) {
+ if (auto device =
+ originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID))
+ entry = device->getSpecForL1CacheSizeInBytes();
+ }
// Currently I am not caching the results because we do not return
// default values of these properties. Instead if the property is
// missing, we return std::nullopt so that the users can resort to
@@ -736,42 +716,6 @@ std::optional<uint32_t> mlir::DataLayout::getL1CacheSizeInBytes(
return detail::getL1CacheSizeInBytes(entry);
}
-std::optional<int64_t> mlir::DataLayout::getCanonicalizerMaxIterations(
- TargetDeviceDescSpecInterface::DeviceID deviceID) const {
- checkValid();
- DataLayoutEntryInterface entry;
- if (originalTargetSystemDesc)
- entry = originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID)
- .getSpecForCanonicalizerMaxIterations(
- originalTargetSystemDesc.getContext());
- // Currently I am not caching the results because we do not return
- // default values of these properties. Instead if the property is
- // missing, we return std::nullopt so that the users can resort to
- // the default value however they want.
- if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
- return iface.getCanonicalizerMaxIterations(entry);
- else
- return detail::getCanonicalizerMaxIterations(entry);
-}
-
-std::optional<int64_t> mlir::DataLayout::getCanonicalizerMaxNumRewrites(
- TargetDeviceDescSpecInterface::DeviceID deviceID) const {
- checkValid();
- DataLayoutEntryInterface entry;
- if (originalTargetSystemDesc)
- entry = originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID)
- .getSpecForCanonicalizerMaxNumRewrites(
- originalTargetSystemDesc.getContext());
- // Currently I am not caching the results because we do not return
- // default values of these properties. Instead if the property is
- // missing, we return std::nullopt so that the users can resort to
- // the default value however they want.
- if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
- return iface.getCanonicalizerMaxNumRewrites(entry);
- else
- return detail::getCanonicalizerMaxNumRewrites(entry);
-}
-
//===----------------------------------------------------------------------===//
// DataLayoutSpecInterface
//===----------------------------------------------------------------------===//
@@ -888,7 +832,7 @@ mlir::detail::verifyTargetSystemDescSpec(TargetSystemDescSpecInterface spec,
// Check that device IDs are unique across all entries.
MLIRContext *context = tdd_spec.getContext();
- uint32_t device_id = tdd_spec.getDeviceID(context);
+ uint32_t device_id = tdd_spec.getDeviceID();
if (!device_ids.insert(device_id).second) {
return failure();
}
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 2948804b8f92a..1a66f14bc606c 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -50,20 +50,6 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
config.maxIterations = maxIterations;
config.maxNumRewrites = maxNumRewrites;
- LLVM_DEBUG(llvm::dbgs()
- << "[CostModel] Canonicalizer MaxIterations (default):"
- << config.maxIterations << "\n");
- LLVM_DEBUG(llvm::dbgs()
- << "[CostModel] Canonicalizer MaxNumRewrites (default):"
- << config.maxNumRewrites << "\n");
-
- LLVM_DEBUG(llvm::dbgs()
- << "[CostModel] Canonicalizer MaxIterations (default):"
- << config.maxIterations << "\n");
- LLVM_DEBUG(llvm::dbgs()
- << "[CostModel] Canonicalizer MaxNumRewrites (default):"
- << config.maxNumRewrites << "\n");
-
RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
@@ -75,41 +61,6 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
return success();
}
void runOnOperation() override {
- Operation *op = getOperation();
- uint32_t cpuID = 0;
-
- if (isa<ModuleOp>(op)) {
- if (std::optional<int64_t> v =
- DataLayout(llvm::dyn_cast<ModuleOp>(*op))
- .getCanonicalizerMaxIterations(cpuID)) {
- config.maxIterations = *v;
- }
- } else {
- ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
- if (std::optional<int64_t> v =
- DataLayout(moduleOp).getCanonicalizerMaxIterations(cpuID)) {
- config.maxIterations = *v;
- }
- }
-
- if (isa<ModuleOp>(op)) {
- if (std::optional<int64_t> v =
- DataLayout(llvm::dyn_cast<ModuleOp>(*op))
- .getCanonicalizerMaxNumRewrites(cpuID)) {
- config.maxNumRewrites = *v;
- }
- } else {
- ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
- if (std::optional<int64_t> v =
- DataLayout(moduleOp).getCanonicalizerMaxNumRewrites(cpuID)) {
- config.maxNumRewrites = *v;
- }
- }
- LLVM_DEBUG(llvm::dbgs() << "[CostModel] Canonicalizer MaxIterations (new):"
- << config.maxIterations << "\n");
- LLVM_DEBUG(llvm::dbgs() << "[CostModel] Canonicalizer MaxNumRewrites (new):"
- << config.maxNumRewrites << "\n");
-
LogicalResult converged =
applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
// Canonicalization is best-effort. Non-convergence is not a pass failure.
>From 9144c519e23db6fe001bb2837812ad2242dd9a39 Mon Sep 17 00:00:00 2001
From: Niranjan Hasabnis <niranjan.hasabnis at intel.com>
Date: Mon, 20 May 2024 09:03:26 -0700
Subject: [PATCH 3/7] Addressing review comments; adding unit tests
---
mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td | 4 +-
.../mlir/Interfaces/DataLayoutInterfaces.td | 6 +
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 14 +-
mlir/lib/Dialect/DLTI/DLTI.cpp | 17 ++-
.../Linalg/Transforms/BlockPackMatmul.cpp | 22 +--
mlir/lib/Transforms/Canonicalizer.cpp | 2 -
mlir/test/Dialect/DLTI/invalid.mlir | 131 ++++++++++++++++-
mlir/test/Dialect/DLTI/roundtrip.mlir | 26 ++++
mlir/test/Dialect/DLTI/valid.mlir | 101 +++++++++++++
.../Interfaces/DataLayoutInterfacesTest.cpp | 138 +++++++++++++++++-
10 files changed, 418 insertions(+), 43 deletions(-)
create mode 100644 mlir/test/Dialect/DLTI/valid.mlir
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
index b2b3b8cd31d20..7f915b7a015a6 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
@@ -172,9 +172,9 @@ def DLTI_TargetSystemDescSpecAttr :
let extraClassDefinition = [{
std::optional<TargetDeviceDescSpecInterface>
$cppClass::getDeviceDescForDeviceID(
- TargetDeviceDescSpecInterface::DeviceID DeviceID) {
+ TargetDeviceDescSpecInterface::DeviceID deviceID) {
for (TargetDeviceDescSpecInterface entry : getEntries()) {
- if (entry.getDeviceID() == DeviceID)
+ if (entry.getDeviceID() == deviceID)
return entry;
}
return std::nullopt;
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index 6547b83080804..5c431097c7e77 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -253,6 +253,12 @@ def TargetDeviceDescSpecInterface : AttrInterface<"TargetDeviceDescSpecInterface
InterfaceMethod<
/*description=*/"Returns the L1 cache size identifier.",
/*retTy=*/"::mlir::StringAttr",
+ /*methodName=*/"getL1CacheSizeInBytesIdentifier",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the max vector op width identifier.",
+ /*retTy=*/"::mlir::StringAttr",
/*methodName=*/"getMaxVectorOpWidthIdentifier",
/*args=*/(ins)
>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f3e1f83b56550..033e66c6118f3 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -15,7 +15,6 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
@@ -30,8 +29,6 @@ namespace mlir {
using namespace mlir;
using namespace mlir::amdgpu;
-#define DEBUG_TYPE "amd-gpu-to-rocdl"
-
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
Type llvmI32 = rewriter.getI32Type();
@@ -52,6 +49,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
: ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
Chipset chipset;
+ static constexpr uint32_t maxVectorOpWidth = 128;
LogicalResult
matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
@@ -113,16 +111,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
uint32_t elemBits = dataVector.getElementTypeBitWidth();
uint32_t totalBits = elemBits * dataVector.getNumElements();
- uint32_t maxVectorOpWidth = 128; // default value
- ModuleOp moduleOp = gpuOp->template getParentOfType<mlir::ModuleOp>();
- std::optional<uint32_t> v = std::nullopt;
- if (moduleOp &&
- (v = DataLayout(moduleOp).getMaxVectorOpWidth(1 /* gpu ID*/))) {
- maxVectorOpWidth = *v;
- }
- LLVM_DEBUG(llvm::dbgs() << "[CostModel] GPU MaxVectorWidth:"
- << maxVectorOpWidth << "\n");
-
if (totalBits > maxVectorOpWidth)
return gpuOp.emitOpError(
"Total width of loads or stores must be no more than " +
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 9d2865e8152ef..fe2b18b966c6b 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -302,8 +302,21 @@ TargetDeviceDescSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
if (auto value = llvm::dyn_cast<StringAttr>(entry.getValue())) {
targetDeviceTypeKeyPresentAndValid = true;
}
- } else if (entryName != DLTIDialect::kTargetDeviceMaxVectorOpWidthKey &&
- entryName != DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey) {
+ } else if (entryName == DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey) {
+ IntegerAttr value =
+ llvm::dyn_cast_if_present<IntegerAttr>(entry.getValue());
+ if (!value || !value.getType().isUnsignedInteger(32))
+ return emitError() << "target_device_desc_spec requires value of key: "
+ << DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey
+ << " to be of ui32 type";
+ } else if (entryName == DLTIDialect::kTargetDeviceMaxVectorOpWidthKey) {
+ IntegerAttr value =
+ llvm::dyn_cast_if_present<IntegerAttr>(entry.getValue());
+ if (!value || !value.getType().isUnsignedInteger(32))
+ return emitError() << "target_device_desc_spec requires value of key: "
+ << DLTIDialect::kTargetDeviceMaxVectorOpWidthKey
+ << " to be of ui32 type";
+ } else {
return emitError() << "unknown target device desc key name: "
<< entryName;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index c08224f7af54e..91d4efa3372b7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -18,8 +18,6 @@
#include <optional>
-#define DEBUG_TYPE "block-pack-matmul"
-
namespace mlir {
#define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
#include "mlir/Dialect/Linalg/Passes.h.inc"
@@ -136,24 +134,6 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
return packTransposedMatmul;
}
-static SmallVector<int64_t> getDefaultBlockFactors(linalg::LinalgOp linalgOp) {
- // get L1 cache size first.
- uint32_t L1_cache_size = 4096; // default value
- uint32_t cpuID = 0;
- ModuleOp moduleOp = linalgOp->getParentOfType<ModuleOp>();
- if (std::optional<int64_t> v =
- DataLayout(moduleOp).getL1CacheSizeInBytes(cpuID)) {
- L1_cache_size = *v;
- }
-
- // block_size = sqrt(L1_cache_size) rounded down to nearest power of 2.
- int64_t block_size =
- std::pow(2, std::floor(std::log2(std::sqrt(L1_cache_size))));
- // we use same block size for all dims.
- LLVM_DEBUG(llvm::dbgs() << "block_size:" << block_size << "\n");
- return {block_size, block_size, block_size};
-}
-
/// Pack a matmul operation into blocked 4D layout.
FailureOr<PackResult>
linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
@@ -166,7 +146,7 @@ linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
if (options->blockFactors.size() != 3)
- options->blockFactors = getDefaultBlockFactors(linalgOp);
+ return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
SmallVector<OpFoldResult> mnkTiles =
getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 1a66f14bc606c..d50019bd6aee5 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -23,8 +23,6 @@ namespace mlir {
using namespace mlir;
-#define DEBUG_TYPE "canonicalizer"
-
namespace {
/// Canonicalize operations in nested regions.
struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
diff --git a/mlir/test/Dialect/DLTI/invalid.mlir b/mlir/test/Dialect/DLTI/invalid.mlir
index 465ec72106f70..7810c659f3158 100644
--- a/mlir/test/Dialect/DLTI/invalid.mlir
+++ b/mlir/test/Dialect/DLTI/invalid.mlir
@@ -36,7 +36,7 @@
// -----
-// expected-error at below {{unknown attrribute type: unknown}}
+// expected-error at below {{unknown attribute `unknown` in dialect `dlti`}}
"test.unknown_op"() { test.unknown_attr = #dlti.unknown } : () -> ()
// -----
@@ -90,3 +90,132 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown
// expected-note at above {{enclosing op with data layout}}
"test.op_with_data_layout"() { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown", 32>>} : () -> ()
}
+
+// -----
+
+// expected-error at below {{'dlti.target_system_desc_spec' is expected to be a #dlti.target_system_desc_spec attribute}}
+"test.unknown_op"() { dlti.target_system_desc_spec = 42 } : () -> ()
+
+// -----
+
+// expected-error at below {{invalid kind of attribute specified}}
+"test.unknown_op"() { dlti.target_system_desc_spec = #dlti.target_system_desc_spec<[]> } : () -> ()
+
+// -----
+
+module attributes {
+ // expected-error at +2 {{target_device_desc_spec requires key: dlti.device_id and its value of ui32 type}}
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_type", "CPU">>
+ >} {}
+
+// -----
+
+module attributes {
+ // expected-error at +2 {{target_device_desc_spec requires key: dlti.device_type and its value of string type}}
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0: ui32>>
+ >} {}
+
+// -----
+
+module attributes {
+ // expected-error at +2 {{target_device_desc_spec requires key: dlti.device_id and its value of ui32 type}}
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0: i32>>
+ >} {}
+
+// -----
+
+module attributes {
+ // expected-error at +2 {{target_device_desc_spec requires key: dlti.device_type and its value of string type}}
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
+ #dlti.dl_entry<"dlti.device_type", 0: i32>>
+ >} {}
+
+// -----
+
+module attributes {
+ // expected-error at +2 {{repeated layout entry key: dlti.device_id}}
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
+ #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
+ #dlti.dl_entry<"dlti.device_type", "CPU">,
+ #dlti.dl_entry<"dlti.L1_cache_size", 4096 : i32>>
+ >} {}
+
+// -----
+
+module attributes {
+ // expected-error at +2 {{repeated layout entry key: dlti.device_type}}
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
+ #dlti.dl_entry<"dlti.device_type", "CPU">,
+ #dlti.dl_entry<"dlti.device_type", "GPU">,
+ #dlti.dl_entry<"dlti.L1_cache_size", 4096 : i32>>
+ >} {}
+
+// -----
+
+module attributes {
+ // expected-error at +2 {{target_device_desc_spec requires value of key: dlti.L1_cache_size_in_bytes to be of ui32 type}}
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
+ #dlti.dl_entry<"dlti.device_type", "CPU">,
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096.1 : f32>>
+ >} {}
+
+// -----
+
+module attributes {
+ // expected-error at +2 {{target_device_desc_spec requires value of key: dlti.max_vector_op_width to be of ui32 type}}
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
+ #dlti.dl_entry<"dlti.device_type", "CPU">,
+ #dlti.dl_entry<"dlti.max_vector_op_width", 4096.1 : f32>>
+ >} {}
+
+// -----
+
+module attributes {
+ // expected-error at +2 {{unknown target device desc key name: dlti.L2_cache_size_in_bytes}}
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
+ #dlti.dl_entry<"dlti.device_type", "CPU">,
+ #dlti.dl_entry<"dlti.L2_cache_size_in_bytes", 4096 : i32>>
+ >} {}
+
+// -----
+
+module attributes {
+ // expected-error at +2 {{unknown target device desc key name: dlti.unknown_key}}
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0>,
+ #dlti.dl_entry<"dlti.device_type", "CPU">,
+ #dlti.dl_entry<"dlti.unknown_key", 42>>
+ >} {}
+
+// -----
+
+module attributes {
+ // unexpected-error at below {{repeated Device ID in dlti.target_system_desc_spec: 0}}
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "CPU">>,
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "CPU">>
+ >} {}
+
\ No newline at end of file
diff --git a/mlir/test/Dialect/DLTI/roundtrip.mlir b/mlir/test/Dialect/DLTI/roundtrip.mlir
index 613dc354d895d..80330273b8de6 100644
--- a/mlir/test/Dialect/DLTI/roundtrip.mlir
+++ b/mlir/test/Dialect/DLTI/roundtrip.mlir
@@ -53,3 +53,29 @@
}) { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown", 32>> } : () -> ()
"test.maybe_terminator_op"() : () -> ()
}) { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown", 32>> } : () -> ()
+
+// A valid target system description
+// CHECK: module attributes {
+// CHECK-SAME: dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "CPU">,
+// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>,
+// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "GPU">,
+// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>>
+// CHECK-SAME: >} {
+// CHECK: }
+module attributes {
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "CPU">,
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 1: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "GPU">,
+ #dlti.dl_entry<"dlti.max_vector_op_width", 128: ui32>>
+ >} {}
+
\ No newline at end of file
diff --git a/mlir/test/Dialect/DLTI/valid.mlir b/mlir/test/Dialect/DLTI/valid.mlir
new file mode 100644
index 0000000000000..8f44e9568bc5d
--- /dev/null
+++ b/mlir/test/Dialect/DLTI/valid.mlir
@@ -0,0 +1,101 @@
+// RUN: mlir-opt -split-input-file %s | FileCheck %s
+
+// -----
+
+// CHECK: module attributes {
+// CHECK-SAME: dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "CPU">>,
+// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "GPU">>
+// CHECK-SAME: >} {
+// CHECK: }
+module attributes {
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "CPU">>,
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 1: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "GPU">>
+ >} {}
+
+// -----
+
+// CHECK: module attributes {
+// CHECK-SAME: dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "CPU">,
+// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>,
+// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "GPU">,
+// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>>
+// CHECK-SAME: >} {
+// CHECK: }
+module attributes {
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "CPU">,
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 1: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "GPU">,
+ #dlti.dl_entry<"dlti.max_vector_op_width", 128: ui32>>
+ >} {}
+
+// -----
+
+// CHECK: module attributes {
+// CHECK-SAME: dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "CPU">,
+// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>,
+// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "GPU">,
+// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192 : ui32>>
+// CHECK-SAME: >} {
+// CHECK: }
+module attributes {
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "CPU">,
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 1: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "GPU">,
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192: ui32>>
+ >} {}
+
+// -----
+
+// CHECK: module attributes {
+// CHECK-SAME: dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "CPU">,
+// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
+// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
+// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "GPU">,
+// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>>
+// CHECK-SAME: >} {
+// CHECK: }
+module attributes {
+ dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "CPU">,
+ #dlti.dl_entry<"dlti.max_vector_op_width", 64: ui32>>,
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 1: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "GPU">,
+ #dlti.dl_entry<"dlti.max_vector_op_width", 128: ui32>>
+ >} {}
diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
index 38f0ed0ed8da3..542e9753fe0b9 100644
--- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
@@ -32,6 +32,9 @@ constexpr static llvm::StringLiteral kGlobalKeyName =
constexpr static llvm::StringLiteral kStackAlignmentKeyName =
"dltest.stack_alignment";
+constexpr static llvm::StringLiteral kTargetSystemDescAttrName =
+ "dl_target_sys_desc_test.target_system_desc_spec";
+
/// Trivial array storage for the custom data layout spec attribute, just a list
/// of entries.
class DataLayoutSpecStorage : public AttributeStorage {
@@ -91,6 +94,65 @@ struct CustomDataLayoutSpec
}
};
+class TargetSystemDescSpecStorage : public AttributeStorage {
+public:
+ using KeyTy = ArrayRef<TargetDeviceDescSpecInterface>;
+
+ TargetSystemDescSpecStorage(ArrayRef<TargetDeviceDescSpecInterface> entries)
+ : entries(entries) {}
+
+ bool operator==(const KeyTy &key) const { return key == entries; }
+
+ static TargetSystemDescSpecStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ return new (allocator.allocate<TargetSystemDescSpecStorage>())
+ TargetSystemDescSpecStorage(allocator.copyInto(key));
+ }
+
+ ArrayRef<TargetDeviceDescSpecInterface> entries;
+};
+
+struct CustomTargetSystemDescSpec
+ : public Attribute::AttrBase<CustomTargetSystemDescSpec, Attribute,
+ TargetSystemDescSpecStorage,
+ TargetSystemDescSpecInterface::Trait> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CustomDataLayoutSpec)
+
+ using Base::Base;
+
+ static constexpr StringLiteral name = "test.custom_target_system_desc_spec";
+
+ static CustomTargetSystemDescSpec
+ get(MLIRContext *ctx, ArrayRef<TargetDeviceDescSpecInterface> entries) {
+ return Base::get(ctx, entries);
+ }
+ TargetDeviceDescSpecListRef getEntries() const { return getImpl()->entries; }
+ LogicalResult verifySpec(Location loc) { return success(); }
+ std::optional<TargetDeviceDescSpecInterface>
+ getDeviceDescForDeviceID(uint32_t deviceID) {
+ for (TargetDeviceDescSpecInterface entry : getEntries()) {
+ if (entry.getDeviceID() == deviceID)
+ return entry;
+ }
+ return std::nullopt;
+ }
+ StringAttr getDeviceIDIdentifier() {
+ return Builder(getContext()).getStringAttr(DLTIDialect::kTargetDeviceIDKey);
+ }
+ StringAttr getDeviceTypeIdentifier() {
+ return Builder(getContext())
+ .getStringAttr(DLTIDialect::kTargetDeviceTypeKey);
+ }
+ StringAttr getMaxVectorOpWidthIdentifier() {
+ return Builder(getContext())
+ .getStringAttr(DLTIDialect::kTargetDeviceMaxVectorOpWidthKey);
+ }
+ StringAttr getL1CacheSizeInBytesIdentifier() {
+ return Builder(getContext())
+ .getStringAttr(DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey);
+ }
+};
+
/// A type subject to data layout that exits the program if it is queried more
/// than once. Handy to check if the cache works.
struct SingleQueryType
@@ -199,7 +261,7 @@ struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
TargetSystemDescSpecInterface getTargetSystemDescSpec() {
return getOperation()->getAttrOfType<TargetSystemDescSpecInterface>(
- kAttrName);
+ kTargetSystemDescAttrName);
}
static llvm::TypeSize getTypeSizeInBits(Type type,
@@ -251,7 +313,7 @@ struct OpWith7BitByte
TargetSystemDescSpecInterface getTargetSystemDescSpec() {
return getOperation()->getAttrOfType<TargetSystemDescSpecInterface>(
- kAttrName);
+ kTargetSystemDescAttrName);
}
// Bytes are assumed to be 7-bit here.
@@ -318,6 +380,47 @@ struct DLTestDialect : Dialect {
}
};
+struct DLTargetSystemDescTestDialect : Dialect {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTargetSystemDescTestDialect)
+
+ explicit DLTargetSystemDescTestDialect(MLIRContext *ctx)
+ : Dialect(getDialectNamespace(), ctx,
+ TypeID::get<DLTargetSystemDescTestDialect>()) {
+ ctx->getOrLoadDialect<DLTIDialect>();
+ addAttributes<CustomTargetSystemDescSpec>();
+ }
+ static StringRef getDialectNamespace() { return "dl_target_sys_desc_test"; }
+
+ void printAttribute(Attribute attr,
+ DialectAsmPrinter &printer) const override {
+ printer << "target_system_desc_spec<";
+ llvm::interleaveComma(cast<CustomTargetSystemDescSpec>(attr).getEntries(),
+ printer);
+ printer << ">";
+ }
+
+ Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
+ bool ok = succeeded(parser.parseKeyword("target_system_desc_spec")) &&
+ succeeded(parser.parseLess());
+ (void)ok;
+ assert(ok);
+ if (succeeded(parser.parseOptionalGreater()))
+ return CustomTargetSystemDescSpec::get(parser.getContext(), {});
+
+ SmallVector<TargetDeviceDescSpecInterface> entries;
+ ok = succeeded(parser.parseCommaSeparatedList([&]() {
+ entries.emplace_back();
+ ok = succeeded(parser.parseAttribute(entries.back()));
+ assert(ok);
+ return success();
+ }));
+ assert(ok);
+ ok = succeeded(parser.parseGreater());
+ assert(ok);
+ return CustomTargetSystemDescSpec::get(parser.getContext(), entries);
+ }
+};
+
} // namespace
TEST(DataLayout, FallbackDefault) {
@@ -377,6 +480,9 @@ TEST(DataLayout, NullSpec) {
EXPECT_EQ(layout.getProgramMemorySpace(), Attribute());
EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute());
EXPECT_EQ(layout.getStackAlignment(), 0u);
+
+ EXPECT_EQ(layout.getL1CacheSizeInBytes(0 /* device ID*/), std::nullopt);
+ EXPECT_EQ(layout.getMaxVectorOpWidth(0 /* device ID*/), std::nullopt);
}
TEST(DataLayout, EmptySpec) {
@@ -408,6 +514,9 @@ TEST(DataLayout, EmptySpec) {
EXPECT_EQ(layout.getProgramMemorySpace(), Attribute());
EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute());
EXPECT_EQ(layout.getStackAlignment(), 0u);
+
+ EXPECT_EQ(layout.getL1CacheSizeInBytes(0 /* device ID*/), std::nullopt);
+ EXPECT_EQ(layout.getMaxVectorOpWidth(0 /* device ID*/), std::nullopt);
}
TEST(DataLayout, SpecWithEntries) {
@@ -459,6 +568,31 @@ TEST(DataLayout, SpecWithEntries) {
EXPECT_EQ(layout.getStackAlignment(), 128u);
}
+TEST(DataLayout, SpecWithTargetSystemDescEntries) {
+ const char *ir = R"MLIR(
+ module attributes { dl_target_sys_desc_test.target_system_desc_spec =
+ #dl_target_sys_desc_test.target_system_desc_spec<
+ #dlti.target_device_desc_spec<
+ #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
+ #dlti.dl_entry<"dlti.device_type", "CPU">,
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>,
+ #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>
+ >
+ > } {}
+ )MLIR";
+
+ DialectRegistry registry;
+ registry.insert<DLTIDialect, DLTargetSystemDescTestDialect>();
+ MLIRContext ctx(registry);
+
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+ DataLayout layout(*module);
+ EXPECT_EQ(layout.getL1CacheSizeInBytes(0 /* device ID*/),
+ std::optional<uint32_t>(4096));
+ EXPECT_EQ(layout.getMaxVectorOpWidth(0 /* device ID*/),
+ std::optional<uint32_t>(128));
+}
+
TEST(DataLayout, Caching) {
const char *ir = R"MLIR(
"dltest.op_with_layout"() { dltest.layout = #dltest.spec<> } : () -> ()
>From 211002337138ec31a64b603a5cc2f4251b843639 Mon Sep 17 00:00:00 2001
From: Niranjan Hasabnis <niranjan.hasabnis at intel.com>
Date: Thu, 30 May 2024 04:39:14 -0700
Subject: [PATCH 4/7] Addressing review comments
Mainly as follows:
- removing custom attribute parser/printer
- renaming SystemDescSpec to SystemSpec and TargetDeviceDescSpec to
TargetDeviceSpec
- using DeviceID as a type insted of uint32_t
- grammatical errors
---
mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td | 175 ++++++------------
mlir/include/mlir/Dialect/DLTI/DLTIBase.td | 6 +-
mlir/include/mlir/Dialect/DLTI/Traits.h | 6 +-
mlir/include/mlir/IR/BuiltinOps.td | 2 +-
.../mlir/Interfaces/DataLayoutInterfaces.h | 23 ++-
.../mlir/Interfaces/DataLayoutInterfaces.td | 50 ++---
mlir/lib/Dialect/DLTI/DLTI.cpp | 150 ++++++++-------
mlir/lib/Dialect/DLTI/Traits.cpp | 5 +-
mlir/lib/IR/BuiltinDialect.cpp | 5 +-
mlir/lib/Interfaces/DataLayoutInterfaces.cpp | 29 ++-
mlir/test/Dialect/DLTI/invalid.mlir | 85 +++++----
mlir/test/Dialect/DLTI/roundtrip.mlir | 12 +-
mlir/test/Dialect/DLTI/valid.mlir | 48 ++---
.../Interfaces/DataLayoutInterfacesTest.cpp | 81 ++++----
14 files changed, 289 insertions(+), 388 deletions(-)
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
index 7f915b7a015a6..7f45c4acd4164 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
@@ -28,7 +28,7 @@ def DataLayoutEntryTrait
def DLTI_DataLayoutEntryAttr :
DLTIAttr<"DataLayoutEntry", [DataLayoutEntryTrait]> {
let summary = [{
- An attribute to represent an entry of a data layout specification
+ An attribute to represent an entry of a data layout specification.
}];
let description = [{
A data layout entry attribute is a key-value pair where the key is a type or
@@ -38,7 +38,7 @@ def DLTI_DataLayoutEntryAttr :
let parameters = (ins
"DataLayoutEntryKey":$key, "Attribute":$value
);
- // We do not generate storage class because llvm::PointerUnion
+ // TODO: We do not generate storage class because llvm::PointerUnion
// does not work with hash_key method.
let genStorageClass = 0;
let mnemonic = "dl_entry";
@@ -62,7 +62,9 @@ def DataLayoutSpecTrait
def DLTI_DataLayoutSpecAttr :
DLTIAttr<"DataLayoutSpec", [DataLayoutSpecTrait]> {
- let summary = [{An attribute to represent a data layout specification}];
+ let summary = [{
+ An attribute to represent a data layout specification.
+ }];
let description = [{
A data layout specification is a list of entries that specify (partial) data
layout information. It is expected to be attached to operations that serve
@@ -96,84 +98,57 @@ def DLTI_DataLayoutSpecAttr :
/// Returns the stack alignment identifier.
StringAttr getStackAlignmentIdentifier(MLIRContext *context) const;
}];
- let extraClassDefinition = [{
- StringAttr
- $cppClass::getEndiannessIdentifier(MLIRContext *context) const {
- return Builder(context).getStringAttr(DLTIDialect::kDataLayoutEndiannessKey);
- }
-
- StringAttr
- $cppClass::getAllocaMemorySpaceIdentifier(MLIRContext *context) const {
- return Builder(context).getStringAttr(
- DLTIDialect::kDataLayoutAllocaMemorySpaceKey);
- }
-
- StringAttr $cppClass::getProgramMemorySpaceIdentifier(
- MLIRContext *context) const {
- return Builder(context).getStringAttr(
- DLTIDialect::kDataLayoutProgramMemorySpaceKey);
- }
-
- StringAttr
- $cppClass::getGlobalMemorySpaceIdentifier(MLIRContext *context) const {
- return Builder(context).getStringAttr(
- DLTIDialect::kDataLayoutGlobalMemorySpaceKey);
- }
-
- StringAttr
- $cppClass::getStackAlignmentIdentifier(MLIRContext *context) const {
- return Builder(context).getStringAttr(
- DLTIDialect::kDataLayoutStackAlignmentKey);
- }
- }];
}
//===----------------------------------------------------------------------===//
-// TargetSystemDescSpecAttr
+// TargetSystemSpecAttr
//===----------------------------------------------------------------------===//
-def TargetSystemDescSpecTrait
- : NativeAttrTrait<"TargetSystemDescSpecInterface::Trait"> {
+def TargetSystemSpecTrait
+ : NativeAttrTrait<"TargetSystemSpecInterface::Trait"> {
let cppNamespace = "::mlir";
}
-def DLTI_TargetSystemDescSpecAttr :
- DLTIAttr<"TargetSystemDescSpec", [TargetSystemDescSpecTrait]> {
- let summary = [{An attribute to represent target system description}];
+def DLTI_TargetSystemSpecAttr :
+ DLTIAttr<"TargetSystemSpec", [TargetSystemSpecTrait]> {
+ let summary = [{
+ An attribute to represent target system specification.
+ }];
let description = [{
- A system description specification describes the overall system
- containing multiple devices, with each device having a unique ID
- and its corresponding TargetDeviceDescSpec object.
+ A system specification describes the overall system containing
+ multiple devices, with each device having a unique ID
+ and its corresponding TargetDeviceSpec object.
Example:
- dlti.target_system_desc_spec =
- #dlti.target_device_desc_spec<
+ dlti.target_system_spec =
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0: ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">>,
- #dlti.target_device_desc_spec <
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 1: ui32>,
#dlti.dl_entry<"dlti.device_type", "GPU">,
#dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
- #dlti.target_device_desc_spec <
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 2: ui32>,
#dlti.dl_entry<"dlti.device_type", "XPU">>>
}];
let parameters = (ins
- ArrayRefParameter<"TargetDeviceDescSpecInterface", "">:$entries
+ ArrayRefParameter<"TargetDeviceSpecInterface", "">:$entries
);
- let mnemonic = "target_system_desc_spec";
+ let mnemonic = "target_system_spec";
let genVerifyDecl = 1;
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = "`<` $entries `>`";
let extraClassDeclaration = [{
- /// Return the device descriptor that matches the given device ID
- std::optional<TargetDeviceDescSpecInterface>
- getDeviceDescForDeviceID(uint32_t deviceID);
+ /// Return the device specification that matches the given device ID
+ std::optional<TargetDeviceSpecInterface>
+ getDeviceSpecForDeviceID(
+ TargetDeviceSpecInterface::DeviceID deviceID);
}];
let extraClassDefinition = [{
- std::optional<TargetDeviceDescSpecInterface>
- $cppClass::getDeviceDescForDeviceID(
- TargetDeviceDescSpecInterface::DeviceID deviceID) {
- for (TargetDeviceDescSpecInterface entry : getEntries()) {
+ std::optional<TargetDeviceSpecInterface>
+ $cppClass::getDeviceSpecForDeviceID(
+ TargetDeviceSpecInterface::DeviceID deviceID) {
+ for (TargetDeviceSpecInterface entry : getEntries()) {
if (entry.getDeviceID() == deviceID)
return entry;
}
@@ -183,35 +158,37 @@ def DLTI_TargetSystemDescSpecAttr :
}
//===----------------------------------------------------------------------===//
-// TargetDeviceDescSpecAttr
+// TargetDeviceSpecAttr
//===----------------------------------------------------------------------===//
-def TargetDeviceDescSpecTrait
- : NativeAttrTrait<"TargetDeviceDescSpecInterface::Trait"> {
+def TargetDeviceSpecTrait
+ : NativeAttrTrait<"TargetDeviceSpecInterface::Trait"> {
let cppNamespace = "::mlir";
}
-def DLTI_TargetDeviceDescSpecAttr :
- DLTIAttr<"TargetDeviceDescSpec", [TargetDeviceDescSpecTrait]> {
- let summary = [{An attribute to represent target device description}];
+def DLTI_TargetDeviceSpecAttr :
+ DLTIAttr<"TargetDeviceSpec", [TargetDeviceSpecTrait]> {
+ let summary = [{
+ An attribute to represent target device specification.
+ }];
let description = [{
- Each device description specification describes a single device and
- its hardware properties. Each device description must have a device_id
- and a device_type. In addition, the description can contain any number
- of optional hardware properties (e.g., max_vector_op_width below).
+ Each device specification describes a single device and its
+ hardware properties. Each device specification must have a device_id
+ and a device_type. In addition, the specification can contain any number
+ of optional hardware properties (e.g., max_vector_op_width below).
- Example:
- #dlti.target_device_desc_spec <
- #dlti.dl_entry<"dlti.device_id", 1: ui32>,
- #dlti.dl_entry<"dlti.device_type", "GPU">,
- #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>
- }];
+ Example:
+ #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.device_id", 1: ui32>,
+ #dlti.dl_entry<"dlti.device_type", "GPU">,
+ #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>
+ }];
let parameters = (ins
ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
);
- let mnemonic = "target_device_desc_spec";
+ let mnemonic = "target_device_spec";
let genVerifyDecl = 1;
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = "`<` $entries `>`";
let extraClassDeclaration = [{
/// Returns the device ID identifier.
StringAttr getDeviceIDIdentifier();
@@ -248,55 +225,7 @@ def DLTI_TargetDeviceDescSpecAttr :
DataLayoutEntryInterface getSpecForL1CacheSizeInBytes();
/// Return the value of device ID
- uint32_t getDeviceID();
- }];
-
- let extraClassDefinition = [{
- StringAttr
- $cppClass::getDeviceIDIdentifier() {
- return Builder(getContext()).getStringAttr(DLTIDialect::kTargetDeviceIDKey);
- }
-
- StringAttr
- $cppClass::getDeviceTypeIdentifier() {
- return Builder(getContext()).getStringAttr(DLTIDialect::kTargetDeviceTypeKey);
- }
-
- StringAttr
- $cppClass::getMaxVectorOpWidthIdentifier() {
- return Builder(getContext()).getStringAttr(
- DLTIDialect::kTargetDeviceMaxVectorOpWidthKey);
- }
-
- StringAttr $cppClass::getL1CacheSizeInBytesIdentifier() {
- return Builder(getContext()).getStringAttr(
- DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey);
- }
-
- DataLayoutEntryInterface
- $cppClass::getSpecForDeviceID() {
- return getSpecForIdentifier(getDeviceIDIdentifier());
- }
-
- DataLayoutEntryInterface
- $cppClass::getSpecForDeviceType() {
- return getSpecForIdentifier(getDeviceTypeIdentifier());
- }
-
- DataLayoutEntryInterface
- $cppClass::getSpecForMaxVectorOpWidth() {
- return getSpecForIdentifier(getMaxVectorOpWidthIdentifier());
- }
-
- DataLayoutEntryInterface
- $cppClass::getSpecForL1CacheSizeInBytes() {
- return getSpecForIdentifier(getL1CacheSizeInBytesIdentifier());
- }
-
- uint32_t $cppClass::getDeviceID() {
- DataLayoutEntryInterface entry = getSpecForDeviceID();
- return llvm::cast<IntegerAttr>(entry.getValue()).getValue().getZExtValue();
- }
+ TargetDeviceSpecInterface::DeviceID getDeviceID();
}];
}
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
index 61a3d2553d20e..c4e37db232ddb 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
@@ -29,10 +29,10 @@ def DLTI_Dialect : Dialect {
// Top level attribute name for target system description
constexpr const static ::llvm::StringLiteral
- kTargetSystemDescAttrName = "dlti.target_system_desc_spec";
+ kTargetSystemDescAttrName = "dlti.target_system_spec";
constexpr const static ::llvm::StringLiteral
- kTargetDeviceDescAttrName = "dlti.target_device_desc_spec";
+ kTargetDeviceDescAttrName = "dlti.target_device_spec";
// Constants used in entries.
constexpr const static ::llvm::StringLiteral
@@ -56,7 +56,7 @@ def DLTI_Dialect : Dialect {
constexpr const static ::llvm::StringLiteral
kDataLayoutStackAlignmentKey = "dlti.stack_alignment";
- // Constants used in target description part of DLTI
+ // Constants used in target description part of DLTI.
constexpr const static ::llvm::StringLiteral
kTargetDeviceIDKey = "dlti.device_id";
diff --git a/mlir/include/mlir/Dialect/DLTI/Traits.h b/mlir/include/mlir/Dialect/DLTI/Traits.h
index 44083d54c4cad..edfbdffbd1ba1 100644
--- a/mlir/include/mlir/Dialect/DLTI/Traits.h
+++ b/mlir/include/mlir/Dialect/DLTI/Traits.h
@@ -18,7 +18,7 @@ class DataLayoutSpecAttr;
namespace impl {
LogicalResult verifyHasDefaultDLTIDataLayoutTrait(Operation *op);
DataLayoutSpecInterface getDataLayoutSpec(Operation *op);
-TargetSystemDescSpecInterface getTargetSystemDescSpec(Operation *op);
+TargetSystemSpecInterface getTargetSystemSpec(Operation *op);
} // namespace impl
/// Trait to be used by operations willing to use the implementation of the
@@ -41,8 +41,8 @@ class HasDefaultDLTIDataLayout
/// Returns the target system description specification as provided by DLTI
/// dialect
- TargetSystemDescSpecInterface getTargetSystemDescSpec() {
- return impl::getTargetSystemDescSpec(this->getOperation());
+ TargetSystemSpecInterface getTargetSystemSpec() {
+ return impl::getTargetSystemSpec(this->getOperation());
}
};
} // namespace mlir
diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td
index bdb4ce3ddfe20..56edd7519cd67 100644
--- a/mlir/include/mlir/IR/BuiltinOps.td
+++ b/mlir/include/mlir/IR/BuiltinOps.td
@@ -78,7 +78,7 @@ def ModuleOp : Builtin_Op<"module", [
//===------------------------------------------------------------------===//
DataLayoutSpecInterface getDataLayoutSpec();
- TargetSystemDescSpecInterface getTargetSystemDescSpec();
+ TargetSystemSpecInterface getTargetSystemSpec();
//===------------------------------------------------------------------===//
// OpAsmOpInterface Methods
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
index 625ac2e9367dc..f5bf63a5b5c90 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
@@ -23,15 +23,14 @@
namespace mlir {
class DataLayout;
class DataLayoutEntryInterface;
-class TargetDeviceDescSpecInterface;
-class TargetSystemDescSpecInterface;
+class TargetDeviceSpecInterface;
+class TargetSystemSpecInterface;
using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
// Using explicit SmallVector size because we cannot infer the size from the
// forward declaration, and we need the typedef in the actual declaration.
using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
-using TargetDeviceDescSpecListRef =
- llvm::ArrayRef<TargetDeviceDescSpecInterface>;
+using TargetDeviceSpecListRef = llvm::ArrayRef<TargetDeviceSpecInterface>;
class DataLayoutOpInterface;
class DataLayoutSpecInterface;
class ModuleOp;
@@ -109,8 +108,8 @@ filterEntryForIdentifier(DataLayoutEntryListRef entries, StringAttr id);
/// Given a list of target device entries, returns the entry that has the given
/// identifier as key, if such an entry exists in the list.
-TargetDeviceDescSpecInterface
-filterEntryForIdentifier(TargetDeviceDescSpecListRef entries, StringAttr id);
+TargetDeviceSpecInterface
+filterEntryForIdentifier(TargetDeviceSpecListRef entries, StringAttr id);
/// Verifies that the operation implementing the data layout interface, or a
/// module operation, is valid. This calls the verifier of the spec attribute
@@ -126,8 +125,8 @@ LogicalResult verifyDataLayoutSpec(DataLayoutSpecInterface spec, Location loc);
/// Verifies that a target system desc spec is valid. This dispatches to
/// individual entry verifiers, and then to the verifiers implemented by the
/// relevant dialect interfaces for identifier keys.
-LogicalResult verifyTargetSystemDescSpec(TargetSystemDescSpecInterface spec,
- Location loc);
+LogicalResult verifyTargetSystemSpec(TargetSystemSpecInterface spec,
+ Location loc);
/// Divides the known min value of the numerator by the denominator and rounds
/// the result up to the next integer. Preserves the scalable flag.
@@ -162,7 +161,7 @@ class DataLayoutDialectInterface
/// Checks whether the given data layout entry is valid and reports any errors
/// at the provided location. Derived classes should override this.
- virtual LogicalResult verifyEntry(TargetDeviceDescSpecInterface entry,
+ virtual LogicalResult verifyEntry(TargetDeviceSpecInterface entry,
Location loc) const {
return success();
}
@@ -247,19 +246,19 @@ class DataLayout {
/// Returns for max vector op width if the property is defined for the given
/// device ID, otherwise return std::nullopt.
std::optional<uint32_t>
- getMaxVectorOpWidth(TargetDeviceDescSpecInterface::DeviceID) const;
+ getMaxVectorOpWidth(TargetDeviceSpecInterface::DeviceID) const;
/// Returns for L1 cache size if the property is defined for the given
/// device ID, otherwise return std::nullopt.
std::optional<uint32_t>
- getL1CacheSizeInBytes(TargetDeviceDescSpecInterface::DeviceID) const;
+ getL1CacheSizeInBytes(TargetDeviceSpecInterface::DeviceID) const;
private:
/// Combined layout spec at the given scope.
const DataLayoutSpecInterface originalLayout;
/// Combined target system desc spec at the given scope.
- const TargetSystemDescSpecInterface originalTargetSystemDesc;
+ const TargetSystemSpecInterface originalTargetSystemDesc;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// List of enclosing layout specs.
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index 5c431097c7e77..1d4ede62a337d 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -194,7 +194,7 @@ def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface"> {
}];
}
-def TargetDeviceDescSpecInterface : AttrInterface<"TargetDeviceDescSpecInterface"> {
+def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface"> {
let cppNamespace = "::mlir";
let description = [{
@@ -238,30 +238,6 @@ def TargetDeviceDescSpecInterface : AttrInterface<"TargetDeviceDescSpecInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{ return ::mlir::success(); }]
>,
- InterfaceMethod<
- /*description=*/"Returns the device ID identifier.",
- /*retTy=*/"::mlir::StringAttr",
- /*methodName=*/"getDeviceIDIdentifier",
- /*args=*/(ins)
- >,
- InterfaceMethod<
- /*description=*/"Returns the device type identifier.",
- /*retTy=*/"::mlir::StringAttr",
- /*methodName=*/"getDeviceTypeIdentifier",
- /*args=*/(ins)
- >,
- InterfaceMethod<
- /*description=*/"Returns the L1 cache size identifier.",
- /*retTy=*/"::mlir::StringAttr",
- /*methodName=*/"getL1CacheSizeInBytesIdentifier",
- /*args=*/(ins)
- >,
- InterfaceMethod<
- /*description=*/"Returns the max vector op width identifier.",
- /*retTy=*/"::mlir::StringAttr",
- /*methodName=*/"getMaxVectorOpWidthIdentifier",
- /*args=*/(ins)
- >,
InterfaceMethod<
/*description=*/"Returns the entry related to Device ID. The function"
"will crash if the entry is missing.",
@@ -270,7 +246,7 @@ def TargetDeviceDescSpecInterface : AttrInterface<"TargetDeviceDescSpecInterface
/*args=*/(ins)
>,
InterfaceMethod<
- /*description=*/"Returns the entry related to the given identifier. "
+ /*description=*/"Returns the entry related to Device Type. "
"The function will crash if the entry is missing.",
/*retTy=*/"::mlir::DataLayoutEntryInterface",
/*methodName=*/"getSpecForDeviceType",
@@ -304,7 +280,7 @@ def TargetDeviceDescSpecInterface : AttrInterface<"TargetDeviceDescSpecInterface
}];
}
-def TargetSystemDescSpecInterface : AttrInterface<"TargetSystemDescSpecInterface"> {
+def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface"> {
let cppNamespace = "::mlir";
let description = [{
@@ -324,27 +300,26 @@ def TargetSystemDescSpecInterface : AttrInterface<"TargetSystemDescSpecInterface
let methods = [
InterfaceMethod<
/*description=*/"Returns the list of layout entries.",
- /*retTy=*/"llvm::ArrayRef<::mlir::TargetDeviceDescSpecInterface>",
+ /*retTy=*/"llvm::ArrayRef<::mlir::TargetDeviceSpecInterface>",
/*methodName=*/"getEntries",
/*args=*/(ins)
>,
InterfaceMethod<
/*description=*/"Returns the device description spec for given device "
"ID",
- /*retTy=*/"std::optional<::mlir::TargetDeviceDescSpecInterface>",
- /*methodName=*/"getDeviceDescForDeviceID",
- /*args=*/(ins "int":$deviceID)
+ /*retTy=*/"std::optional<::mlir::TargetDeviceSpecInterface>",
+ /*methodName=*/"getDeviceSpecForDeviceID",
+ /*args=*/(ins "TargetDeviceSpecInterface::DeviceID":$deviceID)
>,
InterfaceMethod<
/*description=*/"Verifies the validity of the specification and "
- "reports "
- "any errors at the given location.",
+ "reports any errors at the given location.",
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"verifySpec",
/*args=*/(ins "::mlir::Location":$loc),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return ::mlir::detail::verifyTargetSystemDescSpec($_attr, loc);
+ return ::mlir::detail::verifyTargetSystemSpec($_attr, loc);
}]
>
];
@@ -385,10 +360,9 @@ def DataLayoutOpInterface : OpInterface<"DataLayoutOpInterface"> {
>,
InterfaceMethod<
/*description=*/"Returns the target system desc specification for this "
- "op, or "
- "null if it does not exist.",
- /*retTy=*/"::mlir::TargetSystemDescSpecInterface",
- /*methodName=*/"getTargetSystemDescSpec",
+ "op, or null if it does not exist.",
+ /*retTy=*/"::mlir::TargetSystemSpecInterface",
+ /*methodName=*/"getTargetSystemSpec",
/*args=*/(ins)
>,
StaticInterfaceMethod<
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index fe2b18b966c6b..23d40e7f4fd51 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -23,6 +23,9 @@ using namespace mlir;
#include "mlir/Dialect/DLTI/DLTIDialect.cpp.inc"
+#define GET_ATTRDEF_CLASSES
+#include <mlir/Dialect/DLTI/DLTIAttrs.cpp.inc>
+
#define DEBUG_TYPE "dlti"
//===----------------------------------------------------------------------===//
@@ -265,12 +268,12 @@ void DataLayoutSpecAttr::print(AsmPrinter &os) const {
}
//===----------------------------------------------------------------------===//
-// TargetDeviceDescSpecAttr
+// TargetDeviceSpecAttr
//===----------------------------------------------------------------------===//
LogicalResult
-TargetDeviceDescSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<DataLayoutEntryInterface> entries) {
+TargetDeviceSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<DataLayoutEntryInterface> entries) {
// Entries in tdd_spec can only have StringAttr as key. It does not support
// type as a key. Hence not reusing DataLayoutEntryInterface::verify.
bool targetDeviceIDKeyPresentAndValid = false;
@@ -280,7 +283,7 @@ TargetDeviceDescSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
for (DataLayoutEntryInterface entry : entries) {
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
return emitError()
- << "dlti.target_device_desc_spec does not allow type as a key: "
+ << "dlti.target_device_spec does not allow type as a key: "
<< type;
} else {
auto id = entry.getKey().get<StringAttr>();
@@ -306,18 +309,18 @@ TargetDeviceDescSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr value =
llvm::dyn_cast_if_present<IntegerAttr>(entry.getValue());
if (!value || !value.getType().isUnsignedInteger(32))
- return emitError() << "target_device_desc_spec requires value of key: "
+ return emitError() << "target_device_spec requires value of key: "
<< DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey
<< " to be of ui32 type";
} else if (entryName == DLTIDialect::kTargetDeviceMaxVectorOpWidthKey) {
IntegerAttr value =
llvm::dyn_cast_if_present<IntegerAttr>(entry.getValue());
if (!value || !value.getType().isUnsignedInteger(32))
- return emitError() << "target_device_desc_spec requires value of key: "
+ return emitError() << "target_device_spec requires value of key: "
<< DLTIDialect::kTargetDeviceMaxVectorOpWidthKey
<< " to be of ui32 type";
} else {
- return emitError() << "unknown target device desc key name: "
+ return emitError() << "unknown target device spec key name: "
<< entryName;
}
}
@@ -325,12 +328,12 @@ TargetDeviceDescSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
// check that both DeviceID and DeviceType are present
// and are of correct type.
if (!targetDeviceIDKeyPresentAndValid) {
- return emitError() << "target_device_desc_spec requires key: "
+ return emitError() << "target_device_spec requires key: "
<< DLTIDialect::kTargetDeviceIDKey
<< " and its value of ui32 type";
}
if (!targetDeviceTypeKeyPresentAndValid) {
- return emitError() << "target_device_desc_spec requires key: "
+ return emitError() << "target_device_spec requires key: "
<< DLTIDialect::kTargetDeviceTypeKey
<< " and its value of string type";
}
@@ -338,94 +341,97 @@ TargetDeviceDescSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
-/// Parses an attribute with syntax
-/// target_device_desc_spec_attr ::=
-/// `#target.` `target_device_desc_spec` `<` dl-entry-attr-list? `>`
-/// dl-entry-attr-list ::= dl-entry-attr
-/// | dl-entry-attr `,` dl-entry-attr-list
-Attribute TargetDeviceDescSpecAttr::parse(AsmParser &parser, Type type) {
- if (failed(parser.parseLess()))
- return {};
+StringAttr TargetDeviceSpecAttr::getDeviceIDIdentifier() {
+ return Builder(getContext()).getStringAttr(DLTIDialect::kTargetDeviceIDKey);
+}
- // Empty spec.
- if (succeeded(parser.parseOptionalGreater()))
- return get(parser.getContext(), {});
+StringAttr TargetDeviceSpecAttr::getDeviceTypeIdentifier() {
+ return Builder(getContext()).getStringAttr(DLTIDialect::kTargetDeviceTypeKey);
+}
- SmallVector<DataLayoutEntryInterface> entries;
- if (parser.parseCommaSeparatedList(
- [&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
- parser.parseGreater())
- return {};
+StringAttr TargetDeviceSpecAttr::getMaxVectorOpWidthIdentifier() {
+ return Builder(getContext())
+ .getStringAttr(DLTIDialect::kTargetDeviceMaxVectorOpWidthKey);
+}
- return getChecked([&] { return parser.emitError(parser.getNameLoc()); },
- parser.getContext(), entries);
+StringAttr TargetDeviceSpecAttr::getL1CacheSizeInBytesIdentifier() {
+ return Builder(getContext())
+ .getStringAttr(DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey);
}
-void TargetDeviceDescSpecAttr::print(AsmPrinter &os) const {
- os << "<";
- llvm::interleaveComma(getEntries(), os);
- os << ">";
+DataLayoutEntryInterface TargetDeviceSpecAttr::getSpecForDeviceID() {
+ return getSpecForIdentifier(getDeviceIDIdentifier());
+}
+
+DataLayoutEntryInterface TargetDeviceSpecAttr::getSpecForDeviceType() {
+ return getSpecForIdentifier(getDeviceTypeIdentifier());
+}
+
+DataLayoutEntryInterface TargetDeviceSpecAttr::getSpecForMaxVectorOpWidth() {
+ return getSpecForIdentifier(getMaxVectorOpWidthIdentifier());
+}
+
+DataLayoutEntryInterface TargetDeviceSpecAttr::getSpecForL1CacheSizeInBytes() {
+ return getSpecForIdentifier(getL1CacheSizeInBytesIdentifier());
+}
+
+TargetDeviceSpecInterface::DeviceID TargetDeviceSpecAttr::getDeviceID() {
+ DataLayoutEntryInterface entry = getSpecForDeviceID();
+ return llvm::cast<IntegerAttr>(entry.getValue()).getValue().getZExtValue();
}
//===----------------------------------------------------------------------===//
-// TargetSystemDescSpecAttr
+// TargetSystemSpecAttr
//===----------------------------------------------------------------------===//
-LogicalResult TargetSystemDescSpecAttr::verify(
- function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<TargetDeviceDescSpecInterface> entries) {
+LogicalResult
+TargetSystemSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<TargetDeviceSpecInterface> entries) {
DenseSet<uint32_t> device_ids;
- for (TargetDeviceDescSpecInterface tdd_spec : entries) {
- // First verify that a target device desc spec is valid.
- if (failed(
- TargetDeviceDescSpecAttr::verify(emitError, tdd_spec.getEntries())))
+ for (TargetDeviceSpecInterface tdd_spec : entries) {
+ // First verify that a target device spec is valid.
+ if (failed(TargetDeviceSpecAttr::verify(emitError, tdd_spec.getEntries())))
return failure();
// Check that device IDs are unique across all entries.
uint32_t device_id = tdd_spec.getDeviceID();
if (!device_ids.insert(device_id).second) {
- return emitError()
- << "repeated Device ID in dlti.target_system_desc_spec: "
- << device_id;
+ return emitError() << "repeated Device ID in dlti.target_system_spec: "
+ << device_id;
}
}
return success();
}
-/// Parses an attribute with syntax
-/// attr ::=
-/// `#target.` `target_system_desc_spec` `<`
-/// target-device-desc-spec-attr-list? `>`
-/// target-device-desc-spec-attr-list ::= target_device_desc_spec
-/// | target_device_desc_spec `,`
-/// target-device-desc-spec-attr-list
-Attribute TargetSystemDescSpecAttr::parse(AsmParser &parser, Type type) {
- if (failed(parser.parseLess()))
- return {};
-
- // Empty spec.
- if (succeeded(parser.parseOptionalGreater()))
- return get(parser.getContext(), {});
+StringAttr
+DataLayoutSpecAttr::getEndiannessIdentifier(MLIRContext *context) const {
+ return Builder(context).getStringAttr(DLTIDialect::kDataLayoutEndiannessKey);
+}
- SmallVector<TargetDeviceDescSpecInterface> entries;
- if (parser.parseCommaSeparatedList(
- [&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
- parser.parseGreater())
- return {};
+StringAttr
+DataLayoutSpecAttr::getAllocaMemorySpaceIdentifier(MLIRContext *context) const {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kDataLayoutAllocaMemorySpaceKey);
+}
- return getChecked([&] { return parser.emitError(parser.getNameLoc()); },
- parser.getContext(), entries);
+StringAttr DataLayoutSpecAttr::getProgramMemorySpaceIdentifier(
+ MLIRContext *context) const {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kDataLayoutProgramMemorySpaceKey);
}
-void TargetSystemDescSpecAttr::print(AsmPrinter &os) const {
- os << "<";
- llvm::interleaveComma(getEntries(), os);
- os << ">";
+StringAttr
+DataLayoutSpecAttr::getGlobalMemorySpaceIdentifier(MLIRContext *context) const {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kDataLayoutGlobalMemorySpaceKey);
}
-#define GET_ATTRDEF_CLASSES
-#include <mlir/Dialect/DLTI/DLTIAttrs.cpp.inc>
+StringAttr
+DataLayoutSpecAttr::getStackAlignmentIdentifier(MLIRContext *context) const {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kDataLayoutStackAlignmentKey);
+}
//===----------------------------------------------------------------------===//
// DLTIDialect
@@ -470,7 +476,7 @@ class SystemDescSpecInterface : public DataLayoutDialectInterface {
public:
using DataLayoutDialectInterface::DataLayoutDialectInterface;
- LogicalResult verifyEntry(TargetDeviceDescSpecInterface entry,
+ LogicalResult verifyEntry(TargetDeviceSpecInterface entry,
Location loc) const final {
for (DataLayoutEntryInterface dl_entry : entry.getEntries()) {
@@ -507,10 +513,10 @@ LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
return detail::verifyDataLayoutOp(op);
return success();
} else if (attr.getName() == DLTIDialect::kTargetSystemDescAttrName) {
- if (!llvm::isa<TargetSystemDescSpecAttr>(attr.getValue())) {
+ if (!llvm::isa<TargetSystemSpecAttr>(attr.getValue())) {
return op->emitError()
<< "'" << DLTIDialect::kTargetSystemDescAttrName
- << "' is expected to be a #dlti.target_system_desc_spec attribute";
+ << "' is expected to be a #dlti.target_system_spec attribute";
}
return success();
}
diff --git a/mlir/lib/Dialect/DLTI/Traits.cpp b/mlir/lib/Dialect/DLTI/Traits.cpp
index ead656774a27c..34f2dd5896083 100644
--- a/mlir/lib/Dialect/DLTI/Traits.cpp
+++ b/mlir/lib/Dialect/DLTI/Traits.cpp
@@ -28,8 +28,7 @@ DataLayoutSpecInterface mlir::impl::getDataLayoutSpec(Operation *op) {
DLTIDialect::kDataLayoutAttrName);
}
-TargetSystemDescSpecInterface
-mlir::impl::getTargetSystemDescSpec(Operation *op) {
- return op->getAttrOfType<TargetSystemDescSpecAttr>(
+TargetSystemSpecInterface mlir::impl::getTargetSystemSpec(Operation *op) {
+ return op->getAttrOfType<TargetSystemSpecAttr>(
DLTIDialect::kTargetSystemDescAttrName);
}
diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index 1d57e0bdef187..99796c5f1c371 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -155,13 +155,12 @@ DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() {
return {};
}
-TargetSystemDescSpecInterface ModuleOp::getTargetSystemDescSpec() {
+TargetSystemSpecInterface ModuleOp::getTargetSystemSpec() {
// Take the first and only (if present) attribute that implements the
// interface. This needs a linear search, but is called only once per data
// layout object construction that is used for repeated queries.
for (NamedAttribute attr : getOperation()->getAttrs())
- if (auto spec =
- llvm::dyn_cast<TargetSystemDescSpecInterface>(attr.getValue()))
+ if (auto spec = llvm::dyn_cast<TargetSystemSpecInterface>(attr.getValue()))
return spec;
return {};
}
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 19a662b3793ba..3456b804b9800 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -346,8 +346,7 @@ static DataLayoutSpecInterface getSpec(Operation *operation) {
});
}
-static TargetSystemDescSpecInterface
-getTargetSystemDescSpec(Operation *operation) {
+static TargetSystemSpecInterface getTargetSystemSpec(Operation *operation) {
if (operation) {
ModuleOp moduleOp;
if (isa<ModuleOp>(operation)) {
@@ -355,9 +354,9 @@ getTargetSystemDescSpec(Operation *operation) {
} else {
moduleOp = operation->getParentOfType<ModuleOp>();
}
- return moduleOp.getTargetSystemDescSpec();
+ return moduleOp.getTargetSystemSpec();
} else
- return TargetSystemDescSpecInterface();
+ return TargetSystemSpecInterface();
}
/// Populates `opsWithLayout` with the list of proper ancestors of `leaf` that
@@ -472,7 +471,7 @@ mlir::DataLayout::DataLayout(DataLayoutOpInterface op)
: originalLayout(getCombinedDataLayout(op)), scope(op),
allocaMemorySpace(std::nullopt), programMemorySpace(std::nullopt),
globalMemorySpace(std::nullopt), stackAlignment(std::nullopt),
- originalTargetSystemDesc(getTargetSystemDescSpec(op)) {
+ originalTargetSystemDesc(getTargetSystemSpec(op)) {
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
checkMissingLayout(originalLayout, op);
collectParentLayouts(op, layoutStack);
@@ -483,7 +482,7 @@ mlir::DataLayout::DataLayout(ModuleOp op)
: originalLayout(getCombinedDataLayout(op)), scope(op),
allocaMemorySpace(std::nullopt), programMemorySpace(std::nullopt),
globalMemorySpace(std::nullopt), stackAlignment(std::nullopt),
- originalTargetSystemDesc(getTargetSystemDescSpec(op)) {
+ originalTargetSystemDesc(getTargetSystemSpec(op)) {
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
checkMissingLayout(originalLayout, op);
collectParentLayouts(op, layoutStack);
@@ -679,12 +678,12 @@ uint64_t mlir::DataLayout::getStackAlignment() const {
}
std::optional<uint32_t> mlir::DataLayout::getMaxVectorOpWidth(
- TargetDeviceDescSpecInterface::DeviceID deviceID) const {
+ TargetDeviceSpecInterface::DeviceID deviceID) const {
checkValid();
DataLayoutEntryInterface entry;
if (originalTargetSystemDesc) {
if (auto device =
- originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID))
+ originalTargetSystemDesc.getDeviceSpecForDeviceID(deviceID))
entry = device->getSpecForMaxVectorOpWidth();
}
// Currently I am not caching the results because we do not return
@@ -698,12 +697,12 @@ std::optional<uint32_t> mlir::DataLayout::getMaxVectorOpWidth(
}
std::optional<uint32_t> mlir::DataLayout::getL1CacheSizeInBytes(
- TargetDeviceDescSpecInterface::DeviceID deviceID) const {
+ TargetDeviceSpecInterface::DeviceID deviceID) const {
checkValid();
DataLayoutEntryInterface entry;
if (originalTargetSystemDesc) {
if (auto device =
- originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID))
+ originalTargetSystemDesc.getDeviceSpecForDeviceID(deviceID))
entry = device->getSpecForL1CacheSizeInBytes();
}
// Currently I am not caching the results because we do not return
@@ -821,18 +820,18 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
}
LogicalResult
-mlir::detail::verifyTargetSystemDescSpec(TargetSystemDescSpecInterface spec,
- Location loc) {
+mlir::detail::verifyTargetSystemSpec(TargetSystemSpecInterface spec,
+ Location loc) {
DenseMap<StringAttr, DataLayoutEntryInterface> device_desc_keys;
- DenseSet<uint32_t> device_ids;
- for (TargetDeviceDescSpecInterface tdd_spec : spec.getEntries()) {
+ DenseSet<TargetDeviceSpecInterface::DeviceID> device_ids;
+ for (TargetDeviceSpecInterface tdd_spec : spec.getEntries()) {
// First, verify individual target device desc specs.
if (failed(tdd_spec.verifyEntry(loc)))
return failure();
// Check that device IDs are unique across all entries.
MLIRContext *context = tdd_spec.getContext();
- uint32_t device_id = tdd_spec.getDeviceID();
+ TargetDeviceSpecInterface::DeviceID device_id = tdd_spec.getDeviceID();
if (!device_ids.insert(device_id).second) {
return failure();
}
diff --git a/mlir/test/Dialect/DLTI/invalid.mlir b/mlir/test/Dialect/DLTI/invalid.mlir
index 7810c659f3158..28be8d37a50bf 100644
--- a/mlir/test/Dialect/DLTI/invalid.mlir
+++ b/mlir/test/Dialect/DLTI/invalid.mlir
@@ -93,47 +93,52 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown
// -----
-// expected-error at below {{'dlti.target_system_desc_spec' is expected to be a #dlti.target_system_desc_spec attribute}}
-"test.unknown_op"() { dlti.target_system_desc_spec = 42 } : () -> ()
+// expected-error at below {{'dlti.target_system_spec' is expected to be a #dlti.target_system_spec attribute}}
+"test.unknown_op"() { dlti.target_system_spec = 42 } : () -> ()
// -----
// expected-error at below {{invalid kind of attribute specified}}
-"test.unknown_op"() { dlti.target_system_desc_spec = #dlti.target_system_desc_spec<[]> } : () -> ()
+// expected-error at below {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+"test.unknown_op"() { dlti.target_system_spec = #dlti.target_system_spec<[]> } : () -> ()
// -----
module attributes {
- // expected-error at +2 {{target_device_desc_spec requires key: dlti.device_id and its value of ui32 type}}
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ // expected-error at +3 {{target_device_spec requires key: dlti.device_id and its value of ui32 type}}
+ // expected-error at +4 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_type", "CPU">>
>} {}
// -----
module attributes {
- // expected-error at +2 {{target_device_desc_spec requires key: dlti.device_type and its value of string type}}
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ // expected-error at +3 {{target_device_spec requires key: dlti.device_type and its value of string type}}
+ // expected-error at +4 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0: ui32>>
>} {}
// -----
module attributes {
- // expected-error at +2 {{target_device_desc_spec requires key: dlti.device_id and its value of ui32 type}}
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ // expected-error at +3 {{target_device_spec requires key: dlti.device_id and its value of ui32 type}}
+ // expected-error at +4 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0: i32>>
>} {}
// -----
module attributes {
- // expected-error at +2 {{target_device_desc_spec requires key: dlti.device_type and its value of string type}}
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ // expected-error at +3 {{target_device_spec requires key: dlti.device_type and its value of string type}}
+ // expected-error at +5 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0 : ui32>,
#dlti.dl_entry<"dlti.device_type", 0: i32>>
>} {}
@@ -141,9 +146,10 @@ module attributes {
// -----
module attributes {
- // expected-error at +2 {{repeated layout entry key: dlti.device_id}}
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ // expected-error at +3 {{repeated layout entry key: dlti.device_id}}
+ // expected-error at +7 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0 : ui32>,
#dlti.dl_entry<"dlti.device_id", 1 : ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">,
@@ -153,9 +159,10 @@ module attributes {
// -----
module attributes {
- // expected-error at +2 {{repeated layout entry key: dlti.device_type}}
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ // expected-error at +3 {{repeated layout entry key: dlti.device_type}}
+ // expected-error at +7 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0 : ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">,
#dlti.dl_entry<"dlti.device_type", "GPU">,
@@ -165,9 +172,10 @@ module attributes {
// -----
module attributes {
- // expected-error at +2 {{target_device_desc_spec requires value of key: dlti.L1_cache_size_in_bytes to be of ui32 type}}
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ // expected-error at +3 {{target_device_spec requires value of key: dlti.L1_cache_size_in_bytes to be of ui32 type}}
+ // expected-error at +6 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0 : ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">,
#dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096.1 : f32>>
@@ -176,9 +184,10 @@ module attributes {
// -----
module attributes {
- // expected-error at +2 {{target_device_desc_spec requires value of key: dlti.max_vector_op_width to be of ui32 type}}
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ // expected-error at +3 {{target_device_spec requires value of key: dlti.max_vector_op_width to be of ui32 type}}
+ // expected-error at +6 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0 : ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">,
#dlti.dl_entry<"dlti.max_vector_op_width", 4096.1 : f32>>
@@ -187,9 +196,10 @@ module attributes {
// -----
module attributes {
- // expected-error at +2 {{unknown target device desc key name: dlti.L2_cache_size_in_bytes}}
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ // expected-error at +3 {{unknown target device spec key name: dlti.L2_cache_size_in_bytes}}
+ // expected-error at +6 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0 : ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">,
#dlti.dl_entry<"dlti.L2_cache_size_in_bytes", 4096 : i32>>
@@ -198,9 +208,10 @@ module attributes {
// -----
module attributes {
- // expected-error at +2 {{unknown target device desc key name: dlti.unknown_key}}
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ // expected-error at +3 {{unknown target device spec key name: dlti.unknown_key}}
+ // expected-error at +6 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0>,
#dlti.dl_entry<"dlti.device_type", "CPU">,
#dlti.dl_entry<"dlti.unknown_key", 42>>
@@ -209,12 +220,12 @@ module attributes {
// -----
module attributes {
- // unexpected-error at below {{repeated Device ID in dlti.target_system_desc_spec: 0}}
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ // unexpected-error at below {{repeated Device ID in dlti.target_system_spec: 0}}
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0: ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">>,
- #dlti.target_device_desc_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0: ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">>
>} {}
diff --git a/mlir/test/Dialect/DLTI/roundtrip.mlir b/mlir/test/Dialect/DLTI/roundtrip.mlir
index 80330273b8de6..ccd80cda6f75f 100644
--- a/mlir/test/Dialect/DLTI/roundtrip.mlir
+++ b/mlir/test/Dialect/DLTI/roundtrip.mlir
@@ -56,24 +56,24 @@
// A valid target system description
// CHECK: module attributes {
-// CHECK-SAME: dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
-// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
+// CHECK-SAME: #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "CPU">,
// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>,
-// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "GPU">,
// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>>
// CHECK-SAME: >} {
// CHECK: }
module attributes {
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0: ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">,
#dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
- #dlti.target_device_desc_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 1: ui32>,
#dlti.dl_entry<"dlti.device_type", "GPU">,
#dlti.dl_entry<"dlti.max_vector_op_width", 128: ui32>>
diff --git a/mlir/test/Dialect/DLTI/valid.mlir b/mlir/test/Dialect/DLTI/valid.mlir
index 8f44e9568bc5d..dff1434a26f48 100644
--- a/mlir/test/Dialect/DLTI/valid.mlir
+++ b/mlir/test/Dialect/DLTI/valid.mlir
@@ -3,21 +3,21 @@
// -----
// CHECK: module attributes {
-// CHECK-SAME: dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
-// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
+// CHECK-SAME: #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "CPU">>,
-// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "GPU">>
// CHECK-SAME: >} {
// CHECK: }
module attributes {
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0: ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">>,
- #dlti.target_device_desc_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 1: ui32>,
#dlti.dl_entry<"dlti.device_type", "GPU">>
>} {}
@@ -25,24 +25,24 @@ module attributes {
// -----
// CHECK: module attributes {
-// CHECK-SAME: dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
-// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
+// CHECK-SAME: #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "CPU">,
// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>,
-// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "GPU">,
// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>>
// CHECK-SAME: >} {
// CHECK: }
module attributes {
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0: ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">,
#dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
- #dlti.target_device_desc_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 1: ui32>,
#dlti.dl_entry<"dlti.device_type", "GPU">,
#dlti.dl_entry<"dlti.max_vector_op_width", 128: ui32>>
@@ -51,24 +51,24 @@ module attributes {
// -----
// CHECK: module attributes {
-// CHECK-SAME: dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
-// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
+// CHECK-SAME: #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "CPU">,
// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>,
-// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "GPU">,
// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192 : ui32>>
// CHECK-SAME: >} {
// CHECK: }
module attributes {
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0: ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">,
#dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
- #dlti.target_device_desc_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 1: ui32>,
#dlti.dl_entry<"dlti.device_type", "GPU">,
#dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192: ui32>>
@@ -77,24 +77,24 @@ module attributes {
// -----
// CHECK: module attributes {
-// CHECK-SAME: dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
-// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
+// CHECK-SAME: #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "CPU">,
// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
-// CHECK-SAME: #dlti.target_device_desc_spec<
+// CHECK-SAME: #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "GPU">,
// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>>
// CHECK-SAME: >} {
// CHECK: }
module attributes {
- dlti.target_system_desc_spec = #dlti.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ dlti.target_system_spec = #dlti.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0: ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">,
#dlti.dl_entry<"dlti.max_vector_op_width", 64: ui32>>,
- #dlti.target_device_desc_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 1: ui32>,
#dlti.dl_entry<"dlti.device_type", "GPU">,
#dlti.dl_entry<"dlti.max_vector_op_width", 128: ui32>>
diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
index 542e9753fe0b9..a0ba54f32a9e3 100644
--- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
@@ -33,7 +33,7 @@ constexpr static llvm::StringLiteral kStackAlignmentKeyName =
"dltest.stack_alignment";
constexpr static llvm::StringLiteral kTargetSystemDescAttrName =
- "dl_target_sys_desc_test.target_system_desc_spec";
+ "dl_target_sys_desc_test.target_system_spec";
/// Trivial array storage for the custom data layout spec attribute, just a list
/// of entries.
@@ -94,63 +94,48 @@ struct CustomDataLayoutSpec
}
};
-class TargetSystemDescSpecStorage : public AttributeStorage {
+class TargetSystemSpecStorage : public AttributeStorage {
public:
- using KeyTy = ArrayRef<TargetDeviceDescSpecInterface>;
+ using KeyTy = ArrayRef<TargetDeviceSpecInterface>;
- TargetSystemDescSpecStorage(ArrayRef<TargetDeviceDescSpecInterface> entries)
+ TargetSystemSpecStorage(ArrayRef<TargetDeviceSpecInterface> entries)
: entries(entries) {}
bool operator==(const KeyTy &key) const { return key == entries; }
- static TargetSystemDescSpecStorage *
+ static TargetSystemSpecStorage *
construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
- return new (allocator.allocate<TargetSystemDescSpecStorage>())
- TargetSystemDescSpecStorage(allocator.copyInto(key));
+ return new (allocator.allocate<TargetSystemSpecStorage>())
+ TargetSystemSpecStorage(allocator.copyInto(key));
}
- ArrayRef<TargetDeviceDescSpecInterface> entries;
+ ArrayRef<TargetDeviceSpecInterface> entries;
};
-struct CustomTargetSystemDescSpec
- : public Attribute::AttrBase<CustomTargetSystemDescSpec, Attribute,
- TargetSystemDescSpecStorage,
- TargetSystemDescSpecInterface::Trait> {
+struct CustomTargetSystemSpec
+ : public Attribute::AttrBase<CustomTargetSystemSpec, Attribute,
+ TargetSystemSpecStorage,
+ TargetSystemSpecInterface::Trait> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CustomDataLayoutSpec)
using Base::Base;
- static constexpr StringLiteral name = "test.custom_target_system_desc_spec";
+ static constexpr StringLiteral name = "test.custom_target_system_spec";
- static CustomTargetSystemDescSpec
- get(MLIRContext *ctx, ArrayRef<TargetDeviceDescSpecInterface> entries) {
+ static CustomTargetSystemSpec
+ get(MLIRContext *ctx, ArrayRef<TargetDeviceSpecInterface> entries) {
return Base::get(ctx, entries);
}
- TargetDeviceDescSpecListRef getEntries() const { return getImpl()->entries; }
+ TargetDeviceSpecListRef getEntries() const { return getImpl()->entries; }
LogicalResult verifySpec(Location loc) { return success(); }
- std::optional<TargetDeviceDescSpecInterface>
- getDeviceDescForDeviceID(uint32_t deviceID) {
- for (TargetDeviceDescSpecInterface entry : getEntries()) {
+ std::optional<TargetDeviceSpecInterface>
+ getDeviceSpecForDeviceID(TargetDeviceSpecInterface::DeviceID deviceID) {
+ for (TargetDeviceSpecInterface entry : getEntries()) {
if (entry.getDeviceID() == deviceID)
return entry;
}
return std::nullopt;
}
- StringAttr getDeviceIDIdentifier() {
- return Builder(getContext()).getStringAttr(DLTIDialect::kTargetDeviceIDKey);
- }
- StringAttr getDeviceTypeIdentifier() {
- return Builder(getContext())
- .getStringAttr(DLTIDialect::kTargetDeviceTypeKey);
- }
- StringAttr getMaxVectorOpWidthIdentifier() {
- return Builder(getContext())
- .getStringAttr(DLTIDialect::kTargetDeviceMaxVectorOpWidthKey);
- }
- StringAttr getL1CacheSizeInBytesIdentifier() {
- return Builder(getContext())
- .getStringAttr(DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey);
- }
};
/// A type subject to data layout that exits the program if it is queried more
@@ -259,8 +244,8 @@ struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
}
- TargetSystemDescSpecInterface getTargetSystemDescSpec() {
- return getOperation()->getAttrOfType<TargetSystemDescSpecInterface>(
+ TargetSystemSpecInterface getTargetSystemSpec() {
+ return getOperation()->getAttrOfType<TargetSystemSpecInterface>(
kTargetSystemDescAttrName);
}
@@ -311,8 +296,8 @@ struct OpWith7BitByte
return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
}
- TargetSystemDescSpecInterface getTargetSystemDescSpec() {
- return getOperation()->getAttrOfType<TargetSystemDescSpecInterface>(
+ TargetSystemSpecInterface getTargetSystemSpec() {
+ return getOperation()->getAttrOfType<TargetSystemSpecInterface>(
kTargetSystemDescAttrName);
}
@@ -387,27 +372,27 @@ struct DLTargetSystemDescTestDialect : Dialect {
: Dialect(getDialectNamespace(), ctx,
TypeID::get<DLTargetSystemDescTestDialect>()) {
ctx->getOrLoadDialect<DLTIDialect>();
- addAttributes<CustomTargetSystemDescSpec>();
+ addAttributes<CustomTargetSystemSpec>();
}
static StringRef getDialectNamespace() { return "dl_target_sys_desc_test"; }
void printAttribute(Attribute attr,
DialectAsmPrinter &printer) const override {
- printer << "target_system_desc_spec<";
- llvm::interleaveComma(cast<CustomTargetSystemDescSpec>(attr).getEntries(),
+ printer << "target_system_spec<";
+ llvm::interleaveComma(cast<CustomTargetSystemSpec>(attr).getEntries(),
printer);
printer << ">";
}
Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
- bool ok = succeeded(parser.parseKeyword("target_system_desc_spec")) &&
+ bool ok = succeeded(parser.parseKeyword("target_system_spec")) &&
succeeded(parser.parseLess());
(void)ok;
assert(ok);
if (succeeded(parser.parseOptionalGreater()))
- return CustomTargetSystemDescSpec::get(parser.getContext(), {});
+ return CustomTargetSystemSpec::get(parser.getContext(), {});
- SmallVector<TargetDeviceDescSpecInterface> entries;
+ SmallVector<TargetDeviceSpecInterface> entries;
ok = succeeded(parser.parseCommaSeparatedList([&]() {
entries.emplace_back();
ok = succeeded(parser.parseAttribute(entries.back()));
@@ -417,7 +402,7 @@ struct DLTargetSystemDescTestDialect : Dialect {
assert(ok);
ok = succeeded(parser.parseGreater());
assert(ok);
- return CustomTargetSystemDescSpec::get(parser.getContext(), entries);
+ return CustomTargetSystemSpec::get(parser.getContext(), entries);
}
};
@@ -570,9 +555,9 @@ TEST(DataLayout, SpecWithEntries) {
TEST(DataLayout, SpecWithTargetSystemDescEntries) {
const char *ir = R"MLIR(
- module attributes { dl_target_sys_desc_test.target_system_desc_spec =
- #dl_target_sys_desc_test.target_system_desc_spec<
- #dlti.target_device_desc_spec<
+ module attributes { dl_target_sys_desc_test.target_system_spec =
+ #dl_target_sys_desc_test.target_system_spec<
+ #dlti.target_device_spec<
#dlti.dl_entry<"dlti.device_id", 0 : ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">,
#dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>,
>From 187e0c24f2ac7ae29b9db36e634ab54d14bb8a16 Mon Sep 17 00:00:00 2001
From: Niranjan Hasabnis <niranjan.hasabnis at intel.com>
Date: Mon, 3 Jun 2024 23:15:59 -0700
Subject: [PATCH 5/7] Addressing review comments
Representing TargetSystemSpec as a set of key-value pairs where key is
the DeviceID (string) and the value is TargetDeviceSpec.
---
mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td | 57 ++----
mlir/include/mlir/Dialect/DLTI/DLTIBase.td | 6 -
.../mlir/Interfaces/DataLayoutInterfaces.h | 8 +-
.../mlir/Interfaces/DataLayoutInterfaces.td | 35 +---
mlir/lib/Dialect/DLTI/DLTI.cpp | 162 ++++++++----------
mlir/lib/Interfaces/DataLayoutInterfaces.cpp | 11 +-
mlir/test/Dialect/DLTI/invalid.mlir | 141 ++++++---------
mlir/test/Dialect/DLTI/roundtrip.mlir | 20 +--
mlir/test/Dialect/DLTI/valid.mlir | 70 ++------
.../Interfaces/DataLayoutInterfacesTest.cpp | 89 +++++++---
10 files changed, 243 insertions(+), 356 deletions(-)
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
index 7f45c4acd4164..0f55558dcca27 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
@@ -116,24 +116,21 @@ def DLTI_TargetSystemSpecAttr :
}];
let description = [{
A system specification describes the overall system containing
- multiple devices, with each device having a unique ID
+ multiple devices, with each device having a unique ID (string)
and its corresponding TargetDeviceSpec object.
Example:
dlti.target_system_spec =
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0: ui32>,
- #dlti.dl_entry<"dlti.device_type", "CPU">>,
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 1: ui32>,
- #dlti.dl_entry<"dlti.device_type", "GPU">,
- #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 2: ui32>,
- #dlti.dl_entry<"dlti.device_type", "XPU">>>
+ #dlti.target_system_spec<
+ "CPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
+ "GPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
+ "XPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.max_vector_op_width", 4096 : ui32>>>
}];
let parameters = (ins
- ArrayRefParameter<"TargetDeviceSpecInterface", "">:$entries
+ ArrayRefParameter<"DeviceIDTargetDeviceSpecPair", "">:$entries
);
let mnemonic = "target_system_spec";
let genVerifyDecl = 1;
@@ -142,15 +139,15 @@ def DLTI_TargetSystemSpecAttr :
/// Return the device specification that matches the given device ID
std::optional<TargetDeviceSpecInterface>
getDeviceSpecForDeviceID(
- TargetDeviceSpecInterface::DeviceID deviceID);
+ TargetSystemSpecInterface::DeviceID deviceID);
}];
let extraClassDefinition = [{
std::optional<TargetDeviceSpecInterface>
$cppClass::getDeviceSpecForDeviceID(
- TargetDeviceSpecInterface::DeviceID deviceID) {
- for (TargetDeviceSpecInterface entry : getEntries()) {
- if (entry.getDeviceID() == deviceID)
- return entry;
+ TargetSystemSpecInterface::DeviceID deviceID) {
+ for (const auto& entry : getEntries()) {
+ if (entry.first == deviceID)
+ return entry.second;
}
return std::nullopt;
}
@@ -173,15 +170,12 @@ def DLTI_TargetDeviceSpecAttr :
}];
let description = [{
Each device specification describes a single device and its
- hardware properties. Each device specification must have a device_id
- and a device_type. In addition, the specification can contain any number
+ hardware properties. Each device specification can contain any number
of optional hardware properties (e.g., max_vector_op_width below).
Example:
#dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 1: ui32>,
- #dlti.dl_entry<"dlti.device_type", "GPU">,
- #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>
+ #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>
}];
let parameters = (ins
ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
@@ -190,28 +184,12 @@ def DLTI_TargetDeviceSpecAttr :
let genVerifyDecl = 1;
let assemblyFormat = "`<` $entries `>`";
let extraClassDeclaration = [{
- /// Returns the device ID identifier.
- StringAttr getDeviceIDIdentifier();
-
- /// Returns the device type identifier.
- StringAttr getDeviceTypeIdentifier();
-
/// Returns max vector op width identifier.
StringAttr getMaxVectorOpWidthIdentifier();
/// Returns L1 cache size identifier
StringAttr getL1CacheSizeInBytesIdentifier();
- /// Returns the interface spec for device ID
- /// Since we verify that the spec contains device ID the function
- /// will return a valid spec.
- DataLayoutEntryInterface getSpecForDeviceID();
-
- /// Returns the interface spec for device type
- /// Since we verify that the spec contains device type the function
- /// will return a valid spec.
- DataLayoutEntryInterface getSpecForDeviceType();
-
/// Returns the interface spec for max vector op width
/// Since max vector op width is an optional property, this function will
/// return a valid spec if the property is defined, otherwise it
@@ -223,9 +201,6 @@ def DLTI_TargetDeviceSpecAttr :
/// return a valid spec if the property is defined, otherwise it
/// will return an empty spec.
DataLayoutEntryInterface getSpecForL1CacheSizeInBytes();
-
- /// Return the value of device ID
- TargetDeviceSpecInterface::DeviceID getDeviceID();
}];
}
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
index c4e37db232ddb..bf4f9f68c3f02 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
@@ -57,12 +57,6 @@ def DLTI_Dialect : Dialect {
kDataLayoutStackAlignmentKey = "dlti.stack_alignment";
// Constants used in target description part of DLTI.
- constexpr const static ::llvm::StringLiteral
- kTargetDeviceIDKey = "dlti.device_id";
-
- constexpr const static ::llvm::StringLiteral
- kTargetDeviceTypeKey = "dlti.device_type";
-
constexpr const static ::llvm::StringLiteral
kTargetDeviceMaxVectorOpWidthKey = "dlti.max_vector_op_width";
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
index f5bf63a5b5c90..7c4de95806a76 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
@@ -31,6 +31,10 @@ using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
using TargetDeviceSpecListRef = llvm::ArrayRef<TargetDeviceSpecInterface>;
+using DeviceIDTargetDeviceSpecPair =
+ std::pair<StringAttr, TargetDeviceSpecInterface>;
+using DeviceIDTargetDeviceSpecPairListRef =
+ llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>;
class DataLayoutOpInterface;
class DataLayoutSpecInterface;
class ModuleOp;
@@ -246,12 +250,12 @@ class DataLayout {
/// Returns for max vector op width if the property is defined for the given
/// device ID, otherwise return std::nullopt.
std::optional<uint32_t>
- getMaxVectorOpWidth(TargetDeviceSpecInterface::DeviceID) const;
+ getMaxVectorOpWidth(TargetSystemSpecInterface::DeviceID) const;
/// Returns for L1 cache size if the property is defined for the given
/// device ID, otherwise return std::nullopt.
std::optional<uint32_t>
- getL1CacheSizeInBytes(TargetDeviceSpecInterface::DeviceID) const;
+ getL1CacheSizeInBytes(TargetSystemSpecInterface::DeviceID) const;
private:
/// Combined layout spec at the given scope.
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index 1d4ede62a337d..4a4e135077737 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -238,20 +238,6 @@ def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{ return ::mlir::success(); }]
>,
- InterfaceMethod<
- /*description=*/"Returns the entry related to Device ID. The function"
- "will crash if the entry is missing.",
- /*retTy=*/"::mlir::DataLayoutEntryInterface",
- /*methodName=*/"getSpecForDeviceID",
- /*args=*/(ins)
- >,
- InterfaceMethod<
- /*description=*/"Returns the entry related to Device Type. "
- "The function will crash if the entry is missing.",
- /*retTy=*/"::mlir::DataLayoutEntryInterface",
- /*methodName=*/"getSpecForDeviceType",
- /*args=*/(ins)
- >,
InterfaceMethod<
/*description=*/"Returns the entry related to the given identifier, if "
"present. Otherwise, return empty spec.",
@@ -265,19 +251,8 @@ def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface"> {
/*retTy=*/"::mlir::DataLayoutEntryInterface",
/*methodName=*/"getSpecForL1CacheSizeInBytes",
/*args=*/(ins)
- >,
- InterfaceMethod<
- /*description=*/"Returns the entry related to the given identifier, if "
- "present.",
- /*retTy=*/"uint32_t",
- /*methodName=*/"getDeviceID",
- /*args=*/(ins)
- >,
+ >
];
-
- let extraClassDeclaration = [{
- using DeviceID = uint32_t;
- }];
}
def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface"> {
@@ -300,7 +275,7 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface"> {
let methods = [
InterfaceMethod<
/*description=*/"Returns the list of layout entries.",
- /*retTy=*/"llvm::ArrayRef<::mlir::TargetDeviceSpecInterface>",
+ /*retTy=*/"llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>",
/*methodName=*/"getEntries",
/*args=*/(ins)
>,
@@ -309,7 +284,7 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface"> {
"ID",
/*retTy=*/"std::optional<::mlir::TargetDeviceSpecInterface>",
/*methodName=*/"getDeviceSpecForDeviceID",
- /*args=*/(ins "TargetDeviceSpecInterface::DeviceID":$deviceID)
+ /*args=*/(ins "StringAttr":$deviceID)
>,
InterfaceMethod<
/*description=*/"Verifies the validity of the specification and "
@@ -323,6 +298,10 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface"> {
}]
>
];
+
+ let extraClassDeclaration = [{
+ using DeviceID = StringAttr;
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 23d40e7f4fd51..7c2fb94eb32a3 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -239,6 +239,35 @@ DataLayoutSpecAttr::combineWith(ArrayRef<DataLayoutSpecInterface> specs) const {
return DataLayoutSpecAttr::get(getContext(), entries);
}
+StringAttr
+DataLayoutSpecAttr::getEndiannessIdentifier(MLIRContext *context) const {
+ return Builder(context).getStringAttr(DLTIDialect::kDataLayoutEndiannessKey);
+}
+
+StringAttr
+DataLayoutSpecAttr::getAllocaMemorySpaceIdentifier(MLIRContext *context) const {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kDataLayoutAllocaMemorySpaceKey);
+}
+
+StringAttr DataLayoutSpecAttr::getProgramMemorySpaceIdentifier(
+ MLIRContext *context) const {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kDataLayoutProgramMemorySpaceKey);
+}
+
+StringAttr
+DataLayoutSpecAttr::getGlobalMemorySpaceIdentifier(MLIRContext *context) const {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kDataLayoutGlobalMemorySpaceKey);
+}
+
+StringAttr
+DataLayoutSpecAttr::getStackAlignmentIdentifier(MLIRContext *context) const {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kDataLayoutStackAlignmentKey);
+}
+
/// Parses an attribute with syntax
/// attr ::= `#target.` `dl_spec` `<` attr-list? `>`
/// attr-list ::= attr
@@ -271,14 +300,48 @@ void DataLayoutSpecAttr::print(AsmPrinter &os) const {
// TargetDeviceSpecAttr
//===----------------------------------------------------------------------===//
+namespace mlir {
+template <>
+struct FieldParser<DeviceIDTargetDeviceSpecPair> {
+ static FailureOr<DeviceIDTargetDeviceSpecPair> parse(AsmParser &parser) {
+ std::string deviceID;
+
+ if (failed(parser.parseString(&deviceID))) {
+ parser.emitError(parser.getCurrentLocation())
+ << "DeviceID is missing, or is not of string type";
+ return failure();
+ }
+
+ if (failed(parser.parseColon())) {
+ parser.emitError(parser.getCurrentLocation()) << "Missing colon";
+ return failure();
+ }
+
+ auto target_device_spec =
+ FieldParser<TargetDeviceSpecInterface>::parse(parser);
+ if (failed(target_device_spec)) {
+ parser.emitError(parser.getCurrentLocation())
+ << "Error in parsing target device spec";
+ return failure();
+ }
+
+ return std::make_pair(parser.getBuilder().getStringAttr(deviceID),
+ *target_device_spec);
+ }
+};
+
+inline AsmPrinter &operator<<(AsmPrinter &printer,
+ DeviceIDTargetDeviceSpecPair param) {
+ return printer << param.first << " : " << param.second;
+}
+
+} // namespace mlir
+
LogicalResult
TargetDeviceSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<DataLayoutEntryInterface> entries) {
// Entries in tdd_spec can only have StringAttr as key. It does not support
// type as a key. Hence not reusing DataLayoutEntryInterface::verify.
- bool targetDeviceIDKeyPresentAndValid = false;
- bool targetDeviceTypeKeyPresentAndValid = false;
-
DenseSet<StringAttr> ids;
for (DataLayoutEntryInterface entry : entries) {
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
@@ -291,21 +354,9 @@ TargetDeviceSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "repeated layout entry key: " << id.getValue();
}
- // check that Device ID and Device Type are present.
+ // Check that required keys are of right type.
StringRef entryName = entry.getKey().get<StringAttr>().strref();
- if (entryName == DLTIDialect::kTargetDeviceIDKey) {
- // Also check the type of the value.
- IntegerAttr value =
- llvm::dyn_cast_if_present<IntegerAttr>(entry.getValue());
- if (value && value.getType().isUnsignedInteger(32)) {
- targetDeviceIDKeyPresentAndValid = true;
- }
- } else if (entryName == DLTIDialect::kTargetDeviceTypeKey) {
- // Also check the type of the value.
- if (auto value = llvm::dyn_cast<StringAttr>(entry.getValue())) {
- targetDeviceTypeKeyPresentAndValid = true;
- }
- } else if (entryName == DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey) {
+ if (entryName == DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey) {
IntegerAttr value =
llvm::dyn_cast_if_present<IntegerAttr>(entry.getValue());
if (!value || !value.getType().isUnsignedInteger(32))
@@ -325,30 +376,9 @@ TargetDeviceSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
}
}
- // check that both DeviceID and DeviceType are present
- // and are of correct type.
- if (!targetDeviceIDKeyPresentAndValid) {
- return emitError() << "target_device_spec requires key: "
- << DLTIDialect::kTargetDeviceIDKey
- << " and its value of ui32 type";
- }
- if (!targetDeviceTypeKeyPresentAndValid) {
- return emitError() << "target_device_spec requires key: "
- << DLTIDialect::kTargetDeviceTypeKey
- << " and its value of string type";
- }
-
return success();
}
-StringAttr TargetDeviceSpecAttr::getDeviceIDIdentifier() {
- return Builder(getContext()).getStringAttr(DLTIDialect::kTargetDeviceIDKey);
-}
-
-StringAttr TargetDeviceSpecAttr::getDeviceTypeIdentifier() {
- return Builder(getContext()).getStringAttr(DLTIDialect::kTargetDeviceTypeKey);
-}
-
StringAttr TargetDeviceSpecAttr::getMaxVectorOpWidthIdentifier() {
return Builder(getContext())
.getStringAttr(DLTIDialect::kTargetDeviceMaxVectorOpWidthKey);
@@ -359,14 +389,6 @@ StringAttr TargetDeviceSpecAttr::getL1CacheSizeInBytesIdentifier() {
.getStringAttr(DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey);
}
-DataLayoutEntryInterface TargetDeviceSpecAttr::getSpecForDeviceID() {
- return getSpecForIdentifier(getDeviceIDIdentifier());
-}
-
-DataLayoutEntryInterface TargetDeviceSpecAttr::getSpecForDeviceType() {
- return getSpecForIdentifier(getDeviceTypeIdentifier());
-}
-
DataLayoutEntryInterface TargetDeviceSpecAttr::getSpecForMaxVectorOpWidth() {
return getSpecForIdentifier(getMaxVectorOpWidthIdentifier());
}
@@ -375,27 +397,24 @@ DataLayoutEntryInterface TargetDeviceSpecAttr::getSpecForL1CacheSizeInBytes() {
return getSpecForIdentifier(getL1CacheSizeInBytesIdentifier());
}
-TargetDeviceSpecInterface::DeviceID TargetDeviceSpecAttr::getDeviceID() {
- DataLayoutEntryInterface entry = getSpecForDeviceID();
- return llvm::cast<IntegerAttr>(entry.getValue()).getValue().getZExtValue();
-}
-
//===----------------------------------------------------------------------===//
// TargetSystemSpecAttr
//===----------------------------------------------------------------------===//
LogicalResult
TargetSystemSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<TargetDeviceSpecInterface> entries) {
- DenseSet<uint32_t> device_ids;
+ ArrayRef<DeviceIDTargetDeviceSpecPair> entries) {
+ DenseSet<TargetSystemSpecInterface::DeviceID> device_ids;
+
+ for (const auto &entry : entries) {
+ TargetDeviceSpecInterface tdd_spec = entry.second;
- for (TargetDeviceSpecInterface tdd_spec : entries) {
// First verify that a target device spec is valid.
if (failed(TargetDeviceSpecAttr::verify(emitError, tdd_spec.getEntries())))
return failure();
// Check that device IDs are unique across all entries.
- uint32_t device_id = tdd_spec.getDeviceID();
+ TargetSystemSpecInterface::DeviceID device_id = entry.first;
if (!device_ids.insert(device_id).second) {
return emitError() << "repeated Device ID in dlti.target_system_spec: "
<< device_id;
@@ -404,35 +423,6 @@ TargetSystemSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
-StringAttr
-DataLayoutSpecAttr::getEndiannessIdentifier(MLIRContext *context) const {
- return Builder(context).getStringAttr(DLTIDialect::kDataLayoutEndiannessKey);
-}
-
-StringAttr
-DataLayoutSpecAttr::getAllocaMemorySpaceIdentifier(MLIRContext *context) const {
- return Builder(context).getStringAttr(
- DLTIDialect::kDataLayoutAllocaMemorySpaceKey);
-}
-
-StringAttr DataLayoutSpecAttr::getProgramMemorySpaceIdentifier(
- MLIRContext *context) const {
- return Builder(context).getStringAttr(
- DLTIDialect::kDataLayoutProgramMemorySpaceKey);
-}
-
-StringAttr
-DataLayoutSpecAttr::getGlobalMemorySpaceIdentifier(MLIRContext *context) const {
- return Builder(context).getStringAttr(
- DLTIDialect::kDataLayoutGlobalMemorySpaceKey);
-}
-
-StringAttr
-DataLayoutSpecAttr::getStackAlignmentIdentifier(MLIRContext *context) const {
- return Builder(context).getStringAttr(
- DLTIDialect::kDataLayoutStackAlignmentKey);
-}
-
//===----------------------------------------------------------------------===//
// DLTIDialect
//===----------------------------------------------------------------------===//
@@ -483,9 +473,7 @@ class SystemDescSpecInterface : public DataLayoutDialectInterface {
StringRef entryName = dl_entry.getKey().get<StringAttr>().strref();
// Check that the key name is known to us. Although, we may allow keys
// unknown to us.
- if (entryName != DLTIDialect::kTargetDeviceIDKey &&
- entryName != DLTIDialect::kTargetDeviceTypeKey &&
- entryName != DLTIDialect::kTargetDeviceMaxVectorOpWidthKey &&
+ if (entryName != DLTIDialect::kTargetDeviceMaxVectorOpWidthKey &&
entryName != DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey)
return emitError(loc) << "unknown target desc key name: " << entryName;
}
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 3456b804b9800..04999e8af6d09 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -678,7 +678,7 @@ uint64_t mlir::DataLayout::getStackAlignment() const {
}
std::optional<uint32_t> mlir::DataLayout::getMaxVectorOpWidth(
- TargetDeviceSpecInterface::DeviceID deviceID) const {
+ TargetSystemSpecInterface::DeviceID deviceID) const {
checkValid();
DataLayoutEntryInterface entry;
if (originalTargetSystemDesc) {
@@ -697,7 +697,7 @@ std::optional<uint32_t> mlir::DataLayout::getMaxVectorOpWidth(
}
std::optional<uint32_t> mlir::DataLayout::getL1CacheSizeInBytes(
- TargetDeviceSpecInterface::DeviceID deviceID) const {
+ TargetSystemSpecInterface::DeviceID deviceID) const {
checkValid();
DataLayoutEntryInterface entry;
if (originalTargetSystemDesc) {
@@ -823,15 +823,16 @@ LogicalResult
mlir::detail::verifyTargetSystemSpec(TargetSystemSpecInterface spec,
Location loc) {
DenseMap<StringAttr, DataLayoutEntryInterface> device_desc_keys;
- DenseSet<TargetDeviceSpecInterface::DeviceID> device_ids;
- for (TargetDeviceSpecInterface tdd_spec : spec.getEntries()) {
+ DenseSet<TargetSystemSpecInterface::DeviceID> device_ids;
+ for (const auto &entry : spec.getEntries()) {
+ TargetDeviceSpecInterface tdd_spec = entry.second;
// First, verify individual target device desc specs.
if (failed(tdd_spec.verifyEntry(loc)))
return failure();
// Check that device IDs are unique across all entries.
MLIRContext *context = tdd_spec.getContext();
- TargetDeviceSpecInterface::DeviceID device_id = tdd_spec.getDeviceID();
+ TargetSystemSpecInterface::DeviceID device_id = entry.first;
if (!device_ids.insert(device_id).second) {
return failure();
}
diff --git a/mlir/test/Dialect/DLTI/invalid.mlir b/mlir/test/Dialect/DLTI/invalid.mlir
index 28be8d37a50bf..0f1558100da71 100644
--- a/mlir/test/Dialect/DLTI/invalid.mlir
+++ b/mlir/test/Dialect/DLTI/invalid.mlir
@@ -98,135 +98,100 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"unknown.unknown
// -----
-// expected-error at below {{invalid kind of attribute specified}}
-// expected-error at below {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+// expected-error at below {{expected string}}
+// expected-error at below {{DeviceID is missing, or is not of string type}}
+// expected-error at below {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>`}}
"test.unknown_op"() { dlti.target_system_spec = #dlti.target_system_spec<[]> } : () -> ()
// -----
module attributes {
- // expected-error at +3 {{target_device_spec requires key: dlti.device_id and its value of ui32 type}}
- // expected-error at +4 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
- dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_type", "CPU">>
- >} {}
-
-// -----
-
-module attributes {
- // expected-error at +3 {{target_device_spec requires key: dlti.device_type and its value of string type}}
- // expected-error at +4 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
- dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0: ui32>>
- >} {}
-
-// -----
-
-module attributes {
- // expected-error at +3 {{target_device_spec requires key: dlti.device_id and its value of ui32 type}}
- // expected-error at +4 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
- dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0: i32>>
- >} {}
-
-// -----
-
-module attributes {
- // expected-error at +3 {{target_device_spec requires key: dlti.device_type and its value of string type}}
- // expected-error at +5 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ // Device ID is missing
+ //
+ // expected-error at +4 {{expected string}}
+ // expected-error at +3 {{DeviceID is missing, or is not of string type}}
+ // expected-error at +2 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>`}}
dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
- #dlti.dl_entry<"dlti.device_type", 0: i32>>
+ : #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>
>} {}
// -----
module attributes {
- // expected-error at +3 {{repeated layout entry key: dlti.device_id}}
- // expected-error at +7 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ // Device ID is wrong type
+ //
+ // expected-error at +4 {{expected string}}
+ // expected-error at +3 {{DeviceID is missing, or is not of string type}}
+ // expected-error at +2 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>`}}
dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
- #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
- #dlti.dl_entry<"dlti.device_type", "CPU">,
- #dlti.dl_entry<"dlti.L1_cache_size", 4096 : i32>>
+ 0: #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: i32>>
>} {}
// -----
module attributes {
- // expected-error at +3 {{repeated layout entry key: dlti.device_type}}
- // expected-error at +7 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ // Repeated Device ID
+ //
+ // expected-error at below {{repeated Device ID in dlti.target_system_spec: "CPU"}}
dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
- #dlti.dl_entry<"dlti.device_type", "CPU">,
- #dlti.dl_entry<"dlti.device_type", "GPU">,
- #dlti.dl_entry<"dlti.L1_cache_size", 4096 : i32>>
+ "CPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>,
+ "CPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192 : ui32>>
>} {}
// -----
module attributes {
- // expected-error at +3 {{target_device_spec requires value of key: dlti.L1_cache_size_in_bytes to be of ui32 type}}
- // expected-error at +6 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ // L1_cache_size_in_bytes is of incorrect type
+ //
+ // expected-error at +4 {{target_device_spec requires value of key: dlti.L1_cache_size_in_bytes to be of ui32 type}}
+ // expected-error at +5 {{Error in parsing target device spec}}
+ // expected-error at +4 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>`}}
dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
- #dlti.dl_entry<"dlti.device_type", "CPU">,
- #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096.1 : f32>>
+ "CPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096.1 : f32>>
>} {}
// -----
module attributes {
- // expected-error at +3 {{target_device_spec requires value of key: dlti.max_vector_op_width to be of ui32 type}}
- // expected-error at +6 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ // max_vector_op_width is of incorrect type
+ //
+ // expected-error at +4 {{target_device_spec requires value of key: dlti.max_vector_op_width to be of ui32 type}}
+ // expected-error at +5 {{Error in parsing target device spec}}
+ // expected-error at +4 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>`}}
dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
- #dlti.dl_entry<"dlti.device_type", "CPU">,
- #dlti.dl_entry<"dlti.max_vector_op_width", 4096.1 : f32>>
+ "CPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.max_vector_op_width", 4096.1 : f32>>
>} {}
// -----
module attributes {
- // expected-error at +3 {{unknown target device spec key name: dlti.L2_cache_size_in_bytes}}
- // expected-error at +6 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ // Repeated DLTI entry
+ //
+ // expected-error at +4 {{repeated layout entry key: dlti.L1_cache_size_in_bytes}}
+ // expected-error at +6 {{Error in parsing target device spec}}
+ // expected-error at +5 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>`}}
dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
- #dlti.dl_entry<"dlti.device_type", "CPU">,
- #dlti.dl_entry<"dlti.L2_cache_size_in_bytes", 4096 : i32>>
+ "CPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>,
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192 : ui32>>
>} {}
// -----
module attributes {
- // expected-error at +3 {{unknown target device spec key name: dlti.unknown_key}}
- // expected-error at +6 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<TargetDeviceSpecInterface>`}}
+ // Unsupported dlti key
+ //
+ // expected-error at +4 {{unknown target device spec key name: dlti.unknown_key}}
+ // expected-error at +5 {{Error in parsing target device spec}}
+ // expected-error at +4 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>`}}
dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0>,
- #dlti.dl_entry<"dlti.device_type", "CPU">,
- #dlti.dl_entry<"dlti.unknown_key", 42>>
+ "CPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.unknown_key", 42>>
>} {}
-// -----
-
-module attributes {
- // unexpected-error at below {{repeated Device ID in dlti.target_system_spec: 0}}
- dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0: ui32>,
- #dlti.dl_entry<"dlti.device_type", "CPU">>,
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0: ui32>,
- #dlti.dl_entry<"dlti.device_type", "CPU">>
- >} {}
-
\ No newline at end of file
diff --git a/mlir/test/Dialect/DLTI/roundtrip.mlir b/mlir/test/Dialect/DLTI/roundtrip.mlir
index ccd80cda6f75f..277187a2b73c0 100644
--- a/mlir/test/Dialect/DLTI/roundtrip.mlir
+++ b/mlir/test/Dialect/DLTI/roundtrip.mlir
@@ -57,25 +57,17 @@
// A valid target system description
// CHECK: module attributes {
// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
-// CHECK-SAME: #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "CPU">,
+// CHECK-SAME: "CPU" : #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>,
-// CHECK-SAME: #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "GPU">,
-// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>>
+// CHECK-SAME: "GPU" : #dlti.target_device_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>>
// CHECK-SAME: >} {
// CHECK: }
module attributes {
dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0: ui32>,
- #dlti.dl_entry<"dlti.device_type", "CPU">,
+ "CPU": #dlti.target_device_spec<
#dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 1: ui32>,
- #dlti.dl_entry<"dlti.device_type", "GPU">,
+ "GPU": #dlti.target_device_spec<
#dlti.dl_entry<"dlti.max_vector_op_width", 128: ui32>>
>} {}
-
\ No newline at end of file
+
diff --git a/mlir/test/Dialect/DLTI/valid.mlir b/mlir/test/Dialect/DLTI/valid.mlir
index dff1434a26f48..6643c32ae5324 100644
--- a/mlir/test/Dialect/DLTI/valid.mlir
+++ b/mlir/test/Dialect/DLTI/valid.mlir
@@ -4,47 +4,17 @@
// CHECK: module attributes {
// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
-// CHECK-SAME: #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "CPU">>,
-// CHECK-SAME: #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "GPU">>
-// CHECK-SAME: >} {
-// CHECK: }
-module attributes {
- dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0: ui32>,
- #dlti.dl_entry<"dlti.device_type", "CPU">>,
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 1: ui32>,
- #dlti.dl_entry<"dlti.device_type", "GPU">>
- >} {}
-
-// -----
-
-// CHECK: module attributes {
-// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
-// CHECK-SAME: #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "CPU">,
+// CHECK-SAME: "CPU" : #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>,
-// CHECK-SAME: #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "GPU">,
+// CHECK-SAME: "GPU" : #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>>
// CHECK-SAME: >} {
// CHECK: }
module attributes {
dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0: ui32>,
- #dlti.dl_entry<"dlti.device_type", "CPU">,
+ "CPU": #dlti.target_device_spec<
#dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 1: ui32>,
- #dlti.dl_entry<"dlti.device_type", "GPU">,
+ "GPU": #dlti.target_device_spec<
#dlti.dl_entry<"dlti.max_vector_op_width", 128: ui32>>
>} {}
@@ -52,25 +22,17 @@ module attributes {
// CHECK: module attributes {
// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
-// CHECK-SAME: #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "CPU">,
+// CHECK-SAME: "CPU" : #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>,
-// CHECK-SAME: #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "GPU">,
+// CHECK-SAME: "GPU" : #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192 : ui32>>
// CHECK-SAME: >} {
// CHECK: }
module attributes {
dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0: ui32>,
- #dlti.dl_entry<"dlti.device_type", "CPU">,
+ "CPU": #dlti.target_device_spec<
#dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 1: ui32>,
- #dlti.dl_entry<"dlti.device_type", "GPU">,
+ "GPU": #dlti.target_device_spec<
#dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192: ui32>>
>} {}
@@ -78,24 +40,16 @@ module attributes {
// CHECK: module attributes {
// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
-// CHECK-SAME: #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "CPU">,
+// CHECK-SAME: "CPU" : #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
-// CHECK-SAME: #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_id", 1 : ui32>,
-// CHECK-SAME: #dlti.dl_entry<"dlti.device_type", "GPU">,
+// CHECK-SAME: "GPU" : #dlti.target_device_spec<
// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>>
// CHECK-SAME: >} {
// CHECK: }
module attributes {
dlti.target_system_spec = #dlti.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0: ui32>,
- #dlti.dl_entry<"dlti.device_type", "CPU">,
+ "CPU": #dlti.target_device_spec<
#dlti.dl_entry<"dlti.max_vector_op_width", 64: ui32>>,
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 1: ui32>,
- #dlti.dl_entry<"dlti.device_type", "GPU">,
+ "GPU": #dlti.target_device_spec<
#dlti.dl_entry<"dlti.max_vector_op_width", 128: ui32>>
>} {}
diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
index a0ba54f32a9e3..753cab04621f1 100644
--- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
@@ -96,9 +96,9 @@ struct CustomDataLayoutSpec
class TargetSystemSpecStorage : public AttributeStorage {
public:
- using KeyTy = ArrayRef<TargetDeviceSpecInterface>;
+ using KeyTy = ArrayRef<DeviceIDTargetDeviceSpecPair>;
- TargetSystemSpecStorage(ArrayRef<TargetDeviceSpecInterface> entries)
+ TargetSystemSpecStorage(ArrayRef<DeviceIDTargetDeviceSpecPair> entries)
: entries(entries) {}
bool operator==(const KeyTy &key) const { return key == entries; }
@@ -109,7 +109,7 @@ class TargetSystemSpecStorage : public AttributeStorage {
TargetSystemSpecStorage(allocator.copyInto(key));
}
- ArrayRef<TargetDeviceSpecInterface> entries;
+ ArrayRef<DeviceIDTargetDeviceSpecPair> entries;
};
struct CustomTargetSystemSpec
@@ -123,16 +123,18 @@ struct CustomTargetSystemSpec
static constexpr StringLiteral name = "test.custom_target_system_spec";
static CustomTargetSystemSpec
- get(MLIRContext *ctx, ArrayRef<TargetDeviceSpecInterface> entries) {
+ get(MLIRContext *ctx, ArrayRef<DeviceIDTargetDeviceSpecPair> entries) {
return Base::get(ctx, entries);
}
- TargetDeviceSpecListRef getEntries() const { return getImpl()->entries; }
+ DeviceIDTargetDeviceSpecPairListRef getEntries() const {
+ return getImpl()->entries;
+ }
LogicalResult verifySpec(Location loc) { return success(); }
std::optional<TargetDeviceSpecInterface>
- getDeviceSpecForDeviceID(TargetDeviceSpecInterface::DeviceID deviceID) {
- for (TargetDeviceSpecInterface entry : getEntries()) {
- if (entry.getDeviceID() == deviceID)
- return entry;
+ getDeviceSpecForDeviceID(TargetSystemSpecInterface::DeviceID deviceID) {
+ for (const auto &entry : getEntries()) {
+ if (entry.first == deviceID)
+ return entry.second;
}
return std::nullopt;
}
@@ -365,7 +367,7 @@ struct DLTestDialect : Dialect {
}
};
-struct DLTargetSystemDescTestDialect : Dialect {
+struct DLTargetSystemDescTestDialect : public Dialect {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTargetSystemDescTestDialect)
explicit DLTargetSystemDescTestDialect(MLIRContext *ctx)
@@ -379,8 +381,9 @@ struct DLTargetSystemDescTestDialect : Dialect {
void printAttribute(Attribute attr,
DialectAsmPrinter &printer) const override {
printer << "target_system_spec<";
- llvm::interleaveComma(cast<CustomTargetSystemSpec>(attr).getEntries(),
- printer);
+ llvm::interleaveComma(
+ cast<CustomTargetSystemSpec>(attr).getEntries(), printer,
+ [&](const auto &it) { printer << it.first << ":" << it.second; });
printer << ">";
}
@@ -392,11 +395,36 @@ struct DLTargetSystemDescTestDialect : Dialect {
if (succeeded(parser.parseOptionalGreater()))
return CustomTargetSystemSpec::get(parser.getContext(), {});
- SmallVector<TargetDeviceSpecInterface> entries;
+ auto parseDeviceIDTargetDeviceSpecPair =
+ [&](AsmParser &parser) -> FailureOr<DeviceIDTargetDeviceSpecPair> {
+ std::string deviceID;
+ if (failed(parser.parseString(&deviceID))) {
+ parser.emitError(parser.getCurrentLocation())
+ << "DeviceID is missing, or is not of string type";
+ return failure();
+ }
+ if (failed(parser.parseColon())) {
+ parser.emitError(parser.getCurrentLocation()) << "Missing colon";
+ return failure();
+ }
+
+ TargetDeviceSpecInterface target_device_spec;
+ if (failed(parser.parseAttribute(target_device_spec))) {
+ parser.emitError(parser.getCurrentLocation())
+ << "Error in parsing target device spec";
+ return failure();
+ }
+ return std::make_pair(parser.getBuilder().getStringAttr(deviceID),
+ target_device_spec);
+ };
+
+ SmallVector<DeviceIDTargetDeviceSpecPair> entries;
ok = succeeded(parser.parseCommaSeparatedList([&]() {
- entries.emplace_back();
- ok = succeeded(parser.parseAttribute(entries.back()));
+ auto deviceID_target_device_spec =
+ parseDeviceIDTargetDeviceSpecPair(parser);
+ ok = succeeded(deviceID_target_device_spec);
assert(ok);
+ entries.push_back(*deviceID_target_device_spec);
return success();
}));
assert(ok);
@@ -466,8 +494,12 @@ TEST(DataLayout, NullSpec) {
EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute());
EXPECT_EQ(layout.getStackAlignment(), 0u);
- EXPECT_EQ(layout.getL1CacheSizeInBytes(0 /* device ID*/), std::nullopt);
- EXPECT_EQ(layout.getMaxVectorOpWidth(0 /* device ID*/), std::nullopt);
+ EXPECT_EQ(layout.getL1CacheSizeInBytes(
+ Builder(&ctx).getStringAttr("CPU" /* device ID*/)),
+ std::nullopt);
+ EXPECT_EQ(layout.getMaxVectorOpWidth(
+ Builder(&ctx).getStringAttr("CPU" /* device ID*/)),
+ std::nullopt);
}
TEST(DataLayout, EmptySpec) {
@@ -500,8 +532,12 @@ TEST(DataLayout, EmptySpec) {
EXPECT_EQ(layout.getGlobalMemorySpace(), Attribute());
EXPECT_EQ(layout.getStackAlignment(), 0u);
- EXPECT_EQ(layout.getL1CacheSizeInBytes(0 /* device ID*/), std::nullopt);
- EXPECT_EQ(layout.getMaxVectorOpWidth(0 /* device ID*/), std::nullopt);
+ EXPECT_EQ(layout.getL1CacheSizeInBytes(
+ Builder(&ctx).getStringAttr("CPU" /* device ID*/)),
+ std::nullopt);
+ EXPECT_EQ(layout.getMaxVectorOpWidth(
+ Builder(&ctx).getStringAttr("CPU" /* device ID*/)),
+ std::nullopt);
}
TEST(DataLayout, SpecWithEntries) {
@@ -557,12 +593,9 @@ TEST(DataLayout, SpecWithTargetSystemDescEntries) {
const char *ir = R"MLIR(
module attributes { dl_target_sys_desc_test.target_system_spec =
#dl_target_sys_desc_test.target_system_spec<
- #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.device_id", 0 : ui32>,
- #dlti.dl_entry<"dlti.device_type", "CPU">,
- #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>,
- #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>
- >
+ "CPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>,
+ #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>>
> } {}
)MLIR";
@@ -572,9 +605,11 @@ TEST(DataLayout, SpecWithTargetSystemDescEntries) {
OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
DataLayout layout(*module);
- EXPECT_EQ(layout.getL1CacheSizeInBytes(0 /* device ID*/),
+ EXPECT_EQ(layout.getL1CacheSizeInBytes(
+ Builder(&ctx).getStringAttr("CPU") /* device ID*/),
std::optional<uint32_t>(4096));
- EXPECT_EQ(layout.getMaxVectorOpWidth(0 /* device ID*/),
+ EXPECT_EQ(layout.getMaxVectorOpWidth(
+ Builder(&ctx).getStringAttr("CPU") /* device ID*/),
std::optional<uint32_t>(128));
}
>From 87d1b209b211eaeff45f38994a6115e6f6205380 Mon Sep 17 00:00:00 2001
From: Niranjan Hasabnis <niranjan.hasabnis at intel.com>
Date: Thu, 6 Jun 2024 13:42:51 -0700
Subject: [PATCH 6/7] Addressing review comments
---
mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td | 12 ---------
.../mlir/Interfaces/DataLayoutInterfaces.td | 22 ++++++++--------
mlir/lib/Dialect/DLTI/DLTI.cpp | 25 ++++++++-----------
mlir/lib/Interfaces/DataLayoutInterfaces.cpp | 10 +++++---
mlir/test/Dialect/DLTI/roundtrip.mlir | 12 ++++-----
.../Interfaces/DataLayoutInterfacesTest.cpp | 1 +
6 files changed, 35 insertions(+), 47 deletions(-)
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
index 0f55558dcca27..fef003c675281 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
@@ -189,18 +189,6 @@ def DLTI_TargetDeviceSpecAttr :
/// Returns L1 cache size identifier
StringAttr getL1CacheSizeInBytesIdentifier();
-
- /// Returns the interface spec for max vector op width
- /// Since max vector op width is an optional property, this function will
- /// return a valid spec if the property is defined, otherwise it
- /// will return an empty spec.
- DataLayoutEntryInterface getSpecForMaxVectorOpWidth();
-
- /// Returns the interface spec for L1 cache size
- /// Since L1 cache size is an optional property, this function will
- /// return a valid spec if the property is defined, otherwise it
- /// will return an empty spec.
- DataLayoutEntryInterface getSpecForL1CacheSizeInBytes();
}];
}
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index 4a4e135077737..49aec01bceae1 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -239,19 +239,19 @@ def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface"> {
/*defaultImplementation=*/[{ return ::mlir::success(); }]
>,
InterfaceMethod<
- /*description=*/"Returns the entry related to the given identifier, if "
- "present. Otherwise, return empty spec.",
- /*retTy=*/"::mlir::DataLayoutEntryInterface",
- /*methodName=*/"getSpecForMaxVectorOpWidth",
- /*args=*/(ins)
+ /*description=*/"Returns max vector op width identifier. ",
+ /*retTy=*/"::mlir::StringAttr",
+ /*methodName=*/"getMaxVectorOpWidthIdentifier",
+ /*args=*/(ins),
+ /*methodBody=*/""
>,
InterfaceMethod<
- /*description=*/"Returns the entry related to the given identifier, if "
- "present. Otherwise, return empty spec.",
- /*retTy=*/"::mlir::DataLayoutEntryInterface",
- /*methodName=*/"getSpecForL1CacheSizeInBytes",
- /*args=*/(ins)
- >
+ /*description=*/"Returns L1 cache size identifier identifier. ",
+ /*retTy=*/"::mlir::StringAttr",
+ /*methodName=*/"getL1CacheSizeInBytesIdentifier",
+ /*args=*/(ins),
+ /*methodBody=*/""
+ >,
];
}
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 7c2fb94eb32a3..f598b6220be9f 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -24,7 +24,7 @@ using namespace mlir;
#include "mlir/Dialect/DLTI/DLTIDialect.cpp.inc"
#define GET_ATTRDEF_CLASSES
-#include <mlir/Dialect/DLTI/DLTIAttrs.cpp.inc>
+#include "mlir/Dialect/DLTI/DLTIAttrs.cpp.inc"
#define DEBUG_TYPE "dlti"
@@ -301,6 +301,8 @@ void DataLayoutSpecAttr::print(AsmPrinter &os) const {
//===----------------------------------------------------------------------===//
namespace mlir {
+/// A FieldParser for key-value pairs of DeviceID-target device spec pairs that
+/// make up a target system spec.
template <>
struct FieldParser<DeviceIDTargetDeviceSpecPair> {
static FailureOr<DeviceIDTargetDeviceSpecPair> parse(AsmParser &parser) {
@@ -340,8 +342,9 @@ inline AsmPrinter &operator<<(AsmPrinter &printer,
LogicalResult
TargetDeviceSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<DataLayoutEntryInterface> entries) {
- // Entries in tdd_spec can only have StringAttr as key. It does not support
- // type as a key. Hence not reusing DataLayoutEntryInterface::verify.
+ // Entries in a target device spec can only have StringAttr as key. It does
+ // not support type as a key. Hence not reusing
+ // DataLayoutEntryInterface::verify.
DenseSet<StringAttr> ids;
for (DataLayoutEntryInterface entry : entries) {
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
@@ -389,14 +392,6 @@ StringAttr TargetDeviceSpecAttr::getL1CacheSizeInBytesIdentifier() {
.getStringAttr(DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey);
}
-DataLayoutEntryInterface TargetDeviceSpecAttr::getSpecForMaxVectorOpWidth() {
- return getSpecForIdentifier(getMaxVectorOpWidthIdentifier());
-}
-
-DataLayoutEntryInterface TargetDeviceSpecAttr::getSpecForL1CacheSizeInBytes() {
- return getSpecForIdentifier(getL1CacheSizeInBytesIdentifier());
-}
-
//===----------------------------------------------------------------------===//
// TargetSystemSpecAttr
//===----------------------------------------------------------------------===//
@@ -407,10 +402,11 @@ TargetSystemSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
DenseSet<TargetSystemSpecInterface::DeviceID> device_ids;
for (const auto &entry : entries) {
- TargetDeviceSpecInterface tdd_spec = entry.second;
+ TargetDeviceSpecInterface target_device_spec = entry.second;
// First verify that a target device spec is valid.
- if (failed(TargetDeviceSpecAttr::verify(emitError, tdd_spec.getEntries())))
+ if (failed(TargetDeviceSpecAttr::verify(emitError,
+ target_device_spec.getEntries())))
return failure();
// Check that device IDs are unique across all entries.
@@ -462,6 +458,7 @@ class TargetDataLayoutInterface : public DataLayoutDialectInterface {
} // namespace
namespace {
+/// An interface to check entries of a target device spec.
class SystemDescSpecInterface : public DataLayoutDialectInterface {
public:
using DataLayoutDialectInterface::DataLayoutDialectInterface;
@@ -485,7 +482,7 @@ class SystemDescSpecInterface : public DataLayoutDialectInterface {
void DLTIDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
-#include <mlir/Dialect/DLTI/DLTIAttrs.cpp.inc>
+#include "mlir/Dialect/DLTI/DLTIAttrs.cpp.inc"
>();
addInterfaces<TargetDataLayoutInterface, SystemDescSpecInterface>();
}
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 04999e8af6d09..16776ee511590 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -684,7 +684,8 @@ std::optional<uint32_t> mlir::DataLayout::getMaxVectorOpWidth(
if (originalTargetSystemDesc) {
if (auto device =
originalTargetSystemDesc.getDeviceSpecForDeviceID(deviceID))
- entry = device->getSpecForMaxVectorOpWidth();
+ entry =
+ device->getSpecForIdentifier(device->getMaxVectorOpWidthIdentifier());
}
// Currently I am not caching the results because we do not return
// default values of these properties. Instead if the property is
@@ -703,7 +704,8 @@ std::optional<uint32_t> mlir::DataLayout::getL1CacheSizeInBytes(
if (originalTargetSystemDesc) {
if (auto device =
originalTargetSystemDesc.getDeviceSpecForDeviceID(deviceID))
- entry = device->getSpecForL1CacheSizeInBytes();
+ entry = device->getSpecForIdentifier(
+ device->getL1CacheSizeInBytesIdentifier());
}
// Currently I am not caching the results because we do not return
// default values of these properties. Instead if the property is
@@ -831,7 +833,6 @@ mlir::detail::verifyTargetSystemSpec(TargetSystemSpecInterface spec,
return failure();
// Check that device IDs are unique across all entries.
- MLIRContext *context = tdd_spec.getContext();
TargetSystemSpecInterface::DeviceID device_id = entry.first;
if (!device_ids.insert(device_id).second) {
return failure();
@@ -842,8 +843,9 @@ mlir::detail::verifyTargetSystemSpec(TargetSystemSpecInterface spec,
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
// tdd_spec does not support Type as a key.
return failure();
- } else
+ } else {
device_desc_keys[entry.getKey().get<StringAttr>()] = entry;
+ }
}
}
diff --git a/mlir/test/Dialect/DLTI/roundtrip.mlir b/mlir/test/Dialect/DLTI/roundtrip.mlir
index 277187a2b73c0..7b8255ad71d4d 100644
--- a/mlir/test/Dialect/DLTI/roundtrip.mlir
+++ b/mlir/test/Dialect/DLTI/roundtrip.mlir
@@ -56,12 +56,12 @@
// A valid target system description
// CHECK: module attributes {
-// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
-// CHECK-SAME: "CPU" : #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>,
-// CHECK-SAME: "GPU" : #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>>
-// CHECK-SAME: >} {
+// CHECK: dlti.target_system_spec = #dlti.target_system_spec<
+// CHECK: "CPU" : #dlti.target_device_spec<
+// CHECK: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>,
+// CHECK: "GPU" : #dlti.target_device_spec<
+// CHECK: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>>
+// CHECK: >} {
// CHECK: }
module attributes {
dlti.target_system_spec = #dlti.target_system_spec<
diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
index 753cab04621f1..4a4900147f7b7 100644
--- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
@@ -367,6 +367,7 @@ struct DLTestDialect : Dialect {
}
};
+/// A dialect to test DLTI's target system spec and related attributes
struct DLTargetSystemDescTestDialect : public Dialect {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DLTargetSystemDescTestDialect)
>From 856426f0563ef7ef7820cd3e0410ab3c8423576d Mon Sep 17 00:00:00 2001
From: Niranjan Hasabnis <niranjan.hasabnis at intel.com>
Date: Fri, 14 Jun 2024 19:13:16 -0700
Subject: [PATCH 7/7] Addressing review comments
---
.../mlir/Interfaces/DataLayoutInterfaces.h | 8 +-
.../mlir/Interfaces/DataLayoutInterfaces.td | 4 +-
mlir/lib/Dialect/DLTI/DLTI.cpp | 8 +-
mlir/lib/Interfaces/DataLayoutInterfaces.cpp | 20 ++---
mlir/test/Dialect/DLTI/invalid.mlir | 14 ++--
mlir/test/Dialect/DLTI/valid.mlir | 78 ++++++++++++++++---
.../Interfaces/DataLayoutInterfacesTest.cpp | 4 +-
7 files changed, 95 insertions(+), 41 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
index 7c4de95806a76..b66c8ed3446ec 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
@@ -93,11 +93,11 @@ uint64_t getDefaultStackAlignment(DataLayoutEntryInterface entry);
/// Return max vector op width from the specified DataLayoutEntry. If the
/// property is missing from the entry, then return std::nullopt.
-std::optional<uint32_t> getMaxVectorOpWidth(DataLayoutEntryInterface entry);
+std::optional<int64_t> getMaxVectorOpWidth(DataLayoutEntryInterface entry);
/// Return L1 cache size in bytes from the specified DataLayoutEntry. If the
/// property is missing from the entry, then return std::nullopt.
-std::optional<uint32_t> getL1CacheSizeInBytes(DataLayoutEntryInterface entry);
+std::optional<int64_t> getL1CacheSizeInBytes(DataLayoutEntryInterface entry);
/// Given a list of data layout entries, returns a new list containing the
/// entries with keys having the given type ID, i.e. belonging to the same type
@@ -249,12 +249,12 @@ class DataLayout {
/// Returns for max vector op width if the property is defined for the given
/// device ID, otherwise return std::nullopt.
- std::optional<uint32_t>
+ std::optional<int64_t>
getMaxVectorOpWidth(TargetSystemSpecInterface::DeviceID) const;
/// Returns for L1 cache size if the property is defined for the given
/// device ID, otherwise return std::nullopt.
- std::optional<uint32_t>
+ std::optional<int64_t>
getL1CacheSizeInBytes(TargetSystemSpecInterface::DeviceID) const;
private:
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index 49aec01bceae1..2ae9d29c2d33c 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -482,7 +482,7 @@ def DataLayoutOpInterface : OpInterface<"DataLayoutOpInterface"> {
StaticInterfaceMethod<
/*description=*/"Returns the max vector op width, if the property is "
"defined. Otherwise, it returns std::nullopt.",
- /*retTy=*/"std::optional<uint32_t>",
+ /*retTy=*/"std::optional<int64_t>",
/*methodName=*/"getMaxVectorOpWidth",
/*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
/*methodBody=*/"",
@@ -493,7 +493,7 @@ def DataLayoutOpInterface : OpInterface<"DataLayoutOpInterface"> {
StaticInterfaceMethod<
/*description=*/"Returns the L1 cache size in bytes, if the property is "
"defined. Otherwise, it returns std::nullopt.",
- /*retTy=*/"std::optional<uint32_t>",
+ /*retTy=*/"std::optional<int64_t>",
/*methodName=*/"getL1CacheSizeInBytes",
/*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
/*methodBody=*/"",
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index f598b6220be9f..3ea6e57df3acb 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -362,17 +362,17 @@ TargetDeviceSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
if (entryName == DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey) {
IntegerAttr value =
llvm::dyn_cast_if_present<IntegerAttr>(entry.getValue());
- if (!value || !value.getType().isUnsignedInteger(32))
+ if (!value || !value.getType().isInteger())
return emitError() << "target_device_spec requires value of key: "
<< DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey
- << " to be of ui32 type";
+ << " to be of integer type";
} else if (entryName == DLTIDialect::kTargetDeviceMaxVectorOpWidthKey) {
IntegerAttr value =
llvm::dyn_cast_if_present<IntegerAttr>(entry.getValue());
- if (!value || !value.getType().isUnsignedInteger(32))
+ if (!value || !value.getType().isInteger())
return emitError() << "target_device_spec requires value of key: "
<< DLTIDialect::kTargetDeviceMaxVectorOpWidthKey
- << " to be of ui32 type";
+ << " to be of integer type";
} else {
return emitError() << "unknown target device spec key name: "
<< entryName;
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 16776ee511590..fd5bb186172c7 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -295,7 +295,7 @@ mlir::detail::getDefaultStackAlignment(DataLayoutEntryInterface entry) {
// Returns the max vector op width if specified in the given entry. If the entry
// is empty (meaning the spec is missing), returns std::nullopt.
-std::optional<uint32_t>
+std::optional<int64_t>
mlir::detail::getMaxVectorOpWidth(DataLayoutEntryInterface entry) {
if (entry == DataLayoutEntryInterface())
return std::nullopt;
@@ -306,7 +306,7 @@ mlir::detail::getMaxVectorOpWidth(DataLayoutEntryInterface entry) {
// Returns the L1 cache size if specified in the given entry. If the entry
// is empty (meaning the spec is missing), returns std::nullopt.
-std::optional<uint32_t>
+std::optional<int64_t>
mlir::detail::getL1CacheSizeInBytes(DataLayoutEntryInterface entry) {
if (entry == DataLayoutEntryInterface())
return std::nullopt;
@@ -468,10 +468,10 @@ void checkMissingLayout(DataLayoutSpecInterface originalLayout, OpTy op) {
mlir::DataLayout::DataLayout() : DataLayout(ModuleOp()) {}
mlir::DataLayout::DataLayout(DataLayoutOpInterface op)
- : originalLayout(getCombinedDataLayout(op)), scope(op),
+ : originalLayout(getCombinedDataLayout(op)),
+ originalTargetSystemDesc(getTargetSystemSpec(op)), scope(op),
allocaMemorySpace(std::nullopt), programMemorySpace(std::nullopt),
- globalMemorySpace(std::nullopt), stackAlignment(std::nullopt),
- originalTargetSystemDesc(getTargetSystemSpec(op)) {
+ globalMemorySpace(std::nullopt), stackAlignment(std::nullopt) {
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
checkMissingLayout(originalLayout, op);
collectParentLayouts(op, layoutStack);
@@ -479,10 +479,10 @@ mlir::DataLayout::DataLayout(DataLayoutOpInterface op)
}
mlir::DataLayout::DataLayout(ModuleOp op)
- : originalLayout(getCombinedDataLayout(op)), scope(op),
+ : originalLayout(getCombinedDataLayout(op)),
+ originalTargetSystemDesc(getTargetSystemSpec(op)), scope(op),
allocaMemorySpace(std::nullopt), programMemorySpace(std::nullopt),
- globalMemorySpace(std::nullopt), stackAlignment(std::nullopt),
- originalTargetSystemDesc(getTargetSystemSpec(op)) {
+ globalMemorySpace(std::nullopt), stackAlignment(std::nullopt) {
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
checkMissingLayout(originalLayout, op);
collectParentLayouts(op, layoutStack);
@@ -677,7 +677,7 @@ uint64_t mlir::DataLayout::getStackAlignment() const {
return *stackAlignment;
}
-std::optional<uint32_t> mlir::DataLayout::getMaxVectorOpWidth(
+std::optional<int64_t> mlir::DataLayout::getMaxVectorOpWidth(
TargetSystemSpecInterface::DeviceID deviceID) const {
checkValid();
DataLayoutEntryInterface entry;
@@ -697,7 +697,7 @@ std::optional<uint32_t> mlir::DataLayout::getMaxVectorOpWidth(
return detail::getMaxVectorOpWidth(entry);
}
-std::optional<uint32_t> mlir::DataLayout::getL1CacheSizeInBytes(
+std::optional<int64_t> mlir::DataLayout::getL1CacheSizeInBytes(
TargetSystemSpecInterface::DeviceID deviceID) const {
checkValid();
DataLayoutEntryInterface entry;
diff --git a/mlir/test/Dialect/DLTI/invalid.mlir b/mlir/test/Dialect/DLTI/invalid.mlir
index 0f1558100da71..b2cf8753eaaf2 100644
--- a/mlir/test/Dialect/DLTI/invalid.mlir
+++ b/mlir/test/Dialect/DLTI/invalid.mlir
@@ -113,7 +113,7 @@ module attributes {
// expected-error at +2 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>`}}
dlti.target_system_spec = #dlti.target_system_spec<
: #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : i32>>
>} {}
// -----
@@ -137,9 +137,9 @@ module attributes {
// expected-error at below {{repeated Device ID in dlti.target_system_spec: "CPU"}}
dlti.target_system_spec = #dlti.target_system_spec<
"CPU": #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>,
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096>>,
"CPU": #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192 : ui32>>
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192>>
>} {}
// -----
@@ -147,7 +147,7 @@ module attributes {
module attributes {
// L1_cache_size_in_bytes is of incorrect type
//
- // expected-error at +4 {{target_device_spec requires value of key: dlti.L1_cache_size_in_bytes to be of ui32 type}}
+ // expected-error at +4 {{target_device_spec requires value of key: dlti.L1_cache_size_in_bytes to be of integer type}}
// expected-error at +5 {{Error in parsing target device spec}}
// expected-error at +4 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>`}}
dlti.target_system_spec = #dlti.target_system_spec<
@@ -160,7 +160,7 @@ module attributes {
module attributes {
// max_vector_op_width is of incorrect type
//
- // expected-error at +4 {{target_device_spec requires value of key: dlti.max_vector_op_width to be of ui32 type}}
+ // expected-error at +4 {{target_device_spec requires value of key: dlti.max_vector_op_width to be of integer type}}
// expected-error at +5 {{Error in parsing target device spec}}
// expected-error at +4 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>`}}
dlti.target_system_spec = #dlti.target_system_spec<
@@ -178,8 +178,8 @@ module attributes {
// expected-error at +5 {{failed to parse DLTI_TargetSystemSpecAttr parameter 'entries' which is to be a `::llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>`}}
dlti.target_system_spec = #dlti.target_system_spec<
"CPU": #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>,
- #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192 : ui32>>
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096>,
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192>>
>} {}
// -----
diff --git a/mlir/test/Dialect/DLTI/valid.mlir b/mlir/test/Dialect/DLTI/valid.mlir
index 6643c32ae5324..175368f306822 100644
--- a/mlir/test/Dialect/DLTI/valid.mlir
+++ b/mlir/test/Dialect/DLTI/valid.mlir
@@ -5,17 +5,17 @@
// CHECK: module attributes {
// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
// CHECK-SAME: "CPU" : #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>,
+// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : i32>>,
// CHECK-SAME: "GPU" : #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>>
+// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : i32>>
// CHECK-SAME: >} {
// CHECK: }
module attributes {
dlti.target_system_spec = #dlti.target_system_spec<
"CPU": #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: i32>>,
"GPU": #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.max_vector_op_width", 128: ui32>>
+ #dlti.dl_entry<"dlti.max_vector_op_width", 128: i32>>
>} {}
// -----
@@ -23,17 +23,17 @@ module attributes {
// CHECK: module attributes {
// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
// CHECK-SAME: "CPU" : #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : ui32>>,
+// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : i32>>,
// CHECK-SAME: "GPU" : #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192 : ui32>>
+// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192 : i32>>
// CHECK-SAME: >} {
// CHECK: }
module attributes {
dlti.target_system_spec = #dlti.target_system_spec<
"CPU": #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: i32>>,
"GPU": #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192: ui32>>
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192: i32>>
>} {}
// -----
@@ -41,15 +41,69 @@ module attributes {
// CHECK: module attributes {
// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
// CHECK-SAME: "CPU" : #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
+// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096 : i64>>,
// CHECK-SAME: "GPU" : #dlti.target_device_spec<
-// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : ui32>>
+// CHECK-SAME: #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192 : i64>>
// CHECK-SAME: >} {
// CHECK: }
module attributes {
dlti.target_system_spec = #dlti.target_system_spec<
"CPU": #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.max_vector_op_width", 64: ui32>>,
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: i64>>,
"GPU": #dlti.target_device_spec<
- #dlti.dl_entry<"dlti.max_vector_op_width", 128: ui32>>
+ #dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 8192: i64>>
+ >} {}
+
+// -----
+
+// CHECK: module attributes {
+// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
+// CHECK-SAME: "CPU" : #dlti.target_device_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 64 : i32>>,
+// CHECK-SAME: "GPU" : #dlti.target_device_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : i32>>
+// CHECK-SAME: >} {
+// CHECK: }
+module attributes {
+ dlti.target_system_spec = #dlti.target_system_spec<
+ "CPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.max_vector_op_width", 64: i32>>,
+ "GPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.max_vector_op_width", 128: i32>>
+ >} {}
+
+// -----
+
+// CHECK: module attributes {
+// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
+// CHECK-SAME: "CPU" : #dlti.target_device_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 64 : i64>>,
+// CHECK-SAME: "GPU" : #dlti.target_device_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : i64>>
+// CHECK-SAME: >} {
+// CHECK: }
+module attributes {
+ dlti.target_system_spec = #dlti.target_system_spec<
+ "CPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.max_vector_op_width", 64: i64>>,
+ "GPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.max_vector_op_width", 128: i64>>
+ >} {}
+
+// -----
+
+// CHECK: module attributes {
+// CHECK-SAME: dlti.target_system_spec = #dlti.target_system_spec<
+// CHECK-SAME: "CPU" : #dlti.target_device_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 64 : i64>>,
+// CHECK-SAME: "GPU" : #dlti.target_device_spec<
+// CHECK-SAME: #dlti.dl_entry<"dlti.max_vector_op_width", 128 : i64>>
+// CHECK-SAME: >} {
+// CHECK: }
+module attributes {
+ dlti.target_system_spec = #dlti.target_system_spec<
+ "CPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.max_vector_op_width", 64>>,
+ "GPU": #dlti.target_device_spec<
+ #dlti.dl_entry<"dlti.max_vector_op_width", 128>>
>} {}
diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
index 4a4900147f7b7..321084cfbacf7 100644
--- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
@@ -608,10 +608,10 @@ TEST(DataLayout, SpecWithTargetSystemDescEntries) {
DataLayout layout(*module);
EXPECT_EQ(layout.getL1CacheSizeInBytes(
Builder(&ctx).getStringAttr("CPU") /* device ID*/),
- std::optional<uint32_t>(4096));
+ std::optional<int64_t>(4096));
EXPECT_EQ(layout.getMaxVectorOpWidth(
Builder(&ctx).getStringAttr("CPU") /* device ID*/),
- std::optional<uint32_t>(128));
+ std::optional<int64_t>(128));
}
TEST(DataLayout, Caching) {
More information about the Mlir-commits
mailing list