[Mlir-commits] [mlir] Reimplementing target description concept using DLTI attribute (PR #92138)

Niranjan Hasabnis llvmlistbot at llvm.org
Tue May 14 21:51:06 PDT 2024


https://github.com/nhasabni updated https://github.com/llvm/llvm-project/pull/92138

>From 029f35dccb2fcf252a2a717565725098ee0d1c4f Mon Sep 17 00:00:00 2001
From: Niranjan Hasabnis <niranjan.hasabnis at intel.com>
Date: Tue, 14 May 2024 07:53:49 -0700
Subject: [PATCH] Reimplementing target description concept using DLTI
 attribute

---
 mlir/include/mlir/Dialect/DLTI/DLTI.h         | 146 +++++++
 mlir/include/mlir/Dialect/DLTI/DLTIBase.td    |  44 +++
 mlir/include/mlir/Dialect/DLTI/Traits.h       |   7 +
 mlir/include/mlir/IR/BuiltinOps.td            |   1 +
 .../mlir/Interfaces/DataLayoutInterfaces.h    |  65 ++++
 .../mlir/Interfaces/DataLayoutInterfaces.td   | 228 +++++++++++
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  13 +-
 mlir/lib/Dialect/DLTI/DLTI.cpp                | 367 +++++++++++++++++-
 mlir/lib/Dialect/DLTI/Traits.cpp              |   6 +
 .../Linalg/Transforms/BlockPackMatmul.cpp     |  22 +-
 mlir/lib/IR/BuiltinDialect.cpp                |  11 +
 mlir/lib/Interfaces/DataLayoutInterfaces.cpp  | 185 ++++++++-
 mlir/lib/Transforms/Canonicalizer.cpp         |  51 +++
 .../Interfaces/DataLayoutInterfacesTest.cpp   |  10 +
 14 files changed, 1148 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/DLTI/DLTI.h b/mlir/include/mlir/Dialect/DLTI/DLTI.h
index 5ac7c11e6ffee..f78e8bdc5eb98 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTI.h
+++ b/mlir/include/mlir/Dialect/DLTI/DLTI.h
@@ -21,6 +21,8 @@ namespace mlir {
 namespace impl {
 class DataLayoutEntryStorage;
 class DataLayoutSpecStorage;
+class TargetSystemDescSpecAttrStorage;
+class TargetDeviceDescSpecAttrStorage;
 } // namespace impl
 
 //===----------------------------------------------------------------------===//
@@ -124,6 +126,150 @@ class DataLayoutSpecAttr
   static constexpr StringLiteral name = "builtin.data_layout_spec";
 };
 
+//===----------------------------------------------------------------------===//
+// TargetSystemDescSpecAttr
+//===----------------------------------------------------------------------===//
+
+/// A system description attribute is a list of device descriptors, each
+/// having a unique device ID
+class TargetSystemDescSpecAttr
+    : public Attribute::AttrBase<TargetSystemDescSpecAttr, Attribute,
+                                 impl::TargetSystemDescSpecAttrStorage,
+                                 TargetSystemDescSpecInterface::Trait> {
+public:
+  using Base::Base;
+
+  /// The keyword used for this attribute in custom syntax.
+  constexpr const static StringLiteral kAttrKeyword = "tsd_spec";
+
+  /// Returns a system descriptor attribute from the given system descriptor
+  static TargetSystemDescSpecAttr
+  get(MLIRContext *context, ArrayRef<TargetDeviceDescSpecInterface> entries);
+
+  /// Returns the list of entries.
+  TargetDeviceDescSpecListRef getEntries() const;
+
+  /// Return the device descriptor that matches the given device ID
+  TargetDeviceDescSpecInterface getDeviceDescForDeviceID(uint32_t deviceID);
+
+  /// Returns the specification containing the given list of keys. If the list
+  /// contains duplicate keys or is otherwise invalid, reports errors using the
+  /// given callback and returns null.
+  static TargetSystemDescSpecAttr
+  getChecked(function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
+             ArrayRef<TargetDeviceDescSpecInterface> entries);
+
+  /// Checks that the given list of entries does not contain duplicate keys.
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              ArrayRef<TargetDeviceDescSpecInterface> entries);
+
+  /// Parses an instance of this attribute.
+  static TargetSystemDescSpecAttr parse(AsmParser &parser);
+
+  /// Prints this attribute.
+  void print(AsmPrinter &os) const;
+
+  static constexpr StringLiteral name = "builtin.target_system_description";
+};
+
+//===----------------------------------------------------------------------===//
+// TargetDeviceDescSpecAttr
+//===----------------------------------------------------------------------===//
+
+class TargetDeviceDescSpecAttr
+    : public Attribute::AttrBase<TargetDeviceDescSpecAttr, Attribute,
+                                 impl::TargetDeviceDescSpecAttrStorage,
+                                 TargetDeviceDescSpecInterface::Trait> {
+public:
+  using Base::Base;
+
+  /// The keyword used for this attribute in custom syntax.
+  constexpr const static StringLiteral kAttrKeyword = "tdd_spec";
+
+  /// Returns a system descriptor attribute from the given system descriptor
+  static TargetDeviceDescSpecAttr
+  get(MLIRContext *context, ArrayRef<DataLayoutEntryInterface> entries);
+
+  /// Returns the specification containing the given list of keys. If the list
+  /// contains duplicate keys or is otherwise invalid, reports errors using the
+  /// given callback and returns null.
+  static TargetDeviceDescSpecAttr
+  getChecked(function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
+             ArrayRef<DataLayoutEntryInterface> entries);
+
+  /// Checks that the given list of entries does not contain duplicate keys.
+  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                              ArrayRef<DataLayoutEntryInterface> entries);
+
+  /// Returns the list of entries.
+  DataLayoutEntryListRef getEntries() const;
+
+  /// Parses an instance of this attribute.
+  static TargetDeviceDescSpecAttr parse(AsmParser &parser);
+
+  /// Prints this attribute.
+  void print(AsmPrinter &os) const;
+
+  /// Returns the device ID identifier.
+  StringAttr getDeviceIDIdentifier(MLIRContext *context);
+
+  /// Returns the device type identifier.
+  StringAttr getDeviceTypeIdentifier(MLIRContext *context);
+
+  /// Returns max vector op width identifier.
+  StringAttr getMaxVectorOpWidthIdentifier(MLIRContext *context);
+
+  /// Returns canonicalizer max iterations identifier.
+  StringAttr getCanonicalizerMaxIterationsIdentifier(MLIRContext *context);
+
+  /// Returns canonicalizer max num rewrites identifier.
+  StringAttr getCanonicalizerMaxNumRewritesIdentifier(MLIRContext *context);
+
+  /// Returns L1 cache size identifier
+  StringAttr getL1CacheSizeInBytesIdentifier(MLIRContext *context);
+
+  /// Returns the interface spec for device ID
+  /// Since we verify that the spec contains device ID the function
+  /// will return a valid spec.
+  DataLayoutEntryInterface getSpecForDeviceID(MLIRContext *context);
+
+  /// Returns the interface spec for device type
+  /// Since we verify that the spec contains device type the function
+  /// will return a valid spec.
+  DataLayoutEntryInterface getSpecForDeviceType(MLIRContext *context);
+
+  /// Returns the interface spec for max vector op width
+  /// Since max vector op width is an optional property, this function will
+  /// return a valid spec if the property is defined, otherwise it
+  /// will return an empty spec.
+  DataLayoutEntryInterface getSpecForMaxVectorOpWidth(MLIRContext *context);
+
+  /// Returns the interface spec for L1 cache size
+  /// Since L1 cache size is an optional property, this function will
+  /// return a valid spec if the property is defined, otherwise it
+  /// will return an empty spec.
+  DataLayoutEntryInterface getSpecForL1CacheSizeInBytes(MLIRContext *context);
+
+  /// Returns the interface spec for canonicalizer max iterations.
+  /// Since this is an optional property, this function will
+  /// return a valid spec if the property is defined, otherwise it
+  /// will return an empty spec.
+  DataLayoutEntryInterface
+  getSpecForCanonicalizerMaxIterations(MLIRContext *context);
+
+  /// Returns the interface spec for canonicalizer max num rewrites.
+  /// Since this is an optional property, this function will
+  /// return a valid spec if the property is defined, otherwise it
+  /// will return an empty spec.
+  DataLayoutEntryInterface
+  getSpecForCanonicalizerMaxNumRewrites(MLIRContext *context);
+
+  /// Return the value of device ID
+  uint32_t getDeviceID(MLIRContext *context);
+
+  static constexpr StringLiteral name = "builtin.target_device_description";
+};
+
 } // namespace mlir
 
 #include "mlir/Dialect/DLTI/DLTIDialect.h.inc"
diff --git a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
index 3572a99fad874..c9a054b3c1e51 100644
--- a/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
+++ b/mlir/include/mlir/Dialect/DLTI/DLTIBase.td
@@ -27,6 +27,13 @@ def DLTI_Dialect : Dialect {
     constexpr const static ::llvm::StringLiteral
     kDataLayoutAttrName = "dlti.dl_spec";
 
+    // Top level attribute name for target system description
+    constexpr const static ::llvm::StringLiteral
+    kTargetSystemDescAttrName = "dlti.tsd_spec";
+
+    constexpr const static ::llvm::StringLiteral
+    kTargetDeviceDescAttrName = "dlti.tdd_spec";
+
     // Constants used in entries.
     constexpr const static ::llvm::StringLiteral
     kDataLayoutEndiannessKey = "dlti.endianness";
@@ -48,6 +55,25 @@ def DLTI_Dialect : Dialect {
 
     constexpr const static ::llvm::StringLiteral
     kDataLayoutStackAlignmentKey = "dlti.stack_alignment";
+
+    // Constants used in target description part of DLTI
+    constexpr const static ::llvm::StringLiteral
+    kTargetDeviceIDKey = "dlti.device_id";
+
+    constexpr const static ::llvm::StringLiteral
+    kTargetDeviceTypeKey = "dlti.device_type";
+
+    constexpr const static ::llvm::StringLiteral
+    kTargetDeviceMaxVectorOpWidthKey = "dlti.max_vector_op_width";
+
+    constexpr const static ::llvm::StringLiteral
+    kTargetDeviceCanonicalizerMaxIterationsKey = "dlti.canonicalizer_max_iterations";
+
+    constexpr const static ::llvm::StringLiteral
+    kTargetDeviceCanonicalizerMaxNumRewritesKey = "dlti.canonicalizer_max_num_rewrites";
+
+    constexpr const static ::llvm::StringLiteral
+    kTargetDeviceL1CacheSizeInBytesKey = "dlti.L1_cache_size_in_bytes";
   }];
 
   let useDefaultAttributePrinterParser = 1;
@@ -71,6 +97,24 @@ def DLTI_DataLayoutSpecAttr : DialectAttr<
   let convertFromStorage = "$_self";
 }
 
+def DLTI_TargetSystemDescSpecAttr : DialectAttr<
+    DLTI_Dialect,
+    CPred<"::llvm::isa<::mlir::TargetSystemDescSpecAttr>($_self)">,
+    "Target system description part of DLTI"> {
+  let storageType = "::mlir::TargetSystemDescSpecAttr";
+  let returnType = "::mlir::TargetSystemDescSpecAttr";
+  let convertFromStorage = "$_self";
+}
+
+def DLTI_TargetDeviceDescSpecAttr : DialectAttr<
+    DLTI_Dialect,
+    CPred<"::llvm::isa<::mlir::TargetDeviceDescSpecAttr>($_self)">,
+    "Target device description part of DLTI"> {
+  let storageType = "::mlir::TargetDeviceDescSpecAttr";
+  let returnType = "::mlir::TargetDeviceDescSpecAttr";
+  let convertFromStorage = "$_self";
+}
+
 def HasDefaultDLTIDataLayout : NativeOpTrait<"HasDefaultDLTIDataLayout"> {
   let cppNamespace = "::mlir";
 }
diff --git a/mlir/include/mlir/Dialect/DLTI/Traits.h b/mlir/include/mlir/Dialect/DLTI/Traits.h
index 5d86195305a95..44083d54c4cad 100644
--- a/mlir/include/mlir/Dialect/DLTI/Traits.h
+++ b/mlir/include/mlir/Dialect/DLTI/Traits.h
@@ -18,6 +18,7 @@ class DataLayoutSpecAttr;
 namespace impl {
 LogicalResult verifyHasDefaultDLTIDataLayoutTrait(Operation *op);
 DataLayoutSpecInterface getDataLayoutSpec(Operation *op);
+TargetSystemDescSpecInterface getTargetSystemDescSpec(Operation *op);
 } // namespace impl
 
 /// Trait to be used by operations willing to use the implementation of the
@@ -37,6 +38,12 @@ class HasDefaultDLTIDataLayout
   DataLayoutSpecInterface getDataLayoutSpec() {
     return impl::getDataLayoutSpec(this->getOperation());
   }
+
+  /// Returns the target system description specification as provided by DLTI
+  /// dialect
+  TargetSystemDescSpecInterface getTargetSystemDescSpec() {
+    return impl::getTargetSystemDescSpec(this->getOperation());
+  }
 };
 } // namespace mlir
 
diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td
index eda24615c71ea..bdb4ce3ddfe20 100644
--- a/mlir/include/mlir/IR/BuiltinOps.td
+++ b/mlir/include/mlir/IR/BuiltinOps.td
@@ -78,6 +78,7 @@ def ModuleOp : Builtin_Op<"module", [
     //===------------------------------------------------------------------===//
 
     DataLayoutSpecInterface getDataLayoutSpec();
+    TargetSystemDescSpecInterface getTargetSystemDescSpec();
 
     //===------------------------------------------------------------------===//
     // OpAsmOpInterface Methods
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
index 76bf33e89a716..1584a13247dff 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
@@ -23,11 +23,17 @@
 namespace mlir {
 class DataLayout;
 class DataLayoutEntryInterface;
+class TargetDeviceDescSpecInterface;
+class TargetSystemDescSpecInterface;
 using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
 // Using explicit SmallVector size because we cannot infer the size from the
 // forward declaration, and we need the typedef in the actual declaration.
 using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
 using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
+// using TargetDeviceDescSpecList =
+// llvm::SmallVector<TargetDeviceDescSpecInterface, 4>;
+using TargetDeviceDescSpecListRef =
+    llvm::ArrayRef<TargetDeviceDescSpecInterface>;
 class DataLayoutOpInterface;
 class DataLayoutSpecInterface;
 class ModuleOp;
@@ -84,6 +90,24 @@ Attribute getDefaultGlobalMemorySpace(DataLayoutEntryInterface entry);
 /// DataLayoutInterface if specified, otherwise returns the default.
 uint64_t getDefaultStackAlignment(DataLayoutEntryInterface entry);
 
+/// return max vector op width from the specified DataLayoutEntry. If the
+/// property is missing from the entry, then return std::nullopt.
+std::optional<uint32_t> getMaxVectorOpWidth(DataLayoutEntryInterface entry);
+
+/// return L1 cache size in bytes from the specified DataLayoutEntry. If the
+/// property is missing from the entry, then return std::nullopt.
+std::optional<uint32_t> getL1CacheSizeInBytes(DataLayoutEntryInterface entry);
+
+/// return canonicalizer max iterations from the specified DataLayoutEntry.
+/// If the property is missing from the entry, then return std::nullopt.
+std::optional<int64_t>
+getCanonicalizerMaxIterations(DataLayoutEntryInterface entry);
+
+/// returncanonicalizer max num rewrites from the specified DataLayoutEntry.
+/// If the property is missing from the entry, then return std::nullopt.
+std::optional<int64_t>
+getCanonicalizerMaxNumRewrites(DataLayoutEntryInterface entry);
+
 /// Given a list of data layout entries, returns a new list containing the
 /// entries with keys having the given type ID, i.e. belonging to the same type
 /// class.
@@ -95,6 +119,11 @@ DataLayoutEntryList filterEntriesForType(DataLayoutEntryListRef entries,
 DataLayoutEntryInterface
 filterEntryForIdentifier(DataLayoutEntryListRef entries, StringAttr id);
 
+/// Given a list of target device entries, returns the entry that has the given
+/// identifier as key, if such an entry exists in the list.
+TargetDeviceDescSpecInterface
+filterEntryForIdentifier(TargetDeviceDescSpecListRef entries, StringAttr id);
+
 /// Verifies that the operation implementing the data layout interface, or a
 /// module operation, is valid. This calls the verifier of the spec attribute
 /// and checks if the layout is compatible with specs attached to the enclosing
@@ -106,6 +135,12 @@ LogicalResult verifyDataLayoutOp(Operation *op);
 /// and dialect interfaces for type and identifier keys respectively.
 LogicalResult verifyDataLayoutSpec(DataLayoutSpecInterface spec, Location loc);
 
+/// Verifies that a target system desc spec is valid. This dispatches to
+/// individual entry verifiers, and then to the verifiers implemented by the
+/// relevant dialect interfaces for identifier keys.
+LogicalResult verifyTargetSystemDescSpec(TargetSystemDescSpecInterface spec,
+                                         Location loc);
+
 /// Divides the known min value of the numerator by the denominator and rounds
 /// the result up to the next integer. Preserves the scalable flag.
 llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator);
@@ -137,6 +172,13 @@ class DataLayoutDialectInterface
     return success();
   }
 
+  /// Checks whether the given data layout entry is valid and reports any errors
+  /// at the provided location. Derived classes should override this.
+  virtual LogicalResult verifyEntry(TargetDeviceDescSpecInterface entry,
+                                    Location loc) const {
+    return success();
+  }
+
   /// Default implementation of entry combination that combines identical
   /// entries and returns null otherwise.
   static DataLayoutEntryInterface
@@ -214,10 +256,33 @@ class DataLayout {
   /// unspecified.
   uint64_t getStackAlignment() const;
 
+  /// Returns for max vector op width if the property is defined for the given
+  /// device ID, otherwise return std::nullopt.
+  std::optional<uint32_t>
+      getMaxVectorOpWidth(TargetDeviceDescSpecInterface::DeviceID) const;
+
+  /// Returns for L1 cache size if the property is defined for the given
+  /// device ID, otherwise return std::nullopt.
+  std::optional<uint32_t>
+      getL1CacheSizeInBytes(TargetDeviceDescSpecInterface::DeviceID) const;
+
+  /// Returns for canonicalizer max iterations if the property is defined for
+  /// the given device ID, otherwise return std::nullopt.
+  std::optional<int64_t> getCanonicalizerMaxIterations(
+      TargetDeviceDescSpecInterface::DeviceID) const;
+
+  /// Returns for canonicalizer max rewrites if the property is defined for
+  /// the given device ID, otherwise return std::nullopt.
+  std::optional<int64_t> getCanonicalizerMaxNumRewrites(
+      TargetDeviceDescSpecInterface::DeviceID) const;
+
 private:
   /// Combined layout spec at the given scope.
   const DataLayoutSpecInterface originalLayout;
 
+  /// Combined target system desc spec at the given scope.
+  const TargetSystemDescSpecInterface originalTargetSystemDesc;
+
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
   /// List of enclosing layout specs.
   SmallVector<DataLayoutSpecInterface, 2> layoutStack;
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index 9edc885b9c5a9..75e609dde8fcf 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -194,6 +194,182 @@ def DataLayoutSpecInterface : AttrInterface<"DataLayoutSpecInterface"> {
   }];
 }
 
+def TargetDeviceDescSpecInterface : AttrInterface<"TargetDeviceDescSpecInterface"> {
+  let cppNamespace = "::mlir";
+
+  let description = [{
+    Attribute interface describing a target device description specification.
+
+    A target device description specification is a list of device properties (key)
+    and their values for a specific device. The device is identified using "device_id"
+    (as a key and ui32 value) and "device_type" key which must have a string value.
+    Both "device_id" and "device_type" are mandatory keys. As an example, L1 cache
+    size could be a device property, and its value would be a device specific size.
+
+    A target device description specification is attached to a module as a module level
+    attribute.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      /*description=*/"Returns the list of layout entries.",
+      /*retTy=*/"::mlir::DataLayoutEntryListRef",
+      /*methodName=*/"getEntries",
+      /*args=*/(ins)
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns the entry related to the given identifier, if "
+                      "present.",
+      /*retTy=*/"::mlir::DataLayoutEntryInterface",
+      /*methodName=*/"getSpecForIdentifier",
+      /*args=*/(ins "::mlir::StringAttr":$identifier),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::detail::filterEntryForIdentifier($_attr.getEntries(),
+                                                        identifier);
+      }]
+    >,
+    InterfaceMethod<
+      /*description=*/"Checks that the entry is well-formed, reports errors "
+                      "at the provided location.",
+      /*retTy=*/"::mlir::LogicalResult",
+      /*methodName=*/"verifyEntry",
+      /*args=*/(ins "::mlir::Location":$loc),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return ::mlir::success(); }]
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns the device ID identifier.",
+      /*retTy=*/"::mlir::StringAttr",
+      /*methodName=*/"getDeviceIDIdentifier",
+      /*args=*/(ins "::mlir::MLIRContext *":$context)
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns the device type identifier.",
+      /*retTy=*/"::mlir::StringAttr",
+      /*methodName=*/"getDeviceTypeIdentifier",
+      /*args=*/(ins "::mlir::MLIRContext *":$context)
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns the L1 cache size identifier.",
+      /*retTy=*/"::mlir::StringAttr",
+      /*methodName=*/"getMaxVectorOpWidthIdentifier",
+      /*args=*/(ins "::mlir::MLIRContext *":$context)
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns canonicalizer max iterations identifier.",
+      /*retTy=*/"::mlir::StringAttr",
+      /*methodName=*/"getCanonicalizerMaxIterationsIdentifier",
+      /*args=*/(ins "::mlir::MLIRContext *":$context)
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns canonicalizer max num rewrites identifier.",
+      /*retTy=*/"::mlir::StringAttr",
+      /*methodName=*/"getCanonicalizerMaxNumRewritesIdentifier",
+      /*args=*/(ins "::mlir::MLIRContext *":$context)
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns the entry related to Device ID. The function"
+                      "will crash if the entry is missing.",
+      /*retTy=*/"::mlir::DataLayoutEntryInterface",
+      /*methodName=*/"getSpecForDeviceID",
+      /*args=*/(ins "::mlir::MLIRContext *":$context)
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns the entry related to the given identifier. "
+                      "The function will crash if the entry is missing.",
+      /*retTy=*/"::mlir::DataLayoutEntryInterface",
+      /*methodName=*/"getSpecForDeviceType",
+      /*args=*/(ins "::mlir::MLIRContext *":$context)
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns the entry related to the given identifier, if "
+                      "present. Otherwise, return empty spec.",
+      /*retTy=*/"::mlir::DataLayoutEntryInterface",
+      /*methodName=*/"getSpecForMaxVectorOpWidth",
+      /*args=*/(ins "::mlir::MLIRContext *":$context)
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns the entry related to the given identifier, if "
+                      "present. Otherwise, return empty spec.",
+      /*retTy=*/"::mlir::DataLayoutEntryInterface",
+      /*methodName=*/"getSpecForL1CacheSizeInBytes",
+      /*args=*/(ins "::mlir::MLIRContext *":$context)
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns the entry related to the given identifier, if "
+                      "present. Otherwise, return empty spec.",
+      /*retTy=*/"::mlir::DataLayoutEntryInterface",
+      /*methodName=*/"getSpecForCanonicalizerMaxIterations",
+      /*args=*/(ins "::mlir::MLIRContext *":$context)
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns the entry related to the given identifier, if "
+                      "present. Otherwise, return empty spec.",
+      /*retTy=*/"::mlir::DataLayoutEntryInterface",
+      /*methodName=*/"getSpecForCanonicalizerMaxNumRewrites",
+      /*args=*/(ins "::mlir::MLIRContext *":$context)
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns the entry related to the given identifier, if "
+                      "present.",
+      /*retTy=*/"uint32_t",
+      /*methodName=*/"getDeviceID",
+      /*args=*/(ins "::mlir::MLIRContext *":$context)
+    >,
+  ];
+
+  let extraClassDeclaration = [{
+    using DeviceID = uint32_t;
+  }];
+}
+
+def TargetSystemDescSpecInterface : AttrInterface<"TargetSystemDescSpecInterface"> {
+  let cppNamespace = "::mlir";
+
+  let description = [{
+    Attribute interface describing a target system description specification.
+
+    A target system description specification is a list of target device
+    specifications, with one device specification for a device in the system. As
+    such, a target system description specification allows specifying a heterogenous
+    system, with devices of different types (e.g., CPU, GPU, etc.)
+
+    The only requirement on a valid target system description specification is that
+    the "device_id" in every target device description specification needs to be
+    unique. This is because, ultimately, this "device_id" is used by the user to
+    query a value of a device property.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      /*description=*/"Returns the list of layout entries.",
+      /*retTy=*/"::mlir::TargetDeviceDescSpecListRef",
+      /*methodName=*/"getEntries",
+      /*args=*/(ins)
+    >,
+    InterfaceMethod<
+      /*description=*/"Returns the device description spec for given device "
+                      "ID",
+      /*retTy=*/"::mlir::TargetDeviceDescSpecInterface",
+      /*methodName=*/"getDeviceDescForDeviceID",
+      /*args=*/(ins "int":$deviceID)
+    >,
+    InterfaceMethod<
+      /*description=*/"Verifies the validity of the specification and "
+                      "reports "
+                      "any errors at the given location.",
+      /*retTy=*/"::mlir::LogicalResult",
+      /*methodName=*/"verifySpec",
+      /*args=*/(ins "::mlir::Location":$loc),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::detail::verifyTargetSystemDescSpec($_attr, loc);
+      }]
+    >
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // Operation interface
 //===----------------------------------------------------------------------===//
@@ -227,6 +403,14 @@ def DataLayoutOpInterface : OpInterface<"DataLayoutOpInterface"> {
       /*methodName=*/"getDataLayoutSpec",
       /*args=*/(ins)
     >,
+    InterfaceMethod<
+      /*description=*/"Returns the target system desc specification for this "
+                      "op, or "
+                      "null if it does not exist.",
+      /*retTy=*/"::mlir::TargetSystemDescSpecInterface",
+      /*methodName=*/"getTargetSystemDescSpec",
+      /*args=*/(ins)
+    >,
     StaticInterfaceMethod<
       /*description=*/"Returns the size of the given type computed using the "
                       "relevant entries. The data layout object can be used "
@@ -362,6 +546,50 @@ def DataLayoutOpInterface : OpInterface<"DataLayoutOpInterface"> {
         return ::mlir::detail::getDefaultStackAlignment(entry);
       }]
     >,
+    StaticInterfaceMethod<
+      /*description=*/"Returns the max vector op width, if the property is "
+                      "defined. Otherwise, it returns std::nullopt.",
+      /*retTy=*/"std::optional<uint32_t>",
+      /*methodName=*/"getMaxVectorOpWidth",
+      /*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::detail::getMaxVectorOpWidth(entry);
+      }]
+    >,
+    StaticInterfaceMethod<
+      /*description=*/"Returns the L1 cache size in bytes, if the property is "
+                      "defined. Otherwise, it returns std::nullopt.",
+      /*retTy=*/"std::optional<uint32_t>",
+      /*methodName=*/"getL1CacheSizeInBytes",
+      /*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::detail::getL1CacheSizeInBytes(entry);
+      }]
+    >,
+    StaticInterfaceMethod<
+      /*description=*/"Returns the canonicalizer max iterations, if the "
+                      "property is defined. Otherwise, it returns std::nullopt.",
+      /*retTy=*/"std::optional<int64_t>",
+      /*methodName=*/"getCanonicalizerMaxIterations",
+      /*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::detail::getCanonicalizerMaxIterations(entry);
+      }]
+    >,
+    StaticInterfaceMethod<
+      /*description=*/"Returns the canonicalizer max num rewrites, if the "
+                      "property is defined. Otherwise, it returns std::nullopt.",
+      /*retTy=*/"std::optional<int64_t>",
+      /*methodName=*/"getCanonicalizerMaxNumRewrites",
+      /*args=*/(ins "::mlir::DataLayoutEntryInterface":$entry),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::detail::getCanonicalizerMaxNumRewrites(entry);
+      }]
+    >,
   ];
 
   let verify = [{ return ::mlir::detail::verifyDataLayoutOp($_op); }];
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 033e66c6118f3..8ff2b2104b998 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 
@@ -29,6 +30,8 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::amdgpu;
 
+#define DEBUG_TYPE "amd-gpu-to-rocdl"
+
 static Value createI32Constant(ConversionPatternRewriter &rewriter,
                                Location loc, int32_t value) {
   Type llvmI32 = rewriter.getI32Type();
@@ -49,7 +52,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
       : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
 
   Chipset chipset;
-  static constexpr uint32_t maxVectorOpWidth = 128;
 
   LogicalResult
   matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
@@ -111,6 +113,15 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
       uint32_t elemBits = dataVector.getElementTypeBitWidth();
       uint32_t totalBits = elemBits * dataVector.getNumElements();
+      uint32_t maxVectorOpWidth = 128; // default value
+      if (std::optional<uint32_t> v =
+              DataLayout(gpuOp->template getParentOfType<mlir::ModuleOp>())
+                  .getMaxVectorOpWidth(1 /* gpu ID*/)) {
+        maxVectorOpWidth = *v;
+      }
+      LLVM_DEBUG(llvm::dbgs() << "[CostModel] GPU MaxVectorWidth:"
+                              << maxVectorOpWidth << "\n");
+
       if (totalBits > maxVectorOpWidth)
         return gpuOp.emitOpError(
             "Total width of loads or stores must be no more than " +
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 98a8865ef4da3..a8518469c7824 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -10,14 +10,21 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+
 using namespace mlir;
 
 #include "mlir/Dialect/DLTI/DLTIDialect.cpp.inc"
 
+#define DEBUG_TYPE "dlti"
+
 //===----------------------------------------------------------------------===//
 // DataLayoutEntryAttr
 //===----------------------------------------------------------------------===//
@@ -337,6 +344,320 @@ void DataLayoutSpecAttr::print(AsmPrinter &os) const {
   os << ">";
 }
 
+//===----------------------------------------------------------------------===//
+// TargetDeviceDescSpecAttr
+//===----------------------------------------------------------------------===//
+constexpr const StringLiteral mlir::TargetDeviceDescSpecAttr::kAttrKeyword;
+
+constexpr const StringLiteral mlir::DLTIDialect::kTargetDeviceIDKey;
+constexpr const StringLiteral mlir::DLTIDialect::kTargetDeviceTypeKey;
+constexpr const StringLiteral
+    mlir::DLTIDialect::kTargetDeviceMaxVectorOpWidthKey;
+constexpr const StringLiteral
+    mlir::DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey;
+constexpr const StringLiteral
+    mlir::DLTIDialect::kTargetDeviceCanonicalizerMaxIterationsKey;
+constexpr const StringLiteral
+    mlir::DLTIDialect::kTargetDeviceCanonicalizerMaxNumRewritesKey;
+
+namespace mlir {
+namespace impl {
+class TargetDeviceDescSpecAttrStorage : public AttributeStorage {
+public:
+  using KeyTy = ArrayRef<DataLayoutEntryInterface>;
+
+  TargetDeviceDescSpecAttrStorage(KeyTy entries) : entries(entries) {}
+
+  bool operator==(const KeyTy &key) const { return key == entries; }
+
+  static TargetDeviceDescSpecAttrStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    return new (allocator.allocate<TargetDeviceDescSpecAttrStorage>())
+        TargetDeviceDescSpecAttrStorage(allocator.copyInto(key));
+  }
+
+  ArrayRef<DataLayoutEntryInterface> entries;
+};
+} // namespace impl
+} // namespace mlir
+
+TargetDeviceDescSpecAttr
+TargetDeviceDescSpecAttr::get(MLIRContext *ctx,
+                              ArrayRef<DataLayoutEntryInterface> entries) {
+  return Base::get(ctx, entries);
+}
+
+DataLayoutEntryListRef TargetDeviceDescSpecAttr::getEntries() const {
+  return getImpl()->entries;
+}
+
+TargetDeviceDescSpecAttr TargetDeviceDescSpecAttr::getChecked(
+    function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
+    ArrayRef<DataLayoutEntryInterface> entries) {
+  return Base::getChecked(emitError, context, entries);
+}
+
+LogicalResult
+TargetDeviceDescSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                                 ArrayRef<DataLayoutEntryInterface> entries) {
+  // Entries in tdd_spec can only have StringAttr as key. It does not support
+  // type as a key. Hence not reusing DataLayoutEntryInterface::verify.
+  bool targetDeviceIDKeyPresentAndValid = false;
+  bool targetDeviceTypeKeyPresentAndValid = false;
+
+  DenseSet<StringAttr> ids;
+  for (DataLayoutEntryInterface entry : entries) {
+    if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+      return emitError()
+             << "dlti.tdd_spec attribute does not allow type as a key: "
+             << type;
+    } else {
+      auto id = entry.getKey().get<StringAttr>();
+      if (!ids.insert(id).second)
+        return emitError() << "repeated layout entry key: " << id.getValue();
+    }
+
+    // check that Device ID and Device Type are present.
+    StringRef entryName = entry.getKey().get<StringAttr>().strref();
+    if (entryName == DLTIDialect::kTargetDeviceIDKey) {
+      // Also check the type of the value.
+      IntegerAttr value =
+          llvm::dyn_cast_if_present<IntegerAttr>(entry.getValue());
+      if (value && value.getType().isUnsignedInteger(32)) {
+        targetDeviceIDKeyPresentAndValid = true;
+      }
+    } else if (entryName == DLTIDialect::kTargetDeviceTypeKey) {
+      // Also check the type of the value.
+      if (auto value = llvm::dyn_cast<StringAttr>(entry.getValue())) {
+        targetDeviceTypeKeyPresentAndValid = true;
+      }
+    }
+  }
+
+  // check that both DeviceID and DeviceType are present
+  // and are of correct type.
+  if (!targetDeviceIDKeyPresentAndValid) {
+    return emitError() << "tdd_spec requires key: "
+                       << DLTIDialect::kTargetDeviceIDKey
+                       << " and its value of ui32 type";
+  }
+  if (!targetDeviceTypeKeyPresentAndValid) {
+    return emitError() << "tdd_spec requires key: "
+                       << DLTIDialect::kTargetDeviceTypeKey
+                       << " and its value of string type";
+  }
+
+  return success();
+}
+
+/// Parses an attribute with syntax
+///   tdd_spec_attr ::= `#target.` `tdd_spec` `<` dl-entry-attr-list? `>`
+///   dl-entry-attr-list ::= dl-entry-attr
+///                         | dl-entry-attr `,` dl-entry-attr-list
+TargetDeviceDescSpecAttr TargetDeviceDescSpecAttr::parse(AsmParser &parser) {
+  if (failed(parser.parseLess()))
+    return {};
+
+  // Empty spec.
+  if (succeeded(parser.parseOptionalGreater()))
+    return get(parser.getContext(), {});
+
+  SmallVector<DataLayoutEntryInterface> entries;
+  if (parser.parseCommaSeparatedList(
+          [&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
+      parser.parseGreater())
+    return {};
+
+  return getChecked([&] { return parser.emitError(parser.getNameLoc()); },
+                    parser.getContext(), entries);
+}
+
+void TargetDeviceDescSpecAttr::print(AsmPrinter &os) const {
+  os << TargetDeviceDescSpecAttr::kAttrKeyword << "<";
+  llvm::interleaveComma(getEntries(), os);
+  os << ">";
+}
+
+// ---------------------------------------------------------------------------//
+//                      Support for specific keys
+// ---------------------------------------------------------------------------//
+
+StringAttr
+TargetDeviceDescSpecAttr::getDeviceIDIdentifier(MLIRContext *context) {
+  return Builder(context).getStringAttr(DLTIDialect::kTargetDeviceIDKey);
+}
+
+StringAttr
+TargetDeviceDescSpecAttr::getDeviceTypeIdentifier(MLIRContext *context) {
+  return Builder(context).getStringAttr(DLTIDialect::kTargetDeviceTypeKey);
+}
+
+StringAttr
+TargetDeviceDescSpecAttr::getMaxVectorOpWidthIdentifier(MLIRContext *context) {
+  return Builder(context).getStringAttr(
+      DLTIDialect::kTargetDeviceMaxVectorOpWidthKey);
+}
+
+StringAttr TargetDeviceDescSpecAttr::getL1CacheSizeInBytesIdentifier(
+    MLIRContext *context) {
+  return Builder(context).getStringAttr(
+      DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey);
+}
+
+StringAttr TargetDeviceDescSpecAttr::getCanonicalizerMaxIterationsIdentifier(
+    MLIRContext *context) {
+  return Builder(context).getStringAttr(
+      DLTIDialect::kTargetDeviceCanonicalizerMaxIterationsKey);
+}
+
+StringAttr TargetDeviceDescSpecAttr::getCanonicalizerMaxNumRewritesIdentifier(
+    MLIRContext *context) {
+  return Builder(context).getStringAttr(
+      DLTIDialect::kTargetDeviceCanonicalizerMaxNumRewritesKey);
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForDeviceID(MLIRContext *context) {
+  return getSpecForIdentifier(getDeviceIDIdentifier(context));
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForDeviceType(MLIRContext *context) {
+  return getSpecForIdentifier(getDeviceTypeIdentifier(context));
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForMaxVectorOpWidth(MLIRContext *context) {
+  return getSpecForIdentifier(getMaxVectorOpWidthIdentifier(context));
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForL1CacheSizeInBytes(MLIRContext *context) {
+  return getSpecForIdentifier(getL1CacheSizeInBytesIdentifier(context));
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForCanonicalizerMaxIterations(
+    MLIRContext *context) {
+  return getSpecForIdentifier(getCanonicalizerMaxIterationsIdentifier(context));
+}
+
+DataLayoutEntryInterface
+TargetDeviceDescSpecAttr::getSpecForCanonicalizerMaxNumRewrites(
+    MLIRContext *context) {
+  return getSpecForIdentifier(
+      getCanonicalizerMaxNumRewritesIdentifier(context));
+}
+
+uint32_t TargetDeviceDescSpecAttr::getDeviceID(MLIRContext *context) {
+  DataLayoutEntryInterface entry = getSpecForDeviceID(context);
+  return llvm::cast<IntegerAttr>(entry.getValue()).getValue().getZExtValue();
+}
+
+//===----------------------------------------------------------------------===//
+// TargetSystemDescSpecAttr
+//===----------------------------------------------------------------------===//
+
+constexpr const StringLiteral mlir::TargetSystemDescSpecAttr::kAttrKeyword;
+
+namespace mlir {
+namespace impl {
+class TargetSystemDescSpecAttrStorage : public AttributeStorage {
+public:
+  using KeyTy = ArrayRef<TargetDeviceDescSpecInterface>;
+
+  TargetSystemDescSpecAttrStorage(KeyTy entries) : entries(entries) {}
+
+  bool operator==(const KeyTy &key) const { return key == entries; }
+
+  static TargetSystemDescSpecAttrStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    return new (allocator.allocate<TargetSystemDescSpecAttrStorage>())
+        TargetSystemDescSpecAttrStorage(allocator.copyInto(key));
+  }
+
+  // This could be a map of DeviceID to DeviceDesc for faster lookup.
+  ArrayRef<TargetDeviceDescSpecInterface> entries;
+};
+} // namespace impl
+} // namespace mlir
+
+TargetSystemDescSpecAttr
+TargetSystemDescSpecAttr::get(MLIRContext *context,
+                              ArrayRef<TargetDeviceDescSpecInterface> entries) {
+  return Base::get(context, entries);
+}
+
+TargetDeviceDescSpecListRef TargetSystemDescSpecAttr::getEntries() const {
+  return getImpl()->entries;
+}
+
+TargetDeviceDescSpecInterface
+TargetSystemDescSpecAttr::getDeviceDescForDeviceID(
+    TargetDeviceDescSpecInterface::DeviceID DeviceID) {
+  for (TargetDeviceDescSpecInterface entry : getEntries()) {
+    if (entry.getDeviceID(getContext()) == DeviceID)
+      return entry;
+  }
+  return TargetDeviceDescSpecInterface();
+}
+
+TargetSystemDescSpecAttr TargetSystemDescSpecAttr::getChecked(
+    function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
+    ArrayRef<TargetDeviceDescSpecInterface> entries) {
+  return Base::getChecked(emitError, context, entries);
+}
+
+LogicalResult TargetSystemDescSpecAttr::verify(
+    function_ref<InFlightDiagnostic()> emitError,
+    ArrayRef<TargetDeviceDescSpecInterface> entries) {
+  DenseSet<uint32_t> device_ids;
+
+  for (TargetDeviceDescSpecInterface tdd_spec : entries) {
+    // First verify that a target device desc spec is valid.
+    if (failed(
+            TargetDeviceDescSpecAttr::verify(emitError, tdd_spec.getEntries())))
+      return failure();
+
+    // Check that device IDs are unique across all entries.
+    MLIRContext *context = tdd_spec.getContext();
+    uint32_t device_id = tdd_spec.getDeviceID(context);
+    if (!device_ids.insert(device_id).second) {
+      return emitError() << "repeated Device ID in dlti.tsd_spec: "
+                         << device_id;
+    }
+  }
+  return success();
+}
+
+/// Parses an attribute with syntax
+///   attr ::= `#target.` `tsd_spec` `<` tdd-spec-attr-list? `>`
+///   tdd-spec-attr-list ::= tdd_spec
+///                         | tdd_spec `,` tdd_spec_attr_list
+TargetSystemDescSpecAttr TargetSystemDescSpecAttr::parse(AsmParser &parser) {
+  if (failed(parser.parseLess()))
+    return {};
+
+  // Empty spec.
+  if (succeeded(parser.parseOptionalGreater()))
+    return get(parser.getContext(), {});
+
+  SmallVector<TargetDeviceDescSpecInterface> entries;
+  if (parser.parseCommaSeparatedList(
+          [&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
+      parser.parseGreater())
+    return {};
+
+  return getChecked([&] { return parser.emitError(parser.getNameLoc()); },
+                    parser.getContext(), entries);
+}
+
+void TargetSystemDescSpecAttr::print(AsmPrinter &os) const {
+  os << TargetSystemDescSpecAttr::kAttrKeyword << "<";
+  llvm::interleaveComma(getEntries(), os);
+  os << ">";
+}
+
 //===----------------------------------------------------------------------===//
 // DLTIDialect
 //===----------------------------------------------------------------------===//
@@ -375,9 +696,36 @@ class TargetDataLayoutInterface : public DataLayoutDialectInterface {
 };
 } // namespace
 
+namespace {
+class SystemDescSpecInterface : public DataLayoutDialectInterface {
+public:
+  using DataLayoutDialectInterface::DataLayoutDialectInterface;
+
+  LogicalResult verifyEntry(TargetDeviceDescSpecInterface entry,
+                            Location loc) const final {
+
+    for (DataLayoutEntryInterface dl_entry : entry.getEntries()) {
+      StringRef entryName = dl_entry.getKey().get<StringAttr>().strref();
+      // Check that the key name is known to us. Although, we may allow keys
+      // unknown to us.
+      if (entryName != DLTIDialect::kTargetDeviceIDKey &&
+          entryName != DLTIDialect::kTargetDeviceTypeKey &&
+          entryName != DLTIDialect::kTargetDeviceMaxVectorOpWidthKey &&
+          entryName != DLTIDialect::kTargetDeviceL1CacheSizeInBytesKey &&
+          entryName !=
+              DLTIDialect::kTargetDeviceCanonicalizerMaxIterationsKey &&
+          entryName != DLTIDialect::kTargetDeviceCanonicalizerMaxNumRewritesKey)
+        return emitError(loc) << "unknown target desc key name: " << entryName;
+    }
+    return success();
+  }
+};
+} // namespace
+
 void DLTIDialect::initialize() {
-  addAttributes<DataLayoutEntryAttr, DataLayoutSpecAttr>();
-  addInterfaces<TargetDataLayoutInterface>();
+  addAttributes<DataLayoutEntryAttr, DataLayoutSpecAttr,
+                TargetSystemDescSpecAttr, TargetDeviceDescSpecAttr>();
+  addInterfaces<TargetDataLayoutInterface, SystemDescSpecInterface>();
 }
 
 Attribute DLTIDialect::parseAttribute(DialectAsmParser &parser,
@@ -390,6 +738,10 @@ Attribute DLTIDialect::parseAttribute(DialectAsmParser &parser,
     return DataLayoutEntryAttr::parse(parser);
   if (attrKind == DataLayoutSpecAttr::kAttrKeyword)
     return DataLayoutSpecAttr::parse(parser);
+  if (attrKind == TargetSystemDescSpecAttr::kAttrKeyword)
+    return TargetSystemDescSpecAttr::parse(parser);
+  if (attrKind == TargetDeviceDescSpecAttr::kAttrKeyword)
+    return TargetDeviceDescSpecAttr::parse(parser);
 
   parser.emitError(parser.getNameLoc(), "unknown attrribute type: ")
       << attrKind;
@@ -398,8 +750,8 @@ Attribute DLTIDialect::parseAttribute(DialectAsmParser &parser,
 
 void DLTIDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
   llvm::TypeSwitch<Attribute>(attr)
-      .Case<DataLayoutEntryAttr, DataLayoutSpecAttr>(
-          [&](auto a) { a.print(os); })
+      .Case<DataLayoutEntryAttr, DataLayoutSpecAttr, TargetSystemDescSpecAttr,
+            TargetDeviceDescSpecAttr>([&](auto a) { a.print(os); })
       .Default([](Attribute) { llvm_unreachable("unknown attribute kind"); });
 }
 
@@ -413,6 +765,13 @@ LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
     if (isa<ModuleOp>(op))
       return detail::verifyDataLayoutOp(op);
     return success();
+  } else if (attr.getName() == DLTIDialect::kTargetSystemDescAttrName) {
+    if (!llvm::isa<TargetSystemDescSpecAttr>(attr.getValue())) {
+      return op->emitError()
+             << "'" << DLTIDialect::kTargetSystemDescAttrName
+             << "' is expected to be a #dlti.tsd_spec attribute";
+    }
+    return success();
   }
 
   return op->emitError() << "attribute '" << attr.getName().getValue()
diff --git a/mlir/lib/Dialect/DLTI/Traits.cpp b/mlir/lib/Dialect/DLTI/Traits.cpp
index 85acbee46defd..ead656774a27c 100644
--- a/mlir/lib/Dialect/DLTI/Traits.cpp
+++ b/mlir/lib/Dialect/DLTI/Traits.cpp
@@ -27,3 +27,9 @@ DataLayoutSpecInterface mlir::impl::getDataLayoutSpec(Operation *op) {
   return op->getAttrOfType<DataLayoutSpecAttr>(
       DLTIDialect::kDataLayoutAttrName);
 }
+
+TargetSystemDescSpecInterface
+mlir::impl::getTargetSystemDescSpec(Operation *op) {
+  return op->getAttrOfType<TargetSystemDescSpecAttr>(
+      DLTIDialect::kTargetSystemDescAttrName);
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index c07d1387ec753..7f3bef694474c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -18,6 +18,8 @@
 
 #include <optional>
 
+#define DEBUG_TYPE "block-pack-matmul"
+
 namespace mlir {
 #define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
 #include "mlir/Dialect/Linalg/Passes.h.inc"
@@ -134,6 +136,24 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
   return packTransposedMatmul;
 }
 
+static SmallVector<int64_t> getDefaultBlockFactors(linalg::LinalgOp linalgOp) {
+  // get L1 cache size first.
+  uint32_t L1_cache_size = 4096; // default value
+  uint32_t cpuID = 0;
+  ModuleOp moduleOp = linalgOp->getParentOfType<ModuleOp>();
+  if (std::optional<int64_t> v =
+          DataLayout(moduleOp).getL1CacheSizeInBytes(cpuID)) {
+    L1_cache_size = *v;
+  }
+
+  // block_size = sqrt(L1_cache_size) rounded down to nearest power of 2.
+  int64_t block_size =
+      std::pow(2, std::floor(std::log2(std::sqrt(L1_cache_size))));
+  // we use same block size for all dims.
+  LLVM_DEBUG(llvm::dbgs() << "block_size:" << block_size << "\n");
+  return {block_size, block_size, block_size};
+}
+
 /// Pack a matmul operation into blocked 4D layout.
 FailureOr<PackResult>
 linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
@@ -146,7 +166,7 @@ linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
     return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
 
   if (options->blockFactors.size() != 3)
-    return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
+    options->blockFactors = getDefaultBlockFactors(linalgOp);
 
   SmallVector<OpFoldResult> mnkTiles =
       getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index dcb1119fe5207..1d57e0bdef187 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -155,6 +155,17 @@ DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() {
   return {};
 }
 
+TargetSystemDescSpecInterface ModuleOp::getTargetSystemDescSpec() {
+  // Take the first and only (if present) attribute that implements the
+  // interface. This needs a linear search, but is called only once per data
+  // layout object construction that is used for repeated queries.
+  for (NamedAttribute attr : getOperation()->getAttrs())
+    if (auto spec =
+            llvm::dyn_cast<TargetSystemDescSpecInterface>(attr.getValue()))
+      return spec;
+  return {};
+}
+
 LogicalResult ModuleOp::verify() {
   // Check that none of the attributes are non-dialect attributes, except for
   // the symbol related attributes.
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 15cfb3dbaf745..857ccf03ed8c0 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -293,6 +293,50 @@ mlir::detail::getDefaultStackAlignment(DataLayoutEntryInterface entry) {
   return value.getValue().getZExtValue();
 }
 
+// Returns the max vector op width if specified in the given entry. If the entry
+// is empty (meaning the spec is missing), returns std::nullopt.
+std::optional<uint32_t>
+mlir::detail::getMaxVectorOpWidth(DataLayoutEntryInterface entry) {
+  if (entry == DataLayoutEntryInterface())
+    return std::nullopt;
+
+  auto value = cast<IntegerAttr>(entry.getValue());
+  return value.getValue().getZExtValue();
+}
+
+// Returns the L1 cache size if specified in the given entry. If the entry
+// is empty (meaning the spec is missing), returns std::nullopt.
+std::optional<uint32_t>
+mlir::detail::getL1CacheSizeInBytes(DataLayoutEntryInterface entry) {
+  if (entry == DataLayoutEntryInterface())
+    return std::nullopt;
+
+  auto value = cast<IntegerAttr>(entry.getValue());
+  return value.getValue().getZExtValue();
+}
+
+// Returns the canonicalizer max iterations if specified in the given entry.
+// If the entry is empty (meaning the spec is missing), returns std::nullopt.
+std::optional<int64_t>
+mlir::detail::getCanonicalizerMaxIterations(DataLayoutEntryInterface entry) {
+  if (entry == DataLayoutEntryInterface())
+    return std::nullopt;
+
+  auto value = cast<IntegerAttr>(entry.getValue());
+  return value.getValue().getSExtValue();
+}
+
+// Returns the canonicalizer max num rewrites if specified in the given entry.
+// If the entry is empty (meaning the spec is missing), returns std::nullopt.
+std::optional<int64_t>
+mlir::detail::getCanonicalizerMaxNumRewrites(DataLayoutEntryInterface entry) {
+  if (entry == DataLayoutEntryInterface())
+    return std::nullopt;
+
+  auto value = cast<IntegerAttr>(entry.getValue());
+  return value.getValue().getSExtValue();
+}
+
 DataLayoutEntryList
 mlir::detail::filterEntriesForType(DataLayoutEntryListRef entries,
                                    TypeID typeID) {
@@ -324,6 +368,20 @@ static DataLayoutSpecInterface getSpec(Operation *operation) {
       });
 }
 
+static TargetSystemDescSpecInterface
+getTargetSystemDescSpec(Operation *operation) {
+  if (operation) {
+    ModuleOp moduleOp;
+    if (isa<ModuleOp>(operation)) {
+      moduleOp = llvm::dyn_cast<ModuleOp>(operation);
+    } else {
+      moduleOp = operation->getParentOfType<ModuleOp>();
+    }
+    return moduleOp.getTargetSystemDescSpec();
+  } else
+    return TargetSystemDescSpecInterface();
+}
+
 /// Populates `opsWithLayout` with the list of proper ancestors of `leaf` that
 /// are either modules or implement the `DataLayoutOpInterface`.
 static void
@@ -435,7 +493,8 @@ mlir::DataLayout::DataLayout() : DataLayout(ModuleOp()) {}
 mlir::DataLayout::DataLayout(DataLayoutOpInterface op)
     : originalLayout(getCombinedDataLayout(op)), scope(op),
       allocaMemorySpace(std::nullopt), programMemorySpace(std::nullopt),
-      globalMemorySpace(std::nullopt), stackAlignment(std::nullopt) {
+      globalMemorySpace(std::nullopt), stackAlignment(std::nullopt),
+      originalTargetSystemDesc(getTargetSystemDescSpec(op)) {
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
   checkMissingLayout(originalLayout, op);
   collectParentLayouts(op, layoutStack);
@@ -445,7 +504,8 @@ mlir::DataLayout::DataLayout(DataLayoutOpInterface op)
 mlir::DataLayout::DataLayout(ModuleOp op)
     : originalLayout(getCombinedDataLayout(op)), scope(op),
       allocaMemorySpace(std::nullopt), programMemorySpace(std::nullopt),
-      globalMemorySpace(std::nullopt), stackAlignment(std::nullopt) {
+      globalMemorySpace(std::nullopt), stackAlignment(std::nullopt),
+      originalTargetSystemDesc(getTargetSystemDescSpec(op)) {
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
   checkMissingLayout(originalLayout, op);
   collectParentLayouts(op, layoutStack);
@@ -640,6 +700,78 @@ uint64_t mlir::DataLayout::getStackAlignment() const {
   return *stackAlignment;
 }
 
+std::optional<uint32_t> mlir::DataLayout::getMaxVectorOpWidth(
+    TargetDeviceDescSpecInterface::DeviceID deviceID) const {
+  checkValid();
+  DataLayoutEntryInterface entry;
+  if (originalTargetSystemDesc)
+    entry =
+        originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID)
+            .getSpecForMaxVectorOpWidth(originalTargetSystemDesc.getContext());
+  // Currently I am not caching the results because we do not return
+  // default values of these properties. Instead if the property is
+  // missing, we return std::nullopt so that the users can resort to
+  // the default value however they want.
+  if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+    return iface.getMaxVectorOpWidth(entry);
+  else
+    return detail::getMaxVectorOpWidth(entry);
+}
+
+std::optional<uint32_t> mlir::DataLayout::getL1CacheSizeInBytes(
+    TargetDeviceDescSpecInterface::DeviceID deviceID) const {
+  checkValid();
+  DataLayoutEntryInterface entry;
+  if (originalTargetSystemDesc)
+    entry = originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID)
+                .getSpecForL1CacheSizeInBytes(
+                    originalTargetSystemDesc.getContext());
+  // Currently I am not caching the results because we do not return
+  // default values of these properties. Instead if the property is
+  // missing, we return std::nullopt so that the users can resort to
+  // the default value however they want.
+  if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+    return iface.getL1CacheSizeInBytes(entry);
+  else
+    return detail::getL1CacheSizeInBytes(entry);
+}
+
+std::optional<int64_t> mlir::DataLayout::getCanonicalizerMaxIterations(
+    TargetDeviceDescSpecInterface::DeviceID deviceID) const {
+  checkValid();
+  DataLayoutEntryInterface entry;
+  if (originalTargetSystemDesc)
+    entry = originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID)
+                .getSpecForCanonicalizerMaxIterations(
+                    originalTargetSystemDesc.getContext());
+  // Currently I am not caching the results because we do not return
+  // default values of these properties. Instead if the property is
+  // missing, we return std::nullopt so that the users can resort to
+  // the default value however they want.
+  if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+    return iface.getCanonicalizerMaxIterations(entry);
+  else
+    return detail::getCanonicalizerMaxIterations(entry);
+}
+
+std::optional<int64_t> mlir::DataLayout::getCanonicalizerMaxNumRewrites(
+    TargetDeviceDescSpecInterface::DeviceID deviceID) const {
+  checkValid();
+  DataLayoutEntryInterface entry;
+  if (originalTargetSystemDesc)
+    entry = originalTargetSystemDesc.getDeviceDescForDeviceID(deviceID)
+                .getSpecForCanonicalizerMaxNumRewrites(
+                    originalTargetSystemDesc.getContext());
+  // Currently I am not caching the results because we do not return
+  // default values of these properties. Instead if the property is
+  // missing, we return std::nullopt so that the users can resort to
+  // the default value however they want.
+  if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+    return iface.getCanonicalizerMaxNumRewrites(entry);
+  else
+    return detail::getCanonicalizerMaxNumRewrites(entry);
+}
+
 //===----------------------------------------------------------------------===//
 // DataLayoutSpecInterface
 //===----------------------------------------------------------------------===//
@@ -744,6 +876,55 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
   return success();
 }
 
+LogicalResult
+mlir::detail::verifyTargetSystemDescSpec(TargetSystemDescSpecInterface spec,
+                                         Location loc) {
+  DenseMap<StringAttr, DataLayoutEntryInterface> device_desc_keys;
+  DenseSet<uint32_t> device_ids;
+  for (TargetDeviceDescSpecInterface tdd_spec : spec.getEntries()) {
+    // First, verify individual target device desc specs.
+    if (failed(tdd_spec.verifyEntry(loc)))
+      return failure();
+
+    // Check that device IDs are unique across all entries.
+    MLIRContext *context = tdd_spec.getContext();
+    uint32_t device_id = tdd_spec.getDeviceID(context);
+    if (!device_ids.insert(device_id).second) {
+      return failure();
+    }
+
+    // collect all the keys used by all the tdd_specs.
+    for (DataLayoutEntryInterface entry : tdd_spec.getEntries()) {
+      if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+        // tdd_spec does not support Type as a key.
+        return failure();
+      } else
+        device_desc_keys[entry.getKey().get<StringAttr>()] = entry;
+    }
+  }
+
+  for (const auto &kvp : device_desc_keys) {
+    StringAttr identifier = kvp.second.getKey().get<StringAttr>();
+    Dialect *dialect = identifier.getReferencedDialect();
+
+    // Ignore attributes that belong to an unknown dialect, the dialect may
+    // actually implement the relevant interface but we don't know about that.
+    if (!dialect)
+      continue;
+
+    const auto *iface = dyn_cast<DataLayoutDialectInterface>(dialect);
+    if (!iface) {
+      return emitError(loc)
+             << "the '" << dialect->getNamespace()
+             << "' dialect does not support identifier data layout entries";
+    }
+    if (failed(iface->verifyEntry(kvp.second, loc)))
+      return failure();
+  }
+
+  return success();
+}
+
 #include "mlir/Interfaces/DataLayoutAttrInterface.cpp.inc"
 #include "mlir/Interfaces/DataLayoutOpInterface.cpp.inc"
 #include "mlir/Interfaces/DataLayoutTypeInterface.cpp.inc"
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index d50019bd6aee5..2948804b8f92a 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -23,6 +23,8 @@ namespace mlir {
 
 using namespace mlir;
 
+#define DEBUG_TYPE "canonicalizer"
+
 namespace {
 /// Canonicalize operations in nested regions.
 struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
@@ -48,6 +50,20 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
     config.maxIterations = maxIterations;
     config.maxNumRewrites = maxNumRewrites;
 
+    LLVM_DEBUG(llvm::dbgs()
+               << "[CostModel] Canonicalizer MaxIterations (default):"
+               << config.maxIterations << "\n");
+    LLVM_DEBUG(llvm::dbgs()
+               << "[CostModel] Canonicalizer MaxNumRewrites (default):"
+               << config.maxNumRewrites << "\n");
+
+    LLVM_DEBUG(llvm::dbgs()
+               << "[CostModel] Canonicalizer MaxIterations (default):"
+               << config.maxIterations << "\n");
+    LLVM_DEBUG(llvm::dbgs()
+               << "[CostModel] Canonicalizer MaxNumRewrites (default):"
+               << config.maxNumRewrites << "\n");
+
     RewritePatternSet owningPatterns(context);
     for (auto *dialect : context->getLoadedDialects())
       dialect->getCanonicalizationPatterns(owningPatterns);
@@ -59,6 +75,41 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
     return success();
   }
   void runOnOperation() override {
+    Operation *op = getOperation();
+    uint32_t cpuID = 0;
+
+    if (isa<ModuleOp>(op)) {
+      if (std::optional<int64_t> v =
+              DataLayout(llvm::dyn_cast<ModuleOp>(*op))
+                  .getCanonicalizerMaxIterations(cpuID)) {
+        config.maxIterations = *v;
+      }
+    } else {
+      ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
+      if (std::optional<int64_t> v =
+              DataLayout(moduleOp).getCanonicalizerMaxIterations(cpuID)) {
+        config.maxIterations = *v;
+      }
+    }
+
+    if (isa<ModuleOp>(op)) {
+      if (std::optional<int64_t> v =
+              DataLayout(llvm::dyn_cast<ModuleOp>(*op))
+                  .getCanonicalizerMaxNumRewrites(cpuID)) {
+        config.maxNumRewrites = *v;
+      }
+    } else {
+      ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
+      if (std::optional<int64_t> v =
+              DataLayout(moduleOp).getCanonicalizerMaxNumRewrites(cpuID)) {
+        config.maxNumRewrites = *v;
+      }
+    }
+    LLVM_DEBUG(llvm::dbgs() << "[CostModel] Canonicalizer MaxIterations (new):"
+                            << config.maxIterations << "\n");
+    LLVM_DEBUG(llvm::dbgs() << "[CostModel] Canonicalizer MaxNumRewrites (new):"
+                            << config.maxNumRewrites << "\n");
+
     LogicalResult converged =
         applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
     // Canonicalization is best-effort. Non-convergence is not a pass failure.
diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
index 5f484294268ab..38f0ed0ed8da3 100644
--- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
@@ -197,6 +197,11 @@ struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
   }
 
+  TargetSystemDescSpecInterface getTargetSystemDescSpec() {
+    return getOperation()->getAttrOfType<TargetSystemDescSpecInterface>(
+        kAttrName);
+  }
+
   static llvm::TypeSize getTypeSizeInBits(Type type,
                                           const DataLayout &dataLayout,
                                           DataLayoutEntryListRef params) {
@@ -244,6 +249,11 @@ struct OpWith7BitByte
     return getOperation()->getAttrOfType<DataLayoutSpecInterface>(kAttrName);
   }
 
+  TargetSystemDescSpecInterface getTargetSystemDescSpec() {
+    return getOperation()->getAttrOfType<TargetSystemDescSpecInterface>(
+        kAttrName);
+  }
+
   // Bytes are assumed to be 7-bit here.
   static llvm::TypeSize getTypeSize(Type type, const DataLayout &dataLayout,
                                     DataLayoutEntryListRef params) {



More information about the Mlir-commits mailing list