[Mlir-commits] [mlir] [mlir][sparse] external entry method wrapper for sparse tensors (PR #80326)
Peiming Liu
llvmlistbot at llvm.org
Thu Feb 1 11:31:45 PST 2024
================
@@ -0,0 +1,236 @@
+//===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Utils/CodegenUtils.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+using namespace sparse_tensor;
+
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
+// Convert type range to new types range, with sparse tensors externalized.
+void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
+ SmallVectorImpl<Type> *extraTypes = nullptr) {
+ for (auto type : types) {
+ if (auto rtp = dyn_cast<RankedTensorType>(type)) {
+ const SparseTensorType stt(rtp);
+ if (stt.hasEncoding()) {
+ auto shape = {ShapedType::kDynamic};
+ // Convert the external representation of the values array.
+ auto vtp = RankedTensorType::get(shape, stt.getElementType());
+ convTypes.push_back(vtp);
+ if (extraTypes)
+ extraTypes->push_back(vtp);
+ // Convert the external representations of the pos/crd arrays.
+ for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
+ const auto lt = stt.getLvlType(lvl);
+ if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
+ auto ptp = RankedTensorType::get(shape, stt.getPosType());
+ auto ctp = RankedTensorType::get(shape, stt.getCrdType());
+ convTypes.push_back(ptp);
+ convTypes.push_back(ctp);
+ if (extraTypes) {
+ extraTypes->push_back(ptp);
+ extraTypes->push_back(ctp);
+ }
+ } else {
+ assert(isDenseLT(lt)); // TODO: handle other cases
+ }
+ }
+ continue;
+ }
+ }
+ // All other data passes through unmodified.
+ convTypes.push_back(type);
----------------
PeimingLiu wrote:
Nit: maybe put this at beginning? i.e., `if (!encoding) {push_back; continue}` It might make it easier to follow.
https://github.com/llvm/llvm-project/pull/80326
More information about the Mlir-commits
mailing list