[Mlir-commits] [mlir] 49be68b - [mlir][sparse] make loop emitter API more concise.

Peiming Liu llvmlistbot at llvm.org
Thu Dec 22 13:17:35 PST 2022


Author: Peiming Liu
Date: 2022-12-22T21:17:29Z
New Revision: 49be68b8aa3adafef2e7556e2895ed7f7c3a1cf1

URL: https://github.com/llvm/llvm-project/commit/49be68b8aa3adafef2e7556e2895ed7f7c3a1cf1
DIFF: https://github.com/llvm/llvm-project/commit/49be68b8aa3adafef2e7556e2895ed7f7c3a1cf1.diff

LOG: [mlir][sparse] make loop emitter API more concise.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D140583

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index d26e36464913..9fd74f7d3001 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -333,7 +333,7 @@ void SparseTensorLoopEmitter::initializeLoopEmit(
       auto sparseTp = MemRefType::get(dynShape, elementType);
       valBuffer[t] = builder.create<ToValuesOp>(loc, sparseTp, tensor);
     }
-    // NOTE: we can also prepares for 0 dim here in advance, this will hosit
+    // NOTE: we can also prepare for 0 dim here in advance, this will hosit
     // some loop preparation from tensor iteration, but will also (undesirably)
     // hosit the code ouside if conditions.
   }
@@ -380,22 +380,32 @@ Value SparseTensorLoopEmitter::genAffine(OpBuilder &builder, AffineExpr a,
 }
 
 Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim(
-    OpBuilder &builder, Location loc, size_t tid, size_t dim,
-    MutableArrayRef<Value> reduc, bool isParallel, ArrayRef<size_t> extraTids,
-    ArrayRef<size_t> extraDims) {
-
-  assert(dimTypes[tid].size() > dim);
-  // We can not re-enter the same level.
-  assert(!coord[tid][dim]);
+    OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
+    ArrayRef<size_t> dims, MutableArrayRef<Value> reduc, bool isParallel) {
   // TODO: support multiple return on parallel for?
   assert(!isParallel || reduc.size() <= 1);
 
-  Value step = constantIndex(builder, loc, 1);
-  auto dimType = dimTypes[tid][dim];
-  bool isSparseInput = isCompressedDLT(dimType) || isSingletonDLT(dimType);
-  assert(isDenseDLT(dimType) || isCompressedDLT(dimType) ||
-         isSingletonDLT(dimType));
+  bool isSparseInput = false;
+  size_t tid = tids.front(), dim = dims.front();
+  for (auto [t, d] : llvm::zip(tids, dims)) {
+    assert(dimTypes[t].size() > d); // Must be a valid tid, dim pair
+    assert(!coord[t][d]);           // We cannot re-enter the same level
+    auto dimType = dimTypes[t][d];
+    // Must be a recognizable DLT.
+    assert(isDenseDLT(dimType) || isCompressedDLT(dimType) ||
+           isSingletonDLT(dimType));
+    bool isSparse = isCompressedDLT(dimType) || isSingletonDLT(dimType);
+    // We can at most have one sparse input, otherwise, a while loop is required
+    // to co-iterate multiple sparse tensors.
+    assert(!isSparseInput || !isSparse);
+    if (isSparse) {
+      tid = t;
+      dim = d;
+    }
+    isSparseInput = isSparseInput || isSparse;
+  }
 
+  Value step = constantIndex(builder, loc, 1);
   Value lo = isSparseInput ? pidxs[tid][dim]      // current offset
                            : loopSeqStack.back(); // univeral tid
   Value hi = highs[tid][dim];
@@ -439,18 +449,13 @@ Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim(
   } else {
     // Dense tensor, the coordinates is the inducation variable.
     coord[tid][dim] = iv;
-    // generate pidx for dense dim (pidx = i * sz + j)
-    auto enc = getSparseTensorEncoding(tensors[tid].getType());
-    if (enc && !isSparseOutput(tid))
-      pidxs[tid][dim] = genAddress(builder, loc, tid, dim, iv);
   }
-
-  // NOTE: we can also prepares for next dim here in advance
+  // NOTE: we can also prepare for next dim here in advance
   // Push the loop into stack
   loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), loop,
                          coord[tid][dim], loopTag);
   // Emit extra locals.
-  emitExtraLocalsForTensorsAtDenseDims(builder, loc, extraTids, extraDims);
+  emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims);
 
   return loop;
 }
@@ -515,7 +520,7 @@ Operation *SparseTensorLoopEmitter::enterFilterLoopOverTensorAtDim(
   // Set the insert point to matched branch.
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
 
-  // NOTE: we can also prepares for next dim here in advance
+  // NOTE: we can also prepare for next dim here in advance
   // Push the loop into stack
   loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), forOp,
                          coord[tid][dim], nullptr);
@@ -531,8 +536,7 @@ void SparseTensorLoopEmitter::genDenseAffineAddressAtCurLevel(
 
 Operation *SparseTensorLoopEmitter::enterCoIterationOverTensorsAtDims(
     OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
-    ArrayRef<size_t> dims, bool needsUniv, MutableArrayRef<Value> reduc,
-    ArrayRef<size_t> extraTids, ArrayRef<size_t> extraDims) {
+    ArrayRef<size_t> dims, bool needsUniv, MutableArrayRef<Value> reduc) {
   assert(tids.size() == dims.size());
   SmallVector<Type> types;
   SmallVector<Value> operands;
@@ -611,24 +615,12 @@ Operation *SparseTensorLoopEmitter::enterCoIterationOverTensorsAtDims(
     min = after->getArguments().back();
   }
 
-  for (auto [tid, dim] : llvm::zip(tids, dims)) {
-    // All dense dim (as well as sparse output tensor) shared the same pidx in
-    // the while loop.
-    if (isDenseDLT(dimTypes[tid][dim])) {
-      pidxs[tid][dim] = min;
-      // generate pidx for dense dim (pidx = i * sz + j)
-      auto enc = getSparseTensorEncoding(tensors[tid].getType());
-      if (enc && !isSparseOutput(tid))
-        pidxs[tid][dim] = genAddress(builder, loc, tid, dim, min);
-    }
-    // NOTE: we can also prepares for next dim here in advance
-  }
   // Sets up the loop stack.
   loopStack.emplace_back(tids, dims, whileOp, min, loopTag);
   assert(loopStack.size() == loopSeqStack.size());
 
   // Emits extra locals
-  emitExtraLocalsForTensorsAtDenseDims(builder, loc, extraTids, extraDims);
+  emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims);
 
   // Updates reduction variables
   assert(after->getNumArguments() == o + reduc.size() + (needsUniv ? 1 : 0));
@@ -682,18 +674,20 @@ void SparseTensorLoopEmitter::emitExtraLocalsForTensorsAtDenseDims(
   // output tensor unconditionally, since they may not appear in the lattice,
   // but may be needed for linearized codegen.
   for (auto [tid, dim] : llvm::zip(tids, dims)) {
-    assert(isDenseDLT(dimTypes[tid][dim]));
-    auto enc = getSparseTensorEncoding(tensors[tid].getType());
-    if (enc && !isSparseOutput(tid)) {
-      bool validPidx = dim == 0 || pidxs[tid][dim - 1];
-      if (!validPidx) {
-        // We might not find the pidx for the sparse output tensor as it is
-        // unconditionally required by the sparsification.
-        assert(isOutputTensor(tid));
-        continue;
+    if (isDenseDLT(dimTypes[tid][dim])) {
+      auto enc = getSparseTensorEncoding(tensors[tid].getType());
+      if (enc && !isSparseOutput(tid)) {
+        bool validPidx = dim == 0 || pidxs[tid][dim - 1];
+        if (!validPidx) {
+          // We might not find the pidx for the sparse output tensor as it is
+          // unconditionally required by the sparsification.
+          assert(isOutputTensor(tid));
+          continue;
+        }
+        pidxs[tid][dim] =
+            genAddress(builder, loc, tid, dim, loopStack.back().iv);
+        // NOTE: we can also prepare for next dim here in advance
       }
-      pidxs[tid][dim] = genAddress(builder, loc, tid, dim, loopStack.back().iv);
-      // NOTE: we can also prepares for next dim here in advance
     }
   }
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 7fd126f2364a..6e4ea83e1ba5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -623,11 +623,10 @@ class SparseTensorLoopEmitter {
   /// The function will also perform in-place update on the `reduc` vector to
   /// return the reduction variable used inside the generated loop.
   Operation *enterLoopOverTensorAtDim(OpBuilder &builder, Location loc,
-                                      size_t tid, size_t dim,
+                                      ArrayRef<size_t> tids,
+                                      ArrayRef<size_t> dims,
                                       MutableArrayRef<Value> reduc = {},
-                                      bool isParallel = false,
-                                      ArrayRef<size_t> extraTids = {},
-                                      ArrayRef<size_t> extraDims = {});
+                                      bool isParallel = false);
 
   Operation *enterFilterLoopOverTensorAtDim(OpBuilder &builder, Location loc,
                                             size_t tid, size_t dim,
@@ -641,8 +640,7 @@ class SparseTensorLoopEmitter {
   /// Emits a co-iteration loop over a set of tensors.
   Operation *enterCoIterationOverTensorsAtDims(
       OpBuilder &builder, Location loc, ArrayRef<size_t> tids,
-      ArrayRef<size_t> dims, bool needsUniv, MutableArrayRef<Value> reduc = {},
-      ArrayRef<size_t> extraTids = {}, ArrayRef<size_t> extraDims = {});
+      ArrayRef<size_t> dims, bool needsUniv, MutableArrayRef<Value> reduc = {});
 
   void exitCurrentLoop(RewriterBase &rewriter, Location loc,
                        MutableArrayRef<Value> reduc = {});

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 462dd7db8a19..7c558065b8d5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1028,21 +1028,24 @@ static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
 
 /// Generates a for-loop on a single index.
 static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
-                         bool isInner, unsigned idx, size_t tid, size_t dim,
-                         ArrayRef<size_t> extraTids,
-                         ArrayRef<size_t> extraDims) {
+                         bool isInner, unsigned idx, ArrayRef<size_t> tids,
+                         ArrayRef<size_t> dims) {
   linalg::GenericOp op = env.op();
   Location loc = op.getLoc();
   auto iteratorTypes = op.getIteratorTypesArray();
-  bool isSparse =
-      isCompressedDLT(env.dlt(tid, idx)) || isSingletonDLT(env.dlt(tid, idx));
+  bool isSparse = llvm::any_of(tids, [idx, &env](size_t tid) {
+    return isCompressedDLT(env.dlt(tid, idx)) ||
+           isSingletonDLT(env.dlt(tid, idx));
+  });
+
   bool isParallel = isParallelFor(env, isOuter, isSparse);
 
   Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
     if (env.merger().isFilterLoop(idx)) {
-      // extraTids/extraDims must be empty because filter loops only
+      size_t tid = tids.front(), dim = dims.front();
+      // tids/dims must only have one value because filter loops only
       // corresponding to the one and only sparse tensor level.
-      assert(isSparse && extraTids.empty() && extraDims.empty());
+      assert(isSparse && tids.size() == 1 && dims.size() == 1);
       OpOperand *t = &op->getOpOperand(tid);
       auto enc = getSparseTensorEncoding(t->get().getType());
       // Retrieves the affine expression for the filter loop.
@@ -1051,8 +1054,8 @@ static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
       return env.emitter()->enterFilterLoopOverTensorAtDim(builder, loc, tid,
                                                            dim, a, reduc);
     }
-    return env.emitter()->enterLoopOverTensorAtDim(
-        builder, loc, tid, dim, reduc, isParallel, extraTids, extraDims);
+    return env.emitter()->enterLoopOverTensorAtDim(builder, loc, tids, dims,
+                                                   reduc, isParallel);
   });
   assert(loop);
   return loop;
@@ -1060,16 +1063,13 @@ static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
 
 /// Emit a while-loop for co-iteration over multiple indices.
 static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, unsigned idx,
-                           bool needsUniv, ArrayRef<size_t> condTids,
-                           ArrayRef<size_t> condDims,
-                           ArrayRef<size_t> extraTids,
-                           ArrayRef<size_t> extraDims) {
+                           bool needsUniv, ArrayRef<size_t> tids,
+                           ArrayRef<size_t> dims) {
   Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
     // Construct the while-loop with a parameter for each
     // index.
     return env.emitter()->enterCoIterationOverTensorsAtDims(
-        builder, env.op().getLoc(), condTids, condDims, needsUniv, reduc,
-        extraTids, extraDims);
+        builder, env.op().getLoc(), tids, dims, needsUniv, reduc);
   });
   assert(loop);
   return loop;
@@ -1078,20 +1078,16 @@ static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, unsigned idx,
 /// Generates a for-loop or a while-loop, depending on whether it implements
 /// singleton iteration or co-iteration over the given conjunction.
 static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
-                          bool needsUniv, ArrayRef<size_t> condTids,
-                          ArrayRef<size_t> condDims, ArrayRef<size_t> extraTids,
-                          ArrayRef<size_t> extraDims) {
-  assert(condTids.size() == condDims.size());
-  assert(extraTids.size() == extraDims.size());
+                          bool needsUniv, ArrayRef<size_t> tids,
+                          ArrayRef<size_t> dims, bool isFor) {
+  assert(tids.size() == dims.size());
   unsigned idx = env.topSortAt(at);
-  if (condTids.size() == 1) {
+  if (isFor) {
     bool isOuter = at == 0;
     bool isInner = at == env.topSortSize() - 1;
-    return genFor(env, builder, isOuter, isInner, idx, condTids.front(),
-                  condDims.front(), extraTids, extraDims);
+    return genFor(env, builder, isOuter, isInner, idx, tids, dims);
   }
-  return genWhile(env, builder, idx, needsUniv, condTids, condDims, extraTids,
-                  extraDims);
+  return genWhile(env, builder, idx, needsUniv, tids, dims);
 }
 
 /// Generates the induction structure for a while-loop.
@@ -1263,15 +1259,15 @@ static void genInitConstantDenseAddress(CodegenEnv &env,
     genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
 }
 
-static void translateBitsToTidDimPairs(
-    CodegenEnv &env, unsigned li, unsigned idx,
-    SmallVectorImpl<size_t> &condTids, SmallVectorImpl<size_t> &condDims,
-    SmallVectorImpl<size_t> &extraTids, SmallVectorImpl<size_t> &extraDims,
-    SmallVectorImpl<size_t> &affineTids, SmallVectorImpl<size_t> &affineDims,
-    SmallVectorImpl<AffineExpr> &exps) {
+/// Return true if the lattices bit can be iterated by a for loop.
+static bool translateBitsToTidDimPairs(
+    CodegenEnv &env, unsigned li, unsigned idx, SmallVectorImpl<size_t> &tids,
+    SmallVectorImpl<size_t> &dims, SmallVectorImpl<size_t> &affineTids,
+    SmallVectorImpl<size_t> &affineDims, SmallVectorImpl<AffineExpr> &exps) {
   const BitVector &all = env.lat(li).bits;
   const BitVector &simple = env.lat(li).simple;
 
+  unsigned numloopCond = 0;
   // Converts bits to array + dim pair
   env.merger().foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid,
                                                      Optional<unsigned> dim,
@@ -1290,12 +1286,12 @@ static void translateBitsToTidDimPairs(
         if (!dim)
           return;
       }
-      condTids.push_back(tid);
-      condDims.push_back(*dim);
+      tids.push_back(tid);
+      dims.push_back(*dim);
+      numloopCond++;
     } else if (isDenseDLT(dlt)) {
-      // TODO: get rid of extraTids and extraDims.
-      extraTids.push_back(tid);
-      extraDims.push_back(*dim);
+      tids.push_back(tid);
+      dims.push_back(*dim);
     } else {
       assert(isUndefDLT(dlt));
       linalg::GenericOp op = env.op();
@@ -1344,31 +1340,31 @@ static void translateBitsToTidDimPairs(
     // unconditionally, since they may not appear in the lattice, but may be
     // needed for linearized env.
     auto dim = *env.merger().getDimNum(env.merger().getOutTensorID(), idx);
-    extraTids.push_back(env.merger().getOutTensorID());
-    extraDims.push_back(dim);
+    tids.push_back(env.merger().getOutTensorID());
+    dims.push_back(dim);
   }
+
+  assert(numloopCond > 0);
+  // If we just need to one loop conditions, the loop can be generated by a for
+  // loop.
+  return numloopCond == 1;
 }
 
 /// Starts a single loop in current sequence.
 static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
                             unsigned li, bool needsUniv) {
   // The set of tensors + dims to generate loops on
-  SmallVector<size_t> condTids, condDims;
-  // The set of (dense) tensors that is optimized from condition, yet still
-  // need extra locals to iterate on them.
-  SmallVector<size_t> extraTids, extraDims;
+  SmallVector<size_t> tids, dims;
   // The set of dense tensors with non-trivial affine expression that just
   // becomes invariant and the address shall now be generated at the current
   // level.
   SmallVector<size_t> affineTids, affineDims;
   SmallVector<AffineExpr> affines;
-  translateBitsToTidDimPairs(env, li, env.topSortAt(at), condTids, condDims,
-                             extraTids, extraDims, affineTids, affineDims,
-                             affines);
+  bool isFor = translateBitsToTidDimPairs(
+      env, li, env.topSortAt(at), tids, dims, affineTids, affineDims, affines);
 
   // Emit the for/while-loop control.
-  Operation *loop = genLoop(env, builder, at, needsUniv, condTids, condDims,
-                            extraTids, extraDims);
+  Operation *loop = genLoop(env, builder, at, needsUniv, tids, dims, isFor);
   for (auto [tid, dim, exp] : llvm::zip(affineTids, affineDims, affines)) {
     env.emitter()->genDenseAffineAddressAtCurLevel(builder, env.op().getLoc(),
                                                    tid, dim, exp);
@@ -1377,8 +1373,8 @@ static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, unsigned at,
   // Until now, we have entered every <tid, dim> pair in {cond, extra,
   // affine}Tids/Dims. The addresses of the upcoming levels which are dependent
   // on constant affines expression may now be determined.
-  auto allTids = llvm::concat<size_t>(condTids, extraTids, affineTids);
-  auto allDims = llvm::concat<size_t>(condDims, extraDims, affineDims);
+  auto allTids = llvm::concat<size_t>(tids, affineTids);
+  auto allDims = llvm::concat<size_t>(dims, affineDims);
   for (auto [tid, dim] : llvm::zip(allTids, allDims)) {
     if (tid != env.merger().getOutTensorID())
       genConstantDenseAddressFromLevel(env, builder, tid, dim + 1);


        


More information about the Mlir-commits mailing list