[Mlir-commits] [mlir] [mlir][sparse] support sparsifying 2:4 block sparsity (PR #71749)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 8 16:26:40 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/71749.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h (+2-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp (+13-4) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+3-3) 
- (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-2) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir (+44-18) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 215920f8b4607b2..cde6b2d13e58217 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -540,7 +540,8 @@ class Merger {
   bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
     if (isLvlWithNonTrivialIdxExp(b)) {
       auto dlt = getLoopDependentLevelType(b);
-      return isCompressedDLT(dlt) || isSingletonDLT(dlt);
+      return isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
+             isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt);
     }
     return false;
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index bb3c6fb56f692d9..6facc87d1b5a029 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -448,7 +448,7 @@ void LoopEmitter::initializeLoopEmit(
         positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l);
         coordinatesBuffers[t][l] =
             genToCoordinates(builder, loc, tensor, l, cooStart);
-      } else if (isSingletonDLT(lvlTp)) {
+      } else if (isSingletonDLT(lvlTp) || is2OutOf4DLT(lvlTp)) {
         // Singleton level, fetch coordinates.
         coordinatesBuffers[t][l] =
             genToCoordinates(builder, loc, tensor, l, cooStart);
@@ -540,7 +540,8 @@ void LoopEmitter::categorizeLoopCondition(
     auto lvlType = lvlTypes[t][l];
     // Must be a recognizable DLT.
     assert(isDenseDLT(lvlType) || isCompressedDLT(lvlType) ||
-           isLooseCompressedDLT(lvlType) || isSingletonDLT(lvlType));
+           isLooseCompressedDLT(lvlType) || isSingletonDLT(lvlType) ||
+           is2OutOf4DLT(lvlType));
 
     bool isSparse = !isDenseDLT(lvlType);
     bool isSlice = isSparseSlices[t];
@@ -637,6 +638,7 @@ std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
     Value hi, MutableArrayRef<Value> reduc, bool isParallel) {
   bool isSparseCond = isCompressedDLT(lvlTypes[tid][lvl]) ||
                       isLooseCompressedDLT(lvlTypes[tid][lvl]) ||
+                      is2OutOf4DLT(lvlTypes[tid][lvl]) ||
                       isSingletonDLT(lvlTypes[tid][lvl]);
   // TODO: support dynamic slices.
   // Uses the first dimension here to build the loop bound (which is also the
@@ -1240,6 +1242,7 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
 
   const Value c0 = C_IDX(0);
   const Value c1 = C_IDX(1);
+  const Value c2 = C_IDX(2);
   // Either the first level, or the previous level has been set.
   /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
   assert(lvl == 0 || posits[tid][lvl - 1]);
@@ -1248,7 +1251,7 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
 
     Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
     if (isLooseCompressedDLT(lvlTp))
-      pLo = builder.create<arith::MulIOp>(loc, pLo, C_IDX(2));
+      pLo = builder.create<arith::MulIOp>(loc, pLo, c2);
     posits[tid][lvl] = genIndexLoad(builder, loc, mem, pLo);
 
     const Value pHi = ADDI(pLo, c1);
@@ -1271,7 +1274,13 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
                           : ADDI(pLo, c1);
     return;
   }
-
+  if (is2OutOf4DLT(lvlTp)) {
+    const Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
+    // Each 2:4 block has exactly two specified elements.
+    posits[tid][lvl] = MULI(pLo, c2);
+    highs[tid][lvl] = ADDI(posits[tid][lvl], c2);
+    return;
+  }
   llvm_unreachable("Unrecognized level-type!");
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 85d6a6ddabf9eb6..dd121cb05c2184d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -816,7 +816,7 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
       for (LoopId i = 0; i < numLoops; i++) {
         const auto dltI = env.dlt(tid, i);
         if (isCompressedDLT(dltI) || isLooseCompressedDLT(dltI) ||
-            isSingletonDLT(dltI)) {
+            isSingletonDLT(dltI) || is2OutOf4DLT(dltI)) {
           for (LoopId j = 0; j < numLoops; j++)
             if (isUndefDLT(env.dlt(tid, j))) {
               addIterOrdering(i, j, adjM, inDegree);
@@ -1508,7 +1508,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
         assert(ldx == env.merger().loop(b));
         Value clause;
         if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
-            isLooseCompressedDLT(dlt)) {
+            isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt)) {
           assert(lvl.has_value());
           const Value crd = env.emitter().getCoords()[tid][*lvl];
           const Value lvar = env.getLoopVar(ldx);
@@ -1593,7 +1593,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
       needsUniv = true;
     }
     if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
-        isLooseCompressedDLT(dlt) || isIdxReduc) {
+        isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt) || isIdxReduc) {
       // Only when this is a index reduction loop, can the dlt be undefined.
       assert(!isUndefDLT(dlt) || isIdxReduc);
       // sparse/singleton levels, or a dense/sparse index reduction loop.
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 18ebd608608bdcb..033b61fc872a312 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -490,7 +490,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
     if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
       const auto dlt = getLvlType(b);
       if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) &&
-          !isLooseCompressedDLT(dlt)) {
+          !isLooseCompressedDLT(dlt) && !is2OutOf4DLT(dlt)) {
         if (reset)
           simple.reset(b);
         reset = true;
@@ -671,7 +671,7 @@ bool Merger::hasAnySparse(const BitVector &bits) const {
   for (TensorLoopId b : bits.set_bits()) {
     const auto dlt = getLvlType(b);
     if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
-        isLooseCompressedDLT(dlt))
+        isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt))
       return true;
   }
   return hasSparseIdxReduction(bits);
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index 7e9606b1caedee9..0c420f2fd426fb2 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -44,19 +44,41 @@
 #BSR = #sparse_tensor.encoding<{
   map = ( i, j ) ->
   ( i floordiv 2 : dense,
-    j floordiv 3 : compressed,
+    j floordiv 2 : compressed,
     i mod 2      : dense,
-    j mod 3      : dense
+    j mod 2      : dense
   )
 }>
 
+#NV_24 = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i            : dense,
+    j floordiv 4 : dense,
+    j mod 4      : block2_4
+  ),
+}>
+
 module {
 
-func.func @mul(%arg0: tensor<4x6xf64>,
-               %arg1: tensor<4x6xf64, #BSR>) -> tensor<4x4xf64> {
-  %out = tensor.empty() : tensor<4x4xf64>
+func.func @mul(%arg0: tensor<4x8xf64>,
+               %arg1: tensor<4x8xf64, #BSR>) -> tensor<4x4xf64> {
+  %out = arith.constant dense<0.0> : tensor<4x4xf64>
+  %0 = linalg.generic #trait_mul
+    ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64, #BSR>)
+    outs(%out: tensor<4x4xf64>) {
+      ^bb(%x: f64, %y : f64, %z : f64):
+        %1 = arith.mulf %x, %y : f64
+        %2 = arith.addf %1, %z : f64
+        linalg.yield %2 : f64
+  } -> tensor<4x4xf64>
+  return %0 : tensor<4x4xf64>
+}
+
+func.func @mul_24(%arg0: tensor<4x8xf64>,
+                  %arg1: tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64> {
+  %out = arith.constant dense<0.0> : tensor<4x4xf64>
   %0 = linalg.generic #trait_mul
-    ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64, #BSR>)
+    ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64, #NV_24>)
     outs(%out: tensor<4x4xf64>) {
       ^bb(%x: f64, %y : f64, %z : f64):
         %1 = arith.mulf %x, %y : f64
@@ -66,11 +88,11 @@ func.func @mul(%arg0: tensor<4x6xf64>,
   return %0 : tensor<4x4xf64>
 }
 
-func.func @mul_dense(%arg0: tensor<4x6xf64>,
-                     %arg1: tensor<4x6xf64>) -> tensor<4x4xf64> {
-  %out = tensor.empty() : tensor<4x4xf64>
+func.func @mul_dense(%arg0: tensor<4x8xf64>,
+                     %arg1: tensor<4x8xf64>) -> tensor<4x4xf64> {
+  %out = arith.constant dense<0.0> : tensor<4x4xf64>
   %0 = linalg.generic #trait_mul
-    ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64>)
+    ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64>)
     outs(%out: tensor<4x4xf64>) {
       ^bb(%x: f64, %y : f64, %z : f64):
         %1 = arith.mulf %x, %y : f64
@@ -101,22 +123,26 @@ func.func @mul_dense(%arg0: tensor<4x6xf64>,
     %c2 = arith.constant 2 : index
 
 
-    %td = arith.constant dense<[[ 0.0,  1.0,  2.0,  3.0,  4.0,  5.0],
-                                [ 6.0,  7.0,  8.0,  9.0, 10.0, 11.0],
-                                [12.0, 13.0, 14.0, 15.0, 16.0, 17.0],
-                                [18.0, 19.0, 20.0, 21.0, 22.0, 23.0]]> : tensor<4x6xf64>
+    %td = arith.constant dense<[[ 1.0, 2.0,  0.0,  0.0,  0.0,  0.0,  4.0,  5.0],
+                                [ 6.0, 7.0,  0.0,  0.0,  0.0,  0.0, 10.0, 11.0],
+                                [ 0.0, 0.0, 12.0, 13.0, 16.0, 17.0,  0.0,  0.0],
+                                [ 0.0, 0.0, 18.0, 19.0, 22.0, 23.0,  0.0,  0.0]]> : tensor<4x8xf64>
 
 
-    %2 = sparse_tensor.convert %td : tensor<4x6xf64> to tensor<4x6xf64, #BSR>
+    %2 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #BSR>
+    %3 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #NV_24>
 
     %d = call @mul_dense(%td, %td)
-         : (tensor<4x6xf64>, tensor<4x6xf64>) -> tensor<4x4xf64>
+         : (tensor<4x8xf64>, tensor<4x8xf64>) -> tensor<4x4xf64>
     %s = call @mul(%td, %2)
-         : (tensor<4x6xf64>, tensor<4x6xf64, #BSR>) -> tensor<4x4xf64>
+         : (tensor<4x8xf64>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
+    %s24 = call @mul_24(%td, %3)
+         : (tensor<4x8xf64>, tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64>
 
-    // CHECK-COUNT-2: ( ( 55, 145, 235, 325 ), ( 145, 451, 757, 1063 ), ( 235, 757, 1279, 1801 ), ( 325, 1063, 1801, 2539 ) )
+    // CHECK-COUNT-3: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
     call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
     call @dumpf64(%s) : (tensor<4x4xf64>) -> ()
+    call @dumpf64(%s24) : (tensor<4x4xf64>) -> ()
 
     return
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/71749


More information about the Mlir-commits mailing list