[Mlir-commits] [mlir] [mlir][sparse] external entry method wrapper for sparse tensors (PR #80326)

Peiming Liu llvmlistbot at llvm.org
Thu Feb 1 11:31:45 PST 2024


================
@@ -0,0 +1,236 @@
+//===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Utils/CodegenUtils.h"
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+using namespace sparse_tensor;
+
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
+// Convert type range to new types range, with sparse tensors externalized.
+void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
+               SmallVectorImpl<Type> *extraTypes = nullptr) {
+  for (auto type : types) {
+    if (auto rtp = dyn_cast<RankedTensorType>(type)) {
+      const SparseTensorType stt(rtp);
+      if (stt.hasEncoding()) {
+        auto shape = {ShapedType::kDynamic};
+        // Convert the external representation of the values array.
+        auto vtp = RankedTensorType::get(shape, stt.getElementType());
+        convTypes.push_back(vtp);
+        if (extraTypes)
+          extraTypes->push_back(vtp);
+        // Convert the external representations of the pos/crd arrays.
+        for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
+          const auto lt = stt.getLvlType(lvl);
+          if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
+            auto ptp = RankedTensorType::get(shape, stt.getPosType());
+            auto ctp = RankedTensorType::get(shape, stt.getCrdType());
+            convTypes.push_back(ptp);
+            convTypes.push_back(ctp);
+            if (extraTypes) {
+              extraTypes->push_back(ptp);
+              extraTypes->push_back(ctp);
+            }
+          } else {
+            assert(isDenseLT(lt)); // TODO: handle other cases
+          }
+        }
+        continue;
+      }
+    }
+    // All other data passes through unmodified.
+    convTypes.push_back(type);
----------------
PeimingLiu wrote:

Nit: maybe put this at beginning? i.e., `if (!encoding) {push_back; continue}` It might make it easier to follow.

https://github.com/llvm/llvm-project/pull/80326


More information about the Mlir-commits mailing list