[Mlir-commits] [mlir] [mlir] Target Description and Cost Model in MLIR (PR #85141)

Niranjan Hasabnis llvmlistbot at llvm.org
Fri Mar 15 15:41:10 PDT 2024


https://github.com/nhasabni updated https://github.com/llvm/llvm-project/pull/85141

>From c190790d7d57145dd69434ca4ea86d4e66aa9de3 Mon Sep 17 00:00:00 2001
From: Niranjan Hasabnis <niranjan.hasabnis at intel.com>
Date: Wed, 13 Mar 2024 15:07:06 -0700
Subject: [PATCH 1/2] [mlir] Target Description and Cost Model in MLIR

This pull request (PR) demonstrates one example of a target description and cost model in MLIR. See [RFC](https://discourse.llvm.org/t/rfc-target-description-and-cost-model-in-mlir/76990) for the context. **It is not a complete work by any means, and the main purpose of this PR is to initiate a discussion around this idea.**

At a high-level, the PR develops a mechanism to read the system description/cost model from a config file specified on the command line (The PR uses JSON format as an example). In case the user does not specify the description file, the cost model is populated with the default values (which could be known properties of the devices (e.g., L1 cache size) or heuristics used in various compiler passes (e.g., tile size for matrix multiplication)). The system description consists of a number of device descriptions for different devices in the system. The PR also demonstrates modifications to a couple of passes to access information from the device-specific cost models.

In terms of implementation, `SystemDesc` class implements methods to parse a system description file (and build a set of device descriptions) and to add/get device descriptions. `DeviceDesc` class implements methods to set/get fields of a device description (which is defined as a set of key-value pairs) and a set of APIs to get values of commonly-used device properties (such as vector width) or second-order info derived by the MLIR passes from those device properties (e.g., tile size). **Currently, these APIs are added in an ad-hoc manner and better design is possible.** The PR also provides an extendable mechanism (in the form of `DefaultBaseDeviceDesc`) to specify default values of various device properties as well as to derive second-order info from the device properties.
---
 mlir/include/mlir/IR/MLIRContext.h            |   4 +
 mlir/include/mlir/Support/SystemDesc.h        | 514 ++++++++++++++++++
 .../include/mlir/Tools/mlir-opt/MlirOptMain.h |   8 +
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  14 +-
 mlir/lib/IR/MLIRContext.cpp                   |   6 +
 mlir/lib/Support/CMakeLists.txt               |   2 +
 mlir/lib/Support/SystemDesc.cpp               |  79 +++
 mlir/lib/Tools/mlir-opt/MlirOptMain.cpp       |  22 +
 mlir/lib/Transforms/Canonicalizer.cpp         |  25 +-
 9 files changed, 671 insertions(+), 3 deletions(-)
 create mode 100644 mlir/include/mlir/Support/SystemDesc.h
 create mode 100644 mlir/lib/Support/SystemDesc.cpp

diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 11e5329f43e681..6007f19ed503a3 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -34,6 +34,7 @@ class MLIRContextImpl;
 class RegisteredOperationName;
 class StorageUniquer;
 class IRUnit;
+class SystemDesc;
 
 /// MLIRContext is the top-level object for a collection of MLIR operations. It
 /// holds immortal uniqued objects like types, and the tables used to unique
@@ -240,6 +241,9 @@ class MLIRContext {
   /// (attributes, operations, types, etc.).
   llvm::hash_code getRegistryHash();
 
+  /// Get context-specific system description
+  SystemDesc &getSystemDesc();
+
   //===--------------------------------------------------------------------===//
   // Action API
   //===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Support/SystemDesc.h b/mlir/include/mlir/Support/SystemDesc.h
new file mode 100644
index 00000000000000..0784a5183d2061
--- /dev/null
+++ b/mlir/include/mlir/Support/SystemDesc.h
@@ -0,0 +1,514 @@
+//===- SYSTEMDESC.h - class to represent hardware configuration --*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Hardware configuration provides commonly used hardware information to
+// different users, such as optimization passes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_SYSTEMDESC_H
+#define MLIR_SUPPORT_SYSTEMDESC_H
+
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/JSON.h"
+
+/// Sytem description file contains a list of device descriptions that
+/// each describe a device (e.g., CPU, GPU, ASIC, etc.) in the system.
+/// Example:
+/// [
+///  {
+///    "ID": 1,
+///    "TYPE": "CPU",
+///    "DESCRIPTION": "Intel Xeon 8480",
+///    "L1_CACHE_SIZE_IN_BYTES": 8192,
+///    ...
+///  },
+///  {
+///
+///  },
+///  ...
+/// ]
+namespace mlir {
+
+/// Describes the individual device from the system description
+class DeviceDesc {
+public:
+  /// Some typedefs
+  using DeviceID = uint32_t;
+  using DevicePropertyName = std::string;
+  struct DevicePropertyValue {
+    enum Tag { INT, DOUBLE, INT_VECTOR, DOUBLE_VECTOR } tag;
+    struct Data {
+      int64_t iValue;
+      double dValue;
+      std::vector<int64_t> ivValue;
+      std::vector<double> dvValue;
+
+      Data() : iValue(0), dValue(0.0), ivValue({0}), dvValue({0.0}) {}
+      ~Data() {}
+    } data;
+
+    DevicePropertyValue() = default;
+    DevicePropertyValue(const mlir::DeviceDesc::DevicePropertyValue &rhs) {
+      this->tag = rhs.tag;
+      if (this->tag == INT)
+        this->data.iValue = rhs.data.iValue;
+      else if (this->tag == DOUBLE)
+        this->data.dValue = rhs.data.dValue;
+      else if (this->tag == INT_VECTOR)
+        this->data.ivValue = rhs.data.ivValue;
+      else
+        this->data.dvValue = rhs.data.dvValue;
+    }
+    bool operator==(const mlir::DeviceDesc::DevicePropertyValue &rhs) const {
+      return tag == rhs.tag &&
+             ((tag == INT && data.iValue == rhs.data.iValue) ||
+              (tag == DOUBLE && data.dValue == rhs.data.dValue) ||
+              (tag == INT_VECTOR && data.ivValue == rhs.data.ivValue) ||
+              (tag == DOUBLE_VECTOR && data.dvValue == rhs.data.dvValue));
+    }
+    bool operator!=(const mlir::DeviceDesc::DevicePropertyValue &rhs) const {
+      return !(*this == rhs);
+    }
+  };
+  using DevicePropertiesMapTy =
+      std::map<DevicePropertyName, DevicePropertyValue>;
+
+  typedef enum { CPU, GPU, SPECIAL } DeviceType;
+
+  /// Basic constructor
+  DeviceDesc() = delete;
+  DeviceDesc(DeviceID id, DeviceType type) : ID(id), type(type) {}
+  bool operator==(const mlir::DeviceDesc &rhs) const {
+    return ID == rhs.getID() && type == rhs.getType() &&
+           deviceProperties == rhs.getProperties();
+  }
+  bool operator!=(const mlir::DeviceDesc &rhs) const { return !(*this == rhs); }
+
+  /// Type converters
+  static DeviceID strToDeviceID(const std::string &id_str) {
+    llvm::Expected<int64_t> id = llvm::json::parse<int64_t>(id_str);
+    if (!id)
+      llvm::report_fatal_error("Value of \"ID\" is not int");
+    return static_cast<DeviceID>(id.get());
+  }
+  static DeviceType strToDeviceType(const std::string &type_str) {
+    if (type_str == "CPU")
+      return DeviceType::CPU;
+    else if (type_str == "GPU")
+      return DeviceType::GPU;
+    else if (type_str == "SPECIAL")
+      return DeviceType::SPECIAL;
+    llvm::report_fatal_error("Value of \"Type\" is not CPU, GPU, or SPECIAL");
+  }
+
+  /// Set description
+  DeviceDesc &setDescription(std::string desc) {
+    description = desc;
+    return *this;
+  }
+  /// Set property
+  DeviceDesc &setProperty(llvm::StringRef name, int64_t iv) {
+    DevicePropertyValue value;
+    value.tag = DevicePropertyValue::Tag::INT;
+    value.data.iValue = iv;
+    auto inserted =
+        deviceProperties.insert(std::make_pair(std::string(name), value));
+    if (!inserted.second && inserted.first->second != value) {
+      llvm::report_fatal_error("Duplicate device property name found:" + name);
+    }
+    return *this;
+  }
+  DeviceDesc &setProperty(llvm::StringRef name, double dv) {
+    DevicePropertyValue value;
+    value.tag = DevicePropertyValue::Tag::DOUBLE;
+    value.data.dValue = dv;
+    auto inserted =
+        deviceProperties.insert(std::make_pair(std::string(name), value));
+    if (!inserted.second && inserted.first->second != value) {
+      llvm::report_fatal_error("Duplicate device property name found:" + name);
+    }
+    return *this;
+  }
+  DeviceDesc &setProperty(llvm::StringRef name,
+                          const std::vector<int64_t> &ivv) {
+    DevicePropertyValue value;
+    value.tag = DevicePropertyValue::Tag::INT_VECTOR;
+    value.data.ivValue = ivv;
+    auto inserted =
+        deviceProperties.insert(std::make_pair(std::string(name), value));
+    if (!inserted.second && inserted.first->second != value) {
+      llvm::report_fatal_error("Duplicate device property name found:" + name);
+    }
+    return *this;
+  }
+  DeviceDesc &setProperty(llvm::StringRef name,
+                          const std::vector<double> &idv) {
+    DevicePropertyValue value;
+    value.tag = DevicePropertyValue::Tag::DOUBLE_VECTOR;
+    value.data.dvValue = idv;
+    auto inserted =
+        deviceProperties.insert(std::make_pair(std::string(name), value));
+    if (!inserted.second && inserted.first->second != value) {
+      llvm::report_fatal_error("Duplicate device property name found:" + name);
+    }
+    return *this;
+  }
+  // We provide convenience interface to handle int/float value as string
+  DeviceDesc &setProperty(llvm::StringRef name, const std::string &json_value) {
+    if (json_value.length() > 0 && json_value[0] == '[') {
+      // Parse as an array
+      llvm::Expected<std::vector<int64_t>> ivv =
+          llvm::json::parse<std::vector<int64_t>>(json_value);
+      if (ivv) {
+        *this = this->setProperty(name, ivv.get());
+        return *this;
+      }
+
+      llvm::Expected<std::vector<double>> idv =
+          llvm::json::parse<std::vector<double>>(json_value);
+      if (idv) {
+        *this = this->setProperty(name, idv.get());
+        return *this;
+      }
+    } else {
+      // int64_t because llvm::json has int64_t support (not int)
+      llvm::Expected<int64_t> iv = llvm::json::parse<int64_t>(json_value);
+      if (iv) {
+        *this = this->setProperty(name, iv.get());
+        return *this;
+      }
+
+      // Int type failed, try float now.
+      // double because llvm::json has double support (not float)
+      llvm::Expected<double> dv = llvm::json::parse<double>(json_value);
+      if (dv) {
+        *this = this->setProperty(name, dv.get());
+        return *this;
+      }
+    }
+
+    llvm::report_fatal_error(
+        "Neither int/float/vector value in Device Description: key " + name);
+  }
+
+  /// Get ID
+  DeviceID getID() const { return ID; }
+  /// Get device type
+  DeviceType getType() const { return type; }
+  /// Get device description
+  std::string getDescription() const { return description; }
+  /// Get all of device properties
+  const DevicePropertiesMapTy &getProperties() const {
+    return deviceProperties;
+  }
+  /// Get property value: returns the value of the property with given name, if
+  /// it exists. Otherwise throws exception (TODO)
+  std::optional<int64_t> getPropertyValueAsInt(llvm::StringRef name) const {
+    // check that property with the given name exists
+    auto iter = deviceProperties.find(std::string(name));
+    if (iter == deviceProperties.end()) {
+      return std::nullopt;
+    }
+    // TODO: we can do a tag check here.
+    return iter->second.data.iValue;
+  }
+  std::optional<double> getPropertyValueAsFloat(llvm::StringRef name) const {
+    // check that property with the given name exists
+    auto iter = deviceProperties.find(std::string(name));
+    if (iter == deviceProperties.end()) {
+      return std::nullopt;
+    }
+    // TODO: we can do a tag check here.
+    return iter->second.data.dValue;
+  }
+
+  /// Special functions
+  auto getAllDevicePropertyNames() const {
+    return llvm::map_range(
+        deviceProperties,
+        [](const DevicePropertiesMapTy::value_type &item) -> llvm::StringRef {
+          return item.first;
+        });
+  }
+
+  /// We use a list of key-value pairs to represent a system description in
+  /// JSON.
+  using DeviceDescJSONTy = std::map<std::string, std::string>;
+  static DeviceDesc
+  parseDeviceDescFromJSON(const DeviceDescJSONTy &device_desc);
+
+  // -----------------------------------------------------------------------
+  //          CPU specific methods
+  // -----------------------------------------------------------------------
+  static constexpr llvm::StringRef getCPUL1CacheSizeInBytesKeyName() {
+    return "L1_CACHE_SIZE_IN_BYTES";
+  }
+  static constexpr llvm::StringRef getConvAndMatMulBlockingFactorKeyName() {
+    return "CONV_AND_MATMUL_BLOCKING_FACTOR";
+  }
+  static constexpr llvm::StringRef getMatMulTileSizeInBytesKeyName() {
+    return "MATMUL_TILE_SIZE_IN_BYTES";
+  }
+  static constexpr llvm::StringRef getCanonicalizerMaxIterationsKeyName() {
+    return "CANONICALIZER_MAX_ITERS";
+  }
+  static constexpr llvm::StringRef getCanonicalizerMaxNumRewritesKeyName() {
+    return "CANONICALIZER_MAX_NUM_REWRITES";
+  }
+  static constexpr llvm::StringRef getMaxVectorWidthKeyName() {
+    return "MAX_VECTOR_WIDTH";
+  }
+
+  std::optional<int64_t> getL1CacheSizeInBytes() const {
+    if (std::optional<int64_t> v = this->getPropertyValueAsInt(
+            DeviceDesc::getCPUL1CacheSizeInBytesKeyName())) {
+      return v;
+    }
+    return std::nullopt;
+  }
+  void setL1CacheSizeInBytes(int64_t value) {
+    // Temporarily use int override until we support size_t
+    this->setProperty(DeviceDesc::getCPUL1CacheSizeInBytesKeyName(), value);
+  }
+  std::optional<int64_t> getConvAndMatMulBlockingFactor() const {
+    if (std::optional<int64_t> v = this->getPropertyValueAsInt(
+            DeviceDesc::getConvAndMatMulBlockingFactorKeyName())) {
+      return v;
+    }
+    return std::nullopt;
+  }
+  void setConvAndMatMulBlockingFactor(int64_t value) {
+    // Temporarily use int override until we support size_t
+    this->setProperty(DeviceDesc::getConvAndMatMulBlockingFactorKeyName(),
+                      value);
+  }
+  std::optional<int64_t> getMatMulTileSizeInBytes() const {
+    if (std::optional<int64_t> v = this->getPropertyValueAsInt(
+            DeviceDesc::getMatMulTileSizeInBytesKeyName())) {
+      return v;
+    }
+    return std::nullopt;
+  }
+  void setMatMulTileSizeInBytes(int64_t value) {
+    // Temporarily use int override until we support size_t
+    this->setProperty(DeviceDesc::getMatMulTileSizeInBytesKeyName(), value);
+  }
+  std::optional<int64_t> getCanonicalizerMaxNumRewrites() const {
+    if (std::optional<int64_t> v = this->getPropertyValueAsInt(
+            DeviceDesc::getCanonicalizerMaxNumRewritesKeyName())) {
+      return v;
+    }
+    return std::nullopt;
+  }
+  void setCanonicalizerMaxNumRewrites(int64_t value) {
+    this->setProperty(DeviceDesc::getCanonicalizerMaxNumRewritesKeyName(),
+                      value);
+  }
+  std::optional<int64_t> getCanonicalizerMaxIterations() const {
+    if (std::optional<int64_t> v = this->getPropertyValueAsInt(
+            DeviceDesc::getCanonicalizerMaxIterationsKeyName())) {
+      return v;
+    }
+    return std::nullopt;
+  }
+  void setCanonicalizerMaxIterations(int64_t value) {
+    this->setProperty(DeviceDesc::getCanonicalizerMaxIterationsKeyName(),
+                      value);
+  }
+  std::optional<int64_t> getMaxVectorWidth() const {
+    if (std::optional<int64_t> v = this->getPropertyValueAsInt(
+            DeviceDesc::getMaxVectorWidthKeyName())) {
+      return v;
+    }
+    return std::nullopt;
+  }
+  void setMaxVectorWidth(uint32_t value) {
+    this->setProperty(DeviceDesc::getMaxVectorWidthKeyName(),
+                      static_cast<int64_t>(value));
+  }
+
+private:
+  /// Unique device ID for every device
+  DeviceID ID;
+
+  /// Type of device
+  DeviceType type;
+
+  /// Some description of the device
+  std::string description;
+
+  /// Dictionary to store rest of the properties
+  DevicePropertiesMapTy deviceProperties;
+};
+
+class SystemDesc {
+public:
+  SystemDesc() = default;
+
+  /// Read and parse system description from JSON file
+  LogicalResult readSystemDescFromJSONFile(llvm::StringRef filename);
+  void writeSystemDescToJSONFile(llvm::StringRef filename);
+
+  /// Insert a new device description
+  SystemDesc &addDeviceDesc(const DeviceDesc &desc) {
+    auto inserted = deviceDescs.insert(std::make_pair(desc.getID(), desc));
+    if (!inserted.second || inserted.first->second != desc) {
+      llvm::report_fatal_error("Duplicate device description for ID:" +
+                               llvm::StringRef(std::to_string(desc.getID())));
+    }
+    return *this;
+  }
+  /// Get a device description
+  const DeviceDesc &getDeviceDesc(DeviceDesc::DeviceID deviceID) {
+    auto iter = deviceDescs.find(deviceID);
+    if (iter != deviceDescs.end()) {
+      return iter->second;
+    }
+    llvm::report_fatal_error("Device description with ID not found:" +
+                             llvm::StringRef(std::to_string(deviceID)));
+  }
+
+  /// Types
+  using DeviceDescsMapTy = std::map<DeviceDesc::DeviceID, DeviceDesc>;
+
+  // Generic functions: TODO
+  /// Get number of CPU devices in the system
+  static uint32_t getNumCPUDevices() { return 0; }
+  static uint32_t getNumGPUDevices() { return 0; }
+
+private:
+  SystemDesc(const SystemDesc &) = delete;
+  void operator=(const SystemDesc &) = delete;
+
+private:
+  /// Map to store all the device descriptions
+  DeviceDescsMapTy deviceDescs;
+};
+
+// An abstract class that represent device description for an abstract base
+// device
+//
+// This class specifies a set of device properties that could be specified by
+// the default device descriptor that will be used in case a user does not
+// specify its own properties for the device.
+class DefaultBaseDeviceDesc {
+public:
+  virtual ~DefaultBaseDeviceDesc() {}
+  virtual void registerDeviceDesc(MLIRContext *context) const = 0;
+
+  /// -----------------------------------------------------------------------
+  /// Device-agnostic parameters of system description
+  /// -----------------------------------------------------------------------
+  /// Set of common questions asked by various passes
+  // Blocking factor and tile size are typically used by tile/block passes.
+  virtual void setConvAndMatMulBlockingFactor(){};
+  virtual void setMatMulTileSize(){};
+
+  virtual void setCanonicalizerMaxIterations(){};
+  virtual void setCanonicalizerMaxNumRewrites(){};
+
+  /// -----------------------------------------------------------------------
+  /// CPU-specific parameters of system description
+  /// -----------------------------------------------------------------------
+  virtual void setL1CacheSizeInBytes(){};
+
+  /// -----------------------------------------------------------------------
+  /// GPU-specific parameters of system description
+  /// -----------------------------------------------------------------------
+  // Used by Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp#L52
+  virtual void setMaxVectorWidth(){};
+};
+
+// Class that represent device description for a typical CPU device
+class DefaultCPUDeviceDesc : public DefaultBaseDeviceDesc {
+public:
+  // We use default ID of 0 because we are expecting to have only one device so
+  // far. Not heterogeneous setup.
+  DefaultCPUDeviceDesc()
+      : cpu_device_desc(DeviceDesc(/* id */ 0, DeviceDesc::CPU)) {
+    // Register all system properties
+    this->setL1CacheSizeInBytes();
+    this->setConvAndMatMulBlockingFactor();
+    this->setMatMulTileSize();
+    this->setCanonicalizerMaxNumRewrites();
+    this->setCanonicalizerMaxIterations();
+  }
+
+  ~DefaultCPUDeviceDesc() {}
+
+  void registerDeviceDesc(MLIRContext *context) const override {
+    context->getSystemDesc().addDeviceDesc(cpu_device_desc);
+  }
+
+  // -------------------------------------------------------------------------
+  //                    CPU-specific properties
+  // -------------------------------------------------------------------------
+
+  void setL1CacheSizeInBytes() override {
+    cpu_device_desc.setL1CacheSizeInBytes(8192);
+  }
+  void setConvAndMatMulBlockingFactor() override {
+    cpu_device_desc.setConvAndMatMulBlockingFactor(32);
+  }
+  void setMatMulTileSize() override {
+    cpu_device_desc.setMatMulTileSizeInBytes(32);
+  }
+  void setCanonicalizerMaxNumRewrites() override {
+    // taken from include/mlir/Transforms/Passes.td
+    cpu_device_desc.setCanonicalizerMaxNumRewrites(-1);
+  }
+  void setCanonicalizerMaxIterations() override {
+    // taken from include/mlir/Transforms/Passes.td
+    cpu_device_desc.setCanonicalizerMaxIterations(10);
+  }
+
+private:
+  DeviceDesc cpu_device_desc;
+};
+
+class DefaultGPUDeviceDesc : public DefaultBaseDeviceDesc {
+public:
+  // We use default ID of 0 because we are expecting to have only one device so
+  // far. Not heterogeneous setup.
+  DefaultGPUDeviceDesc()
+      : gpu_device_desc(DeviceDesc(/* id */ 1, DeviceDesc::GPU)) {
+    // GPU device supports default value for MaxVectorWidth so far.
+    this->setMaxVectorWidth();
+  }
+
+  ~DefaultGPUDeviceDesc() {}
+
+  void registerDeviceDesc(MLIRContext *context) const override {
+    context->getSystemDesc().addDeviceDesc(gpu_device_desc);
+  }
+
+  // -------------------------------------------------------------------------
+  //                    GPU-specific properties
+  // -------------------------------------------------------------------------
+
+  void setMaxVectorWidth() override {
+    // Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp#L52
+    gpu_device_desc.setMaxVectorWidth(128);
+  }
+
+private:
+  DeviceDesc gpu_device_desc;
+};
+
+} // namespace mlir
+#endif // MLIR_SUPPORT_SYSTEMDESC_H
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 6e90fad1618d21..5022567a5f256b 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -176,6 +176,11 @@ class MlirOptMainConfig {
   /// Reproducer file generation (no crash required).
   StringRef getReproducerFilename() const { return generateReproducerFileFlag; }
 
+  /// System description file
+  StringRef getSystemDescriptionFileName() const {
+    return systemDescriptionFileFlag;
+  }
+
 protected:
   /// Allow operation with no registered dialects.
   /// This option is for convenience during testing only and discouraged in
@@ -234,6 +239,9 @@ class MlirOptMainConfig {
 
   /// The reproducer output filename (no crash required).
   std::string generateReproducerFileFlag = "";
+
+  /// The hardware config file
+  std::string systemDescriptionFileFlag = "";
 };
 
 /// This defines the function type used to setup the pass manager. This can be
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 7e073bae75c0c9..3ffb4a9ec6dc95 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Support/SystemDesc.h"
 
 #include "llvm/ADT/STLExtras.h"
 #include <optional>
@@ -29,6 +30,8 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::amdgpu;
 
+#define DEBUG_TYPE "amd-gpu-to-rocdl"
+
 static Value createI32Constant(ConversionPatternRewriter &rewriter,
                                Location loc, int32_t value) {
   Type llvmI32 = rewriter.getI32Type();
@@ -49,7 +52,6 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
       : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
 
   Chipset chipset;
-  static constexpr uint32_t maxVectorOpWidth = 128;
 
   LogicalResult
   matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
@@ -111,6 +113,16 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
       uint32_t elemBits = dataVector.getElementTypeBitWidth();
       uint32_t totalBits = elemBits * dataVector.getNumElements();
+      uint32_t maxVectorOpWidth = 128; // default value
+      if (std::optional<int64_t> v = gpuOp.getContext()
+                                         ->getSystemDesc()
+                                         .getDeviceDesc(1 /* gpuID */)
+                                         .getMaxVectorWidth()) {
+        maxVectorOpWidth = static_cast<uint32_t>(*v);
+      }
+      LLVM_DEBUG(llvm::dbgs() << "[CostModel] GPU MaxVectorWidth:"
+                              << maxVectorOpWidth << "\n");
+
       if (totalBits > maxVectorOpWidth)
         return gpuOp.emitOpError(
             "Total width of loads or stores must be no more than " +
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index e1e6d14231d9f1..2eb8628ace08a2 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -26,6 +26,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/Types.h"
+#include "mlir/Support/SystemDesc.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SmallString.h"
@@ -198,6 +199,9 @@ class MLIRContextImpl {
   /// A mutex used when accessing operation information.
   llvm::sys::SmartRWMutex<true> operationInfoMutex;
 
+  /// A class to describe hardware properties of a system
+  SystemDesc system_desc;
+
   //===--------------------------------------------------------------------===//
   // Affine uniquing
   //===--------------------------------------------------------------------===//
@@ -702,6 +706,8 @@ bool MLIRContext::isOperationRegistered(StringRef name) {
   return RegisteredOperationName::lookup(name, this).has_value();
 }
 
+SystemDesc &MLIRContext::getSystemDesc() { return impl->system_desc; }
+
 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
   auto &impl = context->getImpl();
   assert(impl.multiThreadedExecutionContext == 0 &&
diff --git a/mlir/lib/Support/CMakeLists.txt b/mlir/lib/Support/CMakeLists.txt
index 488decd52ae647..3ea5ed6d86f2f5 100644
--- a/mlir/lib/Support/CMakeLists.txt
+++ b/mlir/lib/Support/CMakeLists.txt
@@ -1,5 +1,6 @@
 set(LLVM_OPTIONAL_SOURCES
   FileUtilities.cpp
+  SystemDesc.cpp
   InterfaceSupport.cpp
   StorageUniquer.cpp
   Timing.cpp
@@ -9,6 +10,7 @@ set(LLVM_OPTIONAL_SOURCES
 
 add_mlir_library(MLIRSupport
   FileUtilities.cpp
+  SystemDesc.cpp
   InterfaceSupport.cpp
   RawOstreamExtras.cpp
   StorageUniquer.cpp
diff --git a/mlir/lib/Support/SystemDesc.cpp b/mlir/lib/Support/SystemDesc.cpp
new file mode 100644
index 00000000000000..a482ae7133cfe8
--- /dev/null
+++ b/mlir/lib/Support/SystemDesc.cpp
@@ -0,0 +1,79 @@
+//===- HardwareConfig.cpp - Hardware configuration
+//----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TODO
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/SystemDesc.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/JSON.h"
+#include "llvm/Support/ManagedStatic.h"
+#include "llvm/Support/MemoryBuffer.h"
+
+using namespace llvm;
+using namespace mlir;
+
+// ManagedStatic<SystemDesc> systemDesc;
+
+DeviceDesc DeviceDesc::parseDeviceDescFromJSON(
+    const DeviceDescJSONTy &device_desc_in_json) {
+  // ID and Type are mandatory fields.
+  auto iter = device_desc_in_json.find("ID");
+  if (iter == device_desc_in_json.end())
+    llvm::report_fatal_error("\"ID\" key missing in Device Description");
+  DeviceID id = DeviceDesc::strToDeviceID(iter->second);
+
+  iter = device_desc_in_json.find("Type");
+  if (iter == device_desc_in_json.end())
+    llvm::report_fatal_error("\"Type\" key missing in Device Description");
+  DeviceType type = DeviceDesc::strToDeviceType(iter->second);
+
+  // Now process optional fields: description and properties
+  DeviceDesc device_desc(id, type);
+  for (const auto &property : device_desc_in_json) {
+    // skip ID and Type as we have already processed those mandatory fields.
+    if (property.first != "ID" && property.first != "Type") {
+      if (property.first == "Description")
+        device_desc.setDescription(property.second);
+      else
+        device_desc.setProperty(property.first, property.second);
+    }
+  }
+  return device_desc;
+}
+
+LogicalResult SystemDesc::readSystemDescFromJSONFile(llvm::StringRef filename) {
+  std::string errorMessage;
+  std::unique_ptr<llvm::MemoryBuffer> file =
+      openInputFile(filename, &errorMessage);
+  if (!file) {
+    llvm::errs() << errorMessage << "\n";
+    return failure();
+  }
+
+  // Code to parse here
+  auto parsed = llvm::json::parse(file.get()->getBuffer());
+  if (!parsed) {
+    report_fatal_error(parsed.takeError());
+  }
+
+  json::Path::Root NullRoot;
+  using SystemDescJSONTy = std::vector<mlir::DeviceDesc::DeviceDescJSONTy>;
+  SystemDescJSONTy system_desc_in_json;
+  if (!json::fromJSON(*parsed, system_desc_in_json, NullRoot)) {
+    report_fatal_error("Invalid System Description in JSON");
+  }
+  for (auto device_desc_in_json : system_desc_in_json) {
+    auto device_desc = DeviceDesc::parseDeviceDescFromJSON(device_desc_in_json);
+    this->addDeviceDesc(device_desc);
+  }
+
+  return success();
+}
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index b62557153b4167..2ece33c4ea56c2 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/SystemDesc.h"
 #include "mlir/Support/Timing.h"
 #include "mlir/Support/ToolUtilities.h"
 #include "mlir/Tools/ParseUtilities.h"
@@ -161,6 +162,12 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
             cl::location(generateReproducerFileFlag), cl::init(""),
             cl::value_desc("filename"));
 
+    static cl::opt<std::string, /*ExternalStorage=*/true> systemDescriptionFile(
+        "mlir-system-description-file",
+        llvm::cl::desc("Name of the system description file"),
+        cl::location(systemDescriptionFileFlag), cl::init(""),
+        cl::value_desc("filename"));
+
     /// Set the callback to load a pass plugin.
     passPlugins.setCallback([&](const std::string &pluginPath) {
       auto plugin = PassPlugin::load(pluginPath);
@@ -381,6 +388,21 @@ performActions(raw_ostream &os,
 
   context->enableMultithreading(wasThreadingEnabled);
 
+  if (!config.getSystemDescriptionFileName().empty()) {
+    // If there is an error in file IO or parse error, we should report
+    // the error and fallback to default values.
+    if (failed(context->getSystemDesc().readSystemDescFromJSONFile(
+            config.getSystemDescriptionFileName()))) {
+      return failure();
+    }
+  } else {
+    DefaultCPUDeviceDesc default_cpu_device_desc;
+    default_cpu_device_desc.registerDeviceDesc(context);
+
+    DefaultGPUDeviceDesc default_gpu_device_desc;
+    default_gpu_device_desc.registerDeviceDesc(context);
+  }
+
   // Prepare the pass manager, applying command-line and reproducer options.
   PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit);
   pm.enableVerifier(config.shouldVerifyPasses());
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index d50019bd6aee55..e3abcb71998c3e 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Transforms/Passes.h"
 
 #include "mlir/Pass/Pass.h"
+#include "mlir/Support/SystemDesc.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
@@ -23,6 +24,8 @@ namespace mlir {
 
 using namespace mlir;
 
+#define DEBUG_TYPE "canonicalizer"
+
 namespace {
 /// Canonicalize operations in nested regions.
 struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
@@ -45,8 +48,26 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
     // Set the config from possible pass options set in the meantime.
     config.useTopDownTraversal = topDownProcessingEnabled;
     config.enableRegionSimplification = enableRegionSimplification;
-    config.maxIterations = maxIterations;
-    config.maxNumRewrites = maxNumRewrites;
+    DeviceDesc::DeviceID cpuID = 0;
+    if (std::optional<int64_t> v = context->getSystemDesc()
+                                       .getDeviceDesc(cpuID)
+                                       .getCanonicalizerMaxIterations()) {
+      config.maxIterations = *v;
+    } else {
+      config.maxIterations = maxIterations;
+    }
+
+    if (std::optional<int64_t> v = context->getSystemDesc()
+                                       .getDeviceDesc(cpuID)
+                                       .getCanonicalizerMaxNumRewrites()) {
+      config.maxNumRewrites = *v;
+    } else {
+      config.maxNumRewrites = maxNumRewrites;
+    }
+    LLVM_DEBUG(llvm::dbgs() << "[CostModel] Canonicalizer MaxIterations:"
+                            << config.maxIterations << "\n");
+    LLVM_DEBUG(llvm::dbgs() << "[CostModel] Canonicalizer MaxNumRewrites:"
+                            << config.maxNumRewrites << "\n");
 
     RewritePatternSet owningPatterns(context);
     for (auto *dialect : context->getLoadedDialects())

>From a93107d014e81131586d0d01e10319a5aeaf29c7 Mon Sep 17 00:00:00 2001
From: Niranjan Hasabnis <niranjan.hasabnis at intel.com>
Date: Fri, 15 Mar 2024 15:32:59 -0700
Subject: [PATCH 2/2] Addressing review comments: replacing DeviceDesc data
 field with NamedAttrsList; moving parsers out of core class

This PR makes following changes:

1. Replaces tagged union from DeviceDesc by NamedAttrsList
2. Moves config file parsers out of core SystemDesc and DeviceDesc
   classes

Not addressing system description in context comment yet as it it under
discussion.
---
 mlir/include/mlir/IR/MLIRContext.h      |   5 +-
 mlir/include/mlir/Support/SystemDesc.h  | 282 ++++++++++++------------
 mlir/lib/IR/MLIRContext.cpp             |   2 +
 mlir/lib/Support/SystemDesc.cpp         |  54 +++--
 mlir/lib/Tools/mlir-opt/MlirOptMain.cpp |  19 +-
 5 files changed, 189 insertions(+), 173 deletions(-)

diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 6007f19ed503a3..f94e223ba5b615 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -241,9 +241,12 @@ class MLIRContext {
   /// (attributes, operations, types, etc.).
   llvm::hash_code getRegistryHash();
 
-  /// Get context-specific system description
+  /// Get context-specific system descriptor
   SystemDesc &getSystemDesc();
 
+  /// Set context-specific system descriptor
+  void setSystemDesc(const SystemDesc& desc);
+
   //===--------------------------------------------------------------------===//
   // Action API
   //===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Support/SystemDesc.h b/mlir/include/mlir/Support/SystemDesc.h
index 0784a5183d2061..926569eddd23f7 100644
--- a/mlir/include/mlir/Support/SystemDesc.h
+++ b/mlir/include/mlir/Support/SystemDesc.h
@@ -19,7 +19,11 @@
 #include <memory>
 #include <vector>
 
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/Support/FileUtilities.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/StringRef.h"
@@ -47,47 +51,8 @@ namespace mlir {
 /// Describes the individual device from the system description
 class DeviceDesc {
 public:
-  /// Some typedefs
   using DeviceID = uint32_t;
-  using DevicePropertyName = std::string;
-  struct DevicePropertyValue {
-    enum Tag { INT, DOUBLE, INT_VECTOR, DOUBLE_VECTOR } tag;
-    struct Data {
-      int64_t iValue;
-      double dValue;
-      std::vector<int64_t> ivValue;
-      std::vector<double> dvValue;
-
-      Data() : iValue(0), dValue(0.0), ivValue({0}), dvValue({0.0}) {}
-      ~Data() {}
-    } data;
-
-    DevicePropertyValue() = default;
-    DevicePropertyValue(const mlir::DeviceDesc::DevicePropertyValue &rhs) {
-      this->tag = rhs.tag;
-      if (this->tag == INT)
-        this->data.iValue = rhs.data.iValue;
-      else if (this->tag == DOUBLE)
-        this->data.dValue = rhs.data.dValue;
-      else if (this->tag == INT_VECTOR)
-        this->data.ivValue = rhs.data.ivValue;
-      else
-        this->data.dvValue = rhs.data.dvValue;
-    }
-    bool operator==(const mlir::DeviceDesc::DevicePropertyValue &rhs) const {
-      return tag == rhs.tag &&
-             ((tag == INT && data.iValue == rhs.data.iValue) ||
-              (tag == DOUBLE && data.dValue == rhs.data.dValue) ||
-              (tag == INT_VECTOR && data.ivValue == rhs.data.ivValue) ||
-              (tag == DOUBLE_VECTOR && data.dvValue == rhs.data.dvValue));
-    }
-    bool operator!=(const mlir::DeviceDesc::DevicePropertyValue &rhs) const {
-      return !(*this == rhs);
-    }
-  };
-  using DevicePropertiesMapTy =
-      std::map<DevicePropertyName, DevicePropertyValue>;
-
+  using DevicePropertiesMapTy = mlir::NamedAttrList;
   typedef enum { CPU, GPU, SPECIAL } DeviceType;
 
   /// Basic constructor
@@ -121,75 +86,78 @@ class DeviceDesc {
     description = desc;
     return *this;
   }
+
   /// Set property
-  DeviceDesc &setProperty(llvm::StringRef name, int64_t iv) {
-    DevicePropertyValue value;
-    value.tag = DevicePropertyValue::Tag::INT;
-    value.data.iValue = iv;
-    auto inserted =
-        deviceProperties.insert(std::make_pair(std::string(name), value));
-    if (!inserted.second && inserted.first->second != value) {
+  DeviceDesc &setProperty(MLIRContext *context, llvm::StringRef name, int64_t iv) {
+    std::optional<NamedAttribute> attr = deviceProperties.getNamed(name);
+    if (!attr.has_value()) {
+      IntegerType int64Ty = IntegerType::get(context, 64);
+      Attribute value = IntegerAttr::get(int64Ty, iv);
+      deviceProperties.append(name, value);
+    } else
       llvm::report_fatal_error("Duplicate device property name found:" + name);
-    }
     return *this;
   }
-  DeviceDesc &setProperty(llvm::StringRef name, double dv) {
-    DevicePropertyValue value;
-    value.tag = DevicePropertyValue::Tag::DOUBLE;
-    value.data.dValue = dv;
-    auto inserted =
-        deviceProperties.insert(std::make_pair(std::string(name), value));
-    if (!inserted.second && inserted.first->second != value) {
+
+  DeviceDesc &setProperty(MLIRContext *context, llvm::StringRef name, double dv) {
+    std::optional<NamedAttribute> attr = deviceProperties.getNamed(name);
+    if (!attr.has_value()) {
+      FloatType floatType = FloatType::getF64(context);
+      Attribute value = FloatAttr::get(floatType, dv);
+      deviceProperties.append(name, value);
+    } else
       llvm::report_fatal_error("Duplicate device property name found:" + name);
-    }
     return *this;
   }
-  DeviceDesc &setProperty(llvm::StringRef name,
+
+  DeviceDesc &setProperty(MLIRContext *context, llvm::StringRef name,
                           const std::vector<int64_t> &ivv) {
-    DevicePropertyValue value;
-    value.tag = DevicePropertyValue::Tag::INT_VECTOR;
-    value.data.ivValue = ivv;
-    auto inserted =
-        deviceProperties.insert(std::make_pair(std::string(name), value));
-    if (!inserted.second && inserted.first->second != value) {
+    std::optional<NamedAttribute> attr = deviceProperties.getNamed(name);
+    if (!attr.has_value()) {
+      IntegerType int64Ty = IntegerType::get(context, 64);
+      RankedTensorType shape = RankedTensorType::get({static_cast<long>(ivv.size()), 1}, int64Ty);
+      DenseElementsAttr value = DenseElementsAttr::get(shape, llvm::ArrayRef(ivv));
+      deviceProperties.append(name, value);
+    } else
       llvm::report_fatal_error("Duplicate device property name found:" + name);
-    }
     return *this;
   }
-  DeviceDesc &setProperty(llvm::StringRef name,
+
+  DeviceDesc &setProperty(MLIRContext *context, llvm::StringRef name,
                           const std::vector<double> &idv) {
-    DevicePropertyValue value;
-    value.tag = DevicePropertyValue::Tag::DOUBLE_VECTOR;
-    value.data.dvValue = idv;
-    auto inserted =
-        deviceProperties.insert(std::make_pair(std::string(name), value));
-    if (!inserted.second && inserted.first->second != value) {
+    std::optional<NamedAttribute> attr = deviceProperties.getNamed(name);
+    if (!attr.has_value()) {
+      FloatType float64Ty = FloatType::getF64(context);
+      RankedTensorType shape = RankedTensorType::get({static_cast<long>(idv.size()), 1}, float64Ty);
+      DenseElementsAttr value = DenseElementsAttr::get(shape, llvm::ArrayRef(idv));
+      deviceProperties.append(name, value);
+    } else
       llvm::report_fatal_error("Duplicate device property name found:" + name);
-    }
     return *this;
   }
+
   // We provide convenience interface to handle int/float value as string
-  DeviceDesc &setProperty(llvm::StringRef name, const std::string &json_value) {
+  DeviceDesc &setProperty(MLIRContext *context, llvm::StringRef name, const std::string &json_value) {
     if (json_value.length() > 0 && json_value[0] == '[') {
       // Parse as an array
       llvm::Expected<std::vector<int64_t>> ivv =
           llvm::json::parse<std::vector<int64_t>>(json_value);
       if (ivv) {
-        *this = this->setProperty(name, ivv.get());
+        *this = this->setProperty(context, name, ivv.get());
         return *this;
       }
 
       llvm::Expected<std::vector<double>> idv =
           llvm::json::parse<std::vector<double>>(json_value);
       if (idv) {
-        *this = this->setProperty(name, idv.get());
+        *this = this->setProperty(context, name, idv.get());
         return *this;
       }
     } else {
       // int64_t because llvm::json has int64_t support (not int)
       llvm::Expected<int64_t> iv = llvm::json::parse<int64_t>(json_value);
       if (iv) {
-        *this = this->setProperty(name, iv.get());
+        *this = this->setProperty(context, name, iv.get());
         return *this;
       }
 
@@ -197,7 +165,7 @@ class DeviceDesc {
       // double because llvm::json has double support (not float)
       llvm::Expected<double> dv = llvm::json::parse<double>(json_value);
       if (dv) {
-        *this = this->setProperty(name, dv.get());
+        *this = this->setProperty(context, name, dv.get());
         return *this;
       }
     }
@@ -217,41 +185,35 @@ class DeviceDesc {
     return deviceProperties;
   }
   /// Get property value: returns the value of the property with given name, if
-  /// it exists. Otherwise throws exception (TODO)
+  /// it exists. Otherwise returns std::nullopt.
   std::optional<int64_t> getPropertyValueAsInt(llvm::StringRef name) const {
     // check that property with the given name exists
-    auto iter = deviceProperties.find(std::string(name));
-    if (iter == deviceProperties.end()) {
-      return std::nullopt;
+    std::optional<NamedAttribute> attr = deviceProperties.getNamed(name);
+    if (attr) {
+      if (IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr->getValue()))
+        return intAttr.getInt();
     }
-    // TODO: we can do a tag check here.
-    return iter->second.data.iValue;
+    return std::nullopt;
   }
   std::optional<double> getPropertyValueAsFloat(llvm::StringRef name) const {
     // check that property with the given name exists
-    auto iter = deviceProperties.find(std::string(name));
-    if (iter == deviceProperties.end()) {
-      return std::nullopt;
+    std::optional<NamedAttribute> attr = deviceProperties.getNamed(name);
+    if (attr) {
+      if (FloatAttr floatAttr = dyn_cast<FloatAttr>(attr->getValue()))
+        return floatAttr.getValueAsDouble();
     }
-    // TODO: we can do a tag check here.
-    return iter->second.data.dValue;
+    return std::nullopt;
   }
 
   /// Special functions
   auto getAllDevicePropertyNames() const {
     return llvm::map_range(
-        deviceProperties,
-        [](const DevicePropertiesMapTy::value_type &item) -> llvm::StringRef {
-          return item.first;
+        deviceProperties.getAttrs(),
+        [](const NamedAttribute &named_attribute) -> llvm::StringRef {
+          return named_attribute.getName();
         });
   }
 
-  /// We use a list of key-value pairs to represent a system description in
-  /// JSON.
-  using DeviceDescJSONTy = std::map<std::string, std::string>;
-  static DeviceDesc
-  parseDeviceDescFromJSON(const DeviceDescJSONTy &device_desc);
-
   // -----------------------------------------------------------------------
   //          CPU specific methods
   // -----------------------------------------------------------------------
@@ -281,9 +243,9 @@ class DeviceDesc {
     }
     return std::nullopt;
   }
-  void setL1CacheSizeInBytes(int64_t value) {
+  void setL1CacheSizeInBytes(MLIRContext *context, int64_t value) {
     // Temporarily use int override until we support size_t
-    this->setProperty(DeviceDesc::getCPUL1CacheSizeInBytesKeyName(), value);
+    this->setProperty(context, DeviceDesc::getCPUL1CacheSizeInBytesKeyName(), value);
   }
   std::optional<int64_t> getConvAndMatMulBlockingFactor() const {
     if (std::optional<int64_t> v = this->getPropertyValueAsInt(
@@ -292,9 +254,9 @@ class DeviceDesc {
     }
     return std::nullopt;
   }
-  void setConvAndMatMulBlockingFactor(int64_t value) {
+  void setConvAndMatMulBlockingFactor(MLIRContext *context, int64_t value) {
     // Temporarily use int override until we support size_t
-    this->setProperty(DeviceDesc::getConvAndMatMulBlockingFactorKeyName(),
+    this->setProperty(context, DeviceDesc::getConvAndMatMulBlockingFactorKeyName(),
                       value);
   }
   std::optional<int64_t> getMatMulTileSizeInBytes() const {
@@ -304,9 +266,9 @@ class DeviceDesc {
     }
     return std::nullopt;
   }
-  void setMatMulTileSizeInBytes(int64_t value) {
+  void setMatMulTileSizeInBytes(MLIRContext *context, int64_t value) {
     // Temporarily use int override until we support size_t
-    this->setProperty(DeviceDesc::getMatMulTileSizeInBytesKeyName(), value);
+    this->setProperty(context, DeviceDesc::getMatMulTileSizeInBytesKeyName(), value);
   }
   std::optional<int64_t> getCanonicalizerMaxNumRewrites() const {
     if (std::optional<int64_t> v = this->getPropertyValueAsInt(
@@ -315,8 +277,8 @@ class DeviceDesc {
     }
     return std::nullopt;
   }
-  void setCanonicalizerMaxNumRewrites(int64_t value) {
-    this->setProperty(DeviceDesc::getCanonicalizerMaxNumRewritesKeyName(),
+  void setCanonicalizerMaxNumRewrites(MLIRContext *context, int64_t value) {
+    this->setProperty(context, DeviceDesc::getCanonicalizerMaxNumRewritesKeyName(),
                       value);
   }
   std::optional<int64_t> getCanonicalizerMaxIterations() const {
@@ -326,8 +288,8 @@ class DeviceDesc {
     }
     return std::nullopt;
   }
-  void setCanonicalizerMaxIterations(int64_t value) {
-    this->setProperty(DeviceDesc::getCanonicalizerMaxIterationsKeyName(),
+  void setCanonicalizerMaxIterations(MLIRContext *context, int64_t value) {
+    this->setProperty(context, DeviceDesc::getCanonicalizerMaxIterationsKeyName(),
                       value);
   }
   std::optional<int64_t> getMaxVectorWidth() const {
@@ -337,8 +299,8 @@ class DeviceDesc {
     }
     return std::nullopt;
   }
-  void setMaxVectorWidth(uint32_t value) {
-    this->setProperty(DeviceDesc::getMaxVectorWidthKeyName(),
+  void setMaxVectorWidth(MLIRContext *context, uint32_t value) {
+    this->setProperty(context, DeviceDesc::getMaxVectorWidthKeyName(),
                       static_cast<int64_t>(value));
   }
 
@@ -359,10 +321,12 @@ class DeviceDesc {
 class SystemDesc {
 public:
   SystemDesc() = default;
-
-  /// Read and parse system description from JSON file
-  LogicalResult readSystemDescFromJSONFile(llvm::StringRef filename);
-  void writeSystemDescToJSONFile(llvm::StringRef filename);
+  SystemDesc(const SystemDesc &desc) {
+    this->deviceDescs = desc.deviceDescs;
+  }
+  void operator=(const SystemDesc &rhs) {
+    this->deviceDescs = rhs.deviceDescs;
+  }
 
   /// Insert a new device description
   SystemDesc &addDeviceDesc(const DeviceDesc &desc) {
@@ -373,6 +337,7 @@ class SystemDesc {
     }
     return *this;
   }
+
   /// Get a device description
   const DeviceDesc &getDeviceDesc(DeviceDesc::DeviceID deviceID) {
     auto iter = deviceDescs.find(deviceID);
@@ -391,10 +356,6 @@ class SystemDesc {
   static uint32_t getNumCPUDevices() { return 0; }
   static uint32_t getNumGPUDevices() { return 0; }
 
-private:
-  SystemDesc(const SystemDesc &) = delete;
-  void operator=(const SystemDesc &) = delete;
-
 private:
   /// Map to store all the device descriptions
   DeviceDescsMapTy deviceDescs;
@@ -416,22 +377,22 @@ class DefaultBaseDeviceDesc {
   /// -----------------------------------------------------------------------
   /// Set of common questions asked by various passes
   // Blocking factor and tile size are typically used by tile/block passes.
-  virtual void setConvAndMatMulBlockingFactor(){};
-  virtual void setMatMulTileSize(){};
+  virtual void setConvAndMatMulBlockingFactor(MLIRContext *context){};
+  virtual void setMatMulTileSize(MLIRContext *context){};
 
-  virtual void setCanonicalizerMaxIterations(){};
-  virtual void setCanonicalizerMaxNumRewrites(){};
+  virtual void setCanonicalizerMaxIterations(MLIRContext *context){};
+  virtual void setCanonicalizerMaxNumRewrites(MLIRContext *context){};
 
   /// -----------------------------------------------------------------------
   /// CPU-specific parameters of system description
   /// -----------------------------------------------------------------------
-  virtual void setL1CacheSizeInBytes(){};
+  virtual void setL1CacheSizeInBytes(MLIRContext *context){};
 
   /// -----------------------------------------------------------------------
   /// GPU-specific parameters of system description
   /// -----------------------------------------------------------------------
   // Used by Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp#L52
-  virtual void setMaxVectorWidth(){};
+  virtual void setMaxVectorWidth(MLIRContext *context){};
 };
 
 // Class that represent device description for a typical CPU device
@@ -439,14 +400,14 @@ class DefaultCPUDeviceDesc : public DefaultBaseDeviceDesc {
 public:
   // We use default ID of 0 because we are expecting to have only one device so
   // far. Not heterogeneous setup.
-  DefaultCPUDeviceDesc()
+  DefaultCPUDeviceDesc(MLIRContext *context)
       : cpu_device_desc(DeviceDesc(/* id */ 0, DeviceDesc::CPU)) {
     // Register all system properties
-    this->setL1CacheSizeInBytes();
-    this->setConvAndMatMulBlockingFactor();
-    this->setMatMulTileSize();
-    this->setCanonicalizerMaxNumRewrites();
-    this->setCanonicalizerMaxIterations();
+    this->setL1CacheSizeInBytes(context);
+    this->setConvAndMatMulBlockingFactor(context);
+    this->setMatMulTileSize(context);
+    this->setCanonicalizerMaxNumRewrites(context);
+    this->setCanonicalizerMaxIterations(context);
   }
 
   ~DefaultCPUDeviceDesc() {}
@@ -459,22 +420,22 @@ class DefaultCPUDeviceDesc : public DefaultBaseDeviceDesc {
   //                    CPU-specific properties
   // -------------------------------------------------------------------------
 
-  void setL1CacheSizeInBytes() override {
-    cpu_device_desc.setL1CacheSizeInBytes(8192);
+  void setL1CacheSizeInBytes(MLIRContext *context) override {
+    cpu_device_desc.setL1CacheSizeInBytes(context, 8192);
   }
-  void setConvAndMatMulBlockingFactor() override {
-    cpu_device_desc.setConvAndMatMulBlockingFactor(32);
+  void setConvAndMatMulBlockingFactor(MLIRContext *context) override {
+    cpu_device_desc.setConvAndMatMulBlockingFactor(context, 32);
   }
-  void setMatMulTileSize() override {
-    cpu_device_desc.setMatMulTileSizeInBytes(32);
+  void setMatMulTileSize(MLIRContext *context) override {
+    cpu_device_desc.setMatMulTileSizeInBytes(context, 32);
   }
-  void setCanonicalizerMaxNumRewrites() override {
+  void setCanonicalizerMaxNumRewrites(MLIRContext *context) override {
     // taken from include/mlir/Transforms/Passes.td
-    cpu_device_desc.setCanonicalizerMaxNumRewrites(-1);
+    cpu_device_desc.setCanonicalizerMaxNumRewrites(context, -1);
   }
-  void setCanonicalizerMaxIterations() override {
+  void setCanonicalizerMaxIterations(MLIRContext *context) override {
     // taken from include/mlir/Transforms/Passes.td
-    cpu_device_desc.setCanonicalizerMaxIterations(10);
+    cpu_device_desc.setCanonicalizerMaxIterations(context, 10);
   }
 
 private:
@@ -485,10 +446,10 @@ class DefaultGPUDeviceDesc : public DefaultBaseDeviceDesc {
 public:
   // We use default ID of 0 because we are expecting to have only one device so
   // far. Not heterogeneous setup.
-  DefaultGPUDeviceDesc()
+  DefaultGPUDeviceDesc(MLIRContext *context)
       : gpu_device_desc(DeviceDesc(/* id */ 1, DeviceDesc::GPU)) {
     // GPU device supports default value for MaxVectorWidth so far.
-    this->setMaxVectorWidth();
+    this->setMaxVectorWidth(context);
   }
 
   ~DefaultGPUDeviceDesc() {}
@@ -501,14 +462,49 @@ class DefaultGPUDeviceDesc : public DefaultBaseDeviceDesc {
   //                    GPU-specific properties
   // -------------------------------------------------------------------------
 
-  void setMaxVectorWidth() override {
+  void setMaxVectorWidth(MLIRContext *context) override {
     // Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp#L52
-    gpu_device_desc.setMaxVectorWidth(128);
+    gpu_device_desc.setMaxVectorWidth(context, 128);
   }
 
 private:
   DeviceDesc gpu_device_desc;
 };
 
+// ---------------------------------------------------------------------------
+//                     Config file readers
+// ---------------------------------------------------------------------------
+namespace impl {
+  class SystemDescJSONConfigParser {
+  public:
+    /// Build SystemDesc by parsing input config file in JSON format.
+    /// Returns a valid SystemDesc if parsing is successful; otherwise
+    /// returns std::nullopt.
+    static std::optional<SystemDesc> buildSystemDescFromConfigFile(
+      MLIRContext *context, llvm::StringRef filename);
+
+  private:
+    /// We represent DeviceDesc in JSON as a key-value pairs of strings.
+    using DeviceDescJSONTy = std::map<std::string, std::string>;
+
+    /// A utility function to parse device description entry in JSON format
+    /// Returns valid DeviceDesc if parsing is successful; otherwise returns
+    /// std::nullopt.
+    static std::optional<DeviceDesc> buildDeviceDescFromConfigFile(MLIRContext *context,
+      const DeviceDescJSONTy &device_desc_in_json);
+  };
+}
+
+class SystemDescConfigFileParser {
+public:
+  /// Build SystemDesc by parsing input config file. Returns valid SystemDesc
+  /// if parsing is successful; otherwise returns std::nullopt.
+  static std::optional<SystemDesc> buildSystemDescFromConfigFile(
+    MLIRContext *context, llvm::StringRef filename) {
+      // Once we support more formats, we can accept format as the input argument.
+      return impl::SystemDescJSONConfigParser::buildSystemDescFromConfigFile(context, filename);
+  }
+};
+
 } // namespace mlir
 #endif // MLIR_SUPPORT_SYSTEMDESC_H
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 2eb8628ace08a2..f656ab69947bde 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -708,6 +708,8 @@ bool MLIRContext::isOperationRegistered(StringRef name) {
 
 SystemDesc &MLIRContext::getSystemDesc() { return impl->system_desc; }
 
+void MLIRContext::setSystemDesc(const SystemDesc& desc) { impl->system_desc = desc; }
+
 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
   auto &impl = context->getImpl();
   assert(impl.multiThreadedExecutionContext == 0 &&
diff --git a/mlir/lib/Support/SystemDesc.cpp b/mlir/lib/Support/SystemDesc.cpp
index a482ae7133cfe8..356ec13287ee31 100644
--- a/mlir/lib/Support/SystemDesc.cpp
+++ b/mlir/lib/Support/SystemDesc.cpp
@@ -20,20 +20,22 @@
 using namespace llvm;
 using namespace mlir;
 
-// ManagedStatic<SystemDesc> systemDesc;
-
-DeviceDesc DeviceDesc::parseDeviceDescFromJSON(
-    const DeviceDescJSONTy &device_desc_in_json) {
+std::optional<DeviceDesc> impl::SystemDescJSONConfigParser::buildDeviceDescFromConfigFile(
+    MLIRContext *context, const DeviceDescJSONTy &device_desc_in_json) {
   // ID and Type are mandatory fields.
   auto iter = device_desc_in_json.find("ID");
-  if (iter == device_desc_in_json.end())
-    llvm::report_fatal_error("\"ID\" key missing in Device Description");
-  DeviceID id = DeviceDesc::strToDeviceID(iter->second);
+  if (iter == device_desc_in_json.end()) {
+    llvm::errs() << "\"ID\" key missing in Device Description" << "\n";
+    return std::nullopt;
+  }
+  DeviceDesc::DeviceID id = DeviceDesc::strToDeviceID(iter->second);
 
   iter = device_desc_in_json.find("Type");
-  if (iter == device_desc_in_json.end())
-    llvm::report_fatal_error("\"Type\" key missing in Device Description");
-  DeviceType type = DeviceDesc::strToDeviceType(iter->second);
+  if (iter == device_desc_in_json.end()) {
+    llvm::errs() << "\"Type\" key missing in Device Description" << "\n";
+    return std::nullopt;
+  }
+  DeviceDesc::DeviceType type = DeviceDesc::strToDeviceType(iter->second);
 
   // Now process optional fields: description and properties
   DeviceDesc device_desc(id, type);
@@ -43,37 +45,47 @@ DeviceDesc DeviceDesc::parseDeviceDescFromJSON(
       if (property.first == "Description")
         device_desc.setDescription(property.second);
       else
-        device_desc.setProperty(property.first, property.second);
+        device_desc.setProperty(context, property.first, property.second);
     }
   }
-  return device_desc;
+  return std::optional<DeviceDesc>(device_desc);
 }
 
-LogicalResult SystemDesc::readSystemDescFromJSONFile(llvm::StringRef filename) {
+std::optional<SystemDesc> impl::SystemDescJSONConfigParser::buildSystemDescFromConfigFile(
+    MLIRContext *context, llvm::StringRef filename) {
   std::string errorMessage;
   std::unique_ptr<llvm::MemoryBuffer> file =
       openInputFile(filename, &errorMessage);
   if (!file) {
     llvm::errs() << errorMessage << "\n";
-    return failure();
+    return std::nullopt;
   }
 
   // Code to parse here
   auto parsed = llvm::json::parse(file.get()->getBuffer());
   if (!parsed) {
-    report_fatal_error(parsed.takeError());
+    llvm::errs() << parsed.takeError();
+    return std::nullopt;
   }
 
   json::Path::Root NullRoot;
-  using SystemDescJSONTy = std::vector<mlir::DeviceDesc::DeviceDescJSONTy>;
+  // System description is a list of Device descriptions.
+  using SystemDescJSONTy = std::vector<DeviceDescJSONTy>;
   SystemDescJSONTy system_desc_in_json;
   if (!json::fromJSON(*parsed, system_desc_in_json, NullRoot)) {
-    report_fatal_error("Invalid System Description in JSON");
+    llvm::errs() << "Invalid System Description in JSON" << "\n";
+    return std::nullopt;
   }
+
+  SystemDesc system_desc;
   for (auto device_desc_in_json : system_desc_in_json) {
-    auto device_desc = DeviceDesc::parseDeviceDescFromJSON(device_desc_in_json);
-    this->addDeviceDesc(device_desc);
+    std::optional<DeviceDesc> device_desc = impl::SystemDescJSONConfigParser::buildDeviceDescFromConfigFile(
+      context, device_desc_in_json);
+    if (device_desc)
+      system_desc.addDeviceDesc(*device_desc);
+    else
+      return std::nullopt;
   }
 
-  return success();
-}
+  return std::optional<SystemDesc>(system_desc);
+}
\ No newline at end of file
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 2ece33c4ea56c2..fd2db58e5be0fe 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -388,18 +388,21 @@ performActions(raw_ostream &os,
 
   context->enableMultithreading(wasThreadingEnabled);
 
+  bool setDefaultSystemDesc = true;
   if (!config.getSystemDescriptionFileName().empty()) {
-    // If there is an error in file IO or parse error, we should report
-    // the error and fallback to default values.
-    if (failed(context->getSystemDesc().readSystemDescFromJSONFile(
-            config.getSystemDescriptionFileName()))) {
-      return failure();
+    std::optional<SystemDesc> desc =
+      SystemDescConfigFileParser::buildSystemDescFromConfigFile(context,
+                                                                config.getSystemDescriptionFileName());
+    if (desc) {
+      context->setSystemDesc(*desc);
+      setDefaultSystemDesc = false;
     }
-  } else {
-    DefaultCPUDeviceDesc default_cpu_device_desc;
+  }
+  if (setDefaultSystemDesc) {
+    DefaultCPUDeviceDesc default_cpu_device_desc(context);
     default_cpu_device_desc.registerDeviceDesc(context);
 
-    DefaultGPUDeviceDesc default_gpu_device_desc;
+    DefaultGPUDeviceDesc default_gpu_device_desc(context);
     default_gpu_device_desc.registerDeviceDesc(context);
   }
 



More information about the Mlir-commits mailing list