[Mlir-commits] [mlir] ae7942e - [mlir][sparse] adding `SparseTensorType::get{Pointer, Index}Type` methods

wren romano llvmlistbot at llvm.org
Wed Feb 15 14:38:02 PST 2023


Author: wren romano
Date: 2023-02-15T14:37:55-08:00
New Revision: ae7942e2960e73bd4e568b8b15d1ace35303ae10

URL: https://github.com/llvm/llvm-project/commit/ae7942e2960e73bd4e568b8b15d1ace35303ae10
DIFF: https://github.com/llvm/llvm-project/commit/ae7942e2960e73bd4e568b8b15d1ace35303ae10.diff

LOG: [mlir][sparse] adding `SparseTensorType::get{Pointer,Index}Type` methods

Depends On D143800

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D143946

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index ad5501cfa24ca..a52a0dadd42b9 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -144,6 +144,10 @@ DEPRECATED Level toStoredDim(RankedTensorType type, Dimension d);
 
 #undef DEPRECATED
 
+namespace detail {
+Type getIntegerOrIndexType(MLIRContext *ctx, unsigned bitwidth);
+} // namespace detail
+
 } // namespace sparse_tensor
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 4eeaa39e84236..c2adc2694e3ef 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -234,6 +234,16 @@ class SparseTensorType {
     return enc ? enc.getPointerBitWidth() : 0;
   }
 
+  /// Returns the index-overhead MLIR type, defaulting to `IndexType`.
+  Type getIndexType() const {
+    return detail::getIntegerOrIndexType(getContext(), getIndexBitWidth());
+  }
+
+  /// Returns the pointer-overhead MLIR type, defaulting to `IndexType`.
+  Type getPointerType() const {
+    return detail::getIntegerOrIndexType(getContext(), getPointerBitWidth());
+  }
+
 private:
   // These two must be const, to ensure coherence of the memoized fields.
   const RankedTensorType rtp;

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 32d998f7388c9..8ff474c4505d5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -114,18 +114,19 @@ SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
          << "expect positive value or ? for slice offset/size/stride";
 }
 
-static Type getIntegerOrIndexType(MLIRContext *ctx, unsigned bitwidth) {
+Type mlir::sparse_tensor::detail::getIntegerOrIndexType(MLIRContext *ctx,
+                                                        unsigned bitwidth) {
   if (bitwidth)
     return IntegerType::get(ctx, bitwidth);
   return IndexType::get(ctx);
 }
 
 Type SparseTensorEncodingAttr::getPointerType() const {
-  return getIntegerOrIndexType(getContext(), getPointerBitWidth());
+  return detail::getIntegerOrIndexType(getContext(), getPointerBitWidth());
 }
 
 Type SparseTensorEncodingAttr::getIndexType() const {
-  return getIntegerOrIndexType(getContext(), getIndexBitWidth());
+  return detail::getIntegerOrIndexType(getContext(), getIndexBitWidth());
 }
 
 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index ceee541f2a7a8..335f743e2db3d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -160,7 +160,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
       // Append linear x pointers, initialized to zero. Since each compressed
       // dimension initially already has a single zero entry, this maintains
       // the desired "linear + 1" length property at all times.
-      Type ptrType = stt.getEncoding().getPointerType();
+      Type ptrType = stt.getPointerType();
       Value ptrZero = constantZero(builder, loc, ptrType);
       createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, l,
                      ptrZero, linear);
@@ -279,8 +279,7 @@ static void createAllocFields(OpBuilder &builder, Location loc,
   // to all zeros, sets the dimSizes to known values and gives all pointer
   // fields an initial zero entry, so that it is easier to maintain the
   // "linear + 1" length property.
-  Value ptrZero =
-      constantZero(builder, loc, stt.getEncoding().getPointerType());
+  Value ptrZero = constantZero(builder, loc, stt.getPointerType());
   for (Level lvlRank = stt.getLvlRank(), l = 0; l < lvlRank; l++) {
     // Fills dim sizes array.
     // FIXME: this method seems to set *level* sizes, but the name is confusing
@@ -546,7 +545,7 @@ static void genEndInsert(OpBuilder &builder, Location loc,
       // times?
       //
       if (l > 0) {
-        Type ptrType = stt.getEncoding().getPointerType();
+        Type ptrType = stt.getPointerType();
         Value ptrMemRef = desc.getPtrMemRef(l);
         Value hi = desc.getPtrMemSize(builder, loc, l);
         Value zero = constantIndex(builder, loc, 0);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
index 8440630c2eefa..be59ba83f0f4b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp
@@ -179,14 +179,13 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
     llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
                             DimLevelType)>
         callback) {
-  const auto enc = stt.getEncoding();
-  assert(enc);
+  assert(stt.hasEncoding());
   // Construct the basic types.
-  Type idxType = enc.getIndexType();
-  Type ptrType = enc.getPointerType();
+  Type idxType = stt.getIndexType();
+  Type ptrType = stt.getPointerType();
   Type eltType = stt.getElementType();
 
-  Type metaDataType = StorageSpecifierType::get(enc);
+  Type metaDataType = StorageSpecifierType::get(stt.getEncoding());
   // memref<? x ptr>  pointers
   Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType);
   // memref<? x idx>  indices
@@ -195,7 +194,7 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
   Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
 
   foreachFieldInSparseTensor(
-      enc,
+      stt.getEncoding(),
       [metaDataType, ptrMemType, idxMemType, valMemType,
        callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind,
                  Level lvl, DimLevelType dlt) -> bool {


        


More information about the Mlir-commits mailing list