[Mlir-commits] [mlir] 4f2ec7f - [mlir][sparse] finalize sparse output in the presence of reductions

Aart Bik llvmlistbot at llvm.org
Tue Dec 7 10:54:36 PST 2021


Author: Aart Bik
Date: 2021-12-07T10:54:29-08:00
New Revision: 4f2ec7f983b40c5796beff96b0bc846a9dacac25

URL: https://github.com/llvm/llvm-project/commit/4f2ec7f983b40c5796beff96b0bc846a9dacac25
DIFF: https://github.com/llvm/llvm-project/commit/4f2ec7f983b40c5796beff96b0bc846a9dacac25.diff

LOG: [mlir][sparse] finalize sparse output in the presence of reductions

This revision implements sparse outputs (from scratch) in all cases where
the loops can be reordered with all but one parallel loops outer. If the
inner parallel loop appears inside one or more reductions loops, then an
access pattern expansion is required (aka. workspaces in TACO speak).

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D115091

Added: 
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 4bdfa099953e8..28d54298faff6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -219,6 +219,81 @@ def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
                        " type($tensor) `,` type($indices) `,` type($value)";
 }
 
+def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
+    Arguments<(ins AnyTensor:$tensor)>,
+    Results<(outs AnyStridedMemRefOfRank<1>:$values,
+                  StridedMemRefRankOf<[I1],[1]>:$filled,
+                  StridedMemRefRankOf<[Index],[1]>:$added,
+                  Index:$count)> {
+  string summary = "Expands an access pattern for insertion";
+  string description = [{
+    Performs an access pattern expansion for the innermost dimensions of the
+    given tensor. This operation is useful to implement kernels in which a
+    sparse tensor appears as output. This technique is known under several
+    
diff erent names and using several alternative implementations,
+    for example, phase counter [Gustavson71], expanded or switch array
+    [Pissanetzky84], in phase scan [Duff90], access pattern expansion [Bik96],
+    and workspaces [Kjolstad2018].
+
+    The values and filled array have a size the suffices for a *dense* innermost
+    dimension (e.g. a full row for matrices). The added array and count are used
+    to store new indices when a false value is encountered in the filled array.
+    All arrays should be allocated before the loop (possibly even shared between
+    loops in a future optimization) so that their *dense* intitialization can be
+    amortized over many iterations. Setting and resetting the dense arrays in
+    the loop nest itself is kept *sparse* by only iterating over set elements
+    through an indirection using the added array, so that the operations are
+    kept proportional to the number of nonzeros.
+
+    Note that this operation is "impure" in the sense that its behavior
+    is solely defined by side-effects and not SSA values. The semantics
+    may be refined over time as our sparse abstractions evolve.
+
+    Example:
+
+    ```mlir
+    %values, %filled, %added, %count = sparse_tensor.expand %0
+      : tensor<4x4xf64, #CSR> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+    ```
+  }];
+  let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($values)"
+                       " `,` type($filled) `,` type($added) `,` type($count)";
+}
+
+def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
+    Arguments<(ins AnyTensor:$tensor,
+                   StridedMemRefRankOf<[Index],[1]>:$indices,
+                   AnyStridedMemRefOfRank<1>:$values,
+                   StridedMemRefRankOf<[I1],[1]>:$filled,
+                   StridedMemRefRankOf<[Index],[1]>:$added,
+                   Index:$count)> {
+  string summary = "Compressed an access pattern for insertion";
+  string description = [{
+    Finishes a single access pattern by moving the inserted elements
+    into the sparse storage scheme. The values and filled array are reset
+    in a *sparse* fashion by only iterating over set elements through an
+    indirection using the added array, so that the operations are kept
+    proportional to the number of nonzeros. See the 'expand' operation
+    for more details.
+
+    Note that this operation is "impure" in the sense that its behavior
+    is solely defined by side-effects and not SSA values. The semantics
+    may be refined over time as our sparse abstractions evolve.
+
+    Example:
+
+    ```mlir
+    sparse_tensor.compress %0, %1, %values, %filled, %added, %2
+        : tensor<4x4xf64, #CSR>, memref<?xindex>, memref<?xf64>,
+	  memref<?xi1>, memref<?xindex>, index
+    ```
+  }];
+  let assemblyFormat = "$tensor `,` $indices `,` $values `,` $filled `,`"
+                        " $added `,` $count attr-dict `:` type($tensor) `,`"
+			" type($indices) `,` type($values) `,` type($filled) `,`"
+			" type($added) `,` type($count)";
+}
+
 def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
     Arguments<(ins AnyTensor:$tensor, UnitAttr:$hasInserts)>,
     Results<(outs AnyTensor:$result)> {

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 9bc1e824760a2..4858ffd657eb8 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -305,6 +305,18 @@ static LogicalResult verify(LexInsertOp op) {
   return success();
 }
 
+static LogicalResult verify(ExpandOp op) {
+  if (!getSparseTensorEncoding(op.tensor().getType()))
+    return op.emitError("expected a sparse tensor for expansion");
+  return success();
+}
+
+static LogicalResult verify(CompressOp op) {
+  if (!getSparseTensorEncoding(op.tensor().getType()))
+    return op.emitError("expected a sparse tensor for compression");
+  return success();
+}
+
 static LogicalResult verify(LoadOp op) {
   if (!getSparseTensorEncoding(op.tensor().getType()))
     return op.emitError("expected a sparse tensor to materialize");

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 20dda09ab4c51..c694371cc982e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -257,10 +257,17 @@ static void sizesFromPtr(ConversionPatternRewriter &rewriter,
 /// type, but returns it as type `memref<? x $tp>` (rather than as type
 /// `memref<$sz x $tp>`).
 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
-                       unsigned sz, Type tp) {
+                       Value sz, Type tp) {
   auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
-  Value a = constantIndex(rewriter, loc, sz);
-  return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a});
+  return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{sz});
+}
+
+/// Generates an uninitialized temporary buffer of the given size and
+/// type, but returns it as type `memref<? x $tp>` (rather than as type
+/// `memref<$sz x $tp>`).
+static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
+                       unsigned sz, Type tp) {
+  return genAlloca(rewriter, loc, constantIndex(rewriter, loc, sz), tp);
 }
 
 /// Generates an uninitialized temporary buffer with room for one value
@@ -911,6 +918,78 @@ class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
   }
 };
 
+class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    ShapedType srcType = op.tensor().getType().cast<ShapedType>();
+    Type eltType = srcType.getElementType();
+    Type boolType = rewriter.getIntegerType(1);
+    Type idxType = rewriter.getIndexType();
+    // All initialization should be done on entry of the loop nest.
+    rewriter.setInsertionPointAfter(op.tensor().getDefiningOp());
+    // Determine the size for access expansion.
+    auto enc = getSparseTensorEncoding(srcType);
+    Value src = adaptor.getOperands()[0];
+    Value sz = genDimSizeCall(rewriter, op, enc, src, srcType.getRank() - 1);
+    // Allocate temporary stack buffers for values, filled-switch, and indices.
+    Value values = genAlloca(rewriter, loc, sz, eltType);
+    Value filled = genAlloca(rewriter, loc, sz, boolType);
+    Value indices = genAlloca(rewriter, loc, sz, idxType);
+    Value zero = constantZero(rewriter, loc, idxType);
+    // Reset the values/filled-switch to all-zero/false. Note that this
+    // introduces an O(N) operation into the computation, but this reset
+    // operation is amortized over the innermost loops for the access
+    // pattern expansion.
+    rewriter.create<linalg::FillOp>(loc, constantZero(rewriter, loc, eltType),
+                                    values);
+    rewriter.create<linalg::FillOp>(loc, constantZero(rewriter, loc, boolType),
+                                    filled);
+    // Replace expansion op with these buffers and initial index.
+    assert(op.getNumResults() == 4);
+    rewriter.replaceOp(op, {values, filled, indices, zero});
+    return success();
+  }
+};
+
+class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(CompressOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Note that this method call resets the values/filled-switch back to
+    // all-zero/false by only iterating over the set elements, so the
+    // complexity remains proportional to the sparsity of the expanded
+    // access pattern.
+    Type srcType = op.tensor().getType();
+    Type eltType = srcType.cast<ShapedType>().getElementType();
+    StringRef name;
+    if (eltType.isF64())
+      name = "expInsertF64";
+    else if (eltType.isF32())
+      name = "expInsertF32";
+    else if (eltType.isInteger(64))
+      name = "expInsertI64";
+    else if (eltType.isInteger(32))
+      name = "expInsertI32";
+    else if (eltType.isInteger(16))
+      name = "expInsertI16";
+    else if (eltType.isInteger(8))
+      name = "expInsertI8";
+    else
+      return failure();
+    TypeRange noTp;
+    auto fn =
+        getFunc(op, name, noTp, adaptor.getOperands(), /*emitCInterface=*/true);
+    rewriter.replaceOpWithNewOp<CallOp>(op, noTp, fn, adaptor.getOperands());
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -926,6 +1005,7 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
                SparseTensorInitConverter, SparseTensorConvertConverter,
                SparseTensorReleaseConverter, SparseTensorToPointersConverter,
                SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
-               SparseTensorLoadConverter, SparseTensorLexInsertConverter>(
+               SparseTensorLoadConverter, SparseTensorLexInsertConverter,
+               SparseTensorExpandConverter, SparseTensorCompressConverter>(
       typeConverter, patterns.getContext());
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 6a5804172c74c..3255c4ed0ede0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -54,7 +54,8 @@ struct CodeGen {
         pidxs(numTensors, std::vector<Value>(numLoops)),
         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
         redKind(kNoReduc), sparseOut(op), outerParNest(nest), lexIdx(),
-        curVecLength(1), curVecMask() {}
+        expValues(), expFilled(), expAdded(), expCount(), curVecLength(1),
+        curVecMask() {}
   /// Sparsification options.
   SparsificationOptions options;
   /// Universal dense indices and upper bounds (by index). The loops array
@@ -81,10 +82,15 @@ struct CodeGen {
   Reduction redKind;
   // Sparse tensor as output. Implemented either through direct injective
   // insertion in lexicographic index order (where indices are updated
-  // in the temporary array `lexIdx`) or TODO: access pattern expansion
+  // in the temporary array `lexIdx`) or through access pattern expansion
+  // in the innermost loop nest (`expValues` through `expCount`).
   OpOperand *sparseOut;
   unsigned outerParNest;
   Value lexIdx;
+  Value expValues;
+  Value expFilled;
+  Value expAdded;
+  Value expCount;
   // Current vector length and mask.
   unsigned curVecLength;
   Value curVecMask;
@@ -334,8 +340,8 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
     }
     // Determine admissable dynamic insertion situations:
     // (1) fully injective, since there are no reductions,
-    // (2) admissable 1-d expansion in innermost dimension. TODO: accept
-    if (nest == op.getRank(lhs)) {
+    // (2) admissable 1-d expansion in innermost dimension.
+    if (nest >= op.getRank(lhs) - 1) {
       *sparseOut = lhs;
       outerParNest = nest;
       return true;
@@ -680,6 +686,16 @@ static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter,
   }
 }
 
+/// Generates index for load/store on sparse tensor.
+static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) {
+  auto map = op.getTiedIndexingMap(t);
+  auto enc = getSparseTensorEncoding(t->get().getType());
+  AffineExpr a = map.getResult(perm(enc, map.getNumResults() - 1));
+  assert(a.getKind() == AffineExprKind::DimId);
+  unsigned idx = a.cast<AffineDimExpr>().getPosition();
+  return codegen.loops[idx];
+}
+
 /// Generates subscript for load/store on a dense or sparse tensor.
 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter,
                           linalg::GenericOp op, OpOperand *t,
@@ -705,6 +721,62 @@ static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter,
   return codegen.buffers[tensor];
 }
 
+/// Generates insertion code to implement dynamic tensor load.
+static Value genInsertionLoad(CodeGen &codegen, PatternRewriter &rewriter,
+                              linalg::GenericOp op, OpOperand *t) {
+  Location loc = op.getLoc();
+  // Direct lexicographic index order, tensor loads as zero.
+  if (!codegen.expValues) {
+    Type tp = getElementTypeOrSelf(t->get().getType());
+    return rewriter.create<arith::ConstantOp>(loc, tp,
+                                              rewriter.getZeroAttr(tp));
+  }
+  // Load from expanded access pattern.
+  Value index = genIndex(codegen, op, t);
+  return rewriter.create<memref::LoadOp>(loc, codegen.expValues, index);
+}
+
+/// Generates insertion code to implement dynamic tensor store.
+static void genInsertionStore(CodeGen &codegen, PatternRewriter &rewriter,
+                              linalg::GenericOp op, OpOperand *t, Value rhs) {
+  Location loc = op.getLoc();
+  // Direct insertion in lexicographic index order.
+  if (!codegen.expValues) {
+    rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs);
+    return;
+  }
+  // Generates insertion code along expanded access pattern.
+  //   if (!expFilled[i]) then
+  //     expFilled[i] = true
+  //     expAdded[inserts++] = i
+  //   endif
+  //   values[i] = rhs
+  Value index = genIndex(codegen, op, t);
+  Value fval = rewriter.create<arith::ConstantIntOp>(loc, 0, 1); // false
+  Value tval = rewriter.create<arith::ConstantIntOp>(loc, 1, 1); // true
+  // If statement.
+  Value filled = rewriter.create<memref::LoadOp>(loc, codegen.expFilled, index);
+  Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+                                              filled, fval);
+  scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, rewriter.getIndexType(),
+                                              cond, /*else=*/true);
+  // True branch.
+  rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
+  rewriter.create<memref::StoreOp>(loc, tval, codegen.expFilled, index);
+  rewriter.create<memref::StoreOp>(loc, index, codegen.expAdded,
+                                   codegen.expCount);
+  Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  Value add = rewriter.create<arith::AddIOp>(loc, codegen.expCount, one);
+  rewriter.create<scf::YieldOp>(loc, add);
+  // False branch.
+  rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
+  rewriter.create<scf::YieldOp>(loc, codegen.expCount);
+  rewriter.setInsertionPointAfter(ifOp);
+  // Value assignment.
+  codegen.expCount = ifOp.getResult(0);
+  rewriter.create<memref::StoreOp>(loc, rhs, codegen.expValues, index);
+}
+
 /// Generates a load on a dense or sparse tensor.
 static Value genTensorLoad(Merger &merger, CodeGen &codegen,
                            PatternRewriter &rewriter, linalg::GenericOp op,
@@ -716,13 +788,10 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
       return genVectorInvariantValue(codegen, rewriter, val);
     return val;
   }
-  // Insertion (a sparse tensor output "loads" as zero).
+  // Load during insertion.
   OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
-  if (t == codegen.sparseOut) {
-    Type tp = getElementTypeOrSelf(t->get().getType());
-    return rewriter.create<arith::ConstantOp>(op.getLoc(), tp,
-                                              rewriter.getZeroAttr(tp));
-  }
+  if (t == codegen.sparseOut)
+    return genInsertionLoad(codegen, rewriter, op, t);
   // Actual load.
   SmallVector<Value, 4> args;
   Value ptr = genSubscript(codegen, rewriter, op, t, args);
@@ -744,10 +813,10 @@ static void genTensorStore(Merger &merger, CodeGen &codegen,
     updateReduc(merger, codegen, rhs);
     return;
   }
-  // Insertion.
+  // Store during insertion.
   OpOperand *t = op.getOutputOperand(0);
   if (t == codegen.sparseOut) {
-    rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs);
+    genInsertionStore(codegen, rewriter, op, t, rhs);
     return;
   }
   // Actual store.
@@ -916,6 +985,42 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
   }
 }
 
+/// Generates an expanded access pattern in innermost dimension.
+static void genExpansion(Merger &merger, CodeGen &codegen,
+                         PatternRewriter &rewriter, linalg::GenericOp op,
+                         unsigned at, bool atStart) {
+  OpOperand *lhs = codegen.sparseOut;
+  if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 ||
+      at != codegen.outerParNest)
+    return; // not needed at this level
+  // Generate start or end of an expanded access pattern.
+  Value tensor = lhs->get();
+  Location loc = op.getLoc();
+  if (atStart) {
+    auto dynShape = {ShapedType::kDynamicSize};
+    Type etp = tensor.getType().cast<ShapedType>().getElementType();
+    Type t1 = MemRefType::get(dynShape, etp);
+    Type t2 = MemRefType::get(dynShape, genIntType(rewriter, 1));
+    Type t3 = MemRefType::get(dynShape, genIntType(rewriter, 0));
+    Type t4 = rewriter.getIndexType();
+    auto res =
+        rewriter.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
+    assert(res.getNumResults() == 4);
+    assert(!codegen.expValues);
+    codegen.expValues = res.getResult(0);
+    codegen.expFilled = res.getResult(1);
+    codegen.expAdded = res.getResult(2);
+    codegen.expCount = res.getResult(3);
+  } else {
+    assert(codegen.expValues);
+    rewriter.create<CompressOp>(loc, tensor, codegen.lexIdx, codegen.expValues,
+                                codegen.expFilled, codegen.expAdded,
+                                codegen.expCount);
+    codegen.expValues = codegen.expFilled = codegen.expAdded =
+        codegen.expCount = Value();
+  }
+}
+
 /// Generates initialization code for the subsequent loop sequence at
 /// current index level. Returns true if the loop sequence needs to
 /// maintain the universal index.
@@ -1069,9 +1174,13 @@ static Operation *genFor(Merger &merger, CodeGen &codegen,
     }
     operands.push_back(codegen.redVal);
   }
+  if (codegen.expValues)
+    operands.push_back(codegen.expCount);
   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
   if (codegen.redVal)
     updateReduc(merger, codegen, forOp.getRegionIterArgs().front());
+  if (codegen.expValues)
+    codegen.expCount = forOp.getRegionIterArgs().back();
   // Assign induction variable to sparse or dense index.
   Value iv = forOp.getInductionVar();
   if (isSparse)
@@ -1106,6 +1215,10 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
     types.push_back(codegen.redVal.getType());
     operands.push_back(codegen.redVal);
   }
+  if (codegen.expValues) {
+    types.push_back(indexType);
+    operands.push_back(codegen.expCount);
+  }
   if (needsUniv) {
     types.push_back(indexType);
     operands.push_back(codegen.loops[idx]);
@@ -1135,6 +1248,8 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
   }
   if (codegen.redVal)
     updateReduc(merger, codegen, after->getArgument(o++));
+  if (codegen.expValues)
+    codegen.expCount = after->getArgument(o++);
   if (needsUniv)
     codegen.loops[idx] = after->getArgument(o++);
   assert(o == operands.size());
@@ -1215,8 +1330,9 @@ static void genLocals(Merger &merger, CodeGen &codegen,
     }
   }
 
-  // Move the insertion indices in lexicographic index order.
-  if (codegen.sparseOut) {
+  // Move the insertion indices in lexicographic index order. During access
+  // pattern expansion, we can skip setting the innermost dimension.
+  if (codegen.sparseOut && !codegen.expValues) {
     Value pos = rewriter.create<arith::ConstantIndexOp>(loc, at);
     rewriter.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx,
                                      pos);
@@ -1231,11 +1347,21 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen,
                               scf::WhileOp whileOp) {
   Location loc = op.getLoc();
   // Finalize each else branch of all if statements.
-  if (codegen.redVal) {
+  if (codegen.redVal || codegen.expValues) {
     while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
                rewriter.getInsertionBlock()->getParentOp())) {
-      rewriter.create<scf::YieldOp>(loc, codegen.redVal);
-      updateReduc(merger, codegen, ifOp.getResult(0));
+      unsigned y = 0;
+      SmallVector<Value, 4> yields;
+      if (codegen.redVal) {
+        yields.push_back(codegen.redVal);
+        updateReduc(merger, codegen, ifOp.getResult(y++));
+      }
+      if (codegen.expValues) {
+        yields.push_back(codegen.expCount);
+        codegen.expCount = ifOp->getResult(y++);
+      }
+      assert(y == yields.size());
+      rewriter.create<scf::YieldOp>(loc, yields);
       rewriter.setInsertionPointAfter(ifOp);
     }
   }
@@ -1266,6 +1392,10 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen,
     operands.push_back(codegen.redVal);
     updateReduc(merger, codegen, whileOp->getResult(o++));
   }
+  if (codegen.expValues) {
+    operands.push_back(codegen.expCount);
+    codegen.expCount = whileOp->getResult(o++);
+  }
   if (needsUniv) {
     operands.push_back(
         rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one));
@@ -1287,6 +1417,10 @@ static void genForInduction(Merger &merger, CodeGen &codegen,
     operands.push_back(codegen.redVal);
     updateReduc(merger, codegen, loop->getResult(o++));
   }
+  if (codegen.expValues) {
+    operands.push_back(codegen.expCount);
+    codegen.expCount = loop->getResult(o++);
+  }
   assert(o == operands.size());
   if (o > 0)
     rewriter.create<scf::YieldOp>(loc, operands);
@@ -1318,6 +1452,8 @@ static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
   }
   if (codegen.redVal)
     types.push_back(codegen.redVal.getType());
+  if (codegen.expValues)
+    types.push_back(rewriter.getIndexType());
   scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true);
   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
   return ifOp;
@@ -1325,11 +1461,19 @@ static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
 
 /// Generates end of true branch of if-statement within a while-loop.
 static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
-                  linalg::GenericOp op, scf::IfOp ifOp, Value ifInput) {
+                  linalg::GenericOp op, scf::IfOp ifOp, Operation *loop,
+                  Value redInput, Value cntInput) {
+  SmallVector<Value, 4> operands;
   if (codegen.redVal) {
-    rewriter.create<scf::YieldOp>(op.getLoc(), codegen.redVal);
-    updateReduc(merger, codegen, ifInput);
+    operands.push_back(codegen.redVal);
+    updateReduc(merger, codegen, redInput);
+  }
+  if (codegen.expValues) {
+    operands.push_back(codegen.expCount);
+    codegen.expCount = cntInput;
   }
+  if (!operands.empty())
+    rewriter.create<scf::YieldOp>(op.getLoc(), operands);
   rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
 }
 
@@ -1348,6 +1492,8 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen,
   assert(!codegen.loops[idx]);
   // Emit invariants at this loop sequence level.
   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true);
+  // Emit access pattern expansion for sparse tensor output.
+  genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/true);
   // Emit further intitialization at this loop sequence level.
   unsigned l0 = merger.set(lts)[0];
   bool needsUniv =
@@ -1399,7 +1545,7 @@ static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
 /// Ends a loop sequence at given level.
 static void endLoopSeq(Merger &merger, CodeGen &codegen,
                        PatternRewriter &rewriter, linalg::GenericOp op,
-                       unsigned exp, unsigned idx, unsigned ldx) {
+                       unsigned exp, unsigned at, unsigned idx, unsigned ldx) {
   assert(codegen.curVecLength == 1);
   codegen.loops[idx] = Value();
   // Bring a pending reduction back from SIMD form when sequence ends.
@@ -1409,6 +1555,8 @@ static void endLoopSeq(Merger &merger, CodeGen &codegen,
                   genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp));
   // Unmark bookkeeping of invariants and loop index.
   genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false);
+  // Finalize access pattern expansion for sparse tensor output.
+  genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/false);
 }
 
 /// Recursively generates code while computing iteration lattices in order
@@ -1443,7 +1591,8 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
 
     // Visit all lattices points with Li >= Lj to generate the
     // loop-body, possibly with if statements for coiteration.
-    Value ifInput = codegen.redVal;
+    Value redInput = codegen.redVal;
+    Value cntInput = codegen.expCount;
     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
     for (unsigned j = 0; j < lsize; j++) {
       unsigned lj = merger.set(lts)[j];
@@ -1454,7 +1603,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
           scf::IfOp ifOp =
               genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
-          endIf(merger, codegen, rewriter, op, ifOp, ifInput);
+          endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput);
         } else {
           genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
         }
@@ -1467,7 +1616,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
   }
 
   // End a loop sequence.
-  endLoopSeq(merger, codegen, rewriter, op, exp, idx, ldx);
+  endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx);
 }
 
 /// Converts the result computed by the sparse kernel into the required form.

diff  --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index 045bca25c3a5f..193e305e7b4d4 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -163,10 +163,10 @@ struct SparseTensorCOO {
 /// function overloading to implement "partial" method specialization.
 class SparseTensorStorageBase {
 public:
-  // Dimension size query.
+  /// Dimension size query.
   virtual uint64_t getDimSize(uint64_t) = 0;
 
-  // Overhead storage.
+  /// Overhead storage.
   virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
   virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
   virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
@@ -176,7 +176,7 @@ class SparseTensorStorageBase {
   virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
   virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
 
-  // Primary storage.
+  /// Primary storage.
   virtual void getValues(std::vector<double> **) { fatal("valf64"); }
   virtual void getValues(std::vector<float> **) { fatal("valf32"); }
   virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
@@ -184,13 +184,35 @@ class SparseTensorStorageBase {
   virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
   virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
 
-  // Element-wise insertion in lexicographic index order.
+  /// Element-wise insertion in lexicographic index order.
   virtual void lexInsert(uint64_t *, double) { fatal("insf64"); }
   virtual void lexInsert(uint64_t *, float) { fatal("insf32"); }
   virtual void lexInsert(uint64_t *, int64_t) { fatal("insi64"); }
   virtual void lexInsert(uint64_t *, int32_t) { fatal("insi32"); }
   virtual void lexInsert(uint64_t *, int16_t) { fatal("ins16"); }
   virtual void lexInsert(uint64_t *, int8_t) { fatal("insi8"); }
+
+  /// Expanded insertion.
+  virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) {
+    fatal("expf64");
+  }
+  virtual void expInsert(uint64_t *, float *, bool *, uint64_t *, uint64_t) {
+    fatal("expf32");
+  }
+  virtual void expInsert(uint64_t *, int64_t *, bool *, uint64_t *, uint64_t) {
+    fatal("expi64");
+  }
+  virtual void expInsert(uint64_t *, int32_t *, bool *, uint64_t *, uint64_t) {
+    fatal("expi32");
+  }
+  virtual void expInsert(uint64_t *, int16_t *, bool *, uint64_t *, uint64_t) {
+    fatal("expi16");
+  }
+  virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) {
+    fatal("expi8");
+  }
+
+  /// Finishes insertion.
   virtual void endInsert() = 0;
 
   virtual ~SparseTensorStorageBase() {}
@@ -289,6 +311,35 @@ class SparseTensorStorage : public SparseTensorStorageBase {
     insPath(cursor, 
diff , top, val);
   }
 
+  /// Partially specialize expanded insertions based on template types.
+  /// Note that this method resets the values/filled-switch array back
+  /// to all-zero/false while only iterating over the nonzero elements.
+  void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
+                 uint64_t count) override {
+    if (count == 0)
+      return;
+    // Sort.
+    std::sort(added, added + count);
+    // Restore insertion path for first insert.
+    uint64_t rank = getRank();
+    uint64_t index = added[0];
+    cursor[rank - 1] = index;
+    lexInsert(cursor, values[index]);
+    assert(filled[index]);
+    values[index] = 0;
+    filled[index] = false;
+    // Subsequent insertions are quick.
+    for (uint64_t i = 1; i < count; i++) {
+      assert(index < added[i] && "non-lexicographic insertion");
+      index = added[i];
+      cursor[rank - 1] = index;
+      insPath(cursor, rank - 1, added[i - 1] + 1, values[index]);
+      assert(filled[index]);
+      values[index] = 0.0;
+      filled[index] = false;
+    }
+  }
+
   /// Finalizes lexicographic insertions.
   void endInsert() override {
     if (values.empty())
@@ -683,8 +734,7 @@ typedef uint64_t index_t;
 
 #define IMPL_SPARSEVALUES(NAME, TYPE, LIB)                                     \
   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
-    assert(ref);                                                               \
-    assert(tensor);                                                            \
+    assert(ref &&tensor);                                                      \
     std::vector<TYPE> *v;                                                      \
     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
     ref->basePtr = ref->data = v->data();                                      \
@@ -696,8 +746,7 @@ typedef uint64_t index_t;
 #define IMPL_GETOVERHEAD(NAME, TYPE, LIB)                                      \
   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
                            index_t d) {                                        \
-    assert(ref);                                                               \
-    assert(tensor);                                                            \
+    assert(ref &&tensor);                                                      \
     std::vector<TYPE> *v;                                                      \
     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
     ref->basePtr = ref->data = v->data();                                      \
@@ -710,9 +759,7 @@ typedef uint64_t index_t;
   void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
                             StridedMemRefType<index_t, 1> *iref,               \
                             StridedMemRefType<index_t, 1> *pref) {             \
-    assert(tensor);                                                            \
-    assert(iref);                                                              \
-    assert(pref);                                                              \
+    assert(tensor &&iref &&pref);                                              \
     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
     assert(iref->sizes[0] == pref->sizes[0]);                                  \
     const index_t *indx = iref->data + iref->offset;                           \
@@ -726,10 +773,11 @@ typedef uint64_t index_t;
   }
 
 #define IMPL_GETNEXT(NAME, V)                                                  \
-  bool _mlir_ciface_##NAME(void *tensor, StridedMemRefType<uint64_t, 1> *iref, \
+  bool _mlir_ciface_##NAME(void *tensor, StridedMemRefType<index_t, 1> *iref,  \
                            StridedMemRefType<V, 0> *vref) {                    \
+    assert(tensor &&iref &&vref);                                              \
     assert(iref->strides[0] == 1);                                             \
-    uint64_t *indx = iref->data + iref->offset;                                \
+    index_t *indx = iref->data + iref->offset;                                 \
     V *value = vref->data + vref->offset;                                      \
     const uint64_t isize = iref->sizes[0];                                     \
     auto iter = static_cast<SparseTensorCOO<V> *>(tensor);                     \
@@ -747,12 +795,32 @@ typedef uint64_t index_t;
 #define IMPL_LEXINSERT(NAME, V)                                                \
   void _mlir_ciface_##NAME(void *tensor, StridedMemRefType<index_t, 1> *cref,  \
                            V val) {                                            \
+    assert(tensor &&cref);                                                     \
     assert(cref->strides[0] == 1);                                             \
-    uint64_t *cursor = cref->data + cref->offset;                              \
+    index_t *cursor = cref->data + cref->offset;                               \
     assert(cursor);                                                            \
     static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
   }
 
+#define IMPL_EXPINSERT(NAME, V)                                                \
+  void _mlir_ciface_##NAME(                                                    \
+      void *tensor, StridedMemRefType<index_t, 1> *cref,                       \
+      StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref,         \
+      StridedMemRefType<index_t, 1> *aref, index_t count) {                    \
+    assert(tensor &&cref &&vref &&fref &&aref);                                \
+    assert(cref->strides[0] == 1);                                             \
+    assert(vref->strides[0] == 1);                                             \
+    assert(fref->strides[0] == 1);                                             \
+    assert(aref->strides[0] == 1);                                             \
+    assert(vref->sizes[0] == fref->sizes[0]);                                  \
+    index_t *cursor = cref->data + cref->offset;                               \
+    V *values = vref->data + vref->offset;                                     \
+    bool *filled = fref->data + fref->offset;                                  \
+    index_t *added = aref->data + aref->offset;                                \
+    static_cast<SparseTensorStorageBase *>(tensor)->expInsert(                 \
+        cursor, values, filled, added, count);                                 \
+  }
+
 /// Constructs a new sparse tensor. This is the "swiss army knife"
 /// method for materializing sparse tensors into the computation.
 ///
@@ -912,12 +980,21 @@ IMPL_LEXINSERT(lexInsertI32, int32_t)
 IMPL_LEXINSERT(lexInsertI16, int16_t)
 IMPL_LEXINSERT(lexInsertI8, int8_t)
 
+/// Helper to insert using expansion, one per value type.
+IMPL_EXPINSERT(expInsertF64, double)
+IMPL_EXPINSERT(expInsertF32, float)
+IMPL_EXPINSERT(expInsertI64, int64_t)
+IMPL_EXPINSERT(expInsertI32, int32_t)
+IMPL_EXPINSERT(expInsertI16, int16_t)
+IMPL_EXPINSERT(expInsertI8, int8_t)
+
 #undef CASE
 #undef IMPL_SPARSEVALUES
 #undef IMPL_GETOVERHEAD
 #undef IMPL_ADDELT
 #undef IMPL_GETNEXT
-#undef IMPL_INSERTLEX
+#undef IMPL_LEXINSERT
+#undef IMPL_EXPINSERT
 
 //===----------------------------------------------------------------------===//
 //

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 45f806b6f83af..89ee0d5b7c816 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -441,3 +441,30 @@ func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
   return
 }
 
+// CHECK-LABEL: func @sparse_expansion()
+//    %[[S:.*]] = call @sparseDimSize
+//    %[[V:.*]] = memref.alloca(%[[S]]) : memref<?xf64>
+//    %[[F:.*]] = memref.alloca(%[[S]]) : memref<?xi1>
+//    %[[A:.*]] = memref.alloca(%[[S]]) : memref<?xindex>
+//    linalg.fill(%{{.*}}, %[[V]]) : f64, memref<?xf64>
+//    linalg.fill(%{{.*}}, %[[F]]) : i1, memref<?xi1>
+//       CHECK: return
+func @sparse_expansion() {
+  %c = arith.constant 8 : index
+  %0 = sparse_tensor.init [%c, %c] : tensor<8x8xf64, #SparseMatrix>
+  %values, %filled, %added, %count = sparse_tensor.expand %0
+    : tensor<8x8xf64, #SparseMatrix> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+  return
+}
+
+// CHECK-LABEL: func @sparse_compression(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
+//       CHECK: call @expInsertF64(%[[A]],
+//       CHECK: return
+func @sparse_compression(%arg0: tensor<8x8xf64, #SparseMatrix>,
+                         %arg1: memref<?xindex>, %arg2: memref<?xf64>, %arg3: memref<?xi1>,
+                         %arg4: memref<?xindex>, %arg5: index) {
+  sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
+    : tensor<8x8xf64, #SparseMatrix>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+  return
+}

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 55c7cc490d741..06d662127174c 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -160,6 +160,25 @@ func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref<?xindex>,
 
 // -----
 
+func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) {
+  // expected-error at +1 {{expected a sparse tensor for expansion}}
+  %values, %filled, %added, %count = sparse_tensor.expand %arg0
+    : tensor<128xf64> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+  return
+}
+
+// -----
+
+func @sparse_unannotated_compression(%arg0: tensor<128xf64>, %arg1: memref<?xindex>,
+                                     %arg2: memref<?xf64>, %arg3: memref<?xi1>,
+				     %arg4: memref<?xindex>, %arg5: index) {
+  // expected-error at +1 {{expected a sparse tensor for compression}}
+  sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
+    : tensor<128xf64>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+}
+
+// -----
+
 func @sparse_convert_unranked(%arg0: tensor<*xf32>) -> tensor<10xf32> {
   // expected-error at +1 {{unexpected type in convert}}
   %0 = sparse_tensor.convert %arg0 : tensor<*xf32> to tensor<10xf32>

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index ad6b90b7918d4..853befc1cdef4 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -149,3 +149,33 @@ func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: memref<?xindex
   sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf64, #SparseVector>, memref<?xindex>, f64
   return
 }
+
+// -----
+
+#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+
+// CHECK-LABEL: func @sparse_expansion(
+//  CHECK-SAME: %[[A:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>)
+//       CHECK: sparse_tensor.expand %[[A]]
+//       CHECK: return
+func @sparse_expansion(%arg0: tensor<8x8xf64, #SparseMatrix>) {
+  %values, %filled, %added, %count = sparse_tensor.expand %arg0
+    : tensor<8x8xf64, #SparseMatrix> to memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+  return
+}
+
+// -----
+
+#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+
+// CHECK-LABEL: func @sparse_compression(
+//  CHECK-SAME: %[[A:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>,
+//       CHECK: sparse_tensor.compress %[[A]]
+//       CHECK: return
+func @sparse_compression(%arg0: tensor<8x8xf64, #SparseMatrix>,
+                         %arg1: memref<?xindex>, %arg2: memref<?xf64>, %arg3: memref<?xi1>,
+                         %arg4: memref<?xindex>, %arg5: index) {
+  sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5
+    : tensor<8x8xf64, #SparseMatrix>, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index
+  return
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir
new file mode 100644
index 0000000000000..8bf99f50da5b7
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir
@@ -0,0 +1,274 @@
+// RUN: mlir-opt %s \
+// RUN:   --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
+// RUN:   --sparsification --sparse-tensor-conversion \
+// RUN:   --linalg-bufferize --convert-linalg-to-loops \
+// RUN:   --convert-vector-to-scf --convert-scf-to-std \
+// RUN:   --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
+// RUN:   --std-bufferize --finalizing-bufferize --lower-affine \
+// RUN:   --convert-vector-to-llvm --convert-memref-to-llvm \
+// RUN:   --convert-std-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+//
+
+#CSR = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  dimOrdering = affine_map<(i,j) -> (i,j)>
+}>
+
+#DCSR = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ],
+  dimOrdering = affine_map<(i,j) -> (i,j)>
+}>
+
+module {
+  //
+  // Computes C = A x B with all matrices dense.
+  //
+  func @matmul1(%A: tensor<4x8xf64>,
+                %B: tensor<8x4xf64>) -> tensor<4x4xf64> {
+    %C = arith.constant dense<0.0> : tensor<4x4xf64>
+    %D = linalg.matmul
+      ins(%A, %B: tensor<4x8xf64>, tensor<8x4xf64>)
+         outs(%C: tensor<4x4xf64>) -> tensor<4x4xf64>
+    return %D: tensor<4x4xf64>
+  }
+
+  //
+  // Computes C = A x B with all matrices sparse (SpMSpM) in CSR.
+  //
+  func @matmul2(%A: tensor<4x8xf64, #CSR>,
+                %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {
+    %c4 = arith.constant 4 : index
+    %C = sparse_tensor.init [%c4, %c4] : tensor<4x4xf64, #CSR>
+    %D = linalg.matmul
+      ins(%A, %B: tensor<4x8xf64, #CSR>, tensor<8x4xf64, #CSR>)
+         outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
+    return %D: tensor<4x4xf64, #CSR>
+  }
+
+  //
+  // Computes C = A x B with all matrices sparse (SpMSpM) in DCSR.
+  //
+  func @matmul3(%A: tensor<4x8xf64, #DCSR>,
+                %B: tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> {
+    %c4 = arith.constant 4 : index
+    %C = sparse_tensor.init [%c4, %c4] : tensor<4x4xf64, #DCSR>
+    %D = linalg.matmul
+      ins(%A, %B: tensor<4x8xf64, #DCSR>, tensor<8x4xf64, #DCSR>)
+         outs(%C: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
+    return %D: tensor<4x4xf64, #DCSR>
+  }
+
+  //
+  // Main driver.
+  //
+  func @entry() {
+    %c0 = arith.constant 0 : index
+    %d1 = arith.constant -1.0 : f64
+
+    // Initialize various matrices, dense for stress testing,
+    // and sparse to verify correct nonzero structure.
+    %da = arith.constant dense<[
+        [ 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1 ],
+        [ 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2 ],
+        [ 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3 ],
+        [ 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4 ]
+    ]> : tensor<4x8xf64>
+    %db = arith.constant dense<[
+        [ 10.1, 11.1, 12.1, 13.1 ],
+        [ 10.2, 11.2, 12.2, 13.2 ],
+        [ 10.3, 11.3, 12.3, 13.3 ],
+        [ 10.4, 11.4, 12.4, 13.4 ],
+        [ 10.5, 11.5, 12.5, 13.5 ],
+        [ 10.6, 11.6, 12.6, 13.6 ],
+        [ 10.7, 11.7, 12.7, 13.7 ],
+        [ 10.8, 11.8, 12.8, 13.8 ]
+    ]> : tensor<8x4xf64>
+    %sa = arith.constant dense<[
+        [ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ],
+        [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+        [ 0.0, 2.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ],
+        [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ]
+    ]> : tensor<4x8xf64>
+    %sb = arith.constant dense<[
+        [ 0.0, 0.0, 0.0, 1.0 ],
+        [ 0.0, 0.0, 2.0, 0.0 ],
+        [ 0.0, 3.0, 0.0, 0.0 ],
+        [ 4.0, 0.0, 0.0, 0.0 ],
+        [ 0.0, 0.0, 0.0, 0.0 ],
+        [ 0.0, 5.0, 0.0, 0.0 ],
+        [ 0.0, 0.0, 6.0, 0.0 ],
+        [ 0.0, 0.0, 7.0, 8.0 ]
+    ]> : tensor<8x4xf64>
+
+    // Convert all these matrices to sparse format.
+    %a1 = sparse_tensor.convert %da : tensor<4x8xf64> to tensor<4x8xf64, #CSR>
+    %a2 = sparse_tensor.convert %da : tensor<4x8xf64> to tensor<4x8xf64, #DCSR>
+    %a3 = sparse_tensor.convert %sa : tensor<4x8xf64> to tensor<4x8xf64, #CSR>
+    %a4 = sparse_tensor.convert %sa : tensor<4x8xf64> to tensor<4x8xf64, #DCSR>
+    %b1 = sparse_tensor.convert %db : tensor<8x4xf64> to tensor<8x4xf64, #CSR>
+    %b2 = sparse_tensor.convert %db : tensor<8x4xf64> to tensor<8x4xf64, #DCSR>
+    %b3 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #CSR>
+    %b4 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #DCSR>
+
+    // Call kernels with dense.
+    %0 = call @matmul1(%da, %db)
+       : (tensor<4x8xf64>, tensor<8x4xf64>) -> tensor<4x4xf64>
+    %1 = call @matmul2(%a1, %b1)
+       : (tensor<4x8xf64, #CSR>,
+          tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
+    %2 = call @matmul3(%a2, %b2)
+       : (tensor<4x8xf64, #DCSR>,
+          tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
+
+    // Call kernels with one sparse.
+    %3 = call @matmul1(%sa, %db)
+       : (tensor<4x8xf64>, tensor<8x4xf64>) -> tensor<4x4xf64>
+    %4 = call @matmul2(%a3, %b1)
+       : (tensor<4x8xf64, #CSR>,
+          tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
+    %5 = call @matmul3(%a4, %b2)
+       : (tensor<4x8xf64, #DCSR>,
+          tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
+
+    // Call kernels with sparse.
+    %6 = call @matmul1(%sa, %sb)
+       : (tensor<4x8xf64>, tensor<8x4xf64>) -> tensor<4x4xf64>
+    %7 = call @matmul2(%a3, %b3)
+       : (tensor<4x8xf64, #CSR>,
+          tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
+    %8 = call @matmul3(%a4, %b4)
+       : (tensor<4x8xf64, #DCSR>,
+          tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR>
+
+    //
+    // CHECK:    ( ( 388.76, 425.56, 462.36, 499.16 ),
+    // CHECK-SAME: ( 397.12, 434.72, 472.32, 509.92 ),
+    // CHECK-SAME: ( 405.48, 443.88, 482.28, 520.68 ),
+    // CHECK-SAME: ( 413.84, 453.04, 492.24, 531.44 ) )
+    //
+    %m0 = bufferization.to_memref %0 : memref<4x4xf64>
+    %v0 = vector.transfer_read %m0[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    vector.print %v0 : vector<4x4xf64>
+
+    //
+    // CHECK:    ( ( 388.76, 425.56, 462.36, 499.16 ),
+    // CHECK-SAME: ( 397.12, 434.72, 472.32, 509.92 ),
+    // CHECK-SAME: ( 405.48, 443.88, 482.28, 520.68 ),
+    // CHECK-SAME: ( 413.84, 453.04, 492.24, 531.44 ) )
+    //
+    %c1 = sparse_tensor.convert %1 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
+    %m1 = bufferization.to_memref %c1 : memref<4x4xf64>
+    %v1 = vector.transfer_read %m1[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    vector.print %v1 : vector<4x4xf64>
+
+    //
+    // CHECK:    ( ( 388.76, 425.56, 462.36, 499.16 ),
+    // CHECK-SAME: ( 397.12, 434.72, 472.32, 509.92 ),
+    // CHECK-SAME: ( 405.48, 443.88, 482.28, 520.68 ),
+    // CHECK-SAME: ( 413.84, 453.04, 492.24, 531.44 ) )
+    //
+    %c2 = sparse_tensor.convert %2 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
+    %m2 = bufferization.to_memref %c2 : memref<4x4xf64>
+    %v2 = vector.transfer_read %m2[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    vector.print %v2 : vector<4x4xf64>
+
+    //
+    // CHECK:    ( ( 86.08, 94.28, 102.48, 110.68 ),
+    // CHECK-SAME: ( 0, 0, 0, 0 ),
+    // CHECK-SAME: ( 23.46, 25.76, 28.06, 30.36 ),
+    // CHECK-SAME: ( 10.8, 11.8, 12.8, 13.8 ) )
+    //
+    %m3 = bufferization.to_memref %3 : memref<4x4xf64>
+    %v3 = vector.transfer_read %m3[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    vector.print %v3 : vector<4x4xf64>
+
+    //
+    // CHECK:    ( ( 86.08, 94.28, 102.48, 110.68 ),
+    // CHECK-SAME: ( 0, 0, 0, 0 ),
+    // CHECK-SAME: ( 23.46, 25.76, 28.06, 30.36 ),
+    // CHECK-SAME: ( 10.8, 11.8, 12.8, 13.8 ) )
+    //
+    %c4 = sparse_tensor.convert %4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
+    %m4 = bufferization.to_memref %c4 : memref<4x4xf64>
+    %v4 = vector.transfer_read %m4[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    vector.print %v4 : vector<4x4xf64>
+
+    //
+    // CHECK:    ( ( 86.08, 94.28, 102.48, 110.68 ),
+    // CHECK-SAME: ( 0, 0, 0, 0 ),
+    // CHECK-SAME: ( 23.46, 25.76, 28.06, 30.36 ),
+    // CHECK-SAME: ( 10.8, 11.8, 12.8, 13.8 ) )
+    //
+    %c5 = sparse_tensor.convert %5 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
+    %m5 = bufferization.to_memref %c5 : memref<4x4xf64>
+    %v5 = vector.transfer_read %m5[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    vector.print %v5 : vector<4x4xf64>
+
+    //
+    // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
+    //
+    %m6 = bufferization.to_memref %6 : memref<4x4xf64>
+    %v6 = vector.transfer_read %m6[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    vector.print %v6 : vector<4x4xf64>
+
+    //
+    // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
+    //
+    %c7 = sparse_tensor.convert %7 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
+    %m7 = bufferization.to_memref %c7 : memref<4x4xf64>
+    %v7 = vector.transfer_read %m7[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    vector.print %v7 : vector<4x4xf64>
+
+    //
+    // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
+    //
+    %c8 = sparse_tensor.convert %8 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
+    %m8 = bufferization.to_memref %c8 : memref<4x4xf64>
+    %v8 = vector.transfer_read %m8[%c0, %c0], %d1 : memref<4x4xf64>, vector<4x4xf64>
+    vector.print %v8 : vector<4x4xf64>
+
+    //
+    // Sanity check on nonzeros.
+    //
+    // CHECK: ( 30.5, 4.2, 4.6, 7, 8, -1, -1, -1 )
+    // CHECK: ( 30.5, 4.2, 4.6, 7, 8, -1, -1, -1 )
+    //
+    %val7 = sparse_tensor.values %7 : tensor<4x4xf64, #CSR> to memref<?xf64>
+    %val8 = sparse_tensor.values %8 : tensor<4x4xf64, #DCSR> to memref<?xf64>
+    %nz7 = vector.transfer_read %val7[%c0], %d1 : memref<?xf64>, vector<8xf64>
+    %nz8 = vector.transfer_read %val8[%c0], %d1 : memref<?xf64>, vector<8xf64>
+    vector.print %nz7 : vector<8xf64>
+    vector.print %nz8 : vector<8xf64>
+
+    // Release the resources.
+    sparse_tensor.release %a1 : tensor<4x8xf64, #CSR>
+    sparse_tensor.release %a2 : tensor<4x8xf64, #DCSR>
+    sparse_tensor.release %a3 : tensor<4x8xf64, #CSR>
+    sparse_tensor.release %a4 : tensor<4x8xf64, #DCSR>
+    sparse_tensor.release %b1 : tensor<8x4xf64, #CSR>
+    sparse_tensor.release %b2 : tensor<8x4xf64, #DCSR>
+    sparse_tensor.release %b3 : tensor<8x4xf64, #CSR>
+    sparse_tensor.release %b4 : tensor<8x4xf64, #DCSR>
+    sparse_tensor.release %1 : tensor<4x4xf64, #CSR>
+    sparse_tensor.release %2 : tensor<4x4xf64, #DCSR>
+    sparse_tensor.release %4 : tensor<4x4xf64, #CSR>
+    sparse_tensor.release %5 : tensor<4x4xf64, #DCSR>
+    sparse_tensor.release %7 : tensor<4x4xf64, #CSR>
+    sparse_tensor.release %8 : tensor<4x4xf64, #DCSR>
+    memref.dealloc %m0 : memref<4x4xf64>
+    memref.dealloc %m1 : memref<4x4xf64>
+    memref.dealloc %m2 : memref<4x4xf64>
+    memref.dealloc %m3 : memref<4x4xf64>
+    memref.dealloc %m4 : memref<4x4xf64>
+    memref.dealloc %m5 : memref<4x4xf64>
+    memref.dealloc %m6 : memref<4x4xf64>
+    memref.dealloc %m7 : memref<4x4xf64>
+    memref.dealloc %m8 : memref<4x4xf64>
+
+    return
+  }
+}


        


More information about the Mlir-commits mailing list