[Mlir-commits] [mlir] [mlir][gpu] Add metadata attributes for storing kernel metadata in GPU objects (PR #95292)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Tue Aug 27 02:28:09 PDT 2024
================
@@ -2165,6 +2166,113 @@ LogicalResult gpu::DynamicSharedMemoryOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// GPU KernelAttr
+//===----------------------------------------------------------------------===//
+
+KernelAttr KernelAttr::get(FunctionOpInterface kernel,
+ DictionaryAttr metadata) {
+ assert(kernel && "invalid kernel");
+ return get(kernel.getNameAttr(), kernel.getFunctionType(),
+ kernel.getAllArgAttrs(), metadata);
+}
+
+KernelAttr KernelAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
+ FunctionOpInterface kernel,
+ DictionaryAttr metadata) {
+ assert(kernel && "invalid kernel");
+ return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
+ kernel.getAllArgAttrs(), metadata);
+}
+
+KernelAttr KernelAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
+ if (attrs.empty())
+ return *this;
+ NamedAttrList attrList;
+ if (DictionaryAttr dict = getMetadata())
+ attrList.append(dict);
+ attrList.append(attrs);
+ return KernelAttr::get(getName(), getFunctionType(), getArgAttrs(),
+ attrList.getDictionary(getContext()));
+}
+
+LogicalResult KernelAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ StringAttr name, Type functionType,
+ ArrayAttr argAttrs, DictionaryAttr metadata) {
+ if (name.empty())
+ return emitError() << "the kernel name can't be empty";
+ if (argAttrs) {
+ if (llvm::any_of(argAttrs, [](Attribute attr) {
+ return !llvm::isa<DictionaryAttr>(attr);
+ }))
+ return emitError()
+ << "all attributes in the array must be a dictionary attribute";
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GPU KernelTableAttr
+//===----------------------------------------------------------------------===//
+
+KernelTableAttr KernelTableAttr::get(MLIRContext *context,
+ ArrayRef<KernelAttr> kernels,
+ bool isSorted) {
+ // Note that `is_sorted` is always only invoked once even with assertions ON.
+ assert((!isSorted || llvm::is_sorted(kernels)) &&
+ "expected a sorted kernel array");
+ // Immediately return the attribute if the array is sorted.
+ if (isSorted || llvm::is_sorted(kernels))
+ return Base::get(context, kernels);
+ // Sort the array.
+ SmallVector<KernelAttr> kernelsTmp(kernels);
+ llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
+ return Base::get(context, kernelsTmp);
+}
+
+KernelTableAttr
+KernelTableAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
+ MLIRContext *context, ArrayRef<KernelAttr> kernels,
+ bool isSorted) {
+ // Note that `is_sorted` is always only invoked once even with assertions ON.
+ assert((!isSorted || llvm::is_sorted(kernels)) &&
+ "expected a sorted kernel array");
+ // Immediately return the attribute if the array is sorted.
+ if (isSorted || llvm::is_sorted(kernels))
+ return Base::getChecked(emitError, context, kernels);
+ // Sort the array.
+ SmallVector<KernelAttr> kernelsTmp(kernels);
+ llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
+ return Base::getChecked(emitError, context, kernelsTmp);
+}
+
+LogicalResult
+KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<KernelAttr> kernels) {
+ if (kernels.size() < 2)
+ return success();
+ // Check that the kernels are uniquely named.
+ if (std::adjacent_find(kernels.begin(), kernels.end(),
+ [](KernelAttr l, KernelAttr r) {
+ return l.getName() == r.getName();
+ }) != kernels.end()) {
+ return emitError() << "expected all kernels to be uniquely named";
+ }
+ return success();
+}
+
+KernelAttr KernelTableAttr::lookup(StringRef key) const {
+ std::pair<ArrayRef<KernelAttr>::iterator, bool> it =
+ impl::findAttrSorted(begin(), end(), key);
+ return it.second ? *it.first : KernelAttr();
----------------
ftynse wrote:
Also nit:why the `impl` function doesn't just return the end iterator to indicate not found?
https://github.com/llvm/llvm-project/pull/95292
More information about the Mlir-commits
mailing list