[Mlir-commits] [mlir] [mlir][sparse] connect MapRef's lvl2dim with latest AffineMap computation (PR #69540)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 18 17:28:01 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Aart Bik (aartbik)

<details>
<summary>Changes</summary>

This makes sure

- GEN MAP dim=2 lvl=4
  (d0, d1) -> (d0 floordiv 2, d1 floordiv 2, d0 mod 2, d1 mod 2)
--
  (d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 2 + d3)

is indeed encoded as

MAP-REF (dim=2, lvl=4) isperm=0
  d2l = [ d0/2 d1/2 d0%2 d1%2 ]
  ld2 = [ l2+2*l0 l3+2*l1 ]

---
Full diff: https://github.com/llvm/llvm-project/pull/69540.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp (+42-20) 


``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 98b412c8ec9eb5b..b1b1d67ac2d420d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -691,6 +691,7 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
   // This code deals with permutations as well as non-permutations that
   // arise from rank changing blocking.
   const auto dimToLvl = stt.getDimToLvl();
+  const auto lvlToDim = stt.getLvlToDim();
   SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
   SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
   SmallVector<Value> lvlSizesValues(lvlRank);
@@ -705,34 +706,26 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
     Dimension d = 0;
     uint64_t cf = 0, cm = 0;
     switch (exp.getKind()) {
-    case AffineExprKind::DimId:
+    case AffineExprKind::DimId: {
       d = exp.cast<AffineDimExpr>().getPosition();
       break;
-    case AffineExprKind::FloorDiv:
-      d = exp.cast<AffineBinaryOpExpr>()
-              .getLHS()
-              .cast<AffineDimExpr>()
-              .getPosition();
-      cf = exp.cast<AffineBinaryOpExpr>()
-               .getRHS()
-               .cast<AffineConstantExpr>()
-               .getValue();
+    }
+    case AffineExprKind::FloorDiv: {
+      auto floor = exp.cast<AffineBinaryOpExpr>();
+      d = floor.getLHS().cast<AffineDimExpr>().getPosition();
+      cf = floor.getRHS().cast<AffineConstantExpr>().getValue();
       break;
-    case AffineExprKind::Mod:
-      d = exp.cast<AffineBinaryOpExpr>()
-              .getLHS()
-              .cast<AffineDimExpr>()
-              .getPosition();
-      cm = exp.cast<AffineBinaryOpExpr>()
-               .getRHS()
-               .cast<AffineConstantExpr>()
-               .getValue();
+    }
+    case AffineExprKind::Mod: {
+      auto mod = exp.cast<AffineBinaryOpExpr>();
+      d = mod.getLHS().cast<AffineDimExpr>().getPosition();
+      cm = mod.getRHS().cast<AffineConstantExpr>().getValue();
       break;
+    }
     default:
       llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type");
     }
     dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm));
-    lvl2dimValues[d] = constantIndex(builder, loc, l); // FIXME, use lvlToDim
     // Compute the level sizes.
     //    (1) l = d        : size(d)
     //    (2) l = d / c    : size(d) / c
@@ -751,6 +744,35 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
     }
     lvlSizesValues[l] = lvlSz;
   }
+  // Generate lvl2dim.
+  assert(dimRank == lvlToDim.getNumResults());
+  for (Dimension d = 0; d < dimRank; d++) {
+    AffineExpr exp = lvlToDim.getResult(d);
+    // We expect:
+    //    (1) d = l
+    //    (2) d = l' * c + l
+    Level l = 0, ll = 0;
+    uint64_t c = 0;
+    switch (exp.getKind()) {
+    case AffineExprKind::DimId: {
+      l = exp.cast<AffineDimExpr>().getPosition();
+      break;
+    }
+    case AffineExprKind::Add: {
+      // Always mul on lhs, symbol/constant on rhs.
+      auto add = exp.cast<AffineBinaryOpExpr>();
+      assert(add.getLHS().getKind() == AffineExprKind::Mul);
+      auto mul = add.getLHS().cast<AffineBinaryOpExpr>();
+      ll = mul.getLHS().cast<AffineDimExpr>().getPosition();
+      c = mul.getRHS().cast<AffineConstantExpr>().getValue();
+      l = add.getRHS().cast<AffineDimExpr>().getPosition();
+      break;
+    }
+    default:
+      llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type");
+    }
+    lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll));
+  }
   // Return buffers.
   dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
   lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);

``````````

</details>


https://github.com/llvm/llvm-project/pull/69540


More information about the Mlir-commits mailing list