[Mlir-commits] [mlir] c8d5dcb - [mlir][sparse] refactor loop sequence codegen

Aart Bik llvmlistbot at llvm.org
Tue Oct 26 13:42:34 PDT 2021


Author: Aart Bik
Date: 2021-10-26T13:42:21-07:00
New Revision: c8d5dcb035284e22ee07ee5b77dec2c5b18e1fa1

URL: https://github.com/llvm/llvm-project/commit/c8d5dcb035284e22ee07ee5b77dec2c5b18e1fa1
DIFF: https://github.com/llvm/llvm-project/commit/c8d5dcb035284e22ee07ee5b77dec2c5b18e1fa1.diff

LOG: [mlir][sparse] refactor loop sequence codegen

This refactoring adds a few "event" functions (start/end loop-seq/loop) for
readability of the core function of codegen. This also prepares sparse tensor
output codegen, where these "event" functions will provide convenient
placeholders to start or stop insertion bookkeeping.

This revision also includes a few various minor changes that kept on
pending in my local workspace.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D112506

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index c894ac866c5a6..cfab38616d55f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -113,9 +113,9 @@ static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) {
 }
 
 /// Helper method to inspect affine expressions. Rejects cases where the
-/// same index is used in more than one dimension of a tensor. Also rejects
-/// affine expressions that are not a direct index for annotated tensors.
-/// TODO: accept more affine cases for sparse tensors
+/// same index is used more than once. Also rejects affine expressions
+/// that are not a direct index for annotated tensors.
+// TODO: accept more affine cases for sparse tensors
 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim,
                        bool isDense) {
   switch (a.getKind()) {
@@ -263,6 +263,22 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
   return true;
 }
 
+/// Returns true if tensor has an in-place annotation.
+static bool isInPlace(Value val) {
+  if (auto arg = val.dyn_cast<BlockArgument>())
+    if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp()))
+      if (auto attr = funcOp.getArgAttrOfType<BoolAttr>(
+              arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName))
+        return attr.getValue();
+  return false;
+}
+
+/// Returns true if tensor materializes into the computation.
+static bool isMaterializing(Value val) {
+  return val.getDefiningOp<linalg::InitTensorOp>() ||
+         val.getDefiningOp<InitOp>();
+}
+
 /// Returns true when the tensor expression is admissable for codegen.
 /// Since all sparse input tensors are admissable, we just need to check
 /// whether the output tensor in the tensor expression codegen is admissable.
@@ -288,16 +304,17 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
     return true;
   // A tensor expression with a sparse output tensor that changes its values
   // but not its nonzero structure, an operation called "simply dynamic" in
-  // [Bik96,Ch9], is also admissable without special codegen.
+  // [Bik96,Ch9], is also admissable without special codegen, provided
+  // the tensor's underlying sparse storage scheme can be modified in place.
   if (merger.isConjunction(tensor, exp))
-    return true;
+    return isInPlace(lhs->get());
   // Reject for now since this requires changes to the nonzero structure.
   // TODO: implement "workspaces" [Kjolstad2019]
   return false;
 }
 
 //===----------------------------------------------------------------------===//
-// Sparse compiler synthesis methods.
+// Sparse compiler synthesis methods (statements and expressions).
 //===----------------------------------------------------------------------===//
 
 /// Maps reduction kind to name encoding.
@@ -350,7 +367,7 @@ static Value genReductionInit(PatternRewriter &rewriter, Location loc,
   case kXor: {
     // Initialize reduction vector to: | 0 | .. | 0 | r |
     Attribute zero = rewriter.getZeroAttr(vtp);
-    Value vec = rewriter.create<ConstantOp>(loc, vtp, zero);
+    Value vec = rewriter.create<arith::ConstantOp>(loc, vtp, zero);
     return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0);
   }
   case kProduct: {
@@ -361,8 +378,8 @@ static Value genReductionInit(PatternRewriter &rewriter, Location loc,
       one = rewriter.getFloatAttr(etp, 1.0);
     else
       one = rewriter.getIntegerAttr(etp, 1);
-    Value vec =
-        rewriter.create<ConstantOp>(loc, vtp, DenseElementsAttr::get(vtp, one));
+    Value vec = rewriter.create<arith::ConstantOp>(
+        loc, vtp, DenseElementsAttr::get(vtp, one));
     return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0);
   }
   case kAnd:
@@ -380,16 +397,6 @@ static Type genIntType(PatternRewriter &rewriter, unsigned width) {
   return rewriter.getIntegerType(width);
 }
 
-/// Detects in-place annotation on tensor argument.
-static bool getInPlace(Value val) {
-  if (auto arg = val.dyn_cast<BlockArgument>())
-    if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp()))
-      if (auto attr = funcOp.getArgAttrOfType<BoolAttr>(
-              arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName))
-        return attr.getValue();
-  return false;
-}
-
 /// Generates buffer for the output tensor. Note that all sparse kernels
 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)),
 /// the output buffer is already initialized to all zeroes and only nonzeroes
@@ -405,18 +412,19 @@ static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
   // be generated for the tensor present in the outs() clause. This has
   // the major advantage that the sparse kernel only updates the nonzero
   // positions for the output tensor.
-  if (getInPlace(tensor))
+  if (isInPlace(tensor))
     return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor);
   // By default, a new buffer is allocated which is initialized to the
   // tensor defined in the outs() clause. This is always correct but
   // introduces a dense initialization component that may negatively
   // impact the running complexity of the sparse kernel. If the tensor
-  // materializes within this method, we need to preserve the zero
+  // materializes into the computation, we need to preserve the zero
   // initialization assumption of all sparse output buffers.
-  if (auto init = tensor.getDefiningOp<linalg::InitTensorOp>()) {
+  if (isMaterializing(tensor)) {
     Type tp = denseTp.getElementType();
     Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args);
-    Value zero = rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
+    Value zero =
+        rewriter.create<arith::ConstantOp>(loc, tp, rewriter.getZeroAttr(tp));
     rewriter.create<linalg::FillOp>(loc, zero, alloc);
     return alloc;
   }
@@ -429,7 +437,7 @@ static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter,
 /// Local bufferization of all dense and sparse data structures.
 /// This code enables testing the first prototype sparse compiler.
 // TODO: replace this with a proliferated bufferization strategy
-static bool genBuffers(Merger &merger, CodeGen &codegen,
+static void genBuffers(Merger &merger, CodeGen &codegen,
                        PatternRewriter &rewriter, linalg::GenericOp op) {
   Location loc = op.getLoc();
   assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1);
@@ -486,15 +494,12 @@ static bool genBuffers(Merger &merger, CodeGen &codegen,
             genOutputBuffer(codegen, rewriter, op, denseTp, args);
     } else {
       // Annotated sparse tensors.
-      if (tensor == op.getNumInputs() && !getInPlace(t->get()))
-        return false; // reject output if not in-place
       auto dynShape = {ShapedType::kDynamicSize};
       auto sparseTp = MemRefType::get(dynShape, elementType);
       codegen.buffers[tensor] =
           rewriter.create<ToValuesOp>(loc, sparseTp, t->get());
     }
   }
-  return true;
 }
 
 /// Constructs vector type.
@@ -623,7 +628,9 @@ static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter,
   if (enc) {
     // Note that currently, all sparse subscripts are simple.
     // TODO: accept affine too?
-    unsigned idx = map.getDimPosition(perm(enc, rank - 1));
+    AffineExpr a = map.getResult(perm(enc, rank - 1));
+    assert(a.getKind() == AffineExprKind::DimId);
+    unsigned idx = a.cast<AffineDimExpr>().getPosition();
     assert(codegen.pidxs[tensor][idx] != nullptr);
     args.push_back(codegen.pidxs[tensor][idx]); // position index
   } else {
@@ -841,6 +848,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
     if (lhs == t) {
       codegen.redExp = hoist ? exp : -1u;
       codegen.redKind = getReduction(last);
+      assert(!codegen.redVal);
     } else if (atLevel) {
       merger.exp(exp).val =
           hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
@@ -948,6 +956,7 @@ static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
         // dimension (true non-unit stride) or if the innermost index appears
         // in a compound subscript in the innermost dimension. Even if the
         // latter is unit stride, it does not play well with scatter/gather.
+        // TODO: accept unit stride affine innermost like a[i,j+k+1]?
         if (a.isFunctionOfDim(idx) &&
             ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId)))
           return false;
@@ -1209,6 +1218,83 @@ static scf::IfOp genIf(Merger &merger, CodeGen &codegen,
   return ifOp;
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse compiler synthesis methods (loop sequence).
+//===----------------------------------------------------------------------===//
+
+/// Starts a loop sequence at given level. Returns true if
+/// the universal loop index must be maintained at this level.
+static bool startLoopSeq(Merger &merger, CodeGen &codegen,
+                         PatternRewriter &rewriter, linalg::GenericOp op,
+                         std::vector<unsigned> &topSort, unsigned exp,
+                         unsigned at, unsigned idx, unsigned ldx,
+                         unsigned lts) {
+  assert(codegen.curVecLength == 1);
+  // Emit invariants at this loop sequence level.
+  genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true);
+  // Emit further intitialization at this loop sequence level.
+  unsigned l0 = merger.set(lts)[0];
+  if (genInit(merger, codegen, rewriter, op, topSort, at,
+              merger.lat(l0).bits)) {
+    // Maintain the universal index only if it is actually
+    // consumed by a subsequent lattice point.
+    unsigned lsize = merger.set(lts).size();
+    for (unsigned i = 1; i < lsize; i++) {
+      unsigned li = merger.set(lts)[i];
+      if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse))
+        return true;
+    }
+  }
+  return false;
+}
+
+/// Starts a single loop in current sequence.
+static Operation *startLoop(Merger &merger, CodeGen &codegen,
+                            PatternRewriter &rewriter, linalg::GenericOp op,
+                            std::vector<unsigned> &topSort, unsigned at,
+                            unsigned li, bool needsUniv) {
+  assert(codegen.curVecLength == 1);
+  // Emit the for/while-loop control.
+  Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at,
+                            needsUniv, merger.lat(li).simple);
+  // Emit the locals for this loop.
+  genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
+            merger.lat(li).bits);
+  return loop;
+}
+
+/// Ends a single loop in current sequence. Returns new values for needsUniv.
+static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
+                    linalg::GenericOp op, Operation *loop, unsigned idx,
+                    unsigned li, bool needsUniv) {
+  codegen.curVecLength = 1;
+  // End a while-loop.
+  if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
+    rewriter.setInsertionPointToEnd(&whileOp.after().front());
+    genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
+                      merger.lat(li).bits, whileOp.results());
+    return needsUniv;
+  }
+  // End a for-loop.
+  if (codegen.redVal) {
+    rewriter.create<scf::YieldOp>(op.getLoc(), codegen.redVal);
+    codegen.redVal = loop->getResult(0);
+  }
+  return false;
+}
+
+/// Ends a loop sequence at given level.
+static void endLoopSeq(Merger &merger, CodeGen &codegen,
+                       PatternRewriter &rewriter, linalg::GenericOp op,
+                       unsigned exp, unsigned idx, unsigned ldx) {
+  assert(codegen.curVecLength == 1);
+  // Finalize any pending reduction.
+  genReductionEnd(merger, codegen, rewriter, op);
+  // Unmark bookkeeping of invariants and loop index.
+  genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
+  codegen.loops[idx] = Value();
+}
+
 /// Recursively generates code while computing iteration lattices in order
 /// to manage the complexity of implementing co-iteration over unions
 /// and intersections of sparse iterations spaces.
@@ -1221,45 +1307,23 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
     genTensorStore(merger, codegen, rewriter, op, rhs);
     return;
   }
-  assert(codegen.curVecLength == 1);
 
   // Construct iteration lattices for current loop index, with L0 at top.
-  // Then emit initialization code for the loop sequence at this level.
-  // We maintain the universal dense index if dense indices are still
-  // in play for a non-singleton loop sequence.
-  Location loc = op.getLoc();
   unsigned idx = topSort[at];
-  unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx));
-  unsigned lsize = merger.set(lts).size();
-  assert(lsize != 0);
-  unsigned l0 = merger.set(lts)[0];
   unsigned ldx = at == 0 ? -1u : topSort[at - 1];
-  genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true);
-  bool needsUniv = false;
-  if (genInit(merger, codegen, rewriter, op, topSort, at,
-              merger.lat(l0).bits)) {
-    // Maintain the universal index only if it is actually
-    // consumed by a subsequent lattice point.
-    for (unsigned i = 1; i < lsize; i++) {
-      unsigned li = merger.set(lts)[i];
-      if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) {
-        needsUniv = true;
-        break;
-      }
-    }
-  }
+  unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx));
+
+  // Start a loop sequence.
+  bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at,
+                                idx, ldx, lts);
 
-  // Emit a loop for every lattice point L0 >= Li.
+  // Emit a loop for every lattice point L0 >= Li in this loop sequence.
+  unsigned lsize = merger.set(lts).size();
   for (unsigned i = 0; i < lsize; i++) {
+    // Start a loop.
     unsigned li = merger.set(lts)[i];
-
-    // Emit loop.
-    codegen.curVecLength = 1;
-    llvm::BitVector indices = merger.lat(li).simple;
     Operation *loop =
-        genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
-    genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv,
-              merger.lat(li).bits);
+        startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv);
 
     // Visit all lattices points with Li >= Lj to generate the
     // loop-body, possibly with if statements for coiteration.
@@ -1280,27 +1344,14 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
       }
     }
 
-    // Wrap-up induction and restore insertion point.
-    if (isWhile) {
-      scf::WhileOp whileOp = cast<scf::WhileOp>(loop);
-      rewriter.setInsertionPointToEnd(&whileOp.after().front());
-      genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
-                        merger.lat(li).bits, whileOp.results());
-    } else {
-      needsUniv = false;
-      if (codegen.redVal) {
-        rewriter.create<scf::YieldOp>(loc, codegen.redVal);
-        codegen.redVal = loop->getResult(0);
-      }
-    }
+    // End a loop.
+    needsUniv =
+        endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv);
     rewriter.setInsertionPointAfter(loop);
   }
 
-  // Wrap-up loop sequence.
-  codegen.curVecLength = 1;
-  genReductionEnd(merger, codegen, rewriter, op);
-  genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
-  codegen.loops[idx] = Value();
+  // End a loop sequence.
+  endLoopSeq(merger, codegen, rewriter, op, exp, idx, ldx);
 }
 
 /// Converts the result computed by the sparse kernel into the required form.
@@ -1385,8 +1436,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
 
     // Recursively generates code.
     CodeGen codegen(options, numTensors, numLoops);
-    if (!genBuffers(merger, codegen, rewriter, op))
-      return failure(); // could not bufferize
+    genBuffers(merger, codegen, rewriter, op);
     genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
     genResult(merger, codegen, rewriter, op);
     return success();


        


More information about the Mlir-commits mailing list