[Mlir-commits] [mlir] [mlir] Target Description and Cost Model in MLIR (PR #85141)

Niranjan Hasabnis llvmlistbot at llvm.org
Wed Mar 13 15:26:23 PDT 2024


https://github.com/nhasabni created https://github.com/llvm/llvm-project/pull/85141

This pull request (PR) demonstrates one example of a target description and cost model in MLIR. See [RFC](https://discourse.llvm.org/t/rfc-target-description-and-cost-model-in-mlir/76990) for the context. **It is not a complete work by any means, and the main purpose of this PR is to initiate a discussion around this idea.**

At a high-level, the PR develops a mechanism to read the system description/cost model from a config file specified on the command line (The PR uses JSON format as an example). In case the user does not specify the description file, the cost model is populated with the default values (which could be known properties of the devices (e.g., L1 cache size) or heuristics used in various compiler passes (e.g., tile size for matrix multiplication)). The system description consists of a number of device descriptions for different devices in the system. The PR also demonstrates modifications to a couple of passes to access information from the device-specific cost models.

In terms of implementation, `SystemDesc` class implements methods to parse a system description file (and build a set of device descriptions) and to add/get device descriptions. `DeviceDesc` class implements methods to set/get fields of a device description (which is defined as a set of key-value pairs) and a set of APIs to get values of commonly-used device properties (such as vector width) or second-order info derived by the MLIR passes from those device properties (e.g., tile size). **Currently, these APIs are added in an ad-hoc manner and better design is possible.** The PR also provides an extendable mechanism (in the form of `DefaultBaseDeviceDesc`) to specify default values of various device properties as well as to derive second-order info from the device properties.

>From c190790d7d57145dd69434ca4ea86d4e66aa9de3 Mon Sep 17 00:00:00 2001
From: Niranjan Hasabnis <niranjan.hasabnis at intel.com>
Date: Wed, 13 Mar 2024 15:07:06 -0700
Subject: [PATCH] [mlir] Target Description and Cost Model in MLIR

This pull request (PR) demonstrates one example of a target description and cost model in MLIR. See [RFC](https://discourse.llvm.org/t/rfc-target-description-and-cost-model-in-mlir/76990) for the context. **It is not a complete work by any means, and the main purpose of this PR is to initiate a discussion around this idea.**

At a high-level, the PR develops a mechanism to read the system description/cost model from a config file specified on the command line (The PR uses JSON format as an example). In case the user does not specify the description file, the cost model is populated with the default values (which could be known properties of the devices (e.g., L1 cache size) or heuristics used in various compiler passes (e.g., tile size for matrix multiplication)). The system description consists of a number of device descriptions for different devices in the system. The PR also demonstrates modifications to a couple of passes to access information from the device-specific cost models.

In terms of implementation, `SystemDesc` class implements methods to parse a system description file (and build a set of device descriptions) and to add/get device descriptions. `DeviceDesc` class implements methods to set/get fields of a device description (which is defined as a set of key-value pairs) and a set of APIs to get values of commonly-used device properties (such as vector width) or second-order info derived by the MLIR passes from those device properties (e.g., tile size). **Currently, these APIs are added in an ad-hoc manner and better design is possible.** The PR also provides an extendable mechanism (in the form of `DefaultBaseDeviceDesc`) to specify default values of various device properties as well as to derive second-order info from the device properties.
---
 mlir/include/mlir/IR/MLIRContext.h            |   4 +
 mlir/include/mlir/Support/SystemDesc.h        | 514 ++++++++++++++++++
 .../include/mlir/Tools/mlir-opt/MlirOptMain.h |   8 +
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  14 +-
 mlir/lib/IR/MLIRContext.cpp                   |   6 +
 mlir/lib/Support/CMakeLists.txt               |   2 +
 mlir/lib/Support/SystemDesc.cpp               |  79 +++
 mlir/lib/Tools/mlir-opt/MlirOptMain.cpp       |  22 +
 mlir/lib/Transforms/Canonicalizer.cpp         |  25 +-
 9 files changed, 671 insertions(+), 3 deletions(-)
 create mode 100644 mlir/include/mlir/Support/SystemDesc.h
 create mode 100644 mlir/lib/Support/SystemDesc.cpp

diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 11e5329f43e681..6007f19ed503a3 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -34,6 +34,7 @@ class MLIRContextImpl;
 class RegisteredOperationName;
 class StorageUniquer;
 class IRUnit;
+class SystemDesc;
 
 /// MLIRContext is the top-level object for a collection of MLIR operations. It
 /// holds immortal uniqued objects like types, and the tables used to unique
@@ -240,6 +241,9 @@ class MLIRContext {
   /// (attributes, operations, types, etc.).
   llvm::hash_code getRegistryHash();
 
+  /// Get context-specific system description
+  SystemDesc &getSystemDesc();
+
   //===--------------------------------------------------------------------===//
   // Action API
   //===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Support/SystemDesc.h b/mlir/include/mlir/Support/SystemDesc.h
new file mode 100644
index 00000000000000..0784a5183d2061
--- /dev/null
+++ b/mlir/include/mlir/Support/SystemDesc.h
@@ -0,0 +1,514 @@
+//===- SYSTEMDESC.h - class to represent hardware configuration --*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Hardware configuration provides commonly used hardware information to
+// different users, such as optimization passes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_SYSTEMDESC_H
+#define MLIR_SUPPORT_SYSTEMDESC_H
+
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/JSON.h"
+
+/// Sytem description file contains a list of device descriptions that
+/// each describe a device (e.g., CPU, GPU, ASIC, etc.) in the system.
+/// Example:
+/// [
+///  {
+///    "ID": 1,
+///    "TYPE": "CPU",
+///    "DESCRIPTION": "Intel Xeon 8480",
+///    "L1_CACHE_SIZE_IN_BYTES": 8192,
+///    ...
+///  },
+///  {
+///
+///  },
+///  ...
+/// ]
+namespace mlir {
+
+/// Describes the individual device from the system description
+class DeviceDesc {
+public:
+  /// Some typedefs
+  using DeviceID = uint32_t;
+  using DevicePropertyName = std::string;
+  struct DevicePropertyValue {
+    enum Tag { INT, DOUBLE, INT_VECTOR, DOUBLE_VECTOR } tag;
+    struct Data {
+      int64_t iValue;
+      double dValue;
+      std::vector<int64_t> ivValue;
+      std::vector<double> dvValue;
+
+      Data() : iValue(0), dValue(0.0), ivValue({0}), dvValue({0.0}) {}
+      ~Data() {}
+    } data;
+
+    DevicePropertyValue() = default;
+    DevicePropertyValue(const mlir::DeviceDesc::DevicePropertyValue &rhs) {
+      this->tag = rhs.tag;
+      if (this->tag == INT)
+        this->data.iValue = rhs.data.iValue;
+      else if (this->tag == DOUBLE)
+        this->data.dValue = rhs.data.dValue;
+      else if (this->tag == INT_VECTOR)
+        this->data.ivValue = rhs.data.ivValue;
+      else
+        this->data.dvValue = rhs.data.dvValue;
+    }
+    bool operator==(const mlir::DeviceDesc::DevicePropertyValue &rhs) const {
+      return tag == rhs.tag &&
+             ((tag == INT && data.iValue == rhs.data.iValue) ||
+              (tag == DOUBLE && data.dValue == rhs.data.dValue) ||
+              (tag == INT_VECTOR && data.ivValue == rhs.data.ivValue) ||
+              (tag == DOUBLE_VECTOR && data.dvValue == rhs.data.dvValue));
+    }
+    bool operator!=(const mlir::DeviceDesc::DevicePropertyValue &rhs) const {
+      return !(*this == rhs);
+    }
+  };
+  using DevicePropertiesMapTy =
+      std::map<DevicePropertyName, DevicePropertyValue>;
+
+  typedef enum { CPU, GPU, SPECIAL } DeviceType;
+
+  /// Basic constructor
+  DeviceDesc() = delete;
+  DeviceDesc(DeviceID id, DeviceType type) : ID(id), type(type) {}
+  bool operator==(const mlir::DeviceDesc &rhs) const {
+    return ID == rhs.getID() && type == rhs.getType() &&
+           deviceProperties == rhs.getProperties();
+  }
+  bool operator!=(const mlir::DeviceDesc &rhs) const { return !(*this == rhs); }
+
+  /// Type converters
+  static DeviceID strToDeviceID(const std::string &id_str) {
+    llvm::Expected<int64_t> id = llvm::json::parse<int64_t>(id_str);
+    if (!id)
+      llvm::report_fatal_error("Value of \"ID\" is not int");
+    return static_cast<DeviceID>(id.get());
+  }
+  static DeviceType strToDeviceType(const std::string &type_str) {
+    if (type_str == "CPU")
+      return DeviceType::CPU;
+    else if (type_str == "GPU")
+      return DeviceType::GPU;
+    else if (type_str == "SPECIAL")
+      return DeviceType::SPECIAL;
+    llvm::report_fatal_error("Value of \"Type\" is not CPU, GPU, or SPECIAL");
+  }
+
+  /// Set description
+  DeviceDesc &setDescription(std::string desc) {
+    description = desc;
+    return *this;
+  }
+  /// Set property
+  DeviceDesc &setProperty(llvm::StringRef name, int64_t iv) {
+    DevicePropertyValue value;
+    value.tag = DevicePropertyValue::Tag::INT;
+    value.data.iValue = iv;
+    auto inserted =
+        deviceProperties.insert(std::make_pair(std::string(name), value));
+    if (!inserted.second && inserted.first->second != value) {
+      llvm::report_fatal_error("Duplicate device property name found:" + name);
+    }
+    return *this;
+  }
+  DeviceDesc &setProperty(llvm::StringRef name, double dv) {
+    DevicePropertyValue value;
+    value.tag = DevicePropertyValue::Tag::DOUBLE;
+    value.data.dValue = dv;
+    auto inserted =
+        deviceProperties.insert(std::make_pair(std::string(name), value));
+    if (!inserted.second && inserted.first->second != value) {
+      llvm::report_fatal_error("Duplicate device property name found:" + name);
+    }
+    return *this;
+  }
+  DeviceDesc &setProperty(llvm::StringRef name,
+                          const std::vector<int64_t> &ivv) {
+    DevicePropertyValue value;
+    value.tag = DevicePropertyValue::Tag::INT_VECTOR;
+    value.data.ivValue = ivv;
+    auto inserted =
+        deviceProperties.insert(std::make_pair(std::string(name), value));
+    if (!inserted.second && inserted.first->second != value) {
+      llvm::report_fatal_error("Duplicate device property name found:" + name);
+    }
+    return *this;
+  }
+  DeviceDesc &setProperty(llvm::StringRef name,
+                          const std::vector<double> &idv) {
+    DevicePropertyValue value;
+    value.tag = DevicePropertyValue::Tag::DOUBLE_VECTOR;
+    value.data.dvValue = idv;
+    auto inserted =
+        deviceProperties.insert(std::make_pair(std::string(name), value));
+    if (!inserted.second && inserted.first->second != value) {
+      llvm::report_fatal_error("Duplicate device property name found:" + name);
+    }
+    return *this;
+  }
+  // We provide convenience interface to handle int/float value as string
+  DeviceDesc &setProperty(llvm::StringRef name, const std::string &json_value) {
+    if (json_value.length() > 0 && json_value[0] == '[') {
+      // Parse as an array
+      llvm::Expected<std::vector<int64_t>> ivv =
+          llvm::json::parse<std::vector<int64_t>>(json_value);
+      if (ivv) {
+        *this = this->setProperty(name, ivv.get());
+        return *this;
+      }
+
+      llvm::Expected<std::vector<double>> idv =
+          llvm::json::parse<std::vector<double>>(json_value);
+      if (idv) {
+        *this = this->setProperty(name, idv.get());
+        return *this;
+      }
+    } else {
+      // int64_t because llvm::json has int64_t support (not int)
+      llvm::Expected<int64_t> iv = llvm::json::parse<int64_t>(json_value);
+      if (iv) {
+        *this = this->setProperty(name, iv.get());
+        return *this;
+      }
+
+      // Int type failed, try float now.
+      // double because llvm::json has double support (not float)
+      llvm::Expected<double> dv = llvm::json::parse<double>(json_value);
+      if (dv) {
+        *this = this->setProperty(name, dv.get());
+        return *this;
+      }
+    }
+
+    llvm::report_fatal_error(
+        "Neither int/float/vector value in Device Description: key " + name);
+  }
+
+  /// Get ID
+  DeviceID getID() const { return ID; }
+  /// Get device type
+  DeviceType getType() const { return type; }
+  /// Get device description
+  std::string getDescription() const { return description; }
+  /// Get all of device properties
+  const DevicePropertiesMapTy &getProperties() const {
+    return deviceProperties;
+  }
+  /// Get property value: returns the value of the property with given name, if
+  /// it exists. Otherwise throws exception (TODO)
+  std::optional<int64_t> getPropertyValueAsInt(llvm::StringRef name) const {
+    // check that property with the given name exists
+    auto iter = deviceProperties.find(std::string(name));
+    if (iter == deviceProperties.end()) {
+      return std::nullopt;
+    }
+    // TODO: we can do a tag check here.
+    return iter->second.data.iValue;
+  }
+  std::optional<double> getPropertyValueAsFloat(llvm::StringRef name) const {
+    // check that property with the given name exists
+    auto iter = deviceProperties.find(std::string(name));
+    if (iter == deviceProperties.end()) {
+      return std::nullopt;
+    }
+    // TODO: we can do a tag check here.
+    return iter->second.data.dValue;
+  }
+
+  /// Special functions
+  auto getAllDevicePropertyNames() const {
+    return llvm::map_range(
+        deviceProperties,
+        [](const DevicePropertiesMapTy::value_type &item) -> llvm::StringRef {
+          return item.first;
+        });
+  }
+
+  /// We use a list of key-value pairs to represent a system description in
+  /// JSON.
+  using DeviceDescJSONTy = std::map<std::string, std::string>;
+  static DeviceDesc
+  parseDeviceDescFromJSON(const DeviceDescJSONTy &device_desc);
+
+  // -----------------------------------------------------------------------
+  //          CPU specific methods
+  // -----------------------------------------------------------------------
+  static constexpr llvm::StringRef getCPUL1CacheSizeInBytesKeyName() {
+    return "L1_CACHE_SIZE_IN_BYTES";
+  }
+  static constexpr llvm::StringRef getConvAndMatMulBlockingFactorKeyName() {
+    return "CONV_AND_MATMUL_BLOCKING_FACTOR";
+  }
+  static constexpr llvm::StringRef getMatMulTileSizeInBytesKeyName() {
+    return "MATMUL_TILE_SIZE_IN_BYTES";
+  }
+  static constexpr llvm::StringRef getCanonicalizerMaxIterationsKeyName() {
+    return "CANONICALIZER_MAX_ITERS";
+  }
+  static constexpr llvm::StringRef getCanonicalizerMaxNumRewritesKeyName() {
+    return "CANONICALIZER_MAX_NUM_REWRITES";
+  }
+  static constexpr llvm::StringRef getMaxVectorWidthKeyName() {
+    return "MAX_VECTOR_WIDTH";
+  }
+
+  std::optional<int64_t> getL1CacheSizeInBytes() const {
+    if (std::optional<int64_t> v = this->getPropertyValueAsInt(
+            DeviceDesc::getCPUL1CacheSizeInBytesKeyName())) {
+      return v;
+    }
+    return std::nullopt;
+  }
+  void setL1CacheSizeInBytes(int64_t value) {
+    // Temporarily use int override until we support size_t
+    this->setProperty(DeviceDesc::getCPUL1CacheSizeInBytesKeyName(), value);
+  }
+  std::optional<int64_t> getConvAndMatMulBlockingFactor() const {
+    if (std::optional<int64_t> v = this->getPropertyValueAsInt(
+            DeviceDesc::getConvAndMatMulBlockingFactorKeyName())) {
+      return v;
+    }
+    return std::nullopt;
+  }
+  void setConvAndMatMulBlockingFactor(int64_t value) {
+    // Temporarily use int override until we support size_t
+    this->setProperty(DeviceDesc::getConvAndMatMulBlockingFactorKeyName(),
+                      value);
+  }
+  std::optional<int64_t> getMatMulTileSizeInBytes() const {
+    if (std::optional<int64_t> v = this->getPropertyValueAsInt(
+            DeviceDesc::getMatMulTileSizeInBytesKeyName())) {
+      return v;
+    }
+    return std::nullopt;
+  }
+  void setMatMulTileSizeInBytes(int64_t value) {
+    // Temporarily use int override until we support size_t
+    this->setProperty(DeviceDesc::getMatMulTileSizeInBytesKeyName(), value);
+  }
+  std::optional<int64_t> getCanonicalizerMaxNumRewrites() const {
+    if (std::optional<int64_t> v = this->getPropertyValueAsInt(
+            DeviceDesc::getCanonicalizerMaxNumRewritesKeyName())) {
+      return v;
+    }
+    return std::nullopt;
+  }
+  void setCanonicalizerMaxNumRewrites(int64_t value) {
+    this->setProperty(DeviceDesc::getCanonicalizerMaxNumRewritesKeyName(),
+                      value);
+  }
+  std::optional<int64_t> getCanonicalizerMaxIterations() const {
+    if (std::optional<int64_t> v = this->getPropertyValueAsInt(
+            DeviceDesc::getCanonicalizerMaxIterationsKeyName())) {
+      return v;
+    }
+    return std::nullopt;
+  }
+  void setCanonicalizerMaxIterations(int64_t value) {
+    this->setProperty(DeviceDesc::getCanonicalizerMaxIterationsKeyName(),
+                      value);
+  }
+  std::optional<int64_t> getMaxVectorWidth() const {
+    if (std::optional<int64_t> v = this->getPropertyValueAsInt(
+            DeviceDesc::getMaxVectorWidthKeyName())) {
+      return v;
+    }
+    return std::nullopt;
+  }
+  void setMaxVectorWidth(uint32_t value) {
+    this->setProperty(DeviceDesc::getMaxVectorWidthKeyName(),
+                      static_cast<int64_t>(value));
+  }
+
+private:
+  /// Unique device ID for every device
+  DeviceID ID;
+
+  /// Type of device
+  DeviceType type;
+
+  /// Some description of the device
+  std::string description;
+
+  /// Dictionary to store rest of the properties
+  DevicePropertiesMapTy deviceProperties;
+};
+
+class SystemDesc {
+public:
+  SystemDesc() = default;
+
+  /// Read and parse system description from JSON file
+  LogicalResult readSystemDescFromJSONFile(llvm::StringRef filename);
+  void writeSystemDescToJSONFile(llvm::StringRef filename);
+
+  /// Insert a new device description
+  SystemDesc &addDeviceDesc(const DeviceDesc &desc) {
+    auto inserted = deviceDescs.insert(std::make_pair(desc.getID(), desc));
+    if (!inserted.second || inserted.first->second != desc) {
+      llvm::report_fatal_error("Duplicate device description for ID:" +
+                               llvm::StringRef(std::to_string(desc.getID())));
+    }
+    return *this;
+  }
+  /// Get a device description
+  const DeviceDesc &getDeviceDesc(DeviceDesc::DeviceID deviceID) {
+    auto iter = deviceDescs.find(deviceID);
+    if (iter != deviceDescs.end()) {
+      return iter->second;
+    }
+    llvm::report_fatal_error("Device description with ID not found:" +
+                             llvm::StringRef(std::to_string(deviceID)));
+  }
+
+  /// Types
+  using DeviceDescsMapTy = std::map<DeviceDesc::DeviceID, DeviceDesc>;
+
+  // Generic functions: TODO
+  /// Get number of CPU devices in the system
+  static uint32_t getNumCPUDevices() { return 0; }
+  static uint32_t getNumGPUDevices() { return 0; }
+
+private:
+  SystemDesc(const SystemDesc &) = delete;
+  void operator=(const SystemDesc &) = delete;
+
+private:
+  /// Map to store all the device descriptions
+  DeviceDescsMapTy deviceDescs;
+};
+
+// An abstract class that represent device description for an abstract base
+// device
+//
+// This class specifies a set of device properties that could be specified by
+// the default device descriptor that will be used in case a user does not
+// specify its own properties for the device.
+class DefaultBaseDeviceDesc {
+public:
+  virtual ~DefaultBaseDeviceDesc() {}
+  virtual void registerDeviceDesc(MLIRContext *context) const = 0;
+
+  /// -----------------------------------------------------------------------
+  /// Device-agnostic parameters of system description
+  /// -----------------------------------------------------------------------
+  /// Set of common questions asked by various passes
+  // Blocking factor and tile size are typically used by tile/block passes.
+  virtual void setConvAndMatMulBlockingFactor(){};
+  virtual void setMatMulTileSize(){};
+
+  virtual void setCanonicalizerMaxIterations(){};
+  virtual void setCanonicalizerMaxNumRewrites(){};
+
+  /// -----------------------------------------------------------------------
+  /// CPU-specific parameters of system description
+  /// -----------------------------------------------------------------------
+  virtual void setL1CacheSizeInBytes(){};
+
+  /// -----------------------------------------------------------------------
+  /// GPU-specific parameters of system description
+  /// -----------------------------------------------------------------------
+  // Used by Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp#L52
+  virtual void setMaxVectorWidth(){};
+};
+
+// Class that represent device description for a typical CPU device
+class DefaultCPUDeviceDesc : public DefaultBaseDeviceDesc {
+public:
+  // We use default ID of 0 because we are expecting to have only one device so
+  // far. Not heterogeneous setup.
+  DefaultCPUDeviceDesc()
+      : cpu_device_desc(DeviceDesc(/* id */ 0, DeviceDesc::CPU)) {
+    // Register all system properties
+    this->setL1CacheSizeInBytes();
+    this->setConvAndMatMulBlockingFactor();
+    this->setMatMulTileSize();
+    this->setCanonicalizerMaxNumRewrites();
+    this->setCanonicalizerMaxIterations();
+  }
+
+  ~DefaultCPUDeviceDesc() {}
+
+  void registerDeviceDesc(MLIRContext *context) const override {
+    context->getSystemDesc().addDeviceDesc(cpu_device_desc);
+  }
+
+  // -------------------------------------------------------------------------
+  //                    CPU-specific properties
+  // -------------------------------------------------------------------------
+
+  void setL1CacheSizeInBytes() override {
+    cpu_device_desc.setL1CacheSizeInBytes(8192);
+  }
+  void setConvAndMatMulBlockingFactor() override {
+    cpu_device_desc.setConvAndMatMulBlockingFactor(32);
+  }
+  void setMatMulTileSize() override {
+    cpu_device_desc.setMatMulTileSizeInBytes(32);
+  }
+  void setCanonicalizerMaxNumRewrites() override {
+    // taken from include/mlir/Transforms/Passes.td
+    cpu_device_desc.setCanonicalizerMaxNumRewrites(-1);
+  }
+  void setCanonicalizerMaxIterations() override {
+    // taken from include/mlir/Transforms/Passes.td
+    cpu_device_desc.setCanonicalizerMaxIterations(10);
+  }
+
+private:
+  DeviceDesc cpu_device_desc;
+};
+
+class DefaultGPUDeviceDesc : public DefaultBaseDeviceDesc {
+public:
+  // We use default ID of 0 because we are expecting to have only one device so
+  // far. Not heterogeneous setup.
+  DefaultGPUDeviceDesc()
+      : gpu_device_desc(DeviceDesc(/* id */ 1, DeviceDesc::GPU)) {
+    // GPU device supports default value for MaxVectorWidth so far.
+    this->setMaxVectorWidth();
+  }
+
+  ~DefaultGPUDeviceDesc() {}
+
+  void registerDeviceDesc(MLIRContext *context) const override {
+    context->getSystemDesc().addDeviceDesc(gpu_device_desc);
+  }
+
+  // -------------------------------------------------------------------------
+  //                    GPU-specific properties
+  // -------------------------------------------------------------------------
+
+  void setMaxVectorWidth() override {
+    // Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp#L52
+    gpu_device_desc.setMaxVectorWidth(128);
+  }
+
+private:
+  DeviceDesc gpu_device_desc;
+};
+
+} // namespace mlir
+#endif // MLIR_SUPPORT_SYSTEMDESC_H
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 6e90fad1618d21..5022567a5f256b 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -176,6 +176,11 @@ class MlirOptMainConfig {
   /// Reproducer file generation (no crash required).
   StringRef getReproducerFilename() const { return generateReproducerFileFlag; }
 
+  /// System description file
+  StringRef getSystemDescriptionFileName() const {
+    return systemDescriptionFileFlag;
+  }
+
 protected:
   /// Allow operation with no registered dialects.
   /// This option is for convenience during testing only and discouraged in
@@ -234,6 +239,9 @@ class MlirOptMainConfig {
 
   /// The reproducer output filename (no crash required).
   std::string generateReproducerFileFlag = "";
+
+  /// The hardware config file
+  std::string systemDescriptionFileFlag = "";
 };
 
 /// This defines the function type used to setup the pass manager. This can be
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 7e073bae75c0c9..3ffb4a9ec6dc95 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Support/SystemDesc.h"
 
 #include "llvm/ADT/STLExtras.h"
 #include <optional>
@@ -29,6 +30,8 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::amdgpu;
 
+#define DEBUG_TYPE "amd-gpu-to-rocdl"
+
 static Value createI32Constant(ConversionPatternRewriter &rewriter,
                                Location loc, int32_t value) {
   Type llvmI32 = rewriter.getI32Type();
@@ -49,7 +52,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
       : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
 
   Chipset chipset;
-  static constexpr uint32_t maxVectorOpWidth = 128;
 
   LogicalResult
   matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
@@ -111,6 +113,16 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
       uint32_t elemBits = dataVector.getElementTypeBitWidth();
       uint32_t totalBits = elemBits * dataVector.getNumElements();
+      uint32_t maxVectorOpWidth = 128; // default value
+      if (std::optional<int64_t> v = gpuOp.getContext()
+                                         ->getSystemDesc()
+                                         .getDeviceDesc(1 /* gpuID */)
+                                         .getMaxVectorWidth()) {
+        maxVectorOpWidth = static_cast<uint32_t>(*v);
+      }
+      LLVM_DEBUG(llvm::dbgs() << "[CostModel] GPU MaxVectorWidth:"
+                              << maxVectorOpWidth << "\n");
+
       if (totalBits > maxVectorOpWidth)
         return gpuOp.emitOpError(
             "Total width of loads or stores must be no more than " +
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index e1e6d14231d9f1..2eb8628ace08a2 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -26,6 +26,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/Types.h"
+#include "mlir/Support/SystemDesc.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SmallString.h"
@@ -198,6 +199,9 @@ class MLIRContextImpl {
   /// A mutex used when accessing operation information.
   llvm::sys::SmartRWMutex<true> operationInfoMutex;
 
+  /// A class to describe hardware properties of a system
+  SystemDesc system_desc;
+
   //===--------------------------------------------------------------------===//
   // Affine uniquing
   //===--------------------------------------------------------------------===//
@@ -702,6 +706,8 @@ bool MLIRContext::isOperationRegistered(StringRef name) {
   return RegisteredOperationName::lookup(name, this).has_value();
 }
 
+SystemDesc &MLIRContext::getSystemDesc() { return impl->system_desc; }
+
 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
   auto &impl = context->getImpl();
   assert(impl.multiThreadedExecutionContext == 0 &&
diff --git a/mlir/lib/Support/CMakeLists.txt b/mlir/lib/Support/CMakeLists.txt
index 488decd52ae647..3ea5ed6d86f2f5 100644
--- a/mlir/lib/Support/CMakeLists.txt
+++ b/mlir/lib/Support/CMakeLists.txt
@@ -1,5 +1,6 @@
 set(LLVM_OPTIONAL_SOURCES
   FileUtilities.cpp
+  SystemDesc.cpp
   InterfaceSupport.cpp
   StorageUniquer.cpp
   Timing.cpp
@@ -9,6 +10,7 @@ set(LLVM_OPTIONAL_SOURCES
 
 add_mlir_library(MLIRSupport
   FileUtilities.cpp
+  SystemDesc.cpp
   InterfaceSupport.cpp
   RawOstreamExtras.cpp
   StorageUniquer.cpp
diff --git a/mlir/lib/Support/SystemDesc.cpp b/mlir/lib/Support/SystemDesc.cpp
new file mode 100644
index 00000000000000..a482ae7133cfe8
--- /dev/null
+++ b/mlir/lib/Support/SystemDesc.cpp
@@ -0,0 +1,79 @@
+//===- HardwareConfig.cpp - Hardware configuration
+//----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TODO
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/SystemDesc.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/JSON.h"
+#include "llvm/Support/ManagedStatic.h"
+#include "llvm/Support/MemoryBuffer.h"
+
+using namespace llvm;
+using namespace mlir;
+
+// ManagedStatic<SystemDesc> systemDesc;
+
+DeviceDesc DeviceDesc::parseDeviceDescFromJSON(
+    const DeviceDescJSONTy &device_desc_in_json) {
+  // ID and Type are mandatory fields.
+  auto iter = device_desc_in_json.find("ID");
+  if (iter == device_desc_in_json.end())
+    llvm::report_fatal_error("\"ID\" key missing in Device Description");
+  DeviceID id = DeviceDesc::strToDeviceID(iter->second);
+
+  iter = device_desc_in_json.find("Type");
+  if (iter == device_desc_in_json.end())
+    llvm::report_fatal_error("\"Type\" key missing in Device Description");
+  DeviceType type = DeviceDesc::strToDeviceType(iter->second);
+
+  // Now process optional fields: description and properties
+  DeviceDesc device_desc(id, type);
+  for (const auto &property : device_desc_in_json) {
+    // skip ID and Type as we have already processed those mandatory fields.
+    if (property.first != "ID" && property.first != "Type") {
+      if (property.first == "Description")
+        device_desc.setDescription(property.second);
+      else
+        device_desc.setProperty(property.first, property.second);
+    }
+  }
+  return device_desc;
+}
+
+LogicalResult SystemDesc::readSystemDescFromJSONFile(llvm::StringRef filename) {
+  std::string errorMessage;
+  std::unique_ptr<llvm::MemoryBuffer> file =
+      openInputFile(filename, &errorMessage);
+  if (!file) {
+    llvm::errs() << errorMessage << "\n";
+    return failure();
+  }
+
+  // Code to parse here
+  auto parsed = llvm::json::parse(file.get()->getBuffer());
+  if (!parsed) {
+    report_fatal_error(parsed.takeError());
+  }
+
+  json::Path::Root NullRoot;
+  using SystemDescJSONTy = std::vector<mlir::DeviceDesc::DeviceDescJSONTy>;
+  SystemDescJSONTy system_desc_in_json;
+  if (!json::fromJSON(*parsed, system_desc_in_json, NullRoot)) {
+    report_fatal_error("Invalid System Description in JSON");
+  }
+  for (auto device_desc_in_json : system_desc_in_json) {
+    auto device_desc = DeviceDesc::parseDeviceDescFromJSON(device_desc_in_json);
+    this->addDeviceDesc(device_desc);
+  }
+
+  return success();
+}
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index b62557153b4167..2ece33c4ea56c2 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/SystemDesc.h"
 #include "mlir/Support/Timing.h"
 #include "mlir/Support/ToolUtilities.h"
 #include "mlir/Tools/ParseUtilities.h"
@@ -161,6 +162,12 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
             cl::location(generateReproducerFileFlag), cl::init(""),
             cl::value_desc("filename"));
 
+    static cl::opt<std::string, /*ExternalStorage=*/true> systemDescriptionFile(
+        "mlir-system-description-file",
+        llvm::cl::desc("Name of the system description file"),
+        cl::location(systemDescriptionFileFlag), cl::init(""),
+        cl::value_desc("filename"));
+
     /// Set the callback to load a pass plugin.
     passPlugins.setCallback([&](const std::string &pluginPath) {
       auto plugin = PassPlugin::load(pluginPath);
@@ -381,6 +388,21 @@ performActions(raw_ostream &os,
 
   context->enableMultithreading(wasThreadingEnabled);
 
+  if (!config.getSystemDescriptionFileName().empty()) {
+    // If there is an error in file IO or parse error, we should report
+    // the error and fallback to default values.
+    if (failed(context->getSystemDesc().readSystemDescFromJSONFile(
+            config.getSystemDescriptionFileName()))) {
+      return failure();
+    }
+  } else {
+    DefaultCPUDeviceDesc default_cpu_device_desc;
+    default_cpu_device_desc.registerDeviceDesc(context);
+
+    DefaultGPUDeviceDesc default_gpu_device_desc;
+    default_gpu_device_desc.registerDeviceDesc(context);
+  }
+
   // Prepare the pass manager, applying command-line and reproducer options.
   PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit);
   pm.enableVerifier(config.shouldVerifyPasses());
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index d50019bd6aee55..e3abcb71998c3e 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Transforms/Passes.h"
 
 #include "mlir/Pass/Pass.h"
+#include "mlir/Support/SystemDesc.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
@@ -23,6 +24,8 @@ namespace mlir {
 
 using namespace mlir;
 
+#define DEBUG_TYPE "canonicalizer"
+
 namespace {
 /// Canonicalize operations in nested regions.
 struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
@@ -45,8 +48,26 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
     // Set the config from possible pass options set in the meantime.
     config.useTopDownTraversal = topDownProcessingEnabled;
     config.enableRegionSimplification = enableRegionSimplification;
-    config.maxIterations = maxIterations;
-    config.maxNumRewrites = maxNumRewrites;
+    DeviceDesc::DeviceID cpuID = 0;
+    if (std::optional<int64_t> v = context->getSystemDesc()
+                                       .getDeviceDesc(cpuID)
+                                       .getCanonicalizerMaxIterations()) {
+      config.maxIterations = *v;
+    } else {
+      config.maxIterations = maxIterations;
+    }
+
+    if (std::optional<int64_t> v = context->getSystemDesc()
+                                       .getDeviceDesc(cpuID)
+                                       .getCanonicalizerMaxNumRewrites()) {
+      config.maxNumRewrites = *v;
+    } else {
+      config.maxNumRewrites = maxNumRewrites;
+    }
+    LLVM_DEBUG(llvm::dbgs() << "[CostModel] Canonicalizer MaxIterations:"
+                            << config.maxIterations << "\n");
+    LLVM_DEBUG(llvm::dbgs() << "[CostModel] Canonicalizer MaxNumRewrites:"
+                            << config.maxNumRewrites << "\n");
 
     RewritePatternSet owningPatterns(context);
     for (auto *dialect : context->getLoadedDialects())



More information about the Mlir-commits mailing list