[Mlir-commits] [mlir] [mlir][sparse] implement non-permutation MapRef encoding (PR #69406)

Yinying Li llvmlistbot at llvm.org
Wed Oct 18 09:12:23 PDT 2023


================
@@ -688,25 +688,70 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
     return dimSizesBuffer;
   }
   // Otherwise, some code needs to be generated to set up the buffers.
-  // TODO: use the lvl2dim once available and deal with non-permutations!
+  // This code deals with permutations as well as non-permutations that
+  // arise from rank changing blocking.
   const auto dimToLvl = stt.getDimToLvl();
-  assert(dimToLvl.isPermutation());
-  SmallVector<Value> dim2lvlValues(dimRank);
-  SmallVector<Value> lvl2dimValues(lvlRank);
+  SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
+  SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
   SmallVector<Value> lvlSizesValues(lvlRank);
+  // Generate dim2lvl.
+  assert(lvlRank == dimToLvl.getNumResults());
   for (Level l = 0; l < lvlRank; l++) {
-    // The `d`th source variable occurs in the `l`th result position.
-    Dimension d = dimToLvl.getDimPosition(l);
-    Value lvl = constantIndex(builder, loc, l);
-    Value dim = constantIndex(builder, loc, d);
-    dim2lvlValues[d] = lvl;
-    lvl2dimValues[l] = dim;
-    if (stt.isDynamicDim(d))
-      lvlSizesValues[l] =
-          builder.create<memref::LoadOp>(loc, dimSizesBuffer, dim);
-    else
-      lvlSizesValues[l] = dimShapesValues[d];
+    AffineExpr exp = dimToLvl.getResult(l);
+    // We expect:
+    //    (1) l = d
+    //    (2) l = d / c
+    //    (3) l = d % c
+    Dimension d = 0;
+    uint64_t cf = 0, cm = 0;
+    switch (exp.getKind()) {
+    case AffineExprKind::DimId:
+      d = exp.cast<AffineDimExpr>().getPosition();
+      break;
+    case AffineExprKind::FloorDiv:
+      d = exp.cast<AffineBinaryOpExpr>()
+              .getLHS()
+              .cast<AffineDimExpr>()
+              .getPosition();
+      cf = exp.cast<AffineBinaryOpExpr>()
+               .getRHS()
+               .cast<AffineConstantExpr>()
+               .getValue();
+      break;
+    case AffineExprKind::Mod:
+      d = exp.cast<AffineBinaryOpExpr>()
+              .getLHS()
+              .cast<AffineDimExpr>()
+              .getPosition();
+      cm = exp.cast<AffineBinaryOpExpr>()
+               .getRHS()
+               .cast<AffineConstantExpr>()
+               .getValue();
+      break;
----------------
yinying-lisa-li wrote:

This is nice! I might "borrow" this and refactor my code. ;)

https://github.com/llvm/llvm-project/pull/69406


More information about the Mlir-commits mailing list