[Mlir-commits] [mlir] [mlir][gpu] Add metadata attributes for storing kernel metadata in GPU objects (PR #95292)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Tue Aug 27 02:28:09 PDT 2024


================
@@ -2165,6 +2166,113 @@ LogicalResult gpu::DynamicSharedMemoryOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// GPU KernelAttr
+//===----------------------------------------------------------------------===//
+
+KernelAttr KernelAttr::get(FunctionOpInterface kernel,
+                           DictionaryAttr metadata) {
+  assert(kernel && "invalid kernel");
+  return get(kernel.getNameAttr(), kernel.getFunctionType(),
+             kernel.getAllArgAttrs(), metadata);
+}
+
+KernelAttr KernelAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                                  FunctionOpInterface kernel,
+                                  DictionaryAttr metadata) {
+  assert(kernel && "invalid kernel");
+  return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
+                    kernel.getAllArgAttrs(), metadata);
+}
+
+KernelAttr KernelAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
+  if (attrs.empty())
+    return *this;
+  NamedAttrList attrList;
+  if (DictionaryAttr dict = getMetadata())
+    attrList.append(dict);
+  attrList.append(attrs);
+  return KernelAttr::get(getName(), getFunctionType(), getArgAttrs(),
+                         attrList.getDictionary(getContext()));
+}
+
+LogicalResult KernelAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                                 StringAttr name, Type functionType,
+                                 ArrayAttr argAttrs, DictionaryAttr metadata) {
+  if (name.empty())
+    return emitError() << "the kernel name can't be empty";
+  if (argAttrs) {
+    if (llvm::any_of(argAttrs, [](Attribute attr) {
+          return !llvm::isa<DictionaryAttr>(attr);
+        }))
+      return emitError()
+             << "all attributes in the array must be a dictionary attribute";
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GPU KernelTableAttr
+//===----------------------------------------------------------------------===//
+
+KernelTableAttr KernelTableAttr::get(MLIRContext *context,
+                                     ArrayRef<KernelAttr> kernels,
+                                     bool isSorted) {
+  // Note that `is_sorted` is always only invoked once even with assertions ON.
+  assert((!isSorted || llvm::is_sorted(kernels)) &&
+         "expected a sorted kernel array");
+  // Immediately return the attribute if the array is sorted.
+  if (isSorted || llvm::is_sorted(kernels))
+    return Base::get(context, kernels);
+  // Sort the array.
+  SmallVector<KernelAttr> kernelsTmp(kernels);
+  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
+  return Base::get(context, kernelsTmp);
+}
+
+KernelTableAttr
+KernelTableAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
+                            MLIRContext *context, ArrayRef<KernelAttr> kernels,
+                            bool isSorted) {
+  // Note that `is_sorted` is always only invoked once even with assertions ON.
+  assert((!isSorted || llvm::is_sorted(kernels)) &&
+         "expected a sorted kernel array");
+  // Immediately return the attribute if the array is sorted.
+  if (isSorted || llvm::is_sorted(kernels))
+    return Base::getChecked(emitError, context, kernels);
+  // Sort the array.
+  SmallVector<KernelAttr> kernelsTmp(kernels);
+  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
+  return Base::getChecked(emitError, context, kernelsTmp);
+}
+
+LogicalResult
+KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                        ArrayRef<KernelAttr> kernels) {
+  if (kernels.size() < 2)
+    return success();
+  // Check that the kernels are uniquely named.
+  if (std::adjacent_find(kernels.begin(), kernels.end(),
+                         [](KernelAttr l, KernelAttr r) {
+                           return l.getName() == r.getName();
+                         }) != kernels.end()) {
+    return emitError() << "expected all kernels to be uniquely named";
+  }
+  return success();
+}
+
+KernelAttr KernelTableAttr::lookup(StringRef key) const {
+  std::pair<ArrayRef<KernelAttr>::iterator, bool> it =
+      impl::findAttrSorted(begin(), end(), key);
+  return it.second ? *it.first : KernelAttr();
----------------
ftynse wrote:

Also nit:why the `impl` function doesn't just return the end iterator to indicate not found?

https://github.com/llvm/llvm-project/pull/95292


More information about the Mlir-commits mailing list