[Mlir-commits] [mlir] dd33481 - [mlir][sparse] add getPointerType/getIndexType to SparseTensorEncodingAttr.

Peiming Liu llvmlistbot at llvm.org
Thu Dec 1 14:01:57 PST 2022


Author: Peiming Liu
Date: 2022-12-01T22:01:52Z
New Revision: dd33481f48f264420862d1ee9eae83f2deab7078

URL: https://github.com/llvm/llvm-project/commit/dd33481f48f264420862d1ee9eae83f2deab7078
DIFF: https://github.com/llvm/llvm-project/commit/dd33481f48f264420862d1ee9eae83f2deab7078.diff

LOG: [mlir][sparse] add getPointerType/getIndexType to SparseTensorEncodingAttr.

add new interfaces to SparseTensorEncodingAttr to construct the pointer/index types based on pointer/index bitwidth.

Reviewed By: aartbik, wrengr

Differential Revision: https://reviews.llvm.org/D139141

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index e5272a907fc92..5e472d5998d42 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -158,6 +158,14 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     "unsigned":$indexBitWidth
   );
 
+  let extraClassDeclaration = [{
+    /// Returns the type for pointer storage based on pointerBitWidth
+    Type getPointerType() const;
+
+    /// Returns the type for index storage based on indexBitWidth
+    Type getIndexType() const;
+  }];
+
   let genVerifyDecl = 1;
   let hasCustomAssemblyFormat = 1;
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 652e0504c5ee1..599de1e5fee3d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -41,6 +41,18 @@ static bool acceptBitWidth(unsigned bitWidth) {
   }
 }
 
+Type SparseTensorEncodingAttr::getPointerType() const {
+  unsigned ptrWidth = getPointerBitWidth();
+  Type indexType = IndexType::get(getContext());
+  return ptrWidth ? IntegerType::get(getContext(), ptrWidth) : indexType;
+}
+
+Type SparseTensorEncodingAttr::getIndexType() const {
+  unsigned idxWidth = getIndexBitWidth();
+  Type indexType = IndexType::get(getContext());
+  return idxWidth ? IntegerType::get(getContext(), idxWidth) : indexType;
+}
+
 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
   if (failed(parser.parseLess()))
     return {};

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 97f1f952e5bd5..cb4dafbc7b625 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -203,12 +203,10 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
     return llvm::None;
   // Construct the basic types.
   auto *context = type.getContext();
-  unsigned idxWidth = enc.getIndexBitWidth();
-  unsigned ptrWidth = enc.getPointerBitWidth();
   RankedTensorType rType = type.cast<RankedTensorType>();
   Type indexType = IndexType::get(context);
-  Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType;
-  Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType;
+  Type idxType = enc.getIndexType();
+  Type ptrType = enc.getPointerType();
   Type eltType = rType.getElementType();
   //
   // Sparse tensor storage scheme for rank-dimensional tensor is organized
@@ -268,21 +266,20 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
       // Append linear x pointers, initialized to zero. Since each compressed
       // dimension initially already has a single zero entry, this maintains
       // the desired "linear + 1" length property at all times.
-      unsigned ptrWidth = getSparseTensorEncoding(rtp).getPointerBitWidth();
-      Type indexType = builder.getIndexType();
-      Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType;
+      Type ptrType = getSparseTensorEncoding(rtp).getPointerType();
       Value ptrZero = constantZero(builder, loc, ptrType);
       createPushback(builder, loc, fields, field, ptrZero, linear);
       return;
     }
     if (isSingletonDim(rtp, r)) {
       return; // nothing to do
-    }         // Keep compounding the size, but nothing needs to be initialized
-      // at this level. We will eventually reach a compressed level or
-      // otherwise the values array for the from-here "all-dense" case.
-      assert(isDenseDim(rtp, r));
-      Value size = sizeAtStoredDim(builder, loc, rtp, fields, r);
-      linear = builder.create<arith::MulIOp>(loc, linear, size);
+    }
+    // Keep compounding the size, but nothing needs to be initialized
+    // at this level. We will eventually reach a compressed level or
+    // otherwise the values array for the from-here "all-dense" case.
+    assert(isDenseDim(rtp, r));
+    Value size = sizeAtStoredDim(builder, loc, rtp, fields, r);
+    linear = builder.create<arith::MulIOp>(loc, linear, size);
   }
   // Reached values array so prepare for an insertion.
   Value valZero = constantZero(builder, loc, rtp.getElementType());
@@ -315,13 +312,10 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
                               SmallVectorImpl<Value> &fields) {
   auto enc = getSparseTensorEncoding(type);
   assert(enc);
-  // Construct the basic types.
-  unsigned idxWidth = enc.getIndexBitWidth();
-  unsigned ptrWidth = enc.getPointerBitWidth();
   RankedTensorType rtp = type.cast<RankedTensorType>();
   Type indexType = builder.getIndexType();
-  Type idxType = idxWidth ? builder.getIntegerType(idxWidth) : indexType;
-  Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType;
+  Type idxType = enc.getIndexType();
+  Type ptrType = enc.getPointerType();
   Type eltType = rtp.getElementType();
   auto shape = rtp.getShape();
   unsigned rank = shape.size();
@@ -622,9 +616,7 @@ static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
       // TODO: avoid cleanup and keep compressed scheme consistent at all times?
       //
       if (d > 0) {
-        unsigned ptrWidth = getSparseTensorEncoding(rtp).getPointerBitWidth();
-        Type indexType = builder.getIndexType();
-        Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType;
+        Type ptrType = getSparseTensorEncoding(rtp).getPointerType();
         Value mz = constantIndex(builder, loc, getMemSizesIndex(field));
         Value hi = genLoad(builder, loc, fields[memSizesIdx], mz);
         Value zero = constantIndex(builder, loc, 0);


        


More information about the Mlir-commits mailing list