[Mlir-commits] [mlir] 14da25b - [mlir][sparse] scalarize reductions in for-loops during sparse codegen

Aart Bik llvmlistbot at llvm.org
Thu Dec 17 16:12:31 PST 2020


Author: Aart Bik
Date: 2020-12-17T16:12:21-08:00
New Revision: 14da25b4b2eedf8a16aae34edfefd7bcaa5ceae5

URL: https://github.com/llvm/llvm-project/commit/14da25b4b2eedf8a16aae34edfefd7bcaa5ceae5
DIFF: https://github.com/llvm/llvm-project/commit/14da25b4b2eedf8a16aae34edfefd7bcaa5ceae5.diff

LOG: [mlir][sparse] scalarize reductions in for-loops during sparse codegen

Reductions in innermost loops become harder for the backend to disambiguate
after bufferization into memrefs, resulting in less efficient load-update-store
cycles. By scalarizing innermost reductions, the backend is more likely to assign
a register to perform the reduction (also prepares vectorization). Even though
we could scalarize reductions for more outer loops and while-loops as well,
currently scalarization is only done for chains of innermost for-loops, where
it matters most, to avoid complicating codegen unnecessary (viz. adding lots
of yield instructions).

This CL also refactors condition simplification into the merger class,
where it belongs, so that conditions are simplified only once per loop
nest and not repeatedly as was currently done. This CL also fixes a few
minor bugs, some layout issues, and comments.

Reviewed By: penpornk

Differential Revision: https://reviews.llvm.org/D93143

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
    mlir/test/Dialect/Linalg/sparse_1d.mlir
    mlir/test/Dialect/Linalg/sparse_2d.mlir
    mlir/test/Dialect/Linalg/sparse_3d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index cfdb371e3234..fed2eedd41a4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -52,6 +52,7 @@ using namespace mlir;
 namespace {
 
 enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
+enum class Dim { kSparse, kDense, kUndef };
 
 /// Tensor expression. Represents a MLIR expression in tensor index notation.
 /// For tensors, e0 denotes the tensor index. For invariants, the IR value is
@@ -81,8 +82,13 @@ struct LatPoint {
     bits.set(b);
   }
   LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
-  /// Conjunction of tensor loop indices as bitvector.
+  /// Conjunction of tensor loop indices as bitvector. This represents
+  /// all indices involved in the tensor expression
   llvm::BitVector bits;
+  /// Simplified conjunction of tensor loop indices as bitvector. This
+  /// represents a simplified condition under which this tensor expression
+  /// must execute. Pre-computed during codegen to avoid repeated eval.
+  llvm::BitVector simple;
   /// Index of the tensor expresssion.
   unsigned exp;
 };
@@ -93,8 +99,14 @@ struct LatPoint {
 /// independently from the basic algorithm if bottlenecks are identified.
 class Merger {
 public:
+  /// Constructs a merger for the given number of tensors and loops. The
+  /// user supplies the number of tensors involved in the kernel, with the
+  /// last tensor in this set denoting the output tensor. The merger adds an
+  /// additional synthetic tensor at the end of this set to represent all
+  /// invariant expressions in the kernel.
   Merger(unsigned t, unsigned l)
-      : numTensors(t), numLoops(l), isSparse(t, std::vector<bool>(l, false)) {}
+      : outTensor(t - 1), numTensors(t + 1), numLoops(l),
+        dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
 
   /// Adds a tensor expression. Returns its index.
   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) {
@@ -132,8 +144,8 @@ class Merger {
     return p;
   }
 
-  /// Conjunctive merge of L1 and L2 is conjunction of cartesian product.
-  /// Returns the index of the new set.
+  /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
+  /// cartesian product. Returns the index of the new set.
   unsigned takeConj(Kind kind, unsigned s0, unsigned s1) {
     unsigned s = addSet();
     for (unsigned p0 : latSets[s0])
@@ -142,7 +154,7 @@ class Merger {
     return s;
   }
 
-  /// Disjunctive merge of L0 and L1 is (L0 /\_op L1, L0, L1).
+  /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
   /// Returns the index of the new set.
   unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) {
     unsigned s = takeConj(kind, s0, s1);
@@ -156,26 +168,27 @@ class Merger {
   /// Optimizes the iteration lattice points in the given set. This
   /// method should be called right before code generation to avoid
   /// generating redundant loops and conditions.
-  unsigned optimize(unsigned s0) {
+  unsigned optimizeSet(unsigned s0) {
     unsigned s = addSet();
     assert(latSets[s0].size() != 0);
     unsigned p0 = latSets[s0][0];
     for (unsigned p1 : latSets[s0]) {
       bool add = true;
+      llvm::BitVector simple = simplifyCond(s0, p1);
       if (p0 != p1) {
         // Is this a straightforward copy?
         unsigned e = latPoints[p1].exp;
-        if (exp(e).kind == Kind::kTensor && exp(e).e0 == numTensors - 1)
+        if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
           continue;
-        // Is any dense index exhausted?
+        // Only dense exhausted?
         llvm::BitVector tmp = latPoints[p1].bits;
         tmp ^= latPoints[p0].bits;
-        if (hasAnyOf(tmp, false))
+        if (!hasAnyDimOf(tmp, Dim::kSparse))
           continue;
-        // Is this a direct duplication of an earlier conjunction?
+        // Duplication of an earlier conjunction?
         for (unsigned p2 : latSets[s]) {
-          tmp = latPoints[p1].bits;
-          tmp ^= latPoints[p2].bits;
+          tmp = simple;
+          tmp ^= latPoints[p2].simple;
           if (tmp.count() == 0) {
             add = false;
             break;
@@ -183,13 +196,49 @@ class Merger {
         }
         assert(!add || latGT(p0, p1));
       }
-      if (add)
+      if (add) {
         latSets[s].push_back(p1);
+        latPoints[latSets[s].back()].simple = simple;
+      }
     }
     return s;
   }
 
-  // Returns true if Li > Lj.
+  /// Simplifies the conditions in a conjunction of a given lattice point
+  /// within the given set using just two basic rules:
+  /// (1) multiple dense conditions are reduced to single dense, and
+  /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
+  llvm::BitVector simplifyCond(unsigned s, unsigned p0) {
+    // First determine if this lattice point is a *singleton*, i.e.,
+    // the last point in a lattice, no other is less than this one.
+    bool isSingleton = true;
+    for (unsigned p1 : latSets[s]) {
+      if (p0 != p1 && latGT(p0, p1)) {
+        unsigned e = latPoints[p1].exp;
+        if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
+          continue;
+        llvm::BitVector tmp = latPoints[p1].bits;
+        tmp ^= latPoints[p0].bits;
+        if (hasAnyDimOf(tmp, Dim::kSparse)) {
+          isSingleton = false;
+          break;
+        }
+      }
+    }
+    // Now apply the two basic rules.
+    llvm::BitVector simple = latPoints[p0].bits;
+    bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
+    for (unsigned b = 0, be = simple.size(); b < be; b++) {
+      if (simple[b] && !isDim(b, Dim::kSparse)) {
+        if (reset)
+          simple.reset(b);
+        reset = true;
+      }
+    }
+    return simple;
+  }
+
+  /// Returns true if Li > Lj.
   bool latGT(unsigned i, unsigned j) const {
     const llvm::BitVector &bitsi = latPoints[i].bits;
     const llvm::BitVector &bitsj = latPoints[j].bits;
@@ -203,40 +252,41 @@ class Merger {
     return false;
   }
 
-  // Bit translation.
+  /// Bit translation.
   unsigned tensor(unsigned b) const { return b % numTensors; }
   unsigned index(unsigned b) const { return b / numTensors; }
 
-  // Returns true if bit corresponds to sparse access.
-  bool isSparseBit(unsigned b) const {
-    return isSparseAccess(tensor(b), index(b));
-  }
+  /// Returns true if bit corresponds to queried dim.
+  bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
 
-  // Returns true if tensor access at given index is sparse.
-  bool isSparseAccess(unsigned t, unsigned i) const {
+  /// Returns true if tensor access at given index has queried dim.
+  bool isDim(unsigned t, unsigned i, Dim d) const {
     assert(t < numTensors && i < numLoops);
-    return isSparse[t][i];
+    return dims[t][i] == d;
   }
 
-  // Returns true if any set bit corresponds to sparse/dense access.
-  bool hasAnyOf(const llvm::BitVector &bits, bool sparse) const {
+  /// Returns true if any set bit corresponds to queried dim.
+  bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
     for (unsigned b = 0, be = bits.size(); b < be; b++)
-      if (bits[b] && isSparseBit(b) == sparse)
+      if (bits[b] && isDim(b, d))
         return true;
     return false;
   }
 
-  // Getters.
-  std::vector<std::vector<bool>> &sparse() { return isSparse; }
+  // Setter
+  void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
+
+  /// Getters.
   TensorExp &exp(unsigned e) { return tensorExps[e]; }
   LatPoint &lat(unsigned l) { return latPoints[l]; }
   SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
 
 private:
+  const unsigned outTensor;
   const unsigned numTensors;
   const unsigned numLoops;
 
-  std::vector<std::vector<bool>> isSparse;
+  std::vector<std::vector<Dim>> dims;
   llvm::SmallVector<TensorExp, 32> tensorExps;
   llvm::SmallVector<LatPoint, 16> latPoints;
   llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
@@ -251,34 +301,39 @@ struct CodeGen {
         indices(numTensors, std::vector<Value>(numLoops)),
         highs(numTensors, std::vector<Value>(numLoops)),
         pidxs(numTensors, std::vector<Value>(numLoops)),
-        idxs(numTensors, std::vector<Value>(numLoops)) {}
-  // Sparsification options.
+        idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal() {}
+  /// Sparsification options.
   linalg::SparsificationOptions options;
-  // Universal dense indices and upper bounds (by index). The loops array
-  // is updated with the value of the universal dense index in the current
-  // loop. The sizes array is set once with the inferred dimension sizes.
+  /// Universal dense indices and upper bounds (by index). The loops array
+  /// is updated with the value of the universal dense index in the current
+  /// loop. The sizes array is set once with the inferred dimension sizes.
   std::vector<Value> loops;
   std::vector<Value> sizes;
-  // Buffers for storing dense and sparse numerical values (by tensor).
-  // This array is set once during bufferization of all tensors.
+  /// Buffers for storing dense and sparse numerical values (by tensor).
+  /// This array is set once during bufferization of all tensors.
   std::vector<Value> buffers;
-  // Sparse storage schemes (1-D): pointers and indices (by tensor and index).
-  // This array is set once during bufferization of all sparse tensors.
+  /// Sparse storage schemes (1-D): pointers and indices (by tensor and index).
+  /// This array is set once during bufferization of all sparse tensors.
   std::vector<std::vector<Value>> pointers;
   std::vector<std::vector<Value>> indices;
-  // Sparse iteration information (by tensor and index). These arrays
-  // are updated to remain current within the current loop.
+  /// Sparse iteration information (by tensor and index). These arrays
+  /// are updated to remain current within the current loop.
   std::vector<std::vector<Value>> highs;
   std::vector<std::vector<Value>> pidxs;
   std::vector<std::vector<Value>> idxs;
+  /// Current reduction, updated during code generation. When indices of a
+  /// reduction are exhausted,  all inner loops can "scalarize" the reduction.
+  // TODO: currently only done for (a chain of) innermost for-loops, where it
+  // is most effective; we could generalize to more outer and while-loops.
+  unsigned redExp;
+  Value redVal;
 };
 
 } // namespace
 
 /// Helper method to inspect sparse annotations in the linalg operation.
 /// Fills the per-dimension sparsity information for all tensors.
-static void findSparseAnnotations(linalg::GenericOp op,
-                                  std::vector<std::vector<bool>> &isSparse) {
+static void findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
   unsigned numTensors = op.getNumInputsAndOutputs();
   ArrayAttr sparseAttr = op.sparseAttr();
   for (unsigned t = 0; t < numTensors; t++) {
@@ -287,13 +342,15 @@ static void findSparseAnnotations(linalg::GenericOp op,
     // For each tensor, we accept a per-dimension Sparse or Dense annotation.
     // This is translated to the loop index that indexes that dimension.
     unsigned rank = op.getShapedType(t).getRank();
-    for (unsigned d = 0; d < rank; d++)
+    for (unsigned d = 0; d < rank; d++) {
+      unsigned idx = map.getDimPosition(d);
       if (isSparseDim(dimAttr[d])) {
-        unsigned idx = map.getDimPosition(d);
-        isSparse[t][idx] = true;
+        merger.setDim(t, idx, Dim::kSparse);
       } else {
         assert(isDenseDim(dimAttr[d]));
+        merger.setDim(t, idx, Dim::kDense);
       }
+    }
   }
 }
 
@@ -406,11 +463,11 @@ static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
   Kind kind = merger.exp(exp).kind;
   if (kind == Kind::kTensor || kind == Kind::kInvariant) {
     // Either the index is really used in the tensor expression, or it is
-    // set to the "non-existing dense index" in that dimension. Invariant
-    // expressions borrow the output tensor indices.
+    // set to the undefined index in that dimension. An invariant expression
+    // is set to a synthetic tensor with undefined indices only.
     unsigned s = merger.addSet();
     unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0
-                                       : op.getNumInputsAndOutputs() - 1;
+                                       : op.getNumInputsAndOutputs();
     merger.set(s).push_back(merger.addLat(t, idx, exp));
     return s;
   }
@@ -468,7 +525,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
     for (unsigned d = 0, rank = shape.size(); d < rank; d++) {
       unsigned i = map.getDimPosition(d);
       // Handle sparse storage schemes.
-      if (merger.isSparseAccess(t, i)) {
+      if (merger.isDim(t, i, Dim::kSparse)) {
         allDense = false;
         auto dynShape = {ShapedType::kDynamicSize};
         auto ptrTp = MemRefType::get(
@@ -514,10 +571,8 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
                            unsigned exp) {
   // Test if the load was hoisted to a higher loop nest.
   Value val = merger.exp(exp).val;
-  if (val) {
-    merger.exp(exp).val = Value(); // reset
+  if (val)
     return val;
-  }
   // Actual load.
   SmallVector<Value, 4> args;
   unsigned tensor = merger.exp(exp).e0;
@@ -526,7 +581,7 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
   for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
     unsigned idx = map.getDimPosition(i);
     args.push_back(codegen.loops[idx]); // universal dense index
-    if (sparse || merger.isSparseAccess(tensor, idx)) {
+    if (sparse || merger.isDim(tensor, idx, Dim::kSparse)) {
       sparse = true;
       args.clear();
       args.push_back(codegen.pidxs[tensor][idx]); // position index
@@ -541,6 +596,13 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
 static void genTensorStore(Merger &merger, CodeGen &codegen,
                            PatternRewriter &rewriter, linalg::GenericOp op,
                            unsigned tensor, Value rhs) {
+  // Test if this is a scalarized reduction.
+  unsigned lhs = op.getNumInputsAndOutputs() - 1;
+  if (lhs == tensor && codegen.redVal) {
+    codegen.redVal = rhs;
+    return;
+  }
+  // Actual load.
   SmallVector<Value, 4> args;
   auto map = op.getIndexingMap(tensor);
   for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
@@ -594,27 +656,35 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
 /// Hoists loop invariant tensor loads for which indices have been exhausted.
 static void genInvariants(Merger &merger, CodeGen &codegen,
                           PatternRewriter &rewriter, linalg::GenericOp op,
-                          unsigned exp) {
+                          unsigned exp, unsigned ldx, bool hoist) {
   if (merger.exp(exp).kind == Kind::kTensor) {
-    unsigned lhs = op.getNumInputsAndOutputs() - 1;
+    // Inspect tensor indices.
+    bool atLevel = ldx == -1u;
     unsigned tensor = merger.exp(exp).e0;
-    if (tensor == lhs)
-      return; // TODO: scalarize reduction as well (using scf.yield)
     auto map = op.getIndexingMap(tensor);
     for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
       unsigned idx = map.getDimPosition(i);
       if (!codegen.loops[idx])
         return; // still in play
+      else if (idx == ldx)
+        atLevel = true;
+    }
+    // All exhausted at this level (atLevel denotes exactly at this level).
+    unsigned lhs = op.getNumInputsAndOutputs() - 1;
+    if (lhs == tensor) {
+      codegen.redExp = hoist ? exp : -1u;
+    } else if (atLevel) {
+      merger.exp(exp).val =
+          hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
     }
-    // All exhausted at this level.
-    merger.exp(exp).val = genTensorLoad(merger, codegen, rewriter, op, exp);
-
   } else if (merger.exp(exp).kind != Kind::kInvariant) {
     // Traverse into the binary operations. Note that we only hoist
     // tensor loads, since subsequent MLIR/LLVM passes know how to
     // deal with all other kinds of derived loop invariants.
-    genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e0);
-    genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e1);
+    unsigned e0 = merger.exp(exp).e0;
+    unsigned e1 = merger.exp(exp).e1;
+    genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist);
+    genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist);
   }
 }
 
@@ -633,7 +703,7 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
     if (inits[b]) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
-      if (merger.isSparseBit(b)) {
+      if (merger.isDim(b, Dim::kSparse)) {
         // Initialize sparse index.
         unsigned pat = at;
         for (; pat != 0; pat--) {
@@ -672,7 +742,7 @@ static Operation *genFor(Merger &merger, CodeGen &codegen,
   // is marked "parallel" is a candidate. Whether it is actually converted to
   // a parallel operation depends on the requested strategy.
   auto iteratorTypes = op.iterator_types().getValue();
-  bool isSparse = merger.isSparseBit(fb);
+  bool isSparse = merger.isDim(fb, Dim::kSparse);
   bool isParallel = linalg::isParallelIteratorType(iteratorTypes[idx]);
   switch (codegen.options.parallelizationStrategy) {
   case linalg::SparseParallelizationStrategy::kNone:
@@ -716,8 +786,22 @@ static Operation *genFor(Merger &merger, CodeGen &codegen,
     return parOp;
   }
 
-  // Emit a sequential loop.
-  scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step);
+  // Emit a sequential loop, potentially with a scalarized reduction.
+  bool scalarRed = isInner && codegen.redExp != -1u;
+  SmallVector<Value, 4> operands;
+  if (scalarRed) {
+    Value load =
+        codegen.redVal
+            ? codegen.redVal // chained with previous for-loop
+            : genTensorLoad(merger, codegen, rewriter, op, codegen.redExp);
+    operands.push_back(load);
+  }
+  scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands);
+  if (scalarRed) {
+    codegen.redVal = merger.exp(codegen.redExp).val =
+        forOp.getRegionIterArgs().front();
+  }
+  // Assign induction variable to sparse or dense index.
   if (isSparse)
     codegen.pidxs[tensor][idx] = forOp.getInductionVar();
   else
@@ -736,7 +820,7 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
   // Construct the while-loop with a parameter for each index.
   Type indexType = rewriter.getIndexType();
   for (unsigned b = 0, be = indices.size(); b < be; b++) {
-    if (indices[b] && merger.isSparseBit(b)) {
+    if (indices[b] && merger.isDim(b, Dim::kSparse)) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       types.push_back(indexType);
@@ -758,7 +842,7 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
   Value cond;
   unsigned o = 0;
   for (unsigned b = 0, be = indices.size(); b < be; b++) {
-    if (indices[b] && merger.isSparseBit(b)) {
+    if (indices[b] && merger.isDim(b, Dim::kSparse)) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       Value op1 = before->getArgument(o);
@@ -804,7 +888,7 @@ static void genLocals(Merger &merger, CodeGen &codegen,
   // Initialize sparse indices.
   Value min;
   for (unsigned b = 0, be = locals.size(); b < be; b++) {
-    if (locals[b] && merger.isSparseBit(b)) {
+    if (locals[b] && merger.isDim(b, Dim::kSparse)) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       Value ptr = codegen.indices[tensor][idx];
@@ -831,11 +915,9 @@ static void genLocals(Merger &merger, CodeGen &codegen,
 
   // Initialize dense positions.
   for (unsigned b = 0, be = locals.size(); b < be; b++) {
-    if (locals[b] && !merger.isSparseBit(b)) {
+    if (locals[b] && merger.isDim(b, Dim::kDense)) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
-      if (!codegen.highs[tensor][idx])
-        continue; // unused dimension
       unsigned pat = at;
       for (; pat != 0; pat--)
         if (codegen.pidxs[tensor][topSort[pat - 1]])
@@ -858,8 +940,8 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen,
   unsigned o = 0;
   SmallVector<Value, 4> operands;
   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
-  for (unsigned b = 0, be = induction.size(); b < be; b++)
-    if (induction[b] && merger.isSparseBit(b)) {
+  for (unsigned b = 0, be = induction.size(); b < be; b++) {
+    if (induction[b] && merger.isDim(b, Dim::kSparse)) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       Value op1 = codegen.idxs[tensor][idx];
@@ -870,6 +952,7 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen,
       operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3));
       codegen.pidxs[tensor][idx] = results[o++];
     }
+  }
   if (needsUniv) {
     operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one));
     codegen.loops[idx] = results[o++];
@@ -879,19 +962,17 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen,
 }
 
 /// Generates a single if-statement within a while-loop.
-static void genIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
-                  linalg::GenericOp op, unsigned idx,
-                  llvm::BitVector &conditions, scf::IfOp &ifOp) {
+static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
+                       PatternRewriter &rewriter, linalg::GenericOp op,
+                       unsigned idx, llvm::BitVector &conditions) {
   Location loc = op.getLoc();
-  if (ifOp)
-    rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
   Value cond;
   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
     if (conditions[b]) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       Value clause;
-      if (merger.isSparseBit(b)) {
+      if (merger.isDim(b, Dim::kSparse)) {
         Value op1 = codegen.idxs[tensor][idx];
         Value op2 = codegen.loops[idx];
         clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
@@ -901,25 +982,9 @@ static void genIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
       cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause;
     }
   }
-  ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true);
+  scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true);
   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
-}
-
-/// Optimize the loop indices of Li with two rules rules:
-/// (1) convert multiple dense to single dense, and
-/// (2) convert singleton sparse/dense to sparse/random access.
-static void optimizeIndices(Merger merger, unsigned lsize,
-                            llvm::BitVector &indices) {
-  if (merger.hasAnyOf(indices, false)) {
-    bool reset = lsize == 1 && merger.hasAnyOf(indices, true);
-    for (unsigned b = 0, be = indices.size(); b < be; b++) {
-      if (indices[b] && !merger.isSparseBit(b)) {
-        if (reset)
-          indices.reset(b);
-        reset = true;
-      }
-    }
-  }
+  return ifOp;
 }
 
 /// Recursively generates code while computing iteration lattices in order
@@ -940,43 +1005,51 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
   // Then emit initialization code for the loop sequence at this level.
   // We maintain the universal dense index if dense indices are still
   // in play for a non-singleton loop sequence.
+  // Location loc = op.getLoc();
   unsigned idx = topSort[at];
-  unsigned lts = merger.optimize(buildLattices(merger, op, exp, idx));
+  unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx));
   unsigned lsize = merger.set(lts).size();
   assert(lsize != 0);
   unsigned l0 = merger.set(lts)[0];
-  LatPoint lat0 = merger.lat(l0);
-  genInvariants(merger, codegen, rewriter, op, exp);
-  bool needsUniv =
-      genInit(merger, codegen, rewriter, op, topSort, at, lat0.bits) &&
-      lsize > 1;
+  unsigned ldx = at == 0 ? -1u : topSort[at - 1];
+  genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true);
+  bool needsUniv = genInit(merger, codegen, rewriter, op, topSort, at,
+                           merger.lat(l0).bits) &&
+                   lsize > 1;
 
   // Emit a loop for every lattice point L0 >= Li.
-  for (unsigned li : merger.set(lts)) {
-    LatPoint lati = merger.lat(li);
+  for (unsigned i = 0; i < lsize; i++) {
+    unsigned li = merger.set(lts)[i];
 
     // Emit loop.
-    llvm::BitVector indices = lati.bits;
-    optimizeIndices(merger, lsize, indices);
+    llvm::BitVector indices = merger.lat(li).simple;
     Operation *loop =
         genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
-    genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, lati.bits);
+    genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
+              merger.lat(li).bits);
 
     // Visit all lattices points with Li >= Lj to generate the
     // loop-body, possibly with if statements for coiteration.
     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
-    scf::IfOp ifOp;
-    for (unsigned lj : merger.set(lts)) {
+    for (unsigned j = 0; j < lsize; j++) {
+      unsigned lj = merger.set(lts)[j];
+      unsigned ej = merger.lat(lj).exp;
       if (li == lj || merger.latGT(li, lj)) {
-        LatPoint latj = merger.lat(lj);
-        llvm::BitVector tmp = latj.bits;
-        tmp ^= lati.bits;
-        if (merger.hasAnyOf(tmp, false))
-          continue; // dense exhausted within if/else
+        if (li != lj) {
+          llvm::BitVector tmp = merger.lat(lj).bits;
+          tmp ^= merger.lat(li).bits;
+          if (!merger.hasAnyDimOf(tmp, Dim::kSparse))
+            continue; // only dense exhausted within if/else
+        }
         // Recurse into body of each branch.
-        if (isWhile)
-          genIf(merger, codegen, rewriter, op, idx, latj.bits, ifOp);
-        genStmt(merger, codegen, rewriter, op, topSort, latj.exp, at + 1);
+        if (isWhile) {
+          scf::IfOp ifOp =
+              genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple);
+          genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
+          rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
+        } else {
+          genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1);
+        }
       }
     }
 
@@ -985,13 +1058,26 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
       scf::WhileOp whileOp = cast<scf::WhileOp>(loop);
       rewriter.setInsertionPointToEnd(&whileOp.after().front());
       genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
-                        lati.bits, whileOp.results());
+                        merger.lat(li).bits, whileOp.results());
     } else {
       needsUniv = false;
+      if (codegen.redVal) {
+        rewriter.create<scf::YieldOp>(op.getLoc(), codegen.redVal);
+        codegen.redVal = loop->getResult(0);
+      }
     }
     rewriter.setInsertionPointAfter(loop);
   }
+
+  // Wrap-up loop sequence.
+  Value red = codegen.redVal;
+  if (red) {
+    codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain
+    unsigned lhs = op.getNumInputsAndOutputs() - 1;
+    genTensorStore(merger, codegen, rewriter, op, lhs, red);
+  }
   codegen.loops[idx] = Value();
+  genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
 }
 
 namespace {
@@ -1012,7 +1098,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     unsigned numTensors = op.getNumInputsAndOutputs();
     unsigned numLoops = op.iterator_types().getValue().size();
     Merger merger(numTensors, numLoops);
-    findSparseAnnotations(op, merger.sparse());
+    findSparseAnnotations(merger, op);
 
     // Computes a topologically sorted iteration graph to ensure
     // tensors are visited in natural index order. Fails on cycles.

diff  --git a/mlir/test/Dialect/Linalg/sparse_1d.mlir b/mlir/test/Dialect/Linalg/sparse_1d.mlir
index e20cdbd62d64..4c14b2e89279 100644
--- a/mlir/test/Dialect/Linalg/sparse_1d.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_1d.mlir
@@ -636,6 +636,198 @@ func @mul_ss(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
   return %0 : tensor<32xf32>
 }
 
+#trait_two_way_inv = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>, // a
+    affine_map<(i) -> (i)>, // b
+    affine_map<(i) -> (i)>  // x (out)
+  ],
+  sparse = [
+    [ "S" ], // a
+    [ "S" ], // b
+    [ "D" ]  // x
+  ],
+  iterator_types = ["parallel"],
+  doc = "x(i) = a(i) * c + b(i) * c"
+}
+
+// CHECK-LABEL:   func @two_way_inv(
+// CHECK-SAME:                      %[[VAL_0:.*0]]: tensor<16xf32>,
+// CHECK-SAME:                      %[[VAL_1:.*1]]: tensor<16xf32>,
+// CHECK-SAME:                      %[[VAL_2:.*2]]: f32) -> tensor<16xf32> {
+// CHECK:           %[[VAL_3:.*]] = constant 999 : index
+// CHECK:           %[[VAL_4:.*]] = constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = alloca(%[[VAL_3]]) : memref<?xf32>
+// CHECK:           %[[VAL_9:.*]] = alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = alloca(%[[VAL_3]]) : memref<?xf32>
+// CHECK:           %[[VAL_12:.*]] = alloca() : memref<16xf32>
+// CHECK:           %[[VAL_13:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_16:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_17:.*]]:3 = scf.while (%[[VAL_18:.*]] = %[[VAL_13]], %[[VAL_19:.*]] = %[[VAL_15]], %[[VAL_20:.*]] = %[[VAL_4]]) : (index, index, index) -> (index, index, index) {
+// CHECK:             %[[VAL_21:.*]] = cmpi "ult", %[[VAL_18]], %[[VAL_14]] : index
+// CHECK:             %[[VAL_22:.*]] = cmpi "ult", %[[VAL_19]], %[[VAL_16]] : index
+// CHECK:             %[[VAL_23:.*]] = and %[[VAL_21]], %[[VAL_22]] : i1
+// CHECK:             scf.condition(%[[VAL_23]]) %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] : index, index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_24:.*]]: index, %[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index):
+// CHECK:             %[[VAL_27:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref<?xindex>
+// CHECK:             %[[VAL_28:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xindex>
+// CHECK:             %[[VAL_29:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index
+// CHECK:             %[[VAL_30:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index
+// CHECK:             %[[VAL_31:.*]] = and %[[VAL_29]], %[[VAL_30]] : i1
+// CHECK:             scf.if %[[VAL_31]] {
+// CHECK:               %[[VAL_32:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
+// CHECK:               %[[VAL_33:.*]] = mulf %[[VAL_32]], %[[VAL_2]] : f32
+// CHECK:               %[[VAL_34:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref<?xf32>
+// CHECK:               %[[VAL_35:.*]] = mulf %[[VAL_34]], %[[VAL_2]] : f32
+// CHECK:               %[[VAL_36:.*]] = addf %[[VAL_33]], %[[VAL_35]] : f32
+// CHECK:               store %[[VAL_36]], %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<16xf32>
+// CHECK:             } else {
+// CHECK:               %[[VAL_37:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index
+// CHECK:               scf.if %[[VAL_37]] {
+// CHECK:                 %[[VAL_38:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
+// CHECK:                 %[[VAL_39:.*]] = mulf %[[VAL_38]], %[[VAL_2]] : f32
+// CHECK:                 store %[[VAL_39]], %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<16xf32>
+// CHECK:               } else {
+// CHECK:                 %[[VAL_40:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index
+// CHECK:                 scf.if %[[VAL_40]] {
+// CHECK:                   %[[VAL_41:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref<?xf32>
+// CHECK:                   %[[VAL_42:.*]] = mulf %[[VAL_41]], %[[VAL_2]] : f32
+// CHECK:                   store %[[VAL_42]], %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<16xf32>
+// CHECK:                 } else {
+// CHECK:                 }
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_43:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index
+// CHECK:             %[[VAL_44:.*]] = addi %[[VAL_24]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_45:.*]] = select %[[VAL_43]], %[[VAL_44]], %[[VAL_24]] : index
+// CHECK:             %[[VAL_46:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index
+// CHECK:             %[[VAL_47:.*]] = addi %[[VAL_25]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_48:.*]] = select %[[VAL_46]], %[[VAL_47]], %[[VAL_25]] : index
+// CHECK:             %[[VAL_49:.*]] = addi %[[VAL_26]], %[[VAL_5]] : index
+// CHECK:             scf.yield %[[VAL_45]], %[[VAL_48]], %[[VAL_49]] : index, index, index
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_50:.*]] = %[[VAL_51:.*]]#0 to %[[VAL_14]] step %[[VAL_5]] {
+// CHECK:             %[[VAL_52:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_50]]] : memref<?xf32>
+// CHECK:             %[[VAL_53:.*]] = mulf %[[VAL_52]], %[[VAL_2]] : f32
+// CHECK:             store %[[VAL_53]], %[[VAL_12]]{{\[}}%[[VAL_51]]#2] : memref<16xf32>
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_54:.*]] = %[[VAL_55:.*]]#1 to %[[VAL_16]] step %[[VAL_5]] {
+// CHECK:             %[[VAL_56:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_54]]] : memref<?xindex>
+// CHECK:             %[[VAL_57:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_54]]] : memref<?xf32>
+// CHECK:             %[[VAL_58:.*]] = mulf %[[VAL_57]], %[[VAL_2]] : f32
+// CHECK:             store %[[VAL_58]], %[[VAL_12]]{{\[}}%[[VAL_56]]] : memref<16xf32>
+// CHECK:           }
+// CHECK:           %[[VAL_59:.*]] = tensor_load %[[VAL_12]] : memref<16xf32>
+// CHECK:           return %[[VAL_59]] : tensor<16xf32>
+// CHECK:         }
+func @two_way_inv(%arga: tensor<16xf32>,
+                  %argb: tensor<16xf32>, %argc: f32) -> tensor<16xf32> {
+  %0 = linalg.generic #trait_two_way_inv
+    ins(%arga, %argb : tensor<16xf32>, tensor<16xf32>) {
+      ^bb(%a : f32, %b : f32):
+        %0 = mulf %a, %argc : f32
+        %1 = mulf %b, %argc : f32
+        %2 = addf %0, %1 : f32
+        linalg.yield %2: f32
+  } -> tensor<16xf32>
+  return %0 : tensor<16xf32>
+}
+
+// CHECK-LABEL:   func @two_way_inv_alt(
+// CHECK-SAME:                          %[[VAL_0:.*0]]: tensor<16xf32>,
+// CHECK-SAME:                          %[[VAL_1:.*1]]: tensor<16xf32>,
+// CHECK-SAME:                          %[[VAL_2:.*2]]: f32) -> tensor<16xf32> {
+// CHECK:           %[[VAL_3:.*]] = constant 999 : index
+// CHECK:           %[[VAL_4:.*]] = constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = alloca(%[[VAL_3]]) : memref<?xf32>
+// CHECK:           %[[VAL_9:.*]] = alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = alloca(%[[VAL_3]]) : memref<?xf32>
+// CHECK:           %[[VAL_12:.*]] = alloca() : memref<16xf32>
+// CHECK:           %[[VAL_13:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_16:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_17:.*]]:3 = scf.while (%[[VAL_18:.*]] = %[[VAL_13]], %[[VAL_19:.*]] = %[[VAL_15]], %[[VAL_20:.*]] = %[[VAL_4]]) : (index, index, index) -> (index, index, index) {
+// CHECK:             %[[VAL_21:.*]] = cmpi "ult", %[[VAL_18]], %[[VAL_14]] : index
+// CHECK:             %[[VAL_22:.*]] = cmpi "ult", %[[VAL_19]], %[[VAL_16]] : index
+// CHECK:             %[[VAL_23:.*]] = and %[[VAL_21]], %[[VAL_22]] : i1
+// CHECK:             scf.condition(%[[VAL_23]]) %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] : index, index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_24:.*]]: index, %[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index):
+// CHECK:             %[[VAL_27:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref<?xindex>
+// CHECK:             %[[VAL_28:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xindex>
+// CHECK:             %[[VAL_29:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index
+// CHECK:             %[[VAL_30:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index
+// CHECK:             %[[VAL_31:.*]] = and %[[VAL_29]], %[[VAL_30]] : i1
+// CHECK:             scf.if %[[VAL_31]] {
+// CHECK:               %[[VAL_32:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
+// CHECK:               %[[VAL_33:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref<?xf32>
+// CHECK:               %[[VAL_34:.*]] = addf %[[VAL_32]], %[[VAL_33]] : f32
+// CHECK:               %[[VAL_35:.*]] = mulf %[[VAL_34]], %[[VAL_2]] : f32
+// CHECK:               store %[[VAL_35]], %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<16xf32>
+// CHECK:             } else {
+// CHECK:               %[[VAL_36:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index
+// CHECK:               scf.if %[[VAL_36]] {
+// CHECK:                 %[[VAL_37:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
+// CHECK:                 %[[VAL_38:.*]] = mulf %[[VAL_37]], %[[VAL_2]] : f32
+// CHECK:                 store %[[VAL_38]], %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<16xf32>
+// CHECK:               } else {
+// CHECK:                 %[[VAL_39:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index
+// CHECK:                 scf.if %[[VAL_39]] {
+// CHECK:                   %[[VAL_40:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref<?xf32>
+// CHECK:                   %[[VAL_41:.*]] = mulf %[[VAL_40]], %[[VAL_2]] : f32
+// CHECK:                   store %[[VAL_41]], %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<16xf32>
+// CHECK:                 } else {
+// CHECK:                 }
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_42:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index
+// CHECK:             %[[VAL_43:.*]] = addi %[[VAL_24]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_44:.*]] = select %[[VAL_42]], %[[VAL_43]], %[[VAL_24]] : index
+// CHECK:             %[[VAL_45:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index
+// CHECK:             %[[VAL_46:.*]] = addi %[[VAL_25]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_47:.*]] = select %[[VAL_45]], %[[VAL_46]], %[[VAL_25]] : index
+// CHECK:             %[[VAL_48:.*]] = addi %[[VAL_26]], %[[VAL_5]] : index
+// CHECK:             scf.yield %[[VAL_44]], %[[VAL_47]], %[[VAL_48]] : index, index, index
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#0 to %[[VAL_14]] step %[[VAL_5]] {
+// CHECK:             %[[VAL_51:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_49]]] : memref<?xf32>
+// CHECK:             %[[VAL_52:.*]] = mulf %[[VAL_51]], %[[VAL_2]] : f32
+// CHECK:             store %[[VAL_52]], %[[VAL_12]]{{\[}}%[[VAL_50]]#2] : memref<16xf32>
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_53:.*]] = %[[VAL_54:.*]]#1 to %[[VAL_16]] step %[[VAL_5]] {
+// CHECK:             %[[VAL_55:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_53]]] : memref<?xindex>
+// CHECK:             %[[VAL_56:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_53]]] : memref<?xf32>
+// CHECK:             %[[VAL_57:.*]] = mulf %[[VAL_56]], %[[VAL_2]] : f32
+// CHECK:             store %[[VAL_57]], %[[VAL_12]]{{\[}}%[[VAL_55]]] : memref<16xf32>
+// CHECK:           }
+// CHECK:           %[[VAL_58:.*]] = tensor_load %[[VAL_12]] : memref<16xf32>
+// CHECK:           return %[[VAL_58]] : tensor<16xf32>
+// CHECK:         }
+func @two_way_inv_alt(%arga: tensor<16xf32>,
+                      %argb: tensor<16xf32>, %argc: f32) -> tensor<16xf32> {
+  // Same kernel, but now expressed as "x(i) = (a(i) + b(i)) * c".
+  %0 = linalg.generic #trait_two_way_inv
+    ins(%arga, %argb : tensor<16xf32>, tensor<16xf32>) {
+      ^bb(%a : f32, %b : f32):
+        %0 = addf %a, %b : f32
+        %1 = mulf %0, %argc : f32
+        linalg.yield %1: f32
+  } -> tensor<16xf32>
+  return %0 : tensor<16xf32>
+}
+
 #trait_sum_reduction = {
   indexing_maps = [
     affine_map<(i) -> (i)>,  // a
@@ -646,7 +838,7 @@ func @mul_ss(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
     [  ]      // x
   ],
   iterator_types = ["reduction"],
-  doc = "x = SUM_i a(i)"
+  doc = "x += SUM_i a(i)"
 }
 
 // CHECK-LABEL:   func @sum_reduction(
@@ -661,14 +853,15 @@ func @mul_ss(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> {
 // CHECK:           %[[VAL_8:.*]] = alloca() : memref<f32>
 // CHECK:           %[[VAL_9:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
-// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
-// CHECK:             %[[VAL_12:.*]] = load %[[VAL_8]][] : memref<f32>
-// CHECK:             %[[VAL_13:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xf32>
-// CHECK:             %[[VAL_14:.*]] = addf %[[VAL_12]], %[[VAL_13]] : f32
-// CHECK:             store %[[VAL_14]], %[[VAL_8]][] : memref<f32>
+// CHECK:           %[[VAL_11:.*]] = load %[[VAL_8]][] : memref<f32>
+// CHECK:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
+// CHECK:             %[[VAL_15:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref<?xf32>
+// CHECK:             %[[VAL_16:.*]] = addf %[[VAL_14]], %[[VAL_15]] : f32
+// CHECK:             scf.yield %[[VAL_16]] : f32
 // CHECK:           }
-// CHECK:           %[[VAL_15:.*]] = tensor_load %[[VAL_8]] : memref<f32>
-// CHECK:           return %[[VAL_15]] : tensor<f32>
+// CHECK:           store %[[VAL_17:.*]], %[[VAL_8]][] : memref<f32>
+// CHECK:           %[[VAL_18:.*]] = tensor_load %[[VAL_8]] : memref<f32>
+// CHECK:           return %[[VAL_18]] : tensor<f32>
 // CHECK:         }
 func @sum_reduction(%arga: tensor<?xf32>, %argx: tensor<f32>) -> tensor<f32> {
   %0 = linalg.generic #trait_sum_reduction
@@ -680,3 +873,233 @@ func @sum_reduction(%arga: tensor<?xf32>, %argx: tensor<f32>) -> tensor<f32> {
   } -> tensor<f32>
   return %0 : tensor<f32>
 }
+
+#trait_sum_reduction_ss = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>, // a
+    affine_map<(i) -> (i)>, // b
+    affine_map<(i)-> ()>    // x (scalar out)
+  ],
+  sparse = [
+    [ "S" ],  // a
+    [ "S" ],  // b
+    [     ]   // x
+  ],
+  iterator_types = ["reduction"],
+  doc = "x += SUM_i a(i) + b(i)"
+}
+
+// CHECK-LABEL:   func @sum_reduction_ss(
+// CHECK-SAME:                           %[[VAL_0:.*0]]: tensor<16xf32>,
+// CHECK-SAME:                           %[[VAL_1:.*1]]: tensor<16xf32>,
+// CHECK-SAME:                           %[[VAL_2:.*2]]: tensor<f32>) -> tensor<f32> {
+// CHECK:           %[[VAL_3:.*]] = constant 999 : index
+// CHECK:           %[[VAL_4:.*]] = constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK:           %[[VAL_7:.*]] = alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = alloca(%[[VAL_3]]) : memref<?xf32>
+// CHECK:           %[[VAL_9:.*]] = alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = alloca(%[[VAL_3]]) : memref<?xf32>
+// CHECK:           %[[VAL_12:.*]] = alloca() : memref<f32>
+// CHECK:           %[[VAL_13:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_16:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_17:.*]]:3 = scf.while (%[[VAL_18:.*]] = %[[VAL_13]], %[[VAL_19:.*]] = %[[VAL_15]], %[[VAL_20:.*]] = %[[VAL_4]]) : (index, index, index) -> (index, index, index) {
+// CHECK:             %[[VAL_21:.*]] = cmpi "ult", %[[VAL_18]], %[[VAL_14]] : index
+// CHECK:             %[[VAL_22:.*]] = cmpi "ult", %[[VAL_19]], %[[VAL_16]] : index
+// CHECK:             %[[VAL_23:.*]] = and %[[VAL_21]], %[[VAL_22]] : i1
+// CHECK:             scf.condition(%[[VAL_23]]) %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] : index, index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_24:.*]]: index, %[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index):
+// CHECK:             %[[VAL_27:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref<?xindex>
+// CHECK:             %[[VAL_28:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xindex>
+// CHECK:             %[[VAL_29:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index
+// CHECK:             %[[VAL_30:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index
+// CHECK:             %[[VAL_31:.*]] = and %[[VAL_29]], %[[VAL_30]] : i1
+// CHECK:             scf.if %[[VAL_31]] {
+// CHECK:               %[[VAL_32:.*]] = load %[[VAL_12]][] : memref<f32>
+// CHECK:               %[[VAL_33:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
+// CHECK:               %[[VAL_34:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref<?xf32>
+// CHECK:               %[[VAL_35:.*]] = addf %[[VAL_33]], %[[VAL_34]] : f32
+// CHECK:               %[[VAL_36:.*]] = addf %[[VAL_32]], %[[VAL_35]] : f32
+// CHECK:               store %[[VAL_36]], %[[VAL_12]][] : memref<f32>
+// CHECK:             } else {
+// CHECK:               %[[VAL_37:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index
+// CHECK:               scf.if %[[VAL_37]] {
+// CHECK:                 %[[VAL_38:.*]] = load %[[VAL_12]][] : memref<f32>
+// CHECK:                 %[[VAL_39:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xf32>
+// CHECK:                 %[[VAL_40:.*]] = addf %[[VAL_38]], %[[VAL_39]] : f32
+// CHECK:                 store %[[VAL_40]], %[[VAL_12]][] : memref<f32>
+// CHECK:               } else {
+// CHECK:                 %[[VAL_41:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index
+// CHECK:                 scf.if %[[VAL_41]] {
+// CHECK:                   %[[VAL_42:.*]] = load %[[VAL_12]][] : memref<f32>
+// CHECK:                   %[[VAL_43:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref<?xf32>
+// CHECK:                   %[[VAL_44:.*]] = addf %[[VAL_42]], %[[VAL_43]] : f32
+// CHECK:                   store %[[VAL_44]], %[[VAL_12]][] : memref<f32>
+// CHECK:                 } else {
+// CHECK:                 }
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_45:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index
+// CHECK:             %[[VAL_46:.*]] = addi %[[VAL_24]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_47:.*]] = select %[[VAL_45]], %[[VAL_46]], %[[VAL_24]] : index
+// CHECK:             %[[VAL_48:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index
+// CHECK:             %[[VAL_49:.*]] = addi %[[VAL_25]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_50:.*]] = select %[[VAL_48]], %[[VAL_49]], %[[VAL_25]] : index
+// CHECK:             %[[VAL_51:.*]] = addi %[[VAL_26]], %[[VAL_5]] : index
+// CHECK:             scf.yield %[[VAL_47]], %[[VAL_50]], %[[VAL_51]] : index, index, index
+// CHECK:           }
+// CHECK:           %[[VAL_52:.*]] = load %[[VAL_12]][] : memref<f32>
+// CHECK:           %[[VAL_53:.*]] = scf.for %[[VAL_54:.*]] = %[[VAL_55:.*]]#0 to %[[VAL_14]] step %[[VAL_5]] iter_args(%[[VAL_56:.*]] = %[[VAL_52]]) -> (f32) {
+// CHECK:             %[[VAL_57:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_54]]] : memref<?xf32>
+// CHECK:             %[[VAL_58:.*]] = addf %[[VAL_56]], %[[VAL_57]] : f32
+// CHECK:             scf.yield %[[VAL_58]] : f32
+// CHECK:           }
+// CHECK:           %[[VAL_59:.*]] = scf.for %[[VAL_60:.*]] = %[[VAL_61:.*]]#1 to %[[VAL_16]] step %[[VAL_5]] iter_args(%[[VAL_62:.*]] = %[[VAL_63:.*]]) -> (f32) {
+// CHECK:             %[[VAL_64:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_60]]] : memref<?xf32>
+// CHECK:             %[[VAL_65:.*]] = addf %[[VAL_62]], %[[VAL_64]] : f32
+// CHECK:             scf.yield %[[VAL_65]] : f32
+// CHECK:           }
+// CHECK:           store %[[VAL_66:.*]], %[[VAL_12]][] : memref<f32>
+// CHECK:           %[[VAL_67:.*]] = tensor_load %[[VAL_12]] : memref<f32>
+// CHECK:           return %[[VAL_67]] : tensor<f32>
+// CHECK:         }
+func @sum_reduction_ss(%arga: tensor<16xf32>,
+                       %argb: tensor<16xf32>,
+                       %argx: tensor<f32>) -> tensor<f32> {
+  // Just for testing. This case would be better expressed
+  // as two separate reductions kernels.
+  %0 = linalg.generic #trait_sum_reduction_ss
+    ins(%arga, %argb: tensor<16xf32>, tensor<16xf32>)
+    init(%argx : tensor<f32>) {
+      ^bb(%a : f32, %b : f32, %x : f32):
+        %0 = addf %a, %b  : f32
+        %1 = addf %x, %0  : f32
+        linalg.yield %1: f32
+  } -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+#trait_sum_reduction_inv_ss = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>, // a
+    affine_map<(i) -> ()>,  // b
+    affine_map<(i) -> (i)>, // c
+    affine_map<(i) -> ()>   // x (out)
+  ],
+  sparse = [
+    [ "S" ], // a
+    [     ], // b
+    [ "S" ], // c
+    [     ]  // x
+  ],
+  iterator_types = ["reduction"],
+  doc = "x += SUM_i a(i) * b + c(i)"
+}
+
+// CHECK-LABEL:   func @sum_reduction_inv(
+// CHECK-SAME:                            %[[VAL_0:.*0]]: tensor<16xf32>,
+// CHECK-SAME:                            %[[VAL_1:.*1]]: tensor<f32>,
+// CHECK-SAME:                            %[[VAL_2:.*2]]: tensor<16xf32>,
+// CHECK-SAME:                            %[[VAL_3:.*3]]: tensor<f32>) -> tensor<f32> {
+// CHECK:           %[[VAL_4:.*]] = constant 999 : index
+// CHECK:           %[[VAL_5:.*]] = constant 0 : index
+// CHECK:           %[[VAL_6:.*]] = constant 1 : index
+// CHECK:           %[[VAL_7:.*]] = alloca(%[[VAL_4]]) : memref<?xindex>
+// CHECK:           %[[VAL_8:.*]] = alloca(%[[VAL_4]]) : memref<?xindex>
+// CHECK:           %[[VAL_9:.*]] = alloca(%[[VAL_4]]) : memref<?xf32>
+// CHECK:           %[[VAL_10:.*]] = alloca() : memref<f32>
+// CHECK:           %[[VAL_11:.*]] = alloca(%[[VAL_4]]) : memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = alloca(%[[VAL_4]]) : memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = alloca(%[[VAL_4]]) : memref<?xf32>
+// CHECK:           %[[VAL_14:.*]] = alloca() : memref<f32>
+// CHECK:           %[[VAL_15:.*]] = load %[[VAL_10]][] : memref<f32>
+// CHECK:           %[[VAL_16:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_17:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:           %[[VAL_18:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:           %[[VAL_19:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:           %[[VAL_20:.*]]:3 = scf.while (%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]], %[[VAL_23:.*]] = %[[VAL_5]]) : (index, index, index) -> (index, index, index) {
+// CHECK:             %[[VAL_24:.*]] = cmpi "ult", %[[VAL_21]], %[[VAL_17]] : index
+// CHECK:             %[[VAL_25:.*]] = cmpi "ult", %[[VAL_22]], %[[VAL_19]] : index
+// CHECK:             %[[VAL_26:.*]] = and %[[VAL_24]], %[[VAL_25]] : i1
+// CHECK:             scf.condition(%[[VAL_26]]) %[[VAL_21]], %[[VAL_22]], %[[VAL_23]] : index, index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_27:.*]]: index, %[[VAL_28:.*]]: index, %[[VAL_29:.*]]: index):
+// CHECK:             %[[VAL_30:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_27]]] : memref<?xindex>
+// CHECK:             %[[VAL_31:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_28]]] : memref<?xindex>
+// CHECK:             %[[VAL_32:.*]] = cmpi "eq", %[[VAL_30]], %[[VAL_29]] : index
+// CHECK:             %[[VAL_33:.*]] = cmpi "eq", %[[VAL_31]], %[[VAL_29]] : index
+// CHECK:             %[[VAL_34:.*]] = and %[[VAL_32]], %[[VAL_33]] : i1
+// CHECK:             scf.if %[[VAL_34]] {
+// CHECK:               %[[VAL_35:.*]] = load %[[VAL_14]][] : memref<f32>
+// CHECK:               %[[VAL_36:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_27]]] : memref<?xf32>
+// CHECK:               %[[VAL_37:.*]] = mulf %[[VAL_36]], %[[VAL_15]] : f32
+// CHECK:               %[[VAL_38:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_28]]] : memref<?xf32>
+// CHECK:               %[[VAL_39:.*]] = addf %[[VAL_37]], %[[VAL_38]] : f32
+// CHECK:               %[[VAL_40:.*]] = addf %[[VAL_35]], %[[VAL_39]] : f32
+// CHECK:               store %[[VAL_40]], %[[VAL_14]][] : memref<f32>
+// CHECK:             } else {
+// CHECK:               %[[VAL_41:.*]] = cmpi "eq", %[[VAL_30]], %[[VAL_29]] : index
+// CHECK:               scf.if %[[VAL_41]] {
+// CHECK:                 %[[VAL_42:.*]] = load %[[VAL_14]][] : memref<f32>
+// CHECK:                 %[[VAL_43:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_27]]] : memref<?xf32>
+// CHECK:                 %[[VAL_44:.*]] = mulf %[[VAL_43]], %[[VAL_15]] : f32
+// CHECK:                 %[[VAL_45:.*]] = addf %[[VAL_42]], %[[VAL_44]] : f32
+// CHECK:                 store %[[VAL_45]], %[[VAL_14]][] : memref<f32>
+// CHECK:               } else {
+// CHECK:                 %[[VAL_46:.*]] = cmpi "eq", %[[VAL_31]], %[[VAL_29]] : index
+// CHECK:                 scf.if %[[VAL_46]] {
+// CHECK:                   %[[VAL_47:.*]] = load %[[VAL_14]][] : memref<f32>
+// CHECK:                   %[[VAL_48:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_28]]] : memref<?xf32>
+// CHECK:                   %[[VAL_49:.*]] = addf %[[VAL_47]], %[[VAL_48]] : f32
+// CHECK:                   store %[[VAL_49]], %[[VAL_14]][] : memref<f32>
+// CHECK:                 } else {
+// CHECK:                 }
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_50:.*]] = cmpi "eq", %[[VAL_30]], %[[VAL_29]] : index
+// CHECK:             %[[VAL_51:.*]] = addi %[[VAL_27]], %[[VAL_6]] : index
+// CHECK:             %[[VAL_52:.*]] = select %[[VAL_50]], %[[VAL_51]], %[[VAL_27]] : index
+// CHECK:             %[[VAL_53:.*]] = cmpi "eq", %[[VAL_31]], %[[VAL_29]] : index
+// CHECK:             %[[VAL_54:.*]] = addi %[[VAL_28]], %[[VAL_6]] : index
+// CHECK:             %[[VAL_55:.*]] = select %[[VAL_53]], %[[VAL_54]], %[[VAL_28]] : index
+// CHECK:             %[[VAL_56:.*]] = addi %[[VAL_29]], %[[VAL_6]] : index
+// CHECK:             scf.yield %[[VAL_52]], %[[VAL_55]], %[[VAL_56]] : index, index, index
+// CHECK:           }
+// CHECK:           %[[VAL_57:.*]] = load %[[VAL_14]][] : memref<f32>
+// CHECK:           %[[VAL_58:.*]] = scf.for %[[VAL_59:.*]] = %[[VAL_60:.*]]#0 to %[[VAL_17]] step %[[VAL_6]] iter_args(%[[VAL_61:.*]] = %[[VAL_57]]) -> (f32) {
+// CHECK:             %[[VAL_62:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_59]]] : memref<?xf32>
+// CHECK:             %[[VAL_63:.*]] = mulf %[[VAL_62]], %[[VAL_15]] : f32
+// CHECK:             %[[VAL_64:.*]] = addf %[[VAL_61]], %[[VAL_63]] : f32
+// CHECK:             scf.yield %[[VAL_64]] : f32
+// CHECK:           }
+// CHECK:           %[[VAL_65:.*]] = scf.for %[[VAL_66:.*]] = %[[VAL_67:.*]]#1 to %[[VAL_19]] step %[[VAL_6]] iter_args(%[[VAL_68:.*]] = %[[VAL_69:.*]]) -> (f32) {
+// CHECK:             %[[VAL_70:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_66]]] : memref<?xf32>
+// CHECK:             %[[VAL_71:.*]] = addf %[[VAL_68]], %[[VAL_70]] : f32
+// CHECK:             scf.yield %[[VAL_71]] : f32
+// CHECK:           }
+// CHECK:           store %[[VAL_72:.*]], %[[VAL_14]][] : memref<f32>
+// CHECK:           %[[VAL_73:.*]] = tensor_load %[[VAL_14]] : memref<f32>
+// CHECK:           return %[[VAL_73]] : tensor<f32>
+// CHECK:         }
+func @sum_reduction_inv(%arga: tensor<16xf32>,
+                        %argb: tensor<f32>,
+                        %argc: tensor<16xf32>,
+                        %argx: tensor<f32>) -> tensor<f32> {
+  // Just for testing. This case would be better expressed
+  // as two separate reductions kernels.
+  %0 = linalg.generic #trait_sum_reduction_inv_ss
+    ins(%arga, %argb, %argc : tensor<16xf32>, tensor<f32>, tensor<16xf32>)
+    init(%argx : tensor<f32>) {
+      ^bb(%a : f32, %b : f32, %c : f32, %x : f32):
+        %0 = mulf %a, %b  : f32
+        %1 = addf %0, %c  : f32
+        %2 = addf %x, %1  : f32
+        linalg.yield %2: f32
+  } -> tensor<f32>
+  return %0 : tensor<f32>
+}

diff  --git a/mlir/test/Dialect/Linalg/sparse_2d.mlir b/mlir/test/Dialect/Linalg/sparse_2d.mlir
index bdd2de5e437a..dea7444cadae 100644
--- a/mlir/test/Dialect/Linalg/sparse_2d.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_2d.mlir
@@ -1012,7 +1012,7 @@ func @mul_sd_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32
     [ "D" ]        // x
   ],
   iterator_types = ["parallel", "reduction"],
-  doc = "x(i) += A(i,j) * b(j)"
+  doc = "x(i) += SUM_j A(i,j) * b(j)"
 }
 
 // CHECK-LABEL:   func @matvec(
@@ -1032,18 +1032,19 @@ func @mul_sd_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32
 // CHECK:             %[[VAL_13:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref<?xindex>
 // CHECK:             %[[VAL_14:.*]] = addi %[[VAL_12]], %[[VAL_6]] : index
 // CHECK:             %[[VAL_15:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref<?xindex>
-// CHECK:             scf.for %[[VAL_16:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_6]] {
-// CHECK:               %[[VAL_17:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
-// CHECK:               %[[VAL_18:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref<?xf32>
-// CHECK:               %[[VAL_19:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_17]]] : memref<32xf32>
-// CHECK:               %[[VAL_20:.*]] = mulf %[[VAL_18]], %[[VAL_19]] : f32
-// CHECK:               %[[VAL_21:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32>
-// CHECK:               %[[VAL_22:.*]] = addf %[[VAL_20]], %[[VAL_21]] : f32
-// CHECK:               store %[[VAL_22]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32>
+// CHECK:             %[[VAL_16:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32>
+// CHECK:             %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_6]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f32) {
+// CHECK:               %[[VAL_20:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK:               %[[VAL_21:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xf32>
+// CHECK:               %[[VAL_22:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref<32xf32>
+// CHECK:               %[[VAL_23:.*]] = mulf %[[VAL_21]], %[[VAL_22]] : f32
+// CHECK:               %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_19]] : f32
+// CHECK:               scf.yield %[[VAL_24]] : f32
 // CHECK:             }
+// CHECK:             store %[[VAL_25:.*]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32>
 // CHECK:           }
-// CHECK:           %[[VAL_23:.*]] = tensor_load %[[VAL_11]] : memref<16xf32>
-// CHECK:           return %[[VAL_23]] : tensor<16xf32>
+// CHECK:           %[[VAL_26:.*]] = tensor_load %[[VAL_11]] : memref<16xf32>
+// CHECK:           return %[[VAL_26]] : tensor<16xf32>
 // CHECK:         }
 func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf32>) -> tensor<16xf32> {
   %0 = linalg.generic #trait_matvec
@@ -1059,20 +1060,20 @@ func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf
 
 #trait_sum_reduction = {
   indexing_maps = [
-    affine_map<(i,j) -> (i,j)>,  // a
-    affine_map<(i,j) -> ()>      // x (scalar out)
+    affine_map<(i,j) -> (i,j)>, // A
+    affine_map<(i,j) -> ()>     // x (scalar out)
   ],
   sparse = [
-    [ "D","S" ],  // a
+    [ "D", "S" ], // A
     [ ]           // x
   ],
   iterator_types = ["reduction", "reduction"],
-  doc = "x = SUM_ij a(i,j)"
+  doc = "x += SUM_ij A(i,j)"
 }
 
 // CHECK-LABEL:   func @sum_reduction(
-// CHECK-SAME:                        %[[VAL_0:.*0]]: tensor<10x20xf32>,
-// CHECK-SAME:                        %[[VAL_1:.*1]]: tensor<f32>) -> tensor<f32> {
+// CHECK-SAME:                        %[[VAL_0:.*]]: tensor<10x20xf32>,
+// CHECK-SAME:                        %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
 // CHECK:           %[[VAL_2:.*]] = constant 999 : index
 // CHECK:           %[[VAL_3:.*]] = constant 10 : index
 // CHECK:           %[[VAL_4:.*]] = constant 0 : index
@@ -1085,15 +1086,16 @@ func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf
 // CHECK:             %[[VAL_11:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xindex>
 // CHECK:             %[[VAL_12:.*]] = addi %[[VAL_10]], %[[VAL_5]] : index
 // CHECK:             %[[VAL_13:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<?xindex>
-// CHECK:             scf.for %[[VAL_14:.*]] = %[[VAL_11]] to %[[VAL_13]] step %[[VAL_5]] {
-// CHECK:               %[[VAL_15:.*]] = load %[[VAL_9]][] : memref<f32>
-// CHECK:               %[[VAL_16:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref<?xf32>
-// CHECK:               %[[VAL_17:.*]] = addf %[[VAL_15]], %[[VAL_16]] : f32
-// CHECK:               store %[[VAL_17]], %[[VAL_9]][] : memref<f32>
+// CHECK:             %[[VAL_14:.*]] = load %[[VAL_9]][] : memref<f32>
+// CHECK:             %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_11]] to %[[VAL_13]] step %[[VAL_5]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
+// CHECK:               %[[VAL_18:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xf32>
+// CHECK:               %[[VAL_19:.*]] = addf %[[VAL_17]], %[[VAL_18]] : f32
+// CHECK:               scf.yield %[[VAL_19]] : f32
 // CHECK:             }
+// CHECK:             store %[[VAL_20:.*]], %[[VAL_9]][] : memref<f32>
 // CHECK:           }
-// CHECK:           %[[VAL_18:.*]] = tensor_load %[[VAL_9]] : memref<f32>
-// CHECK:           return %[[VAL_18]] : tensor<f32>
+// CHECK:           %[[VAL_21:.*]] = tensor_load %[[VAL_9]] : memref<f32>
+// CHECK:           return %[[VAL_21]] : tensor<f32>
 // CHECK:         }
 func @sum_reduction(%arga: tensor<10x20xf32>, %argx: tensor<f32>) -> tensor<f32> {
   %0 = linalg.generic #trait_sum_reduction
@@ -1170,7 +1172,7 @@ func @scale(%arga: tensor<?x?xf64>) -> tensor<?x?xf64> {
     [ "D", "D" ]   // X
   ],
   iterator_types = ["parallel", "parallel", "reduction"],
-  doc = "X(i,j) = S(i,j) SUM_k A(i,k) B(k,j)"
+  doc = "X(i,j) += S(i,j) SUM_k A(i,k) B(k,j)"
 }
 
 // CHECK-LABEL:   func @sampled_dense_dense(
@@ -1234,3 +1236,235 @@ func @sampled_dense_dense(%args: tensor<?x?xf32>,
   } -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
+
+#trait_sum_kernel_with_inv = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A
+    affine_map<(i,j) -> (i,j)>,  // B
+    affine_map<(i,j) -> (i,j)>,  // C
+    affine_map<(i,j) -> (i)>,    // d
+    affine_map<(i,j) -> ()>,     // e
+    affine_map<(i,j) -> (i)>     // x (out)
+  ],
+  sparse = [
+    [ "S", "S" ], // A
+    [ "D", "S" ], // B
+    [ "D", "S" ], // C
+    [ "D"  ],     // d
+    [      ],     // e
+    [ "D"  ]      // x
+  ],
+  iterator_types = ["parallel", "reduction"],
+  doc = "x(i) = SUM_j A(i,j) * B(i,j) * d(i) * e + C(i,j)"
+}
+
+// CHECK-LABEL:   func @sum_kernel_with_inv(
+// CHECK-SAME:                              %[[VAL_0:.*0]]: tensor<?x?xf32>,
+// CHECK-SAME:                              %[[VAL_1:.*1]]: tensor<?x?xf32>,
+// CHECK-SAME:                              %[[VAL_2:.*2]]: tensor<?x?xf32>,
+// CHECK-SAME:                              %[[VAL_3:.*3]]: tensor<?xf32>,
+// CHECK-SAME:                              %[[VAL_4:.*4]]: tensor<f32>,
+// CHECK-SAME:                              %[[VAL_5:.*5]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK:           %[[VAL_6:.*]] = constant 999 : index
+// CHECK:           %[[VAL_7:.*]] = constant 0 : index
+// CHECK:           %[[VAL_8:.*]] = constant true
+// CHECK:           %[[VAL_9:.*]] = constant 1 : index
+// CHECK:           %[[VAL_10:.*]] = alloca(%[[VAL_6]]) : memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = alloca(%[[VAL_6]]) : memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = alloca(%[[VAL_6]]) : memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = alloca(%[[VAL_6]]) : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = alloca(%[[VAL_6]]) : memref<?xf32>
+// CHECK:           %[[VAL_15:.*]] = alloca(%[[VAL_6]]) : memref<?xindex>
+// CHECK:           %[[VAL_16:.*]] = alloca(%[[VAL_6]]) : memref<?xindex>
+// CHECK:           %[[VAL_17:.*]] = alloca(%[[VAL_6]]) : memref<?xf32>
+// CHECK:           %[[VAL_18:.*]] = alloca(%[[VAL_6]]) : memref<?xindex>
+// CHECK:           %[[VAL_19:.*]] = alloca(%[[VAL_6]]) : memref<?xindex>
+// CHECK:           %[[VAL_20:.*]] = alloca(%[[VAL_6]]) : memref<?xf32>
+// CHECK:           %[[VAL_21:.*]] = dim %[[VAL_3]], %[[VAL_7]] : tensor<?xf32>
+// CHECK:           %[[VAL_22:.*]] = alloca(%[[VAL_21]]) : memref<?xf32>
+// CHECK:           %[[VAL_23:.*]] = alloca() : memref<f32>
+// CHECK:           %[[VAL_24:.*]] = dim %[[VAL_5]], %[[VAL_7]] : tensor<?xf32>
+// CHECK:           %[[VAL_25:.*]] = alloca(%[[VAL_24]]) : memref<?xf32>
+// CHECK:           %[[VAL_26:.*]] = load %[[VAL_23]][] : memref<f32>
+// CHECK:           %[[VAL_27:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK:           %[[VAL_28:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_9]]] : memref<?xindex>
+// CHECK:           %[[VAL_29:.*]]:2 = scf.while (%[[VAL_30:.*]] = %[[VAL_27]], %[[VAL_31:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_32:.*]] = cmpi "ult", %[[VAL_30]], %[[VAL_28]] : index
+// CHECK:             scf.condition(%[[VAL_32]]) %[[VAL_30]], %[[VAL_31]] : index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_33:.*]]: index, %[[VAL_34:.*]]: index):
+// CHECK:             %[[VAL_35:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_33]]] : memref<?xindex>
+// CHECK:             %[[VAL_36:.*]] = cmpi "eq", %[[VAL_35]], %[[VAL_34]] : index
+// CHECK:             scf.if %[[VAL_36]] {
+// CHECK:               %[[VAL_37:.*]] = load %[[VAL_22]]{{\[}}%[[VAL_34]]] : memref<?xf32>
+// CHECK:               %[[VAL_38:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_33]]] : memref<?xindex>
+// CHECK:               %[[VAL_39:.*]] = addi %[[VAL_33]], %[[VAL_9]] : index
+// CHECK:               %[[VAL_40:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_39]]] : memref<?xindex>
+// CHECK:               %[[VAL_41:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_34]]] : memref<?xindex>
+// CHECK:               %[[VAL_42:.*]] = addi %[[VAL_34]], %[[VAL_9]] : index
+// CHECK:               %[[VAL_43:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_42]]] : memref<?xindex>
+// CHECK:               %[[VAL_44:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_34]]] : memref<?xindex>
+// CHECK:               %[[VAL_45:.*]] = addi %[[VAL_34]], %[[VAL_9]] : index
+// CHECK:               %[[VAL_46:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_45]]] : memref<?xindex>
+// CHECK:               %[[VAL_47:.*]]:4 = scf.while (%[[VAL_48:.*]] = %[[VAL_38]], %[[VAL_49:.*]] = %[[VAL_41]], %[[VAL_50:.*]] = %[[VAL_44]], %[[VAL_51:.*]] = %[[VAL_7]]) : (index, index, index, index) -> (index, index, index, index) {
+// CHECK:                 %[[VAL_52:.*]] = cmpi "ult", %[[VAL_48]], %[[VAL_40]] : index
+// CHECK:                 %[[VAL_53:.*]] = cmpi "ult", %[[VAL_49]], %[[VAL_43]] : index
+// CHECK:                 %[[VAL_54:.*]] = and %[[VAL_52]], %[[VAL_53]] : i1
+// CHECK:                 %[[VAL_55:.*]] = cmpi "ult", %[[VAL_50]], %[[VAL_46]] : index
+// CHECK:                 %[[VAL_56:.*]] = and %[[VAL_54]], %[[VAL_55]] : i1
+// CHECK:                 scf.condition(%[[VAL_56]]) %[[VAL_48]], %[[VAL_49]], %[[VAL_50]], %[[VAL_51]] : index, index, index, index
+// CHECK:               } do {
+// CHECK:               ^bb0(%[[VAL_57:.*]]: index, %[[VAL_58:.*]]: index, %[[VAL_59:.*]]: index, %[[VAL_60:.*]]: index):
+// CHECK:                 %[[VAL_61:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_57]]] : memref<?xindex>
+// CHECK:                 %[[VAL_62:.*]] = load %[[VAL_16]]{{\[}}%[[VAL_58]]] : memref<?xindex>
+// CHECK:                 %[[VAL_63:.*]] = load %[[VAL_19]]{{\[}}%[[VAL_59]]] : memref<?xindex>
+// CHECK:                 %[[VAL_64:.*]] = cmpi "eq", %[[VAL_61]], %[[VAL_60]] : index
+// CHECK:                 %[[VAL_65:.*]] = cmpi "eq", %[[VAL_62]], %[[VAL_60]] : index
+// CHECK:                 %[[VAL_66:.*]] = and %[[VAL_64]], %[[VAL_65]] : i1
+// CHECK:                 %[[VAL_67:.*]] = cmpi "eq", %[[VAL_63]], %[[VAL_60]] : index
+// CHECK:                 %[[VAL_68:.*]] = and %[[VAL_66]], %[[VAL_67]] : i1
+// CHECK:                 scf.if %[[VAL_68]] {
+// CHECK:                   %[[VAL_69:.*]] = load %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref<?xf32>
+// CHECK:                   %[[VAL_70:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_57]]] : memref<?xf32>
+// CHECK:                   %[[VAL_71:.*]] = load %[[VAL_17]]{{\[}}%[[VAL_58]]] : memref<?xf32>
+// CHECK:                   %[[VAL_72:.*]] = mulf %[[VAL_70]], %[[VAL_71]] : f32
+// CHECK:                   %[[VAL_73:.*]] = mulf %[[VAL_72]], %[[VAL_37]] : f32
+// CHECK:                   %[[VAL_74:.*]] = mulf %[[VAL_73]], %[[VAL_26]] : f32
+// CHECK:                   %[[VAL_75:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_59]]] : memref<?xf32>
+// CHECK:                   %[[VAL_76:.*]] = addf %[[VAL_74]], %[[VAL_75]] : f32
+// CHECK:                   %[[VAL_77:.*]] = addf %[[VAL_69]], %[[VAL_76]] : f32
+// CHECK:                   store %[[VAL_77]], %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref<?xf32>
+// CHECK:                 } else {
+// CHECK:                   %[[VAL_78:.*]] = cmpi "eq", %[[VAL_61]], %[[VAL_60]] : index
+// CHECK:                   %[[VAL_79:.*]] = cmpi "eq", %[[VAL_62]], %[[VAL_60]] : index
+// CHECK:                   %[[VAL_80:.*]] = and %[[VAL_78]], %[[VAL_79]] : i1
+// CHECK:                   scf.if %[[VAL_80]] {
+// CHECK:                     %[[VAL_81:.*]] = load %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref<?xf32>
+// CHECK:                     %[[VAL_82:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_57]]] : memref<?xf32>
+// CHECK:                     %[[VAL_83:.*]] = load %[[VAL_17]]{{\[}}%[[VAL_58]]] : memref<?xf32>
+// CHECK:                     %[[VAL_84:.*]] = mulf %[[VAL_82]], %[[VAL_83]] : f32
+// CHECK:                     %[[VAL_85:.*]] = mulf %[[VAL_84]], %[[VAL_37]] : f32
+// CHECK:                     %[[VAL_86:.*]] = mulf %[[VAL_85]], %[[VAL_26]] : f32
+// CHECK:                     %[[VAL_87:.*]] = addf %[[VAL_81]], %[[VAL_86]] : f32
+// CHECK:                     store %[[VAL_87]], %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref<?xf32>
+// CHECK:                   } else {
+// CHECK:                     %[[VAL_88:.*]] = cmpi "eq", %[[VAL_63]], %[[VAL_60]] : index
+// CHECK:                     scf.if %[[VAL_88]] {
+// CHECK:                       %[[VAL_89:.*]] = load %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref<?xf32>
+// CHECK:                       %[[VAL_90:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_59]]] : memref<?xf32>
+// CHECK:                       %[[VAL_91:.*]] = addf %[[VAL_89]], %[[VAL_90]] : f32
+// CHECK:                       store %[[VAL_91]], %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref<?xf32>
+// CHECK:                     } else {
+// CHECK:                     }
+// CHECK:                   }
+// CHECK:                 }
+// CHECK:                 %[[VAL_92:.*]] = cmpi "eq", %[[VAL_61]], %[[VAL_60]] : index
+// CHECK:                 %[[VAL_93:.*]] = addi %[[VAL_57]], %[[VAL_9]] : index
+// CHECK:                 %[[VAL_94:.*]] = select %[[VAL_92]], %[[VAL_93]], %[[VAL_57]] : index
+// CHECK:                 %[[VAL_95:.*]] = cmpi "eq", %[[VAL_62]], %[[VAL_60]] : index
+// CHECK:                 %[[VAL_96:.*]] = addi %[[VAL_58]], %[[VAL_9]] : index
+// CHECK:                 %[[VAL_97:.*]] = select %[[VAL_95]], %[[VAL_96]], %[[VAL_58]] : index
+// CHECK:                 %[[VAL_98:.*]] = cmpi "eq", %[[VAL_63]], %[[VAL_60]] : index
+// CHECK:                 %[[VAL_99:.*]] = addi %[[VAL_59]], %[[VAL_9]] : index
+// CHECK:                 %[[VAL_100:.*]] = select %[[VAL_98]], %[[VAL_99]], %[[VAL_59]] : index
+// CHECK:                 %[[VAL_101:.*]] = addi %[[VAL_60]], %[[VAL_9]] : index
+// CHECK:                 scf.yield %[[VAL_94]], %[[VAL_97]], %[[VAL_100]], %[[VAL_101]] : index, index, index, index
+// CHECK:               }
+// CHECK:               %[[VAL_102:.*]]:3 = scf.while (%[[VAL_103:.*]] = %[[VAL_104:.*]]#0, %[[VAL_105:.*]] = %[[VAL_104]]#1, %[[VAL_106:.*]] = %[[VAL_104]]#3) : (index, index, index) -> (index, index, index) {
+// CHECK:                 %[[VAL_107:.*]] = cmpi "ult", %[[VAL_103]], %[[VAL_40]] : index
+// CHECK:                 %[[VAL_108:.*]] = cmpi "ult", %[[VAL_105]], %[[VAL_43]] : index
+// CHECK:                 %[[VAL_109:.*]] = and %[[VAL_107]], %[[VAL_108]] : i1
+// CHECK:                 scf.condition(%[[VAL_109]]) %[[VAL_103]], %[[VAL_105]], %[[VAL_106]] : index, index, index
+// CHECK:               } do {
+// CHECK:               ^bb0(%[[VAL_110:.*]]: index, %[[VAL_111:.*]]: index, %[[VAL_112:.*]]: index):
+// CHECK:                 %[[VAL_113:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_110]]] : memref<?xindex>
+// CHECK:                 %[[VAL_114:.*]] = load %[[VAL_16]]{{\[}}%[[VAL_111]]] : memref<?xindex>
+// CHECK:                 %[[VAL_115:.*]] = cmpi "eq", %[[VAL_113]], %[[VAL_112]] : index
+// CHECK:                 %[[VAL_116:.*]] = cmpi "eq", %[[VAL_114]], %[[VAL_112]] : index
+// CHECK:                 %[[VAL_117:.*]] = and %[[VAL_115]], %[[VAL_116]] : i1
+// CHECK:                 scf.if %[[VAL_117]] {
+// CHECK:                   %[[VAL_118:.*]] = load %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref<?xf32>
+// CHECK:                   %[[VAL_119:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_110]]] : memref<?xf32>
+// CHECK:                   %[[VAL_120:.*]] = load %[[VAL_17]]{{\[}}%[[VAL_111]]] : memref<?xf32>
+// CHECK:                   %[[VAL_121:.*]] = mulf %[[VAL_119]], %[[VAL_120]] : f32
+// CHECK:                   %[[VAL_122:.*]] = mulf %[[VAL_121]], %[[VAL_37]] : f32
+// CHECK:                   %[[VAL_123:.*]] = mulf %[[VAL_122]], %[[VAL_26]] : f32
+// CHECK:                   %[[VAL_124:.*]] = addf %[[VAL_118]], %[[VAL_123]] : f32
+// CHECK:                   store %[[VAL_124]], %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref<?xf32>
+// CHECK:                 } else {
+// CHECK:                 }
+// CHECK:                 %[[VAL_125:.*]] = cmpi "eq", %[[VAL_113]], %[[VAL_112]] : index
+// CHECK:                 %[[VAL_126:.*]] = addi %[[VAL_110]], %[[VAL_9]] : index
+// CHECK:                 %[[VAL_127:.*]] = select %[[VAL_125]], %[[VAL_126]], %[[VAL_110]] : index
+// CHECK:                 %[[VAL_128:.*]] = cmpi "eq", %[[VAL_114]], %[[VAL_112]] : index
+// CHECK:                 %[[VAL_129:.*]] = addi %[[VAL_111]], %[[VAL_9]] : index
+// CHECK:                 %[[VAL_130:.*]] = select %[[VAL_128]], %[[VAL_129]], %[[VAL_111]] : index
+// CHECK:                 %[[VAL_131:.*]] = addi %[[VAL_112]], %[[VAL_9]] : index
+// CHECK:                 scf.yield %[[VAL_127]], %[[VAL_130]], %[[VAL_131]] : index, index, index
+// CHECK:               }
+// CHECK:               %[[VAL_132:.*]] = load %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref<?xf32>
+// CHECK:               %[[VAL_133:.*]] = scf.for %[[VAL_134:.*]] = %[[VAL_135:.*]]#2 to %[[VAL_46]] step %[[VAL_9]] iter_args(%[[VAL_136:.*]] = %[[VAL_132]]) -> (f32) {
+// CHECK:                 %[[VAL_137:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_134]]] : memref<?xf32>
+// CHECK:                 %[[VAL_138:.*]] = addf %[[VAL_136]], %[[VAL_137]] : f32
+// CHECK:                 scf.yield %[[VAL_138]] : f32
+// CHECK:               }
+// CHECK:               store %[[VAL_139:.*]], %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref<?xf32>
+// CHECK:             } else {
+// CHECK:               scf.if %[[VAL_8]] {
+// CHECK:                 %[[VAL_140:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_34]]] : memref<?xindex>
+// CHECK:                 %[[VAL_141:.*]] = addi %[[VAL_34]], %[[VAL_9]] : index
+// CHECK:                 %[[VAL_142:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_141]]] : memref<?xindex>
+// CHECK:                 %[[VAL_143:.*]] = load %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref<?xf32>
+// CHECK:                 %[[VAL_144:.*]] = scf.for %[[VAL_145:.*]] = %[[VAL_140]] to %[[VAL_142]] step %[[VAL_9]] iter_args(%[[VAL_146:.*]] = %[[VAL_143]]) -> (f32) {
+// CHECK:                   %[[VAL_147:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_145]]] : memref<?xf32>
+// CHECK:                   %[[VAL_148:.*]] = addf %[[VAL_146]], %[[VAL_147]] : f32
+// CHECK:                   scf.yield %[[VAL_148]] : f32
+// CHECK:                 }
+// CHECK:                 store %[[VAL_149:.*]], %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref<?xf32>
+// CHECK:               } else {
+// CHECK:               }
+// CHECK:             }
+// CHECK:             %[[VAL_150:.*]] = cmpi "eq", %[[VAL_35]], %[[VAL_34]] : index
+// CHECK:             %[[VAL_151:.*]] = addi %[[VAL_33]], %[[VAL_9]] : index
+// CHECK:             %[[VAL_152:.*]] = select %[[VAL_150]], %[[VAL_151]], %[[VAL_33]] : index
+// CHECK:             %[[VAL_153:.*]] = addi %[[VAL_34]], %[[VAL_9]] : index
+// CHECK:             scf.yield %[[VAL_152]], %[[VAL_153]] : index, index
+// CHECK:           }
+// CHECK:           scf.for %[[VAL_154:.*]] = %[[VAL_155:.*]]#1 to %[[VAL_24]] step %[[VAL_9]] {
+// CHECK:             %[[VAL_156:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_154]]] : memref<?xindex>
+// CHECK:             %[[VAL_157:.*]] = addi %[[VAL_154]], %[[VAL_9]] : index
+// CHECK:             %[[VAL_158:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_157]]] : memref<?xindex>
+// CHECK:             %[[VAL_159:.*]] = load %[[VAL_25]]{{\[}}%[[VAL_154]]] : memref<?xf32>
+// CHECK:             %[[VAL_160:.*]] = scf.for %[[VAL_161:.*]] = %[[VAL_156]] to %[[VAL_158]] step %[[VAL_9]] iter_args(%[[VAL_162:.*]] = %[[VAL_159]]) -> (f32) {
+// CHECK:               %[[VAL_163:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_161]]] : memref<?xf32>
+// CHECK:               %[[VAL_164:.*]] = addf %[[VAL_162]], %[[VAL_163]] : f32
+// CHECK:               scf.yield %[[VAL_164]] : f32
+// CHECK:             }
+// CHECK:             store %[[VAL_165:.*]], %[[VAL_25]]{{\[}}%[[VAL_154]]] : memref<?xf32>
+// CHECK:           }
+// CHECK:           %[[VAL_166:.*]] = tensor_load %[[VAL_25]] : memref<?xf32>
+// CHECK:           return %[[VAL_166]] : tensor<?xf32>
+// CHECK:         }
+func @sum_kernel_with_inv(%arga: tensor<?x?xf32>,
+                          %argb: tensor<?x?xf32>,
+                          %argc: tensor<?x?xf32>,
+                          %argd: tensor<?xf32>,
+                          %arge: tensor<f32>,
+                          %argx: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = linalg.generic #trait_sum_kernel_with_inv
+    ins(%arga, %argb, %argc, %argd, %arge : tensor<?x?xf32>,
+                                            tensor<?x?xf32>,
+                                            tensor<?x?xf32>,
+                                            tensor<?xf32>,
+                                            tensor<f32>)
+    init(%argx : tensor<?xf32>) {
+      ^bb(%a : f32, %b : f32, %c : f32, %d : f32, %e : f32, %x : f32):
+        %0 = mulf %a, %b  : f32
+        %1 = mulf %0, %d  : f32
+        %2 = mulf %1, %e  : f32
+        %3 = addf %2, %c  : f32
+        %4 = addf %x, %3  : f32
+        linalg.yield %4: f32
+  } -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/sparse_3d.mlir b/mlir/test/Dialect/Linalg/sparse_3d.mlir
index a5cb834fd431..41818bb982b6 100644
--- a/mlir/test/Dialect/Linalg/sparse_3d.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_3d.mlir
@@ -1160,7 +1160,7 @@ func @mul_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<
     [ "D", "D" ]        // A
   ],
   iterator_types = ["parallel", "parallel", "reduction", "reduction"],
-  doc = "A(i,j) = SUM_k,l B(i,k,l) * C(k,j) * D(l,j)"
+  doc = "A(i,j) += SUM_k,l B(i,k,l) * C(k,j) * D(l,j)"
 }
 
 // CHECK-LABEL:   func @kernel_3d(
@@ -1223,17 +1223,18 @@ func @kernel_3d(%arga: tensor<?x?xf32>,
   } -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
+
 #trait_sum_reduction = {
   indexing_maps = [
-    affine_map<(i,j,k) -> (i,j,k)>,  // a
+    affine_map<(i,j,k) -> (i,j,k)>,  // A
     affine_map<(i,j,k) -> ()>        // x (scalar out)
   ],
   sparse = [
-    [ "S", "S", "S" ],  // a
+    [ "S", "S", "S" ],  // A
     [ ]                 // x
   ],
   iterator_types = ["reduction", "reduction", "reduction"],
-  doc = "x = SUM_ijk a(i,j,k)"
+  doc = "x += SUM_ijk A(i,j,k)"
 }
 
 // CHECK-LABEL:   func @sum_reduction(
@@ -1260,16 +1261,17 @@ func @kernel_3d(%arga: tensor<?x?xf32>,
 // CHECK:               %[[VAL_20:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xindex>
 // CHECK:               %[[VAL_21:.*]] = addi %[[VAL_19]], %[[VAL_4]] : index
 // CHECK:               %[[VAL_22:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xindex>
-// CHECK:               scf.for %[[VAL_23:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_4]] {
-// CHECK:                 %[[VAL_24:.*]] = load %[[VAL_12]][] : memref<f32>
-// CHECK:                 %[[VAL_25:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_23]]] : memref<?xf32>
-// CHECK:                 %[[VAL_26:.*]] = addf %[[VAL_24]], %[[VAL_25]] : f32
-// CHECK:                 store %[[VAL_26]], %[[VAL_12]][] : memref<f32>
+// CHECK:               %[[VAL_23:.*]] = load %[[VAL_12]][] : memref<f32>
+// CHECK:               %[[VAL_24:.*]] = scf.for %[[VAL_25:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_4]] iter_args(%[[VAL_26:.*]] = %[[VAL_23]]) -> (f32) {
+// CHECK:                 %[[VAL_27:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref<?xf32>
+// CHECK:                 %[[VAL_28:.*]] = addf %[[VAL_26]], %[[VAL_27]] : f32
+// CHECK:                 scf.yield %[[VAL_28]] : f32
 // CHECK:               }
+// CHECK:               store %[[VAL_29:.*]], %[[VAL_12]][] : memref<f32>
 // CHECK:             }
 // CHECK:           }
-// CHECK:           %[[VAL_27:.*]] = tensor_load %[[VAL_12]] : memref<f32>
-// CHECK:           return %[[VAL_27]] : tensor<f32>
+// CHECK:           %[[VAL_30:.*]] = tensor_load %[[VAL_12]] : memref<f32>
+// CHECK:           return %[[VAL_30]] : tensor<f32>
 // CHECK:         }
 func @sum_reduction(%arga: tensor<10x20x30xf32>, %argx: tensor<f32>) -> tensor<f32> {
   %0 = linalg.generic #trait_sum_reduction
@@ -1282,21 +1284,80 @@ func @sum_reduction(%arga: tensor<10x20x30xf32>, %argx: tensor<f32>) -> tensor<f
   return %0 : tensor<f32>
 }
 
+#trait_sum_reduction_inv = {
+  indexing_maps = [
+    affine_map<(i,j,k) -> (i,j,k)>,  // A
+    affine_map<(i,j,k) -> (i)>,      // b
+    affine_map<(i,j,k) -> ()>        // x (scalar out)
+  ],
+  sparse = [
+    [ "D", "D", "D" ], // A
+    [ "D" ],           // b
+    [ ]                // x
+  ],
+  iterator_types = ["reduction", "reduction", "reduction"],
+  doc = "x += SUM_i A(i,j,k) * b(i)"
+}
+
+// CHECK-LABEL:   func @sum_reduction_inv(
+// CHECK-SAME:                            %[[VAL_0:.*]]: tensor<?x?x?xf32>,
+// CHECK-SAME:                            %[[VAL_1:.*]]: tensor<?xf32>,
+// CHECK-SAME:                            %[[VAL_2:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK:           %[[VAL_3:.*]] = constant 2 : index
+// CHECK:           %[[VAL_4:.*]] = constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_7:.*]] = dim %[[VAL_0]], %[[VAL_5]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_8:.*]] = dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_9:.*]] = alloca(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) : memref<?x?x?xf32>
+// CHECK:           %[[VAL_10:.*]] = dim %[[VAL_1]], %[[VAL_4]] : tensor<?xf32>
+// CHECK:           %[[VAL_11:.*]] = alloca(%[[VAL_10]]) : memref<?xf32>
+// CHECK:           %[[VAL_12:.*]] = alloca() : memref<f32>
+// CHECK:           scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_10]] step %[[VAL_5]] {
+// CHECK:             %[[VAL_14:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_13]]] : memref<?xf32>
+// CHECK:             scf.for %[[VAL_15:.*]] = %[[VAL_4]] to %[[VAL_7]] step %[[VAL_5]] {
+// CHECK:               %[[VAL_16:.*]] = load %[[VAL_12]][] : memref<f32>
+// CHECK:               %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_4]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f32) {
+// CHECK:                 %[[VAL_20:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_13]], %[[VAL_15]], %[[VAL_18]]] : memref<?x?x?xf32>
+// CHECK:                 %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_14]] : f32
+// CHECK:                 %[[VAL_22:.*]] = addf %[[VAL_19]], %[[VAL_21]] : f32
+// CHECK:                 scf.yield %[[VAL_22]] : f32
+// CHECK:               }
+// CHECK:               store %[[VAL_23:.*]], %[[VAL_12]][] : memref<f32>
+// CHECK:             }
+// CHECK:           }
+// CHECK:           %[[VAL_24:.*]] = tensor_load %[[VAL_12]] : memref<f32>
+// CHECK:           return %[[VAL_24]] : tensor<f32>
+// CHECK:         }
+func @sum_reduction_inv(%arga: tensor<?x?x?xf32>,
+                        %argb: tensor<?xf32>,
+		        %argx: tensor<f32>) -> tensor<f32> {
+  %0 = linalg.generic #trait_sum_reduction_inv
+    ins(%arga, %argb : tensor<?x?x?xf32>, tensor<?xf32>)
+    init(%argx : tensor<f32>) {
+      ^bb(%a : f32, %b : f32, %x : f32):
+        %0 = mulf %a, %b  : f32
+        %1 = addf %x, %0  : f32
+        linalg.yield %1: f32
+  } -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
 #trait_invariants = {
   indexing_maps = [
     affine_map<(i,j,k) -> (i)>,      // a
     affine_map<(i,j,k) -> (j)>,      // b
     affine_map<(i,j,k) -> (k)>,      // c
-    affine_map<(i,j,k) -> (i,j,k)>   // x
+    affine_map<(i,j,k) -> (i,j,k)>   // X (out)
   ],
   sparse = [
     [ "D" ],           // a
     [ "D" ],           // b
     [ "D" ],           // c
-    [ "D", "D", "D" ]  // x
+    [ "D", "D", "D" ]  // X
   ],
   iterator_types = ["parallel", "parallel", "parallel"],
-  doc = "x(i,j,k) = a(i) * b(j) * c(k)"
+  doc = "X(i,j,k) = a(i) * b(j) * c(k)"
 }
 
 // CHECK-LABEL:   func @invariants(


        


More information about the Mlir-commits mailing list