[Mlir-commits] [mlir] 1be0949 - [mlir][sparse] improved tensor type lowering

Aart Bik llvmlistbot at llvm.org
Thu Sep 1 09:24:33 PDT 2022


Author: Aart Bik
Date: 2022-09-01T09:24:20-07:00
New Revision: 1be09496bfd5fb764b1b2b3e62ca1c16e3180223

URL: https://github.com/llvm/llvm-project/commit/1be09496bfd5fb764b1b2b3e62ca1c16e3180223
DIFF: https://github.com/llvm/llvm-project/commit/1be09496bfd5fb764b1b2b3e62ca1c16e3180223.diff

LOG: [mlir][sparse] improved tensor type lowering

Also includes a first codegen example (although full support need tuple access)

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D133080

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/test/Dialect/SparseTensor/codegen.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index b905b442f0975..d82ebea74d205 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -33,6 +33,16 @@ namespace {
 // Helper methods.
 //===----------------------------------------------------------------------===//
 
+/// Reorders stored dimension to logical dimension.
+static unsigned reorder(const SparseTensorEncodingAttr &enc, unsigned d) {
+  auto order = enc.getDimOrdering();
+  if (order) {
+    assert(order.isPermutation());
+    return order.getDimPosition(d);
+  }
+  return d;
+}
+
 /// Maps a sparse tensor type to the appropriate compounded buffers.
 static Optional<Type> convertSparseTensorType(Type type) {
   auto enc = getSparseTensorEncoding(type);
@@ -47,12 +57,14 @@ static Optional<Type> convertSparseTensorType(Type type) {
   Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType;
   Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType;
   Type eltType = rType.getElementType();
+  ArrayRef<int64_t> shape = rType.getShape();
   //
   // Sparse tensor storage for rank-dimensional tensor is organized as a
   // single compound type with the following fields:
   //
   // struct {
-  //   memref<rank x index> dimSize  ; size in each dimension
+  //   ; if dynamic shape:
+  //     memref<rank x index> dimSize    ; size in each dimension
   //   ; per-dimension d:
   //   ;  if dense:
   //        <nothing>
@@ -61,23 +73,31 @@ static Optional<Type> convertSparseTensorType(Type type) {
   //        memref<? x ptr>  pointers-d  ; pointers for sparse dim d
   //   ;  if singleton:
   //        memref<? x idx>  indices-d   ; indices for singleton dim d
-  //   memref<? x eltType> values    ; values
+  //   memref<? x eltType> values        ; values
   // };
   //
-  // TODO: fill in the ? when statically known
-  //
-  // TODO: emit dimSizes when not needed (e.g. all-dense)
-  //
+  int64_t linear = 1;
+  bool allDense = true;
   unsigned rank = rType.getShape().size();
   SmallVector<Type, 8> fields;
-  fields.push_back(MemRefType::get({rank}, indexType));
+  // The dimSizes array.
+  if (!rType.hasStaticShape())
+    fields.push_back(MemRefType::get({rank}, indexType));
+  // Per-dimension storage.
   for (unsigned r = 0; r < rank; r++) {
+    // Get the original dimension (ro) for the current stored dimension (r).
+    unsigned ro = reorder(enc, r);
     // Dimension level types apply in order to the reordered dimension.
     // As a result, the compound type can be constructed directly in the given
     // order. Clients of this type know what field is what from the sparse
     // tensor type.
     switch (enc.getDimLevelType()[r]) {
     case SparseTensorEncodingAttr::DimLevelType::Dense:
+      // Linearize the size of consecutive dense dimensions.
+      if (ShapedType::isDynamic(shape[ro]) || ShapedType::isDynamic(linear))
+        linear = ShapedType::kDynamicSize;
+      else
+        linear *= shape[ro];
       break;
     case SparseTensorEncodingAttr::DimLevelType::Compressed:
     case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
@@ -85,16 +105,23 @@ static Optional<Type> convertSparseTensorType(Type type) {
     case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
       fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
       fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
+      allDense = false;
+      linear = 1;
       break;
     case SparseTensorEncodingAttr::DimLevelType::Singleton:
     case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
     case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
     case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
       fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
+      allDense = false;
+      linear = 1;
       break;
     }
   }
-  fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType));
+  // The values array.
+  int64_t nnz =
+      (rType.hasStaticShape() && allDense) ? linear : ShapedType::kDynamicSize;
+  fields.push_back(MemRefType::get({nnz}, eltType));
   // Sparse tensor storage (temporarily) lives in a tuple. This allows a
   // simple 1:1 type conversion during codegen. A subsequent pass uses
   // a 1:N type conversion to expand the tuple into its fields.
@@ -102,10 +129,10 @@ static Optional<Type> convertSparseTensorType(Type type) {
 }
 
 //===----------------------------------------------------------------------===//
-// Conversion rules.
+// Codegen rules.
 //===----------------------------------------------------------------------===//
 
-/// Sparse conversion rule for returns.
+/// Sparse codegen rule for returns.
 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -117,6 +144,36 @@ class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
   }
 };
 
+/// Sparse codegen rule for dimension accesses.
+class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    Type type = op.getSource().getType();
+    // Only rewrite annotated DimOp with constant index.
+    auto enc = getSparseTensorEncoding(type);
+    if (!enc)
+      return failure();
+    Optional<int64_t> index = op.getConstantIndex();
+    if (!index)
+      return failure();
+    // Access into static shape can query original type directly.
+    // Note that this is typically already done by DimOp's folding.
+    RankedTensorType rType = type.cast<RankedTensorType>();
+    if (rType.hasStaticShape()) {
+      rewriter.replaceOp(
+          op, constantIndex(rewriter, loc, rType.getShape()[*index]));
+      return success();
+    }
+    // Any other query can consult the dimSize array.
+    // TODO: this needs tuple access
+    return failure();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -136,5 +193,6 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
 /// the sparsification of linear algebra operations.
 void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
                                                RewritePatternSet &patterns) {
-  patterns.add<SparseReturnConverter>(typeConverter, patterns.getContext());
+  patterns.add<SparseReturnConverter, SparseDimOpConverter>(
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index b380c50f98bf4..b30d0d2b927f0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -158,8 +158,7 @@ struct SparseTensorCodegenPass
     ConversionTarget target(*ctx);
     // Everything in the sparse dialect must go!
     target.addIllegalDialect<SparseTensorDialect>();
-    // All dynamic rules below accept new function, call, return, and various
-    // tensor and bufferization operations as legal output of the rewriting.
+    // All dynamic rules below accept new function, call, return.
     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
       return converter.isSignatureLegal(op.getFunctionType());
     });
@@ -169,6 +168,10 @@ struct SparseTensorCodegenPass
     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
       return converter.isLegal(op.getOperandTypes());
     });
+    // Legal dialects may occur in generated code.
+    target.addLegalDialect<arith::ArithmeticDialect,
+                           bufferization::BufferizationDialect,
+                           memref::MemRefDialect, scf::SCFDialect>();
     // Populate with rules and apply rewriting rules.
     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
                                                                    converter);

diff  --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index f9a979e38d42b..66626163ff97c 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -6,7 +6,7 @@
   pointerBitWidth = 32
 }>
 
-#Dense = #sparse_tensor.encoding<{
+#Dense2D = #sparse_tensor.encoding<{
   dimLevelType = [ "dense", "dense" ],
   indexBitWidth = 64,
   pointerBitWidth = 32
@@ -30,6 +30,13 @@
   pointerBitWidth = 32
 }>
 
+#Dense3D = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "dense", "dense" ],
+  indexBitWidth = 64,
+  pointerBitWidth = 32,
+  dimOrdering = affine_map<(i,j,k) -> (k, i,j)>
+}>
+
 // CHECK-LABEL: func @sparse_nop(
 //  CHECK-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>) -> tuple<memref<1xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>
 //       CHECK: return %[[A]] : tuple<memref<1xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>
@@ -37,9 +44,9 @@ func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #Spa
   return %arg0 : tensor<?xf64, #SparseVector>
 }
 
-// CHECK-LABEL: func @sparse_dense(
+// CHECK-LABEL: func @sparse_dense_2d(
 //  CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xf64>>)
-func.func @sparse_dense(%arg0: tensor<?x?xf64, #Dense>) {
+func.func @sparse_dense_2d(%arg0: tensor<?x?xf64, #Dense2D>) {
   return
 }
 
@@ -60,3 +67,16 @@ func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
 func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
   return
 }
+
+//
+// Just a linearized array in the end. Dim op is statically known.
+//
+// CHECK-LABEL: func @sparse_dense_3d(
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<6000xf64>>) -> index
+//       CHECK: %[[C:.*]] = arith.constant 20 : index
+//       CHECK: return %[[C]] : index
+func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
+  %c = arith.constant 1 : index
+  %0 = tensor.dim %arg0, %c : tensor<10x20x30xf64, #Dense3D>
+  return %0 : index
+}


        


More information about the Mlir-commits mailing list