[Mlir-commits] [mlir] c194b49 - [mlir][sparse] add full dimension ordering support

Aart Bik llvmlistbot at llvm.org
Fri May 21 12:35:30 PDT 2021


Author: Aart Bik
Date: 2021-05-21T12:35:13-07:00
New Revision: c194b49c9c8dfe01804ecd0b90814d1e98382fc1

URL: https://github.com/llvm/llvm-project/commit/c194b49c9c8dfe01804ecd0b90814d1e98382fc1
DIFF: https://github.com/llvm/llvm-project/commit/c194b49c9c8dfe01804ecd0b90814d1e98382fc1.diff

LOG: [mlir][sparse] add full dimension ordering support

This revision completes the "dimension ordering" feature
of sparse tensor types that enables the programmer to
define a preferred order on dimension access (other than
the default left-to-right order). This enables e.g. selection
of column-major over row-major storage for sparse matrices,
but generalized to any rank, as in:

dimOrdering = affine_map<(i,j,k,l,m,n,o,p) -> (p,o,j,k,i,l,m,n)>

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D102856

Added: 
    mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir

Modified: 
    mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/ExecutionEngine/SparseUtils.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Dialect/SparseTensor/sparse_lower.mlir
    mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
index a263fc2ad1d35..ffb450b1a3756 100644
--- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
@@ -351,8 +351,8 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT double rtclock();
 //===----------------------------------------------------------------------===//
 // Small runtime support library for sparse tensors.
 //===----------------------------------------------------------------------===//
-extern "C" MLIR_CRUNNERUTILS_EXPORT void *openTensorC(char *filename,
-                                                      uint64_t *idata);
+extern "C" MLIR_CRUNNERUTILS_EXPORT void *
+openTensorC(char *filename, uint64_t *idata, uint64_t *perm);
 extern "C" MLIR_CRUNNERUTILS_EXPORT void
 readTensorItemC(void *tensor, uint64_t *idata, double *ddata);
 extern "C" MLIR_CRUNNERUTILS_EXPORT void closeTensor(void *tensor);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 5239a3d4aa7d4..ed1413d5f4bd5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -54,6 +54,20 @@ getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
   }
 }
 
+/// Returns integers of given width and values as a constant tensor.
+/// We cast the static shape into a dynamic shape to ensure that the
+/// method signature remains uniform accross 
diff erent tensor dimensions.
+static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
+                       Location loc, ArrayRef<APInt> values) {
+  Type etp = rewriter.getIntegerType(width);
+  unsigned sz = values.size();
+  RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
+  RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp);
+  auto elts =
+      rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, values));
+  return rewriter.create<tensor::CastOp>(loc, tt2, elts);
+}
+
 /// Returns function reference (first hit also inserts into module).
 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
                                  ValueRange operands) {
@@ -117,22 +131,29 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
       return failure();
     // User pointer.
     params.push_back(operands[0]);
-    // Sparsity annotations in tensor constant form. Note that we cast
-    // the static shape into a dynamic shape to ensure that the method
-    // signature remains uniform accross 
diff erent tensor dimensions.
+    // Sparsity annotations in tensor constant form.
     SmallVector<APInt, 4> attrs;
     unsigned sz = enc.getDimLevelType().size();
     for (unsigned i = 0; i < sz; i++)
       attrs.push_back(
           APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
-    Type etp = rewriter.getIntegerType(8);
-    RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
-    RankedTensorType tt2 =
-        RankedTensorType::get({ShapedType::kDynamicSize}, etp);
-    auto elts =
-        rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, attrs));
-    params.push_back(rewriter.create<tensor::CastOp>(loc, tt2, elts));
-    // Seconary and primary types encoding.
+    params.push_back(getTensor(rewriter, 8, loc, attrs));
+    // Dimension order permutation array. This is the "identity"
+    // permutation by default, or otherwise the "reverse" permutation
+    // of a given ordering, so that indices can be mapped quickly
+    // to the right position.
+    SmallVector<APInt, 4> perm(sz);
+    AffineMap p = enc.getDimOrdering();
+    if (p) {
+      assert(p.isPermutation() && p.getNumResults() == sz);
+      for (unsigned i = 0; i < sz; i++)
+        perm[p.getDimPosition(i)] = APInt(64, i);
+    } else {
+      for (unsigned i = 0; i < sz; i++)
+        perm[i] = APInt(64, i);
+    }
+    params.push_back(getTensor(rewriter, 64, loc, perm));
+    // Secondary and primary types encoding.
     unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
     unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
     unsigned primary;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 6fb90dcc645f8..ffa16eb2c0127 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -333,6 +333,18 @@ struct CodeGen {
 
 } // namespace
 
+// Helper method to apply dimension ordering permutation.
+static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) {
+  if (enc) {
+    auto order = enc.getDimOrdering();
+    if (order) {
+      assert(order.isPermutation());
+      return order.getDimPosition(d);
+    }
+  }
+  return d;
+}
+
 // Helper method to translate dim level type to internal representation.
 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) {
   if (enc) {
@@ -353,17 +365,17 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
   unsigned lhs = numTensors - 1;
   for (unsigned t = 0; t < numTensors; t++) {
     auto map = op.getIndexingMap(t);
-    unsigned rank = op.getShapedType(t).getRank();
+    if (!map.isProjectedPermutation())
+      return false;
     auto enc = getSparseTensorEncoding(op.getShapedType(t));
     if (enc) {
       annotated = true;
-      if (enc.getDimOrdering() && !enc.getDimOrdering().isIdentity())
-        return false; // TODO: handle permutations
       if (t == lhs)
         return false; // TODO: handle sparse outputs
     }
-    for (unsigned d = 0; d < rank; d++) {
-      unsigned idx = map.getDimPosition(d);
+    assert(map.getNumResults() == op.getShapedType(t).getRank());
+    for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
+      unsigned idx = map.getDimPosition(perm(enc, d));
       merger.setDim(t, idx, toDim(enc, d));
     }
   }
@@ -405,18 +417,18 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
   unsigned numTensors = op.getNumShapedOperands();
   for (unsigned t = 0; t < numTensors; t++) {
     auto map = op.getIndexingMap(t);
+    auto enc = getSparseTensorEncoding(op.getShapedType(t));
     assert(map.getNumDims() == n);
     // Skip dense tensor constraints when sparse only is requested.
-    if (sparseOnly && !getSparseTensorEncoding(op.getShapedType(t)))
+    if (sparseOnly && !enc)
       continue;
-    // At the moment, we take the index variables in the tensor access
-    // expression in the order in which they appear (conceptually a
-    // "row-major" layout of every tensor). So, a tensor access A_ijk
-    // forces the ordering i < j < k on the loop indices.
-    // TODO: support affine map to define alternative dimension orders.
-    for (unsigned d = 1, e = map.getNumResults(); d < e; d++) {
-      unsigned f = map.getDimPosition(d - 1);
-      unsigned t = map.getDimPosition(d);
+    // Each tensor expression and optional dimension ordering (row-major
+    // by default) puts an ordering constraint on the loop indices. For
+    // example, the tensor expresion A_ijk forces the ordering i < j < k
+    // on the loop indices if no explicit dimension ordering is given.
+    for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) {
+      unsigned f = map.getDimPosition(perm(enc, d - 1));
+      unsigned t = map.getDimPosition(perm(enc, d));
       adjM[f][t] = true;
     }
   }
@@ -441,15 +453,10 @@ static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
                                          Value val) {
   if (auto arg = val.dyn_cast<BlockArgument>()) {
     unsigned argN = arg.getArgNumber();
-    if (arg.getOwner()->getParentOp() == op) {
-      // Any parameter of the generic op is considered a tensor,
-      // indexed by the implicit loop bounds.
-      auto map = op.getIndexingMap(argN);
-      if (map.isProjectedPermutation())
-        return merger.addExp(Kind::kTensor, argN);
-      // Cannot handle (yet).
-      return None;
-    }
+    // Any parameter of the generic op is considered a tensor,
+    // indexed by the implicit loop bounds.
+    if (arg.getOwner()->getParentOp() == op)
+      return merger.addExp(Kind::kTensor, argN);
     // Any parameter of a higher op is invariant.
     return merger.addExp(Kind::kInvariant, val);
   }
@@ -568,10 +575,10 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
     auto enc = getSparseTensorEncoding(tensorType);
     // Scan all dimensions of current tensor.
     args.clear();
-    for (unsigned d = 0, rank = shape.size(); d < rank; d++) {
-      unsigned i = map.getDimPosition(d);
+    for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
+      unsigned idx = map.getDimPosition(perm(enc, d));
       // Handle sparse storage schemes.
-      if (merger.isDim(t, i, Dim::kSparse)) {
+      if (merger.isDim(t, idx, Dim::kSparse)) {
         auto dynShape = {ShapedType::kDynamicSize};
         auto ptrTp = MemRefType::get(
             dynShape, genIntType(rewriter, enc.getPointerBitWidth()));
@@ -579,9 +586,9 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
             dynShape, genIntType(rewriter, enc.getIndexBitWidth()));
         Value dim = rewriter.create<ConstantIndexOp>(loc, d);
         // Generate sparse primitives to obtains pointer and indices.
-        codegen.pointers[t][i] =
+        codegen.pointers[t][idx] =
             rewriter.create<ToPointersOp>(loc, ptrTp, tensor, dim);
-        codegen.indices[t][i] =
+        codegen.indices[t][idx] =
             rewriter.create<ToIndicesOp>(loc, indTp, tensor, dim);
       }
       // Find lower and upper bound in current dimension.
@@ -592,7 +599,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
       } else {
         up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
       }
-      codegen.sizes[i] = codegen.highs[t][i] = up;
+      codegen.sizes[idx] = codegen.highs[t][idx] = up;
     }
     // Perform the required bufferization. All dense inputs materialize
     // from the input tensor. The dense output tensor needs special
@@ -705,8 +712,8 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
   unsigned tensor = merger.exp(exp).e0;
   auto map = op.getIndexingMap(tensor);
   auto enc = getSparseTensorEncoding(op.getShapedType(tensor));
-  for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
-    unsigned idx = map.getDimPosition(i);
+  for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
+    unsigned idx = map.getDimPosition(perm(enc, d));
     args.push_back(codegen.loops[idx]); // universal dense index
     if (enc) {
       args.clear();
@@ -737,8 +744,9 @@ static void genTensorStore(Merger &merger, CodeGen &codegen,
   // Actual store.
   SmallVector<Value, 4> args;
   auto map = op.getIndexingMap(tensor);
-  for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
-    unsigned idx = map.getDimPosition(i);
+  assert(!getSparseTensorEncoding(op.getShapedType(tensor)));
+  for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
+    unsigned idx = map.getDimPosition(d);
     args.push_back(codegen.loops[idx]); // universal dense index
   }
   Value ptr = codegen.buffers[tensor];
@@ -888,8 +896,9 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
     bool atLevel = ldx == -1u;
     unsigned tensor = merger.exp(exp).e0;
     auto map = op.getIndexingMap(tensor);
-    for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
-      unsigned idx = map.getDimPosition(i);
+    auto enc = getSparseTensorEncoding(op.getShapedType(tensor));
+    for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
+      unsigned idx = map.getDimPosition(perm(enc, d));
       if (!codegen.loops[idx])
         return; // still in play
       else if (idx == ldx)
@@ -1001,9 +1010,8 @@ static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
   for (unsigned t = 0; t < numTensors; t++) {
     if (!getSparseTensorEncoding(op.getShapedType(t))) {
       auto map = op.getIndexingMap(t);
-      unsigned r = map.getNumResults();
-      for (unsigned i = 0; i < r; i++) {
-        if (map.getDimPosition(i) == idx && i != r - 1)
+      for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
+        if (map.getDimPosition(d) == idx && d != rank - 1)
           return false;
       }
     }

diff  --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp
index 1a654cf22c7a7..3c414f691c5bf 100644
--- a/mlir/lib/ExecutionEngine/SparseUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp
@@ -243,9 +243,11 @@ class SparseTensorStorage : public SparseTensorStorageBase {
 
 /// Templated reader.
 template <typename P, typename I, typename V>
-void *newSparseTensor(char *filename, uint8_t *sparsity, uint64_t size) {
+void *newSparseTensor(char *filename, uint8_t *sparsity, uint64_t *perm,
+                      uint64_t size) {
   uint64_t idata[64];
-  SparseTensor *t = static_cast<SparseTensor *>(openTensorC(filename, idata));
+  SparseTensor *t =
+      static_cast<SparseTensor *>(openTensorC(filename, idata, perm));
   assert(size == t->getRank()); // sparsity array must match rank
   SparseTensorStorageBase *tensor =
       new SparseTensorStorage<P, I, V>(t, sparsity);
@@ -371,7 +373,7 @@ extern "C" {
 /// understood by other methods in the sparse runtime support library. An
 /// array parameter is used to pass the rank, the number of nonzero elements,
 /// and the dimension sizes (one per rank).
-void *openTensorC(char *filename, uint64_t *idata) {
+void *openTensorC(char *filename, uint64_t *idata, uint64_t *perm) {
   // Open the file.
   FILE *file = fopen(filename, "r");
   if (!file) {
@@ -393,16 +395,24 @@ void *openTensorC(char *filename, uint64_t *idata) {
   uint64_t nnz = idata[1];
   std::vector<uint64_t> indices(rank);
   for (uint64_t r = 0; r < rank; r++)
-    indices[r] = idata[2 + r];
+    if (perm)
+      indices[perm[r]] = idata[2 + r];
+    else
+      indices[r] = idata[2 + r];
   SparseTensor *tensor = new SparseTensor(indices, nnz);
   // Read all nonzero elements.
   for (uint64_t k = 0; k < nnz; k++) {
+    uint64_t idx = -1;
     for (uint64_t r = 0; r < rank; r++) {
-      if (fscanf(file, "%" PRIu64, &indices[r]) != 1) {
+      if (fscanf(file, "%" PRIu64, &idx) != 1) {
         fprintf(stderr, "Cannot find next index in %s\n", filename);
         exit(1);
       }
-      indices[r]--; // 0-based index
+      // Add 0-based index.
+      if (perm)
+        indices[perm[r]] = idx - 1;
+      else
+        indices[r] = idx - 1;
     }
     double value;
     if (fscanf(file, "%lg\n", &value) != 1) {
@@ -421,7 +431,7 @@ void *openTensorC(char *filename, uint64_t *idata) {
 void *openTensor(char *filename, uint64_t *ibase, uint64_t *idata,
                  uint64_t ioff, uint64_t isize, uint64_t istride) {
   assert(istride == 1);
-  return openTensorC(filename, idata + ioff);
+  return openTensorC(filename, idata + ioff, nullptr);
 }
 
 /// Yields the next element from the given opaque sparse tensor object.
@@ -477,7 +487,7 @@ char *getTensorFilename(uint64_t id) {
 
 #define CASE(p, i, v, P, I, V)                                                 \
   if (ptrTp == (p) && indTp == (i) && valTp == (v))                            \
-  return newSparseTensor<P, I, V>(filename, sparsity, asize)
+  return newSparseTensor<P, I, V>(filename, sparsity, perm, asize)
 
 #define IMPL1(RET, NAME, TYPE, LIB)                                            \
   RET NAME(void *tensor) {                                                     \
@@ -515,9 +525,12 @@ enum PrimaryTypeEnum : uint64_t {
 
 void *newSparseTensor(char *filename, uint8_t *abase, uint8_t *adata,
                       uint64_t aoff, uint64_t asize, uint64_t astride,
-                      uint64_t ptrTp, uint64_t indTp, uint64_t valTp) {
-  assert(astride == 1);
+                      uint64_t *pbase, uint64_t *pdata, uint64_t poff,
+                      uint64_t psize, uint64_t pstride, uint64_t ptrTp,
+                      uint64_t indTp, uint64_t valTp) {
+  assert(astride == 1 && pstride == 1);
   uint8_t *sparsity = adata + aoff;
+  uint64_t *perm = pdata + poff;
 
   // The most common cases: 64-bit or 32-bit overhead, double/float values.
   CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 777d7ffbd0297..5d8c30845416a 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -20,6 +20,11 @@
   dimLevelType = ["dense", "compressed"]
 }>
 
+#SparseTensor = #sparse_tensor.encoding<{
+  dimLevelType = ["dense", "compressed", "compressed"],
+  dimOrdering = affine_map<(i,j,k) -> (k,i,j)>
+}>
+
 // CHECK-LABEL: func @sparse_dim(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[C:.*]] = constant 0 : index
@@ -35,7 +40,9 @@ func @sparse_dim(%arg0: tensor<?xf64, #SparseVector>) -> index {
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
 //       CHECK: %[[D:.*]] = constant dense<1> : tensor<1xi8>
 //       CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<1xi8> to tensor<?xi8>
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi8>, i64, i64, i64) -> !llvm.ptr<i8>
+//       CHECK: %[[P:.*]] = constant dense<0> : tensor<1xi64>
+//       CHECK: %[[Q:.*]] = tensor.cast %[[P]] : tensor<1xi64> to tensor<?xi64>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %[[Q]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi8>, tensor<?xi64>, i64, i64, i64) -> !llvm.ptr<i8>
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<128xf64, #SparseVector>
@@ -46,13 +53,28 @@ func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
 //       CHECK: %[[D:.*]] = constant dense<[0, 1]> : tensor<2xi8>
 //       CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<2xi8> to tensor<?xi8>
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi8>, i64, i64, i64) -> !llvm.ptr<i8>
+//       CHECK: %[[P:.*]] = constant dense<[0, 1]> : tensor<2xi64>
+//       CHECK: %[[Q:.*]] = tensor.cast %[[P]] : tensor<2xi64> to tensor<?xi64>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %[[Q]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi8>, tensor<?xi64>, i64, i64, i64) -> !llvm.ptr<i8>
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #SparseMatrix> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #SparseMatrix>
   return %0 : tensor<?x?xf32, #SparseMatrix>
 }
 
+// CHECK-LABEL: func @sparse_new3d(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+//       CHECK: %[[D:.*]] = constant dense<[0, 1, 1]> : tensor<3xi8>
+//       CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<3xi8> to tensor<?xi8>
+//       CHECK: %[[P:.*]] = constant dense<[1, 2, 0]> : tensor<3xi64>
+//       CHECK: %[[Q:.*]] = tensor.cast %[[P]] : tensor<3xi64> to tensor<?xi64>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %[[Q]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi8>, tensor<?xi64>, i64, i64, i64) -> !llvm.ptr<i8>
+//       CHECK: return %[[T]] : !llvm.ptr<i8>
+func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
+  %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?x?xf32, #SparseTensor>
+  return %0 : tensor<?x?x?xf32, #SparseTensor>
+}
+
 // CHECK-LABEL: func @sparse_pointers(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[C:.*]] = constant 0 : index

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
index ce2395ff7112f..c105f7ebec32d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
@@ -21,60 +21,60 @@
 }
 
 // CHECK-HIR-LABEL:   func @matvec(
-// CHECK-HIR-SAME:                 %[[VAL_0:.*]]: tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>>,
+// CHECK-HIR-SAME:                 %[[VAL_0:.*]]: tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-HIR-SAME:                 %[[VAL_1:.*]]: tensor<64xf64>,
-// CHECK-HIR-SAME:                 %[[VAL_2:.*]]: tensor<64xf64>) -> tensor<64xf64> {
-// CHECK-HIR:           %[[VAL_3:.*]] = constant 64 : index
+// CHECK-HIR-SAME:                 %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
+// CHECK-HIR:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK-HIR:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK-HIR:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK-HIR:           %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK-HIR:           %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK-HIR:           %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
-// CHECK-HIR:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
-// CHECK-HIR:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64>
-// CHECK-HIR:           %[[VAL_12:.*]] = memref.alloc() : memref<64xf64>
-// CHECK-HIR:           linalg.copy(%[[VAL_11]], %[[VAL_12]]) : memref<64xf64>, memref<64xf64>
-// CHECK-HIR:           scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK-HIR:             %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref<?xindex>
-// CHECK-HIR:             %[[VAL_15:.*]] = addi %[[VAL_13]], %[[VAL_5]] : index
-// CHECK-HIR:             %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
-// CHECK-HIR:             %[[VAL_17:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_13]]] : memref<64xf64>
-// CHECK-HIR:             %[[VAL_18:.*]] = scf.for %[[VAL_19:.*]] = %[[VAL_14]] to %[[VAL_16]] step %[[VAL_5]] iter_args(%[[VAL_20:.*]] = %[[VAL_17]]) -> (f64) {
-// CHECK-HIR:               %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref<?xindex>
-// CHECK-HIR:               %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xf64>
-// CHECK-HIR:               %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<64xf64>
-// CHECK-HIR:               %[[VAL_24:.*]] = mulf %[[VAL_22]], %[[VAL_23]] : f64
-// CHECK-HIR:               %[[VAL_25:.*]] = addf %[[VAL_20]], %[[VAL_24]] : f64
-// CHECK-HIR:               scf.yield %[[VAL_25]] : f64
+// CHECK-HIR:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK-HIR:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK-HIR:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
+// CHECK-HIR:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
+// CHECK-HIR:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
+// CHECK-HIR:           %[[VAL_11:.*]] = memref.alloc() : memref<32xf64>
+// CHECK-HIR:           linalg.copy(%[[VAL_10]], %[[VAL_11]]) : memref<32xf64>, memref<32xf64>
+// CHECK-HIR:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK-HIR:             %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK-HIR:             %[[VAL_14:.*]] = addi %[[VAL_12]], %[[VAL_5]] : index
+// CHECK-HIR:             %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK-HIR:             %[[VAL_16:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<32xf64>
+// CHECK-HIR:             %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f64) {
+// CHECK-HIR:               %[[VAL_20:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK-HIR:               %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf64>
+// CHECK-HIR:               %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref<64xf64>
+// CHECK-HIR:               %[[VAL_23:.*]] = mulf %[[VAL_21]], %[[VAL_22]] : f64
+// CHECK-HIR:               %[[VAL_24:.*]] = addf %[[VAL_19]], %[[VAL_23]] : f64
+// CHECK-HIR:               scf.yield %[[VAL_24]] : f64
 // CHECK-HIR:             }
-// CHECK-HIR:             store %[[VAL_26:.*]], %[[VAL_12]]{{\[}}%[[VAL_13]]] : memref<64xf64>
+// CHECK-HIR:             memref.store %[[VAL_25:.*]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<32xf64>
 // CHECK-HIR:           }
-// CHECK-HIR:           %[[VAL_27:.*]] = memref.tensor_load %[[VAL_12]] : memref<64xf64>
-// CHECK-HIR:           return %[[VAL_27]] : tensor<64xf64>
+// CHECK-HIR:           %[[VAL_26:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64>
+// CHECK-HIR:           return %[[VAL_26]] : tensor<32xf64>
 // CHECK-HIR:         }
 
 // CHECK-MIR-LABEL:   func @matvec(
 // CHECK-MIR-SAME:                 %[[VAL_0:.*]]: !llvm.ptr<i8>,
 // CHECK-MIR-SAME:                 %[[VAL_1:.*]]: tensor<64xf64>,
-// CHECK-MIR-SAME:                 %[[VAL_2:.*]]: tensor<64xf64>) -> tensor<64xf64> {
-// CHECK-MIR:           %[[VAL_3:.*]] = constant 64 : index
+// CHECK-MIR-SAME:                 %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
+// CHECK-MIR:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK-MIR:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK-MIR:           %[[VAL_5:.*]] = constant 1 : index
 // CHECK-MIR:           %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-MIR:           %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-MIR:           %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
 // CHECK-MIR:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
-// CHECK-MIR:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64>
-// CHECK-MIR:           %[[VAL_11:.*]] = memref.alloc() : memref<64xf64>
+// CHECK-MIR:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
+// CHECK-MIR:           %[[VAL_11:.*]] = memref.alloc() : memref<32xf64>
 // CHECK-MIR:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK-MIR:             %[[VAL_13:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_12]]] : memref<64xf64>
-// CHECK-MIR:             store %[[VAL_13]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<64xf64>
+// CHECK-MIR:             %[[VAL_13:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_12]]] : memref<32xf64>
+// CHECK-MIR:             memref.store %[[VAL_13]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<32xf64>
 // CHECK-MIR:           }
 // CHECK-MIR:           scf.for %[[VAL_14:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK-MIR:             %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
 // CHECK-MIR:             %[[VAL_16:.*]] = addi %[[VAL_14]], %[[VAL_5]] : index
 // CHECK-MIR:             %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_16]]] : memref<?xindex>
-// CHECK-MIR:             %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<64xf64>
+// CHECK-MIR:             %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<32xf64>
 // CHECK-MIR:             %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_5]] iter_args(%[[VAL_21:.*]] = %[[VAL_18]]) -> (f64) {
 // CHECK-MIR:               %[[VAL_22:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref<?xindex>
 // CHECK-MIR:               %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xf64>
@@ -83,32 +83,32 @@
 // CHECK-MIR:               %[[VAL_26:.*]] = addf %[[VAL_21]], %[[VAL_25]] : f64
 // CHECK-MIR:               scf.yield %[[VAL_26]] : f64
 // CHECK-MIR:             }
-// CHECK-MIR:             store %[[VAL_27:.*]], %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<64xf64>
+// CHECK-MIR:             memref.store %[[VAL_27:.*]], %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<32xf64>
 // CHECK-MIR:           }
-// CHECK-MIR:           %[[VAL_28:.*]] = memref.tensor_load %[[VAL_11]] : memref<64xf64>
-// CHECK-MIR:           return %[[VAL_28]] : tensor<64xf64>
+// CHECK-MIR:           %[[VAL_28:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64>
+// CHECK-MIR:           return %[[VAL_28]] : tensor<32xf64>
 // CHECK-MIR:         }
 
 // CHECK-LIR-LABEL:   func @matvec(
 // CHECK-LIR-SAME:                 %[[VAL_0:.*]]: !llvm.ptr<i8>,
 // CHECK-LIR-SAME:                 %[[VAL_1:.*]]: memref<64xf64>,
-// CHECK-LIR-SAME:                 %[[VAL_2:.*]]: memref<64xf64>) -> memref<64xf64> {
-// CHECK-LIR:           %[[VAL_3:.*]] = constant 64 : index
+// CHECK-LIR-SAME:                 %[[VAL_2:.*]]: memref<32xf64>) -> memref<32xf64> {
+// CHECK-LIR:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK-LIR:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK-LIR:           %[[VAL_5:.*]] = constant 1 : index
 // CHECK-LIR:           %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-LIR:           %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-LIR:           %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
-// CHECK-LIR:           %[[VAL_9:.*]] = memref.alloc() : memref<64xf64>
+// CHECK-LIR:           %[[VAL_9:.*]] = memref.alloc() : memref<32xf64>
 // CHECK-LIR:           scf.for %[[VAL_10:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
-// CHECK-LIR:             %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<64xf64>
-// CHECK-LIR:             store %[[VAL_11]], %[[VAL_9]]{{\[}}%[[VAL_10]]] : memref<64xf64>
+// CHECK-LIR:             %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<32xf64>
+// CHECK-LIR:             memref.store %[[VAL_11]], %[[VAL_9]]{{\[}}%[[VAL_10]]] : memref<32xf64>
 // CHECK-LIR:           }
 // CHECK-LIR:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK-LIR:             %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
 // CHECK-LIR:             %[[VAL_14:.*]] = addi %[[VAL_12]], %[[VAL_5]] : index
 // CHECK-LIR:             %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<?xindex>
-// CHECK-LIR:             %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<64xf64>
+// CHECK-LIR:             %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<32xf64>
 // CHECK-LIR:             %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f64) {
 // CHECK-LIR:               %[[VAL_20:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_18]]] : memref<?xindex>
 // CHECK-LIR:               %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf64>
@@ -117,21 +117,21 @@
 // CHECK-LIR:               %[[VAL_24:.*]] = addf %[[VAL_19]], %[[VAL_23]] : f64
 // CHECK-LIR:               scf.yield %[[VAL_24]] : f64
 // CHECK-LIR:             }
-// CHECK-LIR:             store %[[VAL_25:.*]], %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<64xf64>
+// CHECK-LIR:             memref.store %[[VAL_25:.*]], %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<32xf64>
 // CHECK-LIR:           }
-// CHECK-LIR:           return %[[VAL_9]] : memref<64xf64>
+// CHECK-LIR:           return %[[VAL_9]] : memref<32xf64>
 // CHECK-LIR:         }
 
-func @matvec(%arga: tensor<64x64xf64, #CSR>,
+func @matvec(%arga: tensor<32x64xf64, #CSR>,
              %argb: tensor<64xf64>,
-	     %argx: tensor<64xf64>) -> tensor<64xf64> {
+             %argx: tensor<32xf64>) -> tensor<32xf64> {
   %0 = linalg.generic #trait_matvec
-      ins(%arga, %argb : tensor<64x64xf64, #CSR>, tensor<64xf64>)
-      outs(%argx: tensor<64xf64>) {
+      ins(%arga, %argb : tensor<32x64xf64, #CSR>, tensor<64xf64>)
+      outs(%argx: tensor<32xf64>) {
     ^bb(%A: f64, %b: f64, %x: f64):
       %0 = mulf %A, %b : f64
       %1 = addf %x, %0 : f64
       linalg.yield %1 : f64
-  } -> tensor<64xf64>
-  return %0 : tensor<64xf64>
+  } -> tensor<32xf64>
+  return %0 : tensor<32xf64>
 }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
new file mode 100644
index 0000000000000..09d1a36f4ba4b
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir
@@ -0,0 +1,139 @@
+// RUN: mlir-opt %s -sparsification | FileCheck %s --check-prefix=CHECK-HIR
+//
+// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion                 \
+// RUN: --convert-linalg-to-loops | FileCheck %s --check-prefix=CHECK-MIR
+//
+// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion                 \
+// RUN: --convert-linalg-to-loops --func-bufferize --tensor-constant-bufferize \
+// RUN: --tensor-bufferize --finalizing-bufferize |                            \
+// RUN: FileCheck %s --check-prefix=CHECK-LIR
+
+#CSC = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  dimOrdering = affine_map<(i,j) -> (j,i)>
+}>
+
+#trait_matvec = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A
+    affine_map<(i,j) -> (j)>,    // b
+    affine_map<(i,j) -> (i)>     // x (out)
+  ],
+  iterator_types = ["parallel","reduction"],
+  doc = "x(i) += A(i,j) * b(j)"
+}
+
+// CHECK-HIR-LABEL:   func @matvec(
+// CHECK-HIR-SAME:                 %[[VAL_0:.*]]: tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-HIR-SAME:                 %[[VAL_1:.*]]: tensor<64xf64>,
+// CHECK-HIR-SAME:                 %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
+// CHECK-HIR:           %[[VAL_3:.*]] = constant 64 : index
+// CHECK-HIR:           %[[VAL_4:.*]] = constant 0 : index
+// CHECK-HIR:           %[[VAL_5:.*]] = constant 1 : index
+// CHECK-HIR:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK-HIR:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK-HIR:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
+// CHECK-HIR:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
+// CHECK-HIR:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
+// CHECK-HIR:           %[[VAL_11:.*]] = memref.alloc() : memref<32xf64>
+// CHECK-HIR:           linalg.copy(%[[VAL_10]], %[[VAL_11]]) : memref<32xf64>, memref<32xf64>
+// CHECK-HIR:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK-HIR:             %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<64xf64>
+// CHECK-HIR:             %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
+// CHECK-HIR:             %[[VAL_15:.*]] = addi %[[VAL_12]], %[[VAL_5]] : index
+// CHECK-HIR:             %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK-HIR:             scf.for %[[VAL_17:.*]] = %[[VAL_14]] to %[[VAL_16]] step %[[VAL_5]] {
+// CHECK-HIR:               %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref<?xindex>
+// CHECK-HIR:               %[[VAL_19:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_18]]] : memref<32xf64>
+// CHECK-HIR:               %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xf64>
+// CHECK-HIR:               %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_13]] : f64
+// CHECK-HIR:               %[[VAL_22:.*]] = addf %[[VAL_19]], %[[VAL_21]] : f64
+// CHECK-HIR:               memref.store %[[VAL_22]], %[[VAL_11]]{{\[}}%[[VAL_18]]] : memref<32xf64>
+// CHECK-HIR:             }
+// CHECK-HIR:           }
+// CHECK-HIR:           %[[VAL_23:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64>
+// CHECK-HIR:           return %[[VAL_23]] : tensor<32xf64>
+// CHECK-HIR:         }
+
+// CHECK-MIR-LABEL:   func @matvec(
+// CHECK-MIR-SAME:                 %[[VAL_0:.*]]: !llvm.ptr<i8>,
+// CHECK-MIR-SAME:                 %[[VAL_1:.*]]: tensor<64xf64>,
+// CHECK-MIR-SAME:                 %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> {
+// CHECK-MIR:           %[[VAL_3:.*]] = constant 64 : index
+// CHECK-MIR:           %[[VAL_4:.*]] = constant 32 : index
+// CHECK-MIR:           %[[VAL_5:.*]] = constant 0 : index
+// CHECK-MIR:           %[[VAL_6:.*]] = constant 1 : index
+// CHECK-MIR:           %[[VAL_7:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK-MIR:           %[[VAL_8:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK-MIR:           %[[VAL_9:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
+// CHECK-MIR:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
+// CHECK-MIR:           %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
+// CHECK-MIR:           %[[VAL_12:.*]] = memref.alloc() : memref<32xf64>
+// CHECK-MIR:           scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
+// CHECK-MIR:             %[[VAL_14:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_13]]] : memref<32xf64>
+// CHECK-MIR:             memref.store %[[VAL_14]], %[[VAL_12]]{{\[}}%[[VAL_13]]] : memref<32xf64>
+// CHECK-MIR:           }
+// CHECK-MIR:           scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK-MIR:             %[[VAL_16:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_15]]] : memref<64xf64>
+// CHECK-MIR:             %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK-MIR:             %[[VAL_18:.*]] = addi %[[VAL_15]], %[[VAL_6]] : index
+// CHECK-MIR:             %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK-MIR:             scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_6]] {
+// CHECK-MIR:               %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK-MIR:               %[[VAL_22:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_21]]] : memref<32xf64>
+// CHECK-MIR:               %[[VAL_23:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref<?xf64>
+// CHECK-MIR:               %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_16]] : f64
+// CHECK-MIR:               %[[VAL_25:.*]] = addf %[[VAL_22]], %[[VAL_24]] : f64
+// CHECK-MIR:               memref.store %[[VAL_25]], %[[VAL_12]]{{\[}}%[[VAL_21]]] : memref<32xf64>
+// CHECK-MIR:             }
+// CHECK-MIR:           }
+// CHECK-MIR:           %[[VAL_26:.*]] = memref.tensor_load %[[VAL_12]] : memref<32xf64>
+// CHECK-MIR:           return %[[VAL_26]] : tensor<32xf64>
+// CHECK-MIR:         }
+
+// CHECK-LIR-LABEL:   func @matvec(
+// CHECK-LIR-SAME:                 %[[VAL_0:.*]]: !llvm.ptr<i8>,
+// CHECK-LIR-SAME:                 %[[VAL_1:.*]]: memref<64xf64>,
+// CHECK-LIR-SAME:                 %[[VAL_2:.*]]: memref<32xf64>) -> memref<32xf64> {
+// CHECK-LIR:           %[[VAL_3:.*]] = constant 64 : index
+// CHECK-LIR:           %[[VAL_4:.*]] = constant 32 : index
+// CHECK-LIR:           %[[VAL_5:.*]] = constant 0 : index
+// CHECK-LIR:           %[[VAL_6:.*]] = constant 1 : index
+// CHECK-LIR:           %[[VAL_7:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK-LIR:           %[[VAL_8:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
+// CHECK-LIR:           %[[VAL_9:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
+// CHECK-LIR:           %[[VAL_10:.*]] = memref.alloc() : memref<32xf64>
+// CHECK-LIR:           scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
+// CHECK-LIR:             %[[VAL_12:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_11]]] : memref<32xf64>
+// CHECK-LIR:             memref.store %[[VAL_12]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
+// CHECK-LIR:           }
+// CHECK-LIR:           scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
+// CHECK-LIR:             %[[VAL_14:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_13]]] : memref<64xf64>
+// CHECK-LIR:             %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref<?xindex>
+// CHECK-LIR:             %[[VAL_16:.*]] = addi %[[VAL_13]], %[[VAL_6]] : index
+// CHECK-LIR:             %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK-LIR:             scf.for %[[VAL_18:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_6]] {
+// CHECK-LIR:               %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK-LIR:               %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK-LIR:               %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xf64>
+// CHECK-LIR:               %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_14]] : f64
+// CHECK-LIR:               %[[VAL_23:.*]] = addf %[[VAL_20]], %[[VAL_22]] : f64
+// CHECK-LIR:               memref.store %[[VAL_23]], %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64>
+// CHECK-LIR:             }
+// CHECK-LIR:           }
+// CHECK-LIR:           return %[[VAL_10]] : memref<32xf64>
+// CHECK-LIR:         }
+
+func @matvec(%arga: tensor<32x64xf64, #CSC>,
+             %argb: tensor<64xf64>,
+             %argx: tensor<32xf64>) -> tensor<32xf64> {
+  %0 = linalg.generic #trait_matvec
+      ins(%arga, %argb : tensor<32x64xf64, #CSC>, tensor<64xf64>)
+      outs(%argx: tensor<32xf64>) {
+    ^bb(%A: f64, %b: f64, %x: f64):
+      %0 = mulf %A, %b : f64
+      %1 = addf %x, %0 : f64
+      linalg.yield %1 : f64
+  } -> tensor<32xf64>
+  return %0 : tensor<32xf64>
+}

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
index 810cef0c24787..537cb6e10b63b 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir
@@ -21,22 +21,22 @@
 }
 
 // CHECK-HIR-LABEL:   func @matvec(
-// CHECK-HIR-SAME:                 %[[VAL_0:.*]]: tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>>,
+// CHECK-HIR-SAME:                 %[[VAL_0:.*]]: tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
 // CHECK-HIR-SAME:                 %[[VAL_1:.*]]: tensor<64xf64>,
-// CHECK-HIR-SAME:                 %[[VAL_2:.*]]: tensor<64xf64> {linalg.inplaceable = true}) -> tensor<64xf64> {
-// CHECK-HIR:           %[[VAL_3:.*]] = constant 64 : index
+// CHECK-HIR-SAME:                 %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK-HIR:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK-HIR:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK-HIR:           %[[VAL_5:.*]] = constant 1 : index
-// CHECK-HIR:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK-HIR:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
-// CHECK-HIR:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
+// CHECK-HIR:           %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK-HIR:           %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xindex>
+// CHECK-HIR:           %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf64>
 // CHECK-HIR:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
-// CHECK-HIR:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64>
+// CHECK-HIR:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
 // CHECK-HIR:           scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK-HIR:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
 // CHECK-HIR:             %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_5]] : index
 // CHECK-HIR:             %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex>
-// CHECK-HIR:             %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64>
+// CHECK-HIR:             %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
 // CHECK-HIR:             %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f64) {
 // CHECK-HIR:               %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref<?xindex>
 // CHECK-HIR:               %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xf64>
@@ -45,29 +45,29 @@
 // CHECK-HIR:               %[[VAL_23:.*]] = addf %[[VAL_18]], %[[VAL_22]] : f64
 // CHECK-HIR:               scf.yield %[[VAL_23]] : f64
 // CHECK-HIR:             }
-// CHECK-HIR:             memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64>
+// CHECK-HIR:             memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
 // CHECK-HIR:           }
-// CHECK-HIR:           %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<64xf64>
-// CHECK-HIR:           return %[[VAL_25]] : tensor<64xf64>
+// CHECK-HIR:           %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<32xf64>
+// CHECK-HIR:           return %[[VAL_25]] : tensor<32xf64>
 // CHECK-HIR:         }
 
 // CHECK-MIR-LABEL:   func @matvec(
 // CHECK-MIR-SAME:                 %[[VAL_0:.*]]: !llvm.ptr<i8>,
 // CHECK-MIR-SAME:                 %[[VAL_1:.*]]: tensor<64xf64>,
-// CHECK-MIR-SAME:                 %[[VAL_2:.*]]: tensor<64xf64> {linalg.inplaceable = true}) -> tensor<64xf64> {
-// CHECK-MIR:           %[[VAL_3:.*]] = constant 64 : index
+// CHECK-MIR-SAME:                 %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
+// CHECK-MIR:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK-MIR:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK-MIR:           %[[VAL_5:.*]] = constant 1 : index
 // CHECK-MIR:           %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-MIR:           %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
 // CHECK-MIR:           %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr<i8>) -> memref<?xf64>
 // CHECK-MIR:           %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64>
-// CHECK-MIR:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64>
+// CHECK-MIR:           %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64>
 // CHECK-MIR:           scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
 // CHECK-MIR:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
 // CHECK-MIR:             %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_5]] : index
 // CHECK-MIR:             %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xindex>
-// CHECK-MIR:             %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64>
+// CHECK-MIR:             %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
 // CHECK-MIR:             %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f64) {
 // CHECK-MIR:               %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref<?xindex>
 // CHECK-MIR:               %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xf64>
@@ -76,17 +76,17 @@
 // CHECK-MIR:               %[[VAL_23:.*]] = addf %[[VAL_18]], %[[VAL_22]] : f64
 // CHECK-MIR:               scf.yield %[[VAL_23]] : f64
 // CHECK-MIR:             }
-// CHECK-MIR:             memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64>
+// CHECK-MIR:             memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64>
 // CHECK-MIR:           }
-// CHECK-MIR:           %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<64xf64>
-// CHECK-MIR:           return %[[VAL_25]] : tensor<64xf64>
+// CHECK-MIR:           %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<32xf64>
+// CHECK-MIR:           return %[[VAL_25]] : tensor<32xf64>
 // CHECK-MIR:         }
 
 // CHECK-LIR-LABEL:   func @matvec(
 // CHECK-LIR-SAME:                 %[[VAL_0:.*]]: !llvm.ptr<i8>,
 // CHECK-LIR-SAME:                 %[[VAL_1:.*]]: memref<64xf64>,
-// CHECK-LIR-SAME:                 %[[VAL_2:.*]]: memref<64xf64> {linalg.inplaceable = true}) -> memref<64xf64> {
-// CHECK-LIR:           %[[VAL_3:.*]] = constant 64 : index
+// CHECK-LIR-SAME:                 %[[VAL_2:.*]]: memref<32xf64> {linalg.inplaceable = true}) -> memref<32xf64> {
+// CHECK-LIR:           %[[VAL_3:.*]] = constant 32 : index
 // CHECK-LIR:           %[[VAL_4:.*]] = constant 0 : index
 // CHECK-LIR:           %[[VAL_5:.*]] = constant 1 : index
 // CHECK-LIR:           %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr<i8>, index) -> memref<?xindex>
@@ -96,7 +96,7 @@
 // CHECK-LIR:             %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref<?xindex>
 // CHECK-LIR:             %[[VAL_11:.*]] = addi %[[VAL_9]], %[[VAL_5]] : index
 // CHECK-LIR:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
-// CHECK-LIR:             %[[VAL_13:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<64xf64>
+// CHECK-LIR:             %[[VAL_13:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<32xf64>
 // CHECK-LIR:             %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_10]] to %[[VAL_12]] step %[[VAL_5]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f64) {
 // CHECK-LIR:               %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
 // CHECK-LIR:               %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xf64>
@@ -105,21 +105,21 @@
 // CHECK-LIR:               %[[VAL_21:.*]] = addf %[[VAL_16]], %[[VAL_20]] : f64
 // CHECK-LIR:               scf.yield %[[VAL_21]] : f64
 // CHECK-LIR:             }
-// CHECK-LIR:             memref.store %[[VAL_22:.*]], %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<64xf64>
+// CHECK-LIR:             memref.store %[[VAL_22:.*]], %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<32xf64>
 // CHECK-LIR:           }
-// CHECK-LIR:           return %[[VAL_2]] : memref<64xf64>
+// CHECK-LIR:           return %[[VAL_2]] : memref<32xf64>
 // CHECK-LIR:         }
 
-func @matvec(%arga: tensor<64x64xf64, #CSR>,
+func @matvec(%arga: tensor<32x64xf64, #CSR>,
              %argb: tensor<64xf64>,
-	     %argx: tensor<64xf64> {linalg.inplaceable = true}) -> tensor<64xf64> {
+	     %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> {
   %0 = linalg.generic #trait_matvec
-      ins(%arga, %argb : tensor<64x64xf64, #CSR>, tensor<64xf64>)
-      outs(%argx: tensor<64xf64>) {
+      ins(%arga, %argb : tensor<32x64xf64, #CSR>, tensor<64xf64>)
+      outs(%argx: tensor<32xf64>) {
     ^bb(%A: f64, %b: f64, %x: f64):
       %0 = mulf %A, %b : f64
       %1 = addf %x, %0 : f64
       linalg.yield %1 : f64
-  } -> tensor<64xf64>
-  return %0 : tensor<64xf64>
+  } -> tensor<32xf64>
+  return %0 : tensor<32xf64>
 }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir
new file mode 100644
index 0000000000000..2b52170aa8b32
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir
@@ -0,0 +1,105 @@
+// RUN: mlir-opt %s \
+// RUN:   --sparsification --sparse-tensor-conversion \
+// RUN:   --convert-linalg-to-loops --convert-vector-to-scf --convert-scf-to-std \
+// RUN:   --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
+// RUN:   --std-bufferize --finalizing-bufferize  \
+// RUN:   --convert-vector-to-llvm --convert-std-to-llvm | \
+// RUN: TENSOR0="%mlir_integration_test_dir/data/test.tns" \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+!Filename = type !llvm.ptr<i8>
+
+#SparseTensor = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed", "compressed", "compressed",
+                   "compressed", "compressed", "compressed", "compressed" ],
+  // Note that any dimOrdering permutation should give the same results
+  // since, even though it impacts the sparse storage scheme layout,
+  // it should not change the semantics.
+  dimOrdering = affine_map<(i,j,k,l,m,n,o,p) -> (p,o,j,k,i,l,m,n)>
+}>
+
+#trait_flatten = {
+  indexing_maps = [
+    affine_map<(i,j,k,l,m,n,o,p) -> (i,j,k,l,m,n,o,p)>, // A
+    affine_map<(i,j,k,l,m,n,o,p) -> (i,j)>              // X (out)
+  ],
+  iterator_types = [ "parallel",  "parallel",  "reduction", "reduction",
+                     "reduction", "reduction", "reduction", "reduction" ],
+  doc = "X(i,j) += A(i,j,k,l,m,n,o,p)"
+}
+
+//
+// Integration test that lowers a kernel annotated as sparse to
+// actual sparse code, initializes a matching sparse storage scheme
+// from file, and runs the resulting code with the JIT compiler.
+//
+module {
+  //
+  // A kernel that flattens a rank 8 tensor into a dense matrix.
+  //
+  func @kernel_flatten(%arga: tensor<7x3x3x3x3x3x5x3xf64, #SparseTensor>,
+                       %argx: tensor<7x3xf64>) -> tensor<7x3xf64> {
+    %0 = linalg.generic #trait_flatten
+      ins(%arga: tensor<7x3x3x3x3x3x5x3xf64, #SparseTensor>)
+      outs(%argx: tensor<7x3xf64>) {
+      ^bb(%a: f64, %x: f64):
+        %0 = addf %x, %a : f64
+        linalg.yield %0 : f64
+    } -> tensor<7x3xf64>
+    return %0 : tensor<7x3xf64>
+  }
+
+  func private @getTensorFilename(index) -> (!Filename)
+
+  //
+  // Main driver that reads tensor from file and calls the sparse kernel.
+  //
+  func @entry() {
+    %d0 = constant 0.0 : f64
+    %c0 = constant 0 : index
+    %c1 = constant 1 : index
+    %c3 = constant 3 : index
+    %c7 = constant 7 : index
+
+    // Setup matrix memory that is initialized to zero.
+    %xdata = memref.alloc() : memref<7x3xf64>
+    scf.for %i = %c0 to %c7 step %c1 {
+      scf.for %j = %c0 to %c3 step %c1 {
+        memref.store %d0, %xdata[%i, %j] : memref<7x3xf64>
+      }
+    }
+    %x = memref.tensor_load %xdata : memref<7x3xf64>
+
+    // Read the sparse tensor from file, construct sparse storage.
+    %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
+    %a = sparse_tensor.new %fileName : !llvm.ptr<i8> to tensor<7x3x3x3x3x3x5x3xf64, #SparseTensor>
+
+    // Call the kernel.
+    %0 = call @kernel_flatten(%a, %x)
+      : (tensor<7x3x3x3x3x3x5x3xf64, #SparseTensor>, tensor<7x3xf64>) -> tensor<7x3xf64>
+
+    // Print the result for verification.
+    //
+    // CHECK: ( 6.25, 0, 0 )
+    // CHECK: ( 4.224, 6.21, 0 )
+    // CHECK: ( 0, 0, 15.455 )
+    // CHECK: ( 0, 0, 0 )
+    // CHECK: ( 0, 0, 0 )
+    // CHECK: ( 0, 0, 0 )
+    // CHECK: ( 7, 0, 0 )
+    //
+    %r = memref.buffer_cast %0 : memref<7x3xf64>
+    scf.for %i = %c0 to %c7 step %c1 {
+      %v = vector.transfer_read %r[%i, %c0], %d0: memref<7x3xf64>, vector<3xf64>
+      vector.print %v : vector<3xf64>
+    }
+
+    // Release the resources.
+    memref.dealloc %xdata : memref<7x3xf64>
+
+    return
+  }
+}


        


More information about the Mlir-commits mailing list