[Mlir-commits] [mlir] Reimplementing target description concept using DLTI Attributes (PR #91670)
Niranjan Hasabnis
llvmlistbot at llvm.org
Thu May 9 15:11:03 PDT 2024
https://github.com/nhasabni created https://github.com/llvm/llvm-project/pull/91670
and Interfaces. This is a newer implementation of PR #85141 and [RFC](https://discourse.llvm.org/t/rfc-target-description-and-cost-model-in-mlir/76990) by considering reviews and comments on the original PR.
As an example of attributes supported by this commit:
```
module attributes {
dlti.tsd_spec =
#dlti.tsd_spec<
#dlti.tdd_spec<#dlti.dl_entry<"dlti.device_id", 0: ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">,
#dlti.dl_entry<"dlti.canonicalizer_max_iterations", 100 : i32>,
#dlti.dl_entry<"dlti.canonicalizer_max_num_rewrites", -5 : i32>>,
#dlti.tdd_spec<
#dlti.dl_entry<"dlti.device_id", 1: ui32>,
#dlti.dl_entry<"dlti.device_type", "GPU">,
#dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
#dlti.tdd_spec<
#dlti.dl_entry<"dlti.device_id", 2: ui32>,
#dlti.dl_entry<"dlti.device_type", "XPU">>>
}
```
>From d2260b344bc10ac0f9f8195c2921c2834f35e069 Mon Sep 17 00:00:00 2001
From: Niranjan Hasabnis <niranjan.hasabnis at intel.com>
Date: Thu, 9 May 2024 15:06:58 -0700
Subject: [PATCH] Reimplementing target description concept using Attributes
and Interfaces
As an example of attributes supported by this commit:
```
module attributes {
dlti.tsd_spec =
#dlti.tsd_spec<
#dlti.tdd_spec<#dlti.dl_entry<"dlti.device_id", 0: ui32>,
#dlti.dl_entry<"dlti.device_type", "CPU">,
#dlti.dl_entry<"dlti.canonicalizer_max_iterations", 100 : i32>,
#dlti.dl_entry<"dlti.canonicalizer_max_num_rewrites", -5 : i32>>,
#dlti.tdd_spec<
#dlti.dl_entry<"dlti.device_id", 1: ui32>,
#dlti.dl_entry<"dlti.device_type", "GPU">,
#dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
#dlti.tdd_spec<
#dlti.dl_entry<"dlti.device_id", 2: ui32>,
#dlti.dl_entry<"dlti.device_type", "XPU">>>
}
```
---
mlir/include/mlir/Dialect/DLTI/DLTI.h | 137 +++++++
mlir/include/mlir/Dialect/DLTI/DLTIBase.td | 41 ++
mlir/include/mlir/Dialect/DLTI/Traits.h | 7 +
mlir/include/mlir/IR/BuiltinOps.td | 1 +
.../mlir/Interfaces/DataLayoutInterfaces.h | 56 +++
.../mlir/Interfaces/DataLayoutInterfaces.td | 210 +++++++++++
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 +-
mlir/lib/Dialect/DLTI/DLTI.cpp | 353 +++++++++++++++++-
mlir/lib/Dialect/DLTI/Traits.cpp | 6 +
mlir/lib/IR/BuiltinDialect.cpp | 11 +
mlir/lib/Interfaces/DataLayoutInterfaces.cpp | 164 +++++++-
mlir/lib/Transforms/Canonicalizer.cpp | 51 +++
12 files changed, 1043 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTI.h b/mlir/include/mlir/Dialect/DLTI/DLTI.h
index 5ac7c11e6ffee..9aad0d19819f5 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTI.h
+++ b/mlir/include/mlir/Dialect/DLTI/DLTI.h
@@ -21,6 +21,8 @@ namespace mlir {
namespace impl {
class DataLayoutEntryStorage;
class DataLayoutSpecStorage;
+class TargetSystemDescSpecAttrStorage;
+class TargetDeviceDescSpecAttrStorage;
} // namespace impl
//===----------------------------------------------------------------------===//
@@ -124,6 +126,141 @@ class DataLayoutSpecAttr
static constexpr StringLiteral name = "builtin.data_layout_spec";
};
+//===----------------------------------------------------------------------===//
+// TargetSystemDescSpecAttr
+//===----------------------------------------------------------------------===//
+
+/// A system description attribute is a list of device descriptors, each
+/// having a uniq device ID
+class TargetSystemDescSpecAttr
+ : public Attribute::AttrBase<TargetSystemDescSpecAttr, Attribute,
+ impl::TargetSystemDescSpecAttrStorage,
+ TargetSystemDescSpecInterface::Trait> {
+public:
+ using Base::Base;
+
+ /// The keyword used for this attribute in custom syntax.
+ constexpr const static StringLiteral kAttrKeyword = "tsd_spec";
+
+ /// Returns a system descriptor attribute from the given system descriptor
+ static TargetSystemDescSpecAttr
+ get(MLIRContext *context, ArrayRef<TargetDeviceDescSpecInterface> entries);
+
+ /// Returns the list of entries.
+ TargetDeviceDescSpecListRef getEntries() const;
+
+ /// Return the device descriptor that matches the given device ID
+ TargetDeviceDescSpecInterface getDeviceDescForDeviceID(uint32_t deviceID);
+
+ /// Returns the specification containing the given list of keys. If the list
+ /// contains duplicate keys or is otherwise invalid, reports errors using the
+ /// given callback and returns null.
+ static TargetSystemDescSpecAttr
+ getChecked(function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
+ ArrayRef<TargetDeviceDescSpecInterface> entries);
+
+ /// Checks that the given list of entries does not contain duplicate keys.
+ static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<TargetDeviceDescSpecInterface> entries);
+
+ /// Parses an instance of this attribute.
+ static TargetSystemDescSpecAttr parse(AsmParser &parser);
+
+ /// Prints this attribute.
+ void print(AsmPrinter &os) const;
+
+ static constexpr StringLiteral name = "builtin.target_system_description";
+};
+
+//===----------------------------------------------------------------------===//
+// TargetDeviceDescSpecAttr
+//===----------------------------------------------------------------------===//
+
+class TargetDeviceDescSpecAttr
+ : public Attribute::AttrBase<TargetDeviceDescSpecAttr, Attribute,
+ impl::TargetDeviceDescSpecAttrStorage,
+ TargetDeviceDescSpecInterface::Trait> {
+public:
+ using Base::Base;
+
+ /// The keyword used for this attribute in custom syntax.
+ constexpr const static StringLiteral kAttrKeyword = "tdd_spec";
+
+ /// Returns a system descriptor attribute from the given system descriptor
+ static TargetDeviceDescSpecAttr
+ get(MLIRContext *context, ArrayRef<DataLayoutEntryInterface> entries);
+
+ /// Returns the specification containing the given list of keys. If the list
+ /// contains duplicate keys or is otherwise invalid, reports errors using the
+ /// given callback and returns null.
+ static TargetDeviceDescSpecAttr
+ getChecked(function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
+ ArrayRef<DataLayoutEntryInterface> entries);
+
+ /// Checks that the given list of entries does not contain duplicate keys.
+ static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<DataLayoutEntryInterface> entries);
+
+ /// Returns the list of entries.
+ DataLayoutEntryListRef getEntries() const;
+
+ /// Parses an instance of this attribute.
+ static TargetDeviceDescSpecAttr parse(AsmParser &parser);
+
+ /// Prints this attribute.
+ void print(AsmPrinter &os) const;
+
+ /// Returns the device ID identifier.
+ StringAttr getDeviceIDIdentifier(MLIRContext *context);
+
+ /// Returns the device type identifier.
+ StringAttr getDeviceTypeIdentifier(MLIRContext *context);
+
+ /// Returns max vector op width identifier.
+ StringAttr getMaxVectorOpWidthIdentifier(MLIRContext *context);
+
+ /// Returns canonicalizer max iterations identifier.
+ StringAttr getCanonicalizerMaxIterationsIdentifier(MLIRContext *context);
+
+ /// Returns canonicalizer max num rewrites identifier.
+ StringAttr getCanonicalizerMaxNumRewritesIdentifier(MLIRContext *context);
+
+ /// Returns the interface spec for device ID
+ /// Since we verify that the spec contains device ID the function
+ /// will return a valid spec.
+ DataLayoutEntryInterface getSpecForDeviceID(MLIRContext *context);
+
+ /// Returns the interface spec for device type
+ /// Since we verify that the spec contains device type the function
+ /// will return a valid spec.
+ DataLayoutEntryInterface getSpecForDeviceType(MLIRContext *context);
+
+ /// Returns the interface spec for max vector op width
+ /// Since max vector op width is an optional property, this function will
+ /// return a valid spec if the property is defined, otherwise it
+ /// will return an empty spec.
+ DataLayoutEntryInterface getSpecForMaxVectorOpWidth(MLIRContext *context);
+
+ /// Returns the interface spec for canonicalizer max iterations.
+ /// Since this is an optional property, this function will
+ /// return a valid spec if the property is defined, otherwise it
+ /// will return an empty spec.
+ DataLayoutEntryInterface
+ getSpecForCanonicalizerMaxIterations(MLIRContext *context);
+
+ /// Returns the interface spec for canonicalizer max num rewrites.
+ /// Since this is an optional property, this function will
+ /// return a valid spec if the property is defined, otherwise it
+ /// will return an empty spec.
+ DataLayoutEntryInterface
+ getSpecForCanonicalizerMaxNumRewrites(MLIRContext *context);
+
+ /// Return the value of device ID
+ uint32_t getDeviceID(MLIRContext *context);
+
+ static constexpr StringLiteral name = "builtin.target_device_description";
+};
+
} // namespace mlir
#include "mlir/Dialect/DLTI/DLTIDialect.h.inc"
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
index 3572a99fad874..d9802ef4d4f13 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
@@ -27,6 +27,13 @@ def DLTI_Dialect : Dialect {
constexpr const static ::llvm::StringLiteral
kDataLayoutAttrName = "dlti.dl_spec";
+ // Top level attribute name for target system description
+ constexpr const static ::llvm::StringLiteral
+ kTargetSystemDescAttrName = "dlti.tsd_spec";
+
+ constexpr const static ::llvm::StringLiteral
+ kTargetDeviceDescAttrName = "dlti.tdd_spec";
+
// Constants used in entries.
constexpr const static ::llvm::StringLiteral
kDataLayoutEndiannessKey = "dlti.endianness";
@@ -48,6 +55,22 @@ def DLTI_Dialect : Dialect {
constexpr const static ::llvm::StringLiteral
kDataLayoutStackAlignmentKey = "dlti.stack_alignment";
+
+ // Constants used in target description part of DLTI
+ constexpr const static ::llvm::StringLiteral
+ kTargetDeviceIDKey = "dlti.device_id";
+
+ constexpr const static ::llvm::StringLiteral
+ kTargetDeviceTypeKey = "dlti.device_type";
+
+ constexpr const static ::llvm::StringLiteral
+ kTargetDeviceMaxVectorOpWidthKey = "dlti.max_vector_op_width";
+
+ constexpr const static ::llvm::StringLiteral
+ kTargetDeviceCanonicalizerMaxIterationsKey = "dlti.canonicalizer_max_iterations";
+
+ constexpr const static ::llvm::StringLiteral
+ kTargetDeviceCanonicalizerMaxNumRewritesKey = "dlti.canonicalizer_max_num_rewrites";
}];
let useDefaultAttributePrinterParser = 1;
@@ -71,6 +94,24 @@ def DLTI_DataLayoutSpecAttr : DialectAttr<
let convertFromStorage = "$_self";
}
+def DLTI_TargetSystemDescSpecAttr : DialectAttr<
+ DLTI_Dialect,
+ CPred<"::llvm::isa<::mlir::TargetSystemDescSpecAttr>($_self)">,
+ "Target system description part of DLTI"> {
+ let storageType = "::mlir::TargetSystemDescSpecAttr";
+ let returnType = "::mlir::TargetSystemDescSpecAttr";
+ let convertFromStorage = "$_self";
+}
+
+def DLTI_TargetDeviceDescSpecAttr : DialectAttr<
+ DLTI_Dialect,
+ CPred<"::llvm::isa<::mlir::TargetDeviceDescSpecAttr>($_self)">,
+ "Target device description part of DLTI"> {
+ let storageType = "::mlir::TargetDeviceDescSpecAttr";
+ let returnType = "::mlir::TargetDeviceDescSpecAttr";
+ let convertFromStorage = "$_self";
+}
+
def HasDefaultDLTIDataLayout : NativeOpTrait<"HasDefaultDLTIDataLayout"> {
let cppNamespace = "::mlir";
}
diff --git a/mlir/include/mlir/Dialect/DLTI/Traits.h b/mlir/include/mlir/Dialect/DLTI/Traits.h
index 5d86195305a95..44083d54c4cad 100644
--- a/mlir/include/mlir/Dialect/DLTI/Traits.h
+++ b/mlir/include/mlir/Dialect/DLTI/Traits.h
@@ -18,6 +18,7 @@ class DataLayoutSpecAttr;
namespace impl {
LogicalResult verifyHasDefaultDLTIDataLayoutTrait(Operation *op);
DataLayoutSpecInterface getDataLayoutSpec(Operation *op);
+TargetSystemDescSpecInterface getTargetSystemDescSpec(Operation *op);
} // namespace impl
/// Trait to be used by operations willing to use the implementation of the
@@ -37,6 +38,12 @@ class HasDefaultDLTIDataLayout
DataLayoutSpecInterface getDataLayoutSpec() {
return impl::getDataLayoutSpec(this->getOperation());
}
+
+ /// Returns the target system description specification as provided by DLTI
+ /// dialect
+ TargetSystemDescSpecInterface getTargetSystemDescSpec() {
+ return impl::getTargetSystemDescSpec(this->getOperation());
+ }
};
} // namespace mlir
diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td
index eda24615c71ea..bdb4ce3ddfe20 100644
--- a/mlir/include/mlir/IR/BuiltinOps.td
+++ b/mlir/include/mlir/IR/BuiltinOps.td
@@ -78,6 +78,7 @@ def ModuleOp : Builtin_Op<"module", [
//===------------------------------------------------------------------===//
DataLayoutSpecInterface getDataLayoutSpec();
+ TargetSystemDescSpecInterface getTargetSystemDescSpec();
//===------------------------------------------------------------------===//
// OpAsmOpInterface Methods
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
index 76bf33e89a716..57b8636690afb 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
@@ -23,11 +23,17 @@
namespace mlir {
class DataLayout;
class DataLayoutEntryInterface;
+class TargetDeviceDescSpecInterface;
+class TargetSystemDescSpecInterface;
using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
// Using explicit SmallVector size because we cannot infer the size from the
// forward declaration, and we need the typedef in the actual declaration.
using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
+// using TargetDeviceDescSpecList =
+// llvm::SmallVector<TargetDeviceDescSpecInterface, 4>;
+using TargetDeviceDescSpecListRef =
+ llvm::ArrayRef<TargetDeviceDescSpecInterface>;
class DataLayoutOpInterface;
class DataLayoutSpecInterface;
class ModuleOp;
@@ -84,6 +90,20 @@ Attribute getDefaultGlobalMemorySpace(DataLayoutEntryInterface entry);
/// DataLayoutInterface if specified, otherwise returns the default.
uint64_t getDefaultStackAlignment(DataLayoutEntryInterface entry);
+/// return max vector op widt from the specified DataLayoutEntry. If the
+/// property is missing from the entry, then return std::nullopt.
+std::optional<uint32_t> getMaxVectorOpWidth(DataLayoutEntryInterface entry);
+
+/// return canonicalizer max iterations from the specified DataLayoutEntry.
+/// If the property is missing from the entry, then return std::nullopt.
+std::optional<int64_t>
+getCanonicalizerMaxIterations(DataLayoutEntryInterface entry);
+
+/// returncanonicalizer max num rewrites from the specified DataLayoutEntry.
+/// If the property is missing from the entry, then return std::nullopt.
+std::optional<int64_t>
+getCanonicalizerMaxNumRewrites(DataLayoutEntryInterface entry);
+
/// Given a list of data layout entries, returns a new list containing the
/// entries with keys having the given type ID, i.e. belonging to the same type
/// class.
@@ -95,6 +115,11 @@ DataLayoutEntryList filterEntriesForType(DataLayoutEntryListRef entries,
DataLayoutEntryInterface
filterEntryForIdentifier(DataLayoutEntryListRef entries, StringAttr id);
+/// Given a list of target device entries, returns the entry that has the given
+/// identifier as key, if such an entry exists in the list.
+TargetDeviceDescSpecInterface
+filterEntryForIdentifier(TargetDeviceDescSpecListRef entries, StringAttr id);
+
/// Verifies that the operation implementing the data layout interface, or a
/// module operation, is valid. This calls the verifier of the spec attribute
/// and checks if the layout is compatible with specs attached to the enclosing
@@ -106,6 +131,12 @@ LogicalResult verifyDataLayoutOp(Operation *op);
/// and dialect interfaces for type and identifier keys respectively.
LogicalResult verifyDataLayoutSpec(DataLayoutSpecInterface spec, Location loc);
+/// Verifies that a target system desc spec is valid. This dispatches to
+/// individual entry verifiers, and then to the verifiers implemented by the
+/// relevant dialect interfaces for identifier keys.
+LogicalResult verifyTargetSystemDescSpec(TargetSystemDescSpecInterface spec,
+ Location loc);
+
/// Divides the known min value of the numerator by the denominator and rounds
/// the result up to the next integer. Preserves the scalable flag.
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator);
@@ -137,6 +168,13 @@ class DataLayoutDialectInterface
return success();
}
+ /// Checks whether the given data layout entry is valid and reports any errors
+ /// at the provided location. Derived classes should override this.
+ virtual LogicalResult verifyEntry(TargetDeviceDescSpecInterface entry,
+ Location loc) const {
+ return success();
+ }
+
/// Default implementation of entry combination that combines identical
/// entries and returns null otherwise.
static DataLayoutEntryInterface
@@ -214,10 +252,28 @@ class DataLayout {
/// unspecified.
uint64_t getStackAlignment() const;
+ /// Returns for max vector op width if the property is defined for the given
+ /// device ID, otherwise return std::nullopt.
+ std::optional<uint32_t>
+ getMaxVectorOpWidth(TargetDeviceDescSpecInterface::DeviceID) const;
+
+ /// Returns for canonicalizer max iterations if the property is defined for
+ /// the given device ID, otherwise return std::nullopt.
+ std::optional<int64_t> getCanonicalizerMaxIterations(
+ TargetDeviceDescSpecInterface::DeviceID) const;
+
+ /// Returns for canonicalizer max rewrites if the property is defined for
+ /// the given device ID, otherwise return std::nullopt.
+ std::optional<int64_t> getCanonicalizerMaxNumRewrites(
+ TargetDeviceDescSpecInterface::DeviceID) const;
+
private:
/// Combined layout spec at the given scope.
const DataLayoutSpecInterface originalLayout;
+ /// Combined target system desc spec at the given scope.
+ const TargetSystemDescSpecInterface originalTargetSystemDesc;
+
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// List of enclosing layout specs.
SmallVector<DataLayoutSpecInterface, 2> layoutStack;
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index 9edc885b9c5a9..318991d65dbdb 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -194,6 +194,175 @@ def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface"> {
}];
}
+def TargetDeviceDescSpecInterface : AttrInterface<"TargetDeviceDescSpecInterface"> {
+ let cppNamespace = "::mlir";
+
+ let description = [{
+ Attribute interface describing a target device description specification.
+
+ A target device description specification is a list of device properties (key)
+ and their values for a specific device. The device is identified using "device_id"
+ (as a key and ui32 value) and "device_type" key which must have a string value.
+ Both "device_id" and "device_type" are mandatory keys. As an example, L1 cache
+ size could be a device property, and its value would be a device specific size.
+
+ A target device description specification is attached to a module as a module level
+ attribute.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*description=*/"Returns the list of layout entries.",
+ /*retTy=*/"::mlir::DataLayoutEntryListRef",
+ /*methodName=*/"getEntries",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the entry related to the given identifier, if "
+ "present.",
+ /*retTy=*/"::mlir::DataLayoutEntryInterface",
+ /*methodName=*/"getSpecForIdentifier",
+ /*args=*/(ins "::mlir::StringAttr":$identifier),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::filterEntryForIdentifier($_attr.getEntries(),
+ identifier);
+ }]
+ >,
+ InterfaceMethod<
+ /*description=*/"Checks that the entry is well-formed, reports errors "
+ "at the provided location.",
+ /*retTy=*/"::mlir::LogicalResult",
+ /*methodName=*/"verifyEntry",
+ /*args=*/(ins "::mlir::Location":$loc),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{ return ::mlir::success(); }]
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the device ID identifier.",
+ /*retTy=*/"::mlir::StringAttr",
+ /*methodName=*/"getDeviceIDIdentifier",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the device type identifier.",
+ /*retTy=*/"::mlir::StringAttr",
+ /*methodName=*/"getDeviceTypeIdentifier",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the L1 cache size identifier.",
+ /*retTy=*/"::mlir::StringAttr",
+ /*methodName=*/"getMaxVectorOpWidthIdentifier",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns canonicalizer max iterations identifier.",
+ /*retTy=*/"::mlir::StringAttr",
+ /*methodName=*/"getCanonicalizerMaxIterationsIdentifier",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns canonicalizer max num rewrites identifier.",
+ /*retTy=*/"::mlir::StringAttr",
+ /*methodName=*/"getCanonicalizerMaxNumRewritesIdentifier",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the entry related to Device ID. The function"
+ "will crash if the entry is missing.",
+ /*retTy=*/"::mlir::DataLayoutEntryInterface",
+ /*methodName=*/"getSpecForDeviceID",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the entry related to the given identifier. "
+ "The function will crash if the entry is missing.",
+ /*retTy=*/"::mlir::DataLayoutEntryInterface",
+ /*methodName=*/"getSpecForDeviceType",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the entry related to the given identifier, if "
+ "present. Otherwise, return empty spec.",
+ /*retTy=*/"::mlir::DataLayoutEntryInterface",
+ /*methodName=*/"getSpecForMaxVectorOpWidth",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the entry related to the given identifier, if "
+ "present. Otherwise, return empty spec.",
+ /*retTy=*/"::mlir::DataLayoutEntryInterface",
+ /*methodName=*/"getSpecForCanonicalizerMaxIterations",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the entry related to the given identifier, if "
+ "present. Otherwise, return empty spec.",
+ /*retTy=*/"::mlir::DataLayoutEntryInterface",
+ /*methodName=*/"getSpecForCanonicalizerMaxNumRewrites",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the entry related to the given identifier, if "
+ "present.",
+ /*retTy=*/"uint32_t",
+ /*methodName=*/"getDeviceID",
+ /*args=*/(ins "::mlir::MLIRContext *":$context)
+ >,
+ ];
+
+ let extraClassDeclaration = [{
+ using DeviceID = uint32_t;
+ }];
+}
+
+def TargetSystemDescSpecInterface : AttrInterface<"TargetSystemDescSpecInterface"> {
+ let cppNamespace = "::mlir";
+
+ let description = [{
+ Attribute interface describing a target system description specification.
+
+ A target system description specification is a list of target device
+ specifications, with one device specification for a device in the system. As
+ such, a target system description specification allows specifying a heterogenous
+ system, with devices of different types (e.g., CPU, GPU, etc.)
+
+ The only requirement on a valid target system description specification is that
+ the "device_id" in every target device description specification needs to be
+ unique. This is because, ultimately, this "device_id" is used by the user to
+ query a value of a device property.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*description=*/"Returns the list of layout entries.",
+ /*retTy=*/"::mlir::TargetDeviceDescSpecListRef",
+ /*methodName=*/"getEntries",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*description=*/"Returns the device description spec for given device "
+ "ID",
+ /*retTy=*/"::mlir::TargetDeviceDescSpecInterface",
+ /*methodName=*/"getDeviceDescForDeviceID",
+ /*args=*/(ins "int":$deviceID)
+ >,
+ InterfaceMethod<
+ /*description=*/"Verifies the validity of the specification and "
+ "reports "
+ "any errors at the given location.",
+ /*retTy=*/"::mlir::LogicalResult",
+ /*methodName=*/"verifySpec",
+ /*args=*/(ins "::mlir::Location":$loc),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::verifyTargetSystemDescSpec($_attr, loc);
+ }]
+ >
+ ];
+}
+
//===----------------------------------------------------------------------===//
// Operation interface
//===----------------------------------------------------------------------===//
@@ -227,6 +396,14 @@ def DataLayoutOpInterface : OpInterface<"DataLayoutOpInterface"> {
/*methodName=*/"getDataLayoutSpec",
/*args=*/(ins)
>,
+ InterfaceMethod<
+ /*description=*/"Returns the target system desc specification for this "
+ "op, or "
+ "null if it does not exist.",
+ /*retTy=*/"::mlir::TargetSystemDescSpecInterface",
+ /*methodName=*/"getTargetSystemDescSpec",
+ /*args=*/(ins)
+ >,
StaticInterfaceMethod<
/*description=*/"Returns the size of the given type computed using the "
"relevant entries. The data layout object can be used "
@@ -362,6 +539,39 @@ def DataLayoutOpInterface : OpInterface<"DataLayoutOpInterface"> {
return ::mlir::detail::getDefaultStackAlignment(entry);
}]
>,
+ StaticInterfaceMethod<
+ /*description=*/"Returns the max vector op width, if the property is "
+ "defined. Otherwise, it returns std::nullopt.",
+ /*retTy=*/"std::optional<uint32_t>",
+ /*methodName=*/"getMaxVectorOpWidth",
+ /*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::getMaxVectorOpWidth(entry);
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*description=*/"Returns the canonicalizer max iterations, if the "
+ "property is defined. Otherwise, it returns std::nullopt.",
+ /*retTy=*/"std::optional<int64_t>",
+ /*methodName=*/"getCanonicalizerMaxIterations",
+ /*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::getCanonicalizerMaxIterations(entry);
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*description=*/"Returns the canonicalizer max num rewrites, if the "
+ "property is defined. Otherwise, it returns std::nullopt.",
+ /*retTy=*/"std::optional<int64_t>",
+ /*methodName=*/"getCanonicalizerMaxNumRewrites",
+ /*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::detail::getCanonicalizerMaxNumRewrites(entry);
+ }]
+ >,
];
let verify = [{ return ::mlir::detail::verifyDataLayoutOp($_op); }];
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 033e66c6118f3..8ff2b2104b998 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
@@ -29,6 +30,8 @@ namespace mlir {
using namespace mlir;
using namespace mlir::amdgpu;
+#define DEBUG_TYPE "amd-gpu-to-rocdl"
+
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
Type llvmI32 = rewriter.getI32Type();
@@ -49,7 +52,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
: ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
Chipset chipset;
- static constexpr uint32_t maxVectorOpWidth = 128;
LogicalResult
matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
@@ -111,6 +113,15 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
uint32_t elemBits = dataVector.getElementTypeBitWidth();
uint32_t totalBits = elemBits * dataVector.getNumElements();
+ uint32_t maxVectorOpWidth = 128; // default value
+ if (std::optional<uint32_t> v =
+ DataLayout(gpuOp->template getParentOfType<mlir::ModuleOp>())
+ .getMaxVectorOpWidth(1 /* gpu ID*/)) {
+ maxVectorOpWidth = *v;
+ }
+ LLVM_DEBUG(llvm::dbgs() << "[CostModel] GPU MaxVectorWidth:"
+ << maxVectorOpWidth << "\n");
+
if (totalBits > maxVectorOpWidth)
return gpuOp.emitOpError(
"Total width of loads or stores must be no more than " +
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 98a8865ef4da3..4bbffdd6b2f8b 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -10,14 +10,21 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+
using namespace mlir;
#include "mlir/Dialect/DLTI/DLTIDialect.cpp.inc"
+#define DEBUG_TYPE "dlti"
+
//===----------------------------------------------------------------------===//
// DataLayoutEntryAttr
//===----------------------------------------------------------------------===//
@@ -337,6 +344,307 @@ void DataLayoutSpecAttr::print(AsmPrinter &os) const {
os << ">";
}
+//===----------------------------------------------------------------------===//
+// TargetDeviceDescSpecAttr
+//===----------------------------------------------------------------------===//
+constexpr const StringLiteral mlir::TargetDeviceDescSpecAttr::kAttrKeyword;
+
+constexpr const StringLiteral mlir::DLTIDialect::kTargetDeviceIDKey;
+constexpr const StringLiteral mlir::DLTIDialect::kTargetDeviceTypeKey;
+constexpr const StringLiteral
+ mlir::DLTIDialect::kTargetDeviceMaxVectorOpWidthKey;
+constexpr const StringLiteral
+ mlir::DLTIDialect::kTargetDeviceCanonicalizerMaxIterationsKey;
+constexpr const StringLiteral
+ mlir::DLTIDialect::kTargetDeviceCanonicalizerMaxNumRewritesKey;
+
+namespace mlir {
+namespace impl {
+class TargetDeviceDescSpecAttrStorage : public AttributeStorage {
+public:
+ using KeyTy = ArrayRef<DataLayoutEntryInterface>;
+
+ TargetDeviceDescSpecAttrStorage(KeyTy entries) : entries(entries) {}
+
+ bool operator==(const KeyTy &key) const { return key == entries; }
+
+ static TargetDeviceDescSpecAttrStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ return new (allocator.allocate<TargetDeviceDescSpecAttrStorage>())
+ TargetDeviceDescSpecAttrStorage(allocator.copyInto(key));
+ }
+
+ ArrayRef<DataLayoutEntryInterface> entries;
+};
+} // namespace impl
+} // namespace mlir
+
+TargetDeviceDescSpecAttr
+TargetDeviceDescSpecAttr::get(MLIRContext *ctx,
+ ArrayRef<DataLayoutEntryInterface> entries) {
+ return Base::get(ctx, entries);
+}
+
+DataLayoutEntryListRef TargetDeviceDescSpecAttr::getEntries() const {
+ return getImpl()->entries;
+}
+
+TargetDeviceDescSpecAttr TargetDeviceDescSpecAttr::getChecked(
+ function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
+ ArrayRef<DataLayoutEntryInterface> entries) {
+ return Base::getChecked(emitError, context, entries);
+}
+
+LogicalResult
+TargetDeviceDescSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<DataLayoutEntryInterface> entries) {
+ // Entries in tdd_spec can only have StringAttr as key. It does not support
+ // type as a key. Hence not reusing DataLayoutEntryInterface::verify.
+ bool targetDeviceIDKeyPresentAndValid = false;
+ bool targetDeviceTypeKeyPresentAndValid = false;
+
+ DenseSet<StringAttr> ids;
+ for (DataLayoutEntryInterface entry : entries) {
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+ return emitError()
+ << "dlti.tdd_spec attribute does not allow type as a key: "
+ << type;
+ } else {
+ auto id = entry.getKey().get<StringAttr>();
+ if (!ids.insert(id).second)
+ return emitError() << "repeated layout entry key: " << id.getValue();
+ }
+
+ // check that Device ID and Device Type are present.
+ StringRef entryName = entry.getKey().get<StringAttr>().strref();
+ if (entryName == DLTIDialect::kTargetDeviceIDKey) {
+ // Also check the type of the value.
+ IntegerAttr value =
+ llvm::dyn_cast_if_present<IntegerAttr>(entry.getValue());
+ if (value && value.getType().isUnsignedInteger(32)) {
+ targetDeviceIDKeyPresentAndValid = true;
+ }
+ } else if (entryName == DLTIDialect::kTargetDeviceTypeKey) {
+ // Also check the type of the value.
+ if (auto value = llvm::dyn_cast<StringAttr>(entry.getValue())) {
+ targetDeviceTypeKeyPresentAndValid = true;
+ }
+ }
+ }
+
+ // check that both DeviceID and DeviceType are present
+ // and are of correct type.
+ if (!targetDeviceIDKeyPresentAndValid) {
+ return emitError() << "tdd_spec requires key: "
+ << DLTIDialect::kTargetDeviceIDKey
+ << " and its value of ui32 type";
+ }
+ if (!targetDeviceTypeKeyPresentAndValid) {
+ return emitError() << "tdd_spec requires key: "
+ << DLTIDialect::kTargetDeviceTypeKey
+ << " and its value of string type";
+ }
+
+ return success();
+}
+
+/// Parses an attribute with syntax
+/// tdd_spec_attr ::= `#target.` `tdd_spec` `<` dl-entry-attr-list? `>`
+/// dl-entry-attr-list ::= dl-entry-attr
+/// | dl-entry-attr `,` dl-entry-attr-list
+TargetDeviceDescSpecAttr TargetDeviceDescSpecAttr::parse(AsmParser &parser) {
+ if (failed(parser.parseLess()))
+ return {};
+
+ // Empty spec.
+ if (succeeded(parser.parseOptionalGreater()))
+ return get(parser.getContext(), {});
+
+ SmallVector<DataLayoutEntryInterface> entries;
+ if (parser.parseCommaSeparatedList(
+ [&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
+ parser.parseGreater())
+ return {};
+
+ return getChecked([&] { return parser.emitError(parser.getNameLoc()); },
+ parser.getContext(), entries);
+}
+
+void TargetDeviceDescSpecAttr::print(AsmPrinter &os) const {
+ os << TargetDeviceDescSpecAttr::kAttrKeyword << "<";
+ llvm::interleaveComma(getEntries(), os);
+ os << ">";
+}
+
+// ---------------------------------------------------------------------------//
+// Support for specific keys
+// ---------------------------------------------------------------------------//
+
+StringAttr
+TargetDeviceDescSpecAttr::getDeviceIDIdentifier(MLIRContext *context) {
+ return Builder(context).getStringAttr(DLTIDialect::kTargetDeviceIDKey);
+}
+
+StringAttr
+TargetDeviceDescSpecAttr::getDeviceTypeIdentifier(MLIRContext *context) {
+ return Builder(context).getStringAttr(DLTIDialect::kTargetDeviceTypeKey);
+}
+
+StringAttr
+TargetDeviceDescSpecAttr::getMaxVectorOpWidthIdentifier(MLIRContext *context) {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kTargetDeviceMaxVectorOpWidthKey);
+}
+
+StringAttr TargetDeviceDescSpecAttr::getCanonicalizerMaxIterationsIdentifier(
+ MLIRContext *context) {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kTargetDeviceCanonicalizerMaxIterationsKey);
+}
+
+StringAttr TargetDeviceDescSpecAttr::getCanonicalizerMaxNumRewritesIdentifier(
+ MLIRContext *context) {
+ return Builder(context).getStringAttr(
+ DLTIDialect::kTargetDeviceCanonicalizerMaxNumRewritesKey);
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForDeviceID(MLIRContext *context) {
+ return getSpecForIdentifier(getDeviceIDIdentifier(context));
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForDeviceType(MLIRContext *context) {
+ return getSpecForIdentifier(getDeviceTypeIdentifier(context));
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForMaxVectorOpWidth(MLIRContext *context) {
+ return getSpecForIdentifier(getMaxVectorOpWidthIdentifier(context));
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForCanonicalizerMaxIterations(
+ MLIRContext *context) {
+ return getSpecForIdentifier(getCanonicalizerMaxIterationsIdentifier(context));
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForCanonicalizerMaxNumRewrites(
+ MLIRContext *context) {
+ return getSpecForIdentifier(
+ getCanonicalizerMaxNumRewritesIdentifier(context));
+}
+
+uint32_t TargetDeviceDescSpecAttr::getDeviceID(MLIRContext *context) {
+ DataLayoutEntryInterface entry = getSpecForDeviceID(context);
+ return llvm::cast<IntegerAttr>(entry.getValue()).getValue().getZExtValue();
+}
+
+//===----------------------------------------------------------------------===//
+// TargetSystemDescSpecAttr
+//===----------------------------------------------------------------------===//
+
+constexpr const StringLiteral mlir::TargetSystemDescSpecAttr::kAttrKeyword;
+
+namespace mlir {
+namespace impl {
+class TargetSystemDescSpecAttrStorage : public AttributeStorage {
+public:
+ using KeyTy = ArrayRef<TargetDeviceDescSpecInterface>;
+
+ TargetSystemDescSpecAttrStorage(KeyTy entries) : entries(entries) {}
+
+ bool operator==(const KeyTy &key) const { return key == entries; }
+
+ static TargetSystemDescSpecAttrStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ return new (allocator.allocate<TargetSystemDescSpecAttrStorage>())
+ TargetSystemDescSpecAttrStorage(allocator.copyInto(key));
+ }
+
+ // This could be a map of DeviceID to DeviceDesc for faster lookup.
+ ArrayRef<TargetDeviceDescSpecInterface> entries;
+};
+} // namespace impl
+} // namespace mlir
+
+TargetSystemDescSpecAttr
+TargetSystemDescSpecAttr::get(MLIRContext *context,
+ ArrayRef<TargetDeviceDescSpecInterface> entries) {
+ return Base::get(context, entries);
+}
+
+TargetDeviceDescSpecListRef TargetSystemDescSpecAttr::getEntries() const {
+ return getImpl()->entries;
+}
+
+TargetDeviceDescSpecInterface
+TargetSystemDescSpecAttr::getDeviceDescForDeviceID(
+ TargetDeviceDescSpecInterface::DeviceID DeviceID) {
+ for (TargetDeviceDescSpecInterface entry : getEntries()) {
+ if (entry.getDeviceID(getContext()) == DeviceID)
+ return entry;
+ }
+ return TargetDeviceDescSpecInterface();
+}
+
+TargetSystemDescSpecAttr TargetSystemDescSpecAttr::getChecked(
+ function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
+ ArrayRef<TargetDeviceDescSpecInterface> entries) {
+ return Base::getChecked(emitError, context, entries);
+}
+
+LogicalResult TargetSystemDescSpecAttr::verify(
+ function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<TargetDeviceDescSpecInterface> entries) {
+ DenseSet<uint32_t> device_ids;
+
+ for (TargetDeviceDescSpecInterface tdd_spec : entries) {
+ // First verify that a target device desc spec is valid.
+ if (failed(
+ TargetDeviceDescSpecAttr::verify(emitError, tdd_spec.getEntries())))
+ return failure();
+
+ // Check that device IDs are unique across all entries.
+ MLIRContext *context = tdd_spec.getContext();
+ uint32_t device_id = tdd_spec.getDeviceID(context);
+ if (!device_ids.insert(device_id).second) {
+ return emitError() << "repeated Device ID in dlti.tsd_spec: "
+ << device_id;
+ }
+ }
+ return success();
+}
+
+/// Parses an attribute with syntax
+/// attr ::= `#target.` `tsd_spec` `<` tdd-spec-attr-list? `>`
+/// tdd-spec-attr-list ::= tdd_spec
+/// | tdd_spec `,` tdd_spec_attr_list
+TargetSystemDescSpecAttr TargetSystemDescSpecAttr::parse(AsmParser &parser) {
+ if (failed(parser.parseLess()))
+ return {};
+
+ // Empty spec.
+ if (succeeded(parser.parseOptionalGreater()))
+ return get(parser.getContext(), {});
+
+ SmallVector<TargetDeviceDescSpecInterface> entries;
+ if (parser.parseCommaSeparatedList(
+ [&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
+ parser.parseGreater())
+ return {};
+
+ return getChecked([&] { return parser.emitError(parser.getNameLoc()); },
+ parser.getContext(), entries);
+}
+
+void TargetSystemDescSpecAttr::print(AsmPrinter &os) const {
+ os << TargetSystemDescSpecAttr::kAttrKeyword << "<";
+ llvm::interleaveComma(getEntries(), os);
+ os << ">";
+}
+
//===----------------------------------------------------------------------===//
// DLTIDialect
//===----------------------------------------------------------------------===//
@@ -375,9 +683,35 @@ class TargetDataLayoutInterface : public DataLayoutDialectInterface {
};
} // namespace
+namespace {
+class SystemDescSpecInterface : public DataLayoutDialectInterface {
+public:
+ using DataLayoutDialectInterface::DataLayoutDialectInterface;
+
+ LogicalResult verifyEntry(TargetDeviceDescSpecInterface entry,
+ Location loc) const final {
+
+ for (DataLayoutEntryInterface dl_entry : entry.getEntries()) {
+ StringRef entryName = dl_entry.getKey().get<StringAttr>().strref();
+ // Check that the key name is known to us. Although, we may allow keys
+ // unknown to us.
+ if (entryName != DLTIDialect::kTargetDeviceIDKey &&
+ entryName != DLTIDialect::kTargetDeviceTypeKey &&
+ entryName != DLTIDialect::kTargetDeviceMaxVectorOpWidthKey &&
+ entryName !=
+ DLTIDialect::kTargetDeviceCanonicalizerMaxIterationsKey &&
+ entryName != DLTIDialect::kTargetDeviceCanonicalizerMaxNumRewritesKey)
+ return emitError(loc) << "unknown target desc key name: " << entryName;
+ }
+ return success();
+ }
+};
+} // namespace
+
void DLTIDialect::initialize() {
- addAttributes<DataLayoutEntryAttr, DataLayoutSpecAttr>();
- addInterfaces<TargetDataLayoutInterface>();
+ addAttributes<DataLayoutEntryAttr, DataLayoutSpecAttr,
+ TargetSystemDescSpecAttr, TargetDeviceDescSpecAttr>();
+ addInterfaces<TargetDataLayoutInterface, SystemDescSpecInterface>();
}
Attribute DLTIDialect::parseAttribute(DialectAsmParser &parser,
@@ -390,6 +724,10 @@ Attribute DLTIDialect::parseAttribute(DialectAsmParser &parser,
return DataLayoutEntryAttr::parse(parser);
if (attrKind == DataLayoutSpecAttr::kAttrKeyword)
return DataLayoutSpecAttr::parse(parser);
+ if (attrKind == TargetSystemDescSpecAttr::kAttrKeyword)
+ return TargetSystemDescSpecAttr::parse(parser);
+ if (attrKind == TargetDeviceDescSpecAttr::kAttrKeyword)
+ return TargetDeviceDescSpecAttr::parse(parser);
parser.emitError(parser.getNameLoc(), "unknown attrribute type: ")
<< attrKind;
@@ -398,8 +736,8 @@ Attribute DLTIDialect::parseAttribute(DialectAsmParser &parser,
void DLTIDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
llvm::TypeSwitch<Attribute>(attr)
- .Case<DataLayoutEntryAttr, DataLayoutSpecAttr>(
- [&](auto a) { a.print(os); })
+ .Case<DataLayoutEntryAttr, DataLayoutSpecAttr, TargetSystemDescSpecAttr,
+ TargetDeviceDescSpecAttr>([&](auto a) { a.print(os); })
.Default([](Attribute) { llvm_unreachable("unknown attribute kind"); });
}
@@ -413,6 +751,13 @@ LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
if (isa<ModuleOp>(op))
return detail::verifyDataLayoutOp(op);
return success();
+ } else if (attr.getName() == DLTIDialect::kTargetSystemDescAttrName) {
+ if (!llvm::isa<TargetSystemDescSpecAttr>(attr.getValue())) {
+ return op->emitError()
+ << "'" << DLTIDialect::kTargetSystemDescAttrName
+ << "' is expected to be a #dlti.tsd_spec attribute";
+ }
+ return success();
}
return op->emitError() << "attribute '" << attr.getName().getValue()
diff --git a/mlir/lib/Dialect/DLTI/Traits.cpp b/mlir/lib/Dialect/DLTI/Traits.cpp
index 85acbee46defd..ead656774a27c 100644
--- a/mlir/lib/Dialect/DLTI/Traits.cpp
+++ b/mlir/lib/Dialect/DLTI/Traits.cpp
@@ -27,3 +27,9 @@ DataLayoutSpecInterface mlir::impl::getDataLayoutSpec(Operation *op) {
return op->getAttrOfType<DataLayoutSpecAttr>(
DLTIDialect::kDataLayoutAttrName);
}
+
+TargetSystemDescSpecInterface
+mlir::impl::getTargetSystemDescSpec(Operation *op) {
+ return op->getAttrOfType<TargetSystemDescSpecAttr>(
+ DLTIDialect::kTargetSystemDescAttrName);
+}
diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index dcb1119fe5207..1d57e0bdef187 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -155,6 +155,17 @@ DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() {
return {};
}
+TargetSystemDescSpecInterface ModuleOp::getTargetSystemDescSpec() {
+ // Take the first and only (if present) attribute that implements the
+ // interface. This needs a linear search, but is called only once per data
+ // layout object construction that is used for repeated queries.
+ for (NamedAttribute attr : getOperation()->getAttrs())
+ if (auto spec =
+ llvm::dyn_cast<TargetSystemDescSpecInterface>(attr.getValue()))
+ return spec;
+ return {};
+}
+
LogicalResult ModuleOp::verify() {
// Check that none of the attributes are non-dialect attributes, except for
// the symbol related attributes.
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 15cfb3dbaf745..22920ad8e8bb4 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -293,6 +293,39 @@ mlir::detail::getDefaultStackAlignment(DataLayoutEntryInterface entry) {
return value.getValue().getZExtValue();
}
+// Returns the max vector op width if specified in the given entry. If the entry
+// is empty (meaning the spec is missing), returns std::nullopt.
+std::optional<uint32_t>
+mlir::detail::getMaxVectorOpWidth(DataLayoutEntryInterface entry) {
+ if (entry == DataLayoutEntryInterface())
+ return std::nullopt;
+
+ auto value = cast<IntegerAttr>(entry.getValue());
+ return value.getValue().getZExtValue();
+}
+
+// Returns the canonicalizer max iterations if specified in the given entry.
+// If the entry is empty (meaning the spec is missing), returns std::nullopt.
+std::optional<int64_t>
+mlir::detail::getCanonicalizerMaxIterations(DataLayoutEntryInterface entry) {
+ if (entry == DataLayoutEntryInterface())
+ return std::nullopt;
+
+ auto value = cast<IntegerAttr>(entry.getValue());
+ return value.getValue().getSExtValue();
+}
+
+// Returns the canonicalizer max num rewrites if specified in the given entry.
+// If the entry is empty (meaning the spec is missing), returns std::nullopt.
+std::optional<int64_t>
+mlir::detail::getCanonicalizerMaxNumRewrites(DataLayoutEntryInterface entry) {
+ if (entry == DataLayoutEntryInterface())
+ return std::nullopt;
+
+ auto value = cast<IntegerAttr>(entry.getValue());
+ return value.getValue().getSExtValue();
+}
+
DataLayoutEntryList
mlir::detail::filterEntriesForType(DataLayoutEntryListRef entries,
TypeID typeID) {
@@ -324,6 +357,28 @@ static DataLayoutSpecInterface getSpec(Operation *operation) {
});
}
+static TargetSystemDescSpecInterface
+getTargetSystemDescSpec(Operation *operation) {
+ if (operation) {
+ ModuleOp moduleOp;
+ if (isa<ModuleOp>(operation)) {
+ moduleOp = llvm::dyn_cast<ModuleOp>(operation);
+ } else {
+ moduleOp = operation->getParentOfType<ModuleOp>();
+ }
+ return moduleOp.getTargetSystemDescSpec();
+ } else
+ return TargetSystemDescSpecInterface();
+ // return llvm::TypeSwitch<Operation *,
+ // TargetSystemDescSpecInterface>(operation)
+ // .Case<ModuleOp, TargetSystemDescSpecInterface>(
+ // [&](auto op) { return op.getTargetSystemDescSpec(); })
+ // .Default([](Operation *) {
+ // llvm_unreachable("expected an op with target system desc spec");
+ // return TargetSystemDescSpecInterface();
+ // });
+}
+
/// Populates `opsWithLayout` with the list of proper ancestors of `leaf` that
/// are either modules or implement the `DataLayoutOpInterface`.
static void
@@ -435,7 +490,8 @@ mlir::DataLayout::DataLayout() : DataLayout(ModuleOp()) {}
mlir::DataLayout::DataLayout(DataLayoutOpInterface op)
: originalLayout(getCombinedDataLayout(op)), scope(op),
allocaMemorySpace(std::nullopt), programMemorySpace(std::nullopt),
- globalMemorySpace(std::nullopt), stackAlignment(std::nullopt) {
+ globalMemorySpace(std::nullopt), stackAlignment(std::nullopt),
+ originalTargetSystemDesc(getTargetSystemDescSpec(op)) {
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
checkMissingLayout(originalLayout, op);
collectParentLayouts(op, layoutStack);
@@ -445,7 +501,8 @@ mlir::DataLayout::DataLayout(DataLayoutOpInterface op)
mlir::DataLayout::DataLayout(ModuleOp op)
: originalLayout(getCombinedDataLayout(op)), scope(op),
allocaMemorySpace(std::nullopt), programMemorySpace(std::nullopt),
- globalMemorySpace(std::nullopt), stackAlignment(std::nullopt) {
+ globalMemorySpace(std::nullopt), stackAlignment(std::nullopt),
+ originalTargetSystemDesc(getTargetSystemDescSpec(op)) {
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
checkMissingLayout(originalLayout, op);
collectParentLayouts(op, layoutStack);
@@ -640,6 +697,60 @@ uint64_t mlir::DataLayout::getStackAlignment() const {
return *stackAlignment;
}
+std::optional<uint32_t> mlir::DataLayout::getMaxVectorOpWidth(
+ TargetDeviceDescSpecInterface::DeviceID deviceID) const {
+ checkValid();
+ DataLayoutEntryInterface entry;
+ if (originalTargetSystemDesc)
+ entry =
+ originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID)
+ .getSpecForMaxVectorOpWidth(originalTargetSystemDesc.getContext());
+ // Currently I am not caching the results because we do not return
+ // default values of these properties. Instead if the property is
+ // missing, we return std::nullopt so that the users can resort to
+ // the default value however they want.
+ if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+ return iface.getMaxVectorOpWidth(entry);
+ else
+ return detail::getMaxVectorOpWidth(entry);
+}
+
+std::optional<int64_t> mlir::DataLayout::getCanonicalizerMaxIterations(
+ TargetDeviceDescSpecInterface::DeviceID deviceID) const {
+ checkValid();
+ DataLayoutEntryInterface entry;
+ if (originalTargetSystemDesc)
+ entry = originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID)
+ .getSpecForCanonicalizerMaxIterations(
+ originalTargetSystemDesc.getContext());
+ // Currently I am not caching the results because we do not return
+ // default values of these properties. Instead if the property is
+ // missing, we return std::nullopt so that the users can resort to
+ // the default value however they want.
+ if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+ return iface.getCanonicalizerMaxIterations(entry);
+ else
+ return detail::getCanonicalizerMaxIterations(entry);
+}
+
+std::optional<int64_t> mlir::DataLayout::getCanonicalizerMaxNumRewrites(
+ TargetDeviceDescSpecInterface::DeviceID deviceID) const {
+ checkValid();
+ DataLayoutEntryInterface entry;
+ if (originalTargetSystemDesc)
+ entry = originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID)
+ .getSpecForCanonicalizerMaxNumRewrites(
+ originalTargetSystemDesc.getContext());
+ // Currently I am not caching the results because we do not return
+ // default values of these properties. Instead if the property is
+ // missing, we return std::nullopt so that the users can resort to
+ // the default value however they want.
+ if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+ return iface.getCanonicalizerMaxNumRewrites(entry);
+ else
+ return detail::getCanonicalizerMaxNumRewrites(entry);
+}
+
//===----------------------------------------------------------------------===//
// DataLayoutSpecInterface
//===----------------------------------------------------------------------===//
@@ -744,6 +855,55 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
return success();
}
+LogicalResult
+mlir::detail::verifyTargetSystemDescSpec(TargetSystemDescSpecInterface spec,
+ Location loc) {
+ DenseMap<StringAttr, DataLayoutEntryInterface> device_desc_keys;
+ DenseSet<uint32_t> device_ids;
+ for (TargetDeviceDescSpecInterface tdd_spec : spec.getEntries()) {
+ // First, verify individual target device desc specs.
+ if (failed(tdd_spec.verifyEntry(loc)))
+ return failure();
+
+ // Check that device IDs are unique across all entries.
+ MLIRContext *context = tdd_spec.getContext();
+ uint32_t device_id = tdd_spec.getDeviceID(context);
+ if (!device_ids.insert(device_id).second) {
+ return failure();
+ }
+
+ // collect all the keys used by all the tdd_specs.
+ for (DataLayoutEntryInterface entry : tdd_spec.getEntries()) {
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+ // tdd_spec does not support Type as a key.
+ return failure();
+ } else
+ device_desc_keys[entry.getKey().get<StringAttr>()] = entry;
+ }
+ }
+
+ for (const auto &kvp : device_desc_keys) {
+ StringAttr identifier = kvp.second.getKey().get<StringAttr>();
+ Dialect *dialect = identifier.getReferencedDialect();
+
+ // Ignore attributes that belong to an unknown dialect, the dialect may
+ // actually implement the relevant interface but we don't know about that.
+ if (!dialect)
+ continue;
+
+ const auto *iface = dyn_cast<DataLayoutDialectInterface>(dialect);
+ if (!iface) {
+ return emitError(loc)
+ << "the '" << dialect->getNamespace()
+ << "' dialect does not support identifier data layout entries";
+ }
+ if (failed(iface->verifyEntry(kvp.second, loc)))
+ return failure();
+ }
+
+ return success();
+}
+
#include "mlir/Interfaces/DataLayoutAttrInterface.cpp.inc"
#include "mlir/Interfaces/DataLayoutOpInterface.cpp.inc"
#include "mlir/Interfaces/DataLayoutTypeInterface.cpp.inc"
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index d50019bd6aee5..2948804b8f92a 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -23,6 +23,8 @@ namespace mlir {
using namespace mlir;
+#define DEBUG_TYPE "canonicalizer"
+
namespace {
/// Canonicalize operations in nested regions.
struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
@@ -48,6 +50,20 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
config.maxIterations = maxIterations;
config.maxNumRewrites = maxNumRewrites;
+ LLVM_DEBUG(llvm::dbgs()
+ << "[CostModel] Canonicalizer MaxIterations (default):"
+ << config.maxIterations << "\n");
+ LLVM_DEBUG(llvm::dbgs()
+ << "[CostModel] Canonicalizer MaxNumRewrites (default):"
+ << config.maxNumRewrites << "\n");
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "[CostModel] Canonicalizer MaxIterations (default):"
+ << config.maxIterations << "\n");
+ LLVM_DEBUG(llvm::dbgs()
+ << "[CostModel] Canonicalizer MaxNumRewrites (default):"
+ << config.maxNumRewrites << "\n");
+
RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
@@ -59,6 +75,41 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
return success();
}
void runOnOperation() override {
+ Operation *op = getOperation();
+ uint32_t cpuID = 0;
+
+ if (isa<ModuleOp>(op)) {
+ if (std::optional<int64_t> v =
+ DataLayout(llvm::dyn_cast<ModuleOp>(*op))
+ .getCanonicalizerMaxIterations(cpuID)) {
+ config.maxIterations = *v;
+ }
+ } else {
+ ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
+ if (std::optional<int64_t> v =
+ DataLayout(moduleOp).getCanonicalizerMaxIterations(cpuID)) {
+ config.maxIterations = *v;
+ }
+ }
+
+ if (isa<ModuleOp>(op)) {
+ if (std::optional<int64_t> v =
+ DataLayout(llvm::dyn_cast<ModuleOp>(*op))
+ .getCanonicalizerMaxNumRewrites(cpuID)) {
+ config.maxNumRewrites = *v;
+ }
+ } else {
+ ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
+ if (std::optional<int64_t> v =
+ DataLayout(moduleOp).getCanonicalizerMaxNumRewrites(cpuID)) {
+ config.maxNumRewrites = *v;
+ }
+ }
+ LLVM_DEBUG(llvm::dbgs() << "[CostModel] Canonicalizer MaxIterations (new):"
+ << config.maxIterations << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "[CostModel] Canonicalizer MaxNumRewrites (new):"
+ << config.maxNumRewrites << "\n");
+
LogicalResult converged =
applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
// Canonicalization is best-effort. Non-convergence is not a pass failure.
More information about the Mlir-commits
mailing list