[Mlir-commits] [mlir] [mlir][sparse] support sparsifying 2:4 block sparsity (PR #71749)
Peiming Liu
llvmlistbot at llvm.org
Thu Nov 9 11:22:27 PST 2023
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/71749
>From b6c7492f636aa6d587cf149653ee878649097afb Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 9 Nov 2023 00:13:26 +0000
Subject: [PATCH] [mlir][sparse] support sparsifying 2:4 block sparsity
---
.../mlir/Dialect/SparseTensor/Utils/Merger.h | 3 +-
.../SparseTensor/Transforms/LoopEmitter.cpp | 17 ++++--
.../Transforms/Sparsification.cpp | 6 +-
.../lib/Dialect/SparseTensor/Utils/Merger.cpp | 4 +-
.../SparseTensor/CPU/sparse_block_matmul.mlir | 58 ++++++++++++++-----
5 files changed, 62 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 215920f8b4607b2..cde6b2d13e58217 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -540,7 +540,8 @@ class Merger {
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
if (isLvlWithNonTrivialIdxExp(b)) {
auto dlt = getLoopDependentLevelType(b);
- return isCompressedDLT(dlt) || isSingletonDLT(dlt);
+ return isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
+ isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt);
}
return false;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index bb3c6fb56f692d9..6facc87d1b5a029 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -448,7 +448,7 @@ void LoopEmitter::initializeLoopEmit(
positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l);
coordinatesBuffers[t][l] =
genToCoordinates(builder, loc, tensor, l, cooStart);
- } else if (isSingletonDLT(lvlTp)) {
+ } else if (isSingletonDLT(lvlTp) || is2OutOf4DLT(lvlTp)) {
// Singleton level, fetch coordinates.
coordinatesBuffers[t][l] =
genToCoordinates(builder, loc, tensor, l, cooStart);
@@ -540,7 +540,8 @@ void LoopEmitter::categorizeLoopCondition(
auto lvlType = lvlTypes[t][l];
// Must be a recognizable DLT.
assert(isDenseDLT(lvlType) || isCompressedDLT(lvlType) ||
- isLooseCompressedDLT(lvlType) || isSingletonDLT(lvlType));
+ isLooseCompressedDLT(lvlType) || isSingletonDLT(lvlType) ||
+ is2OutOf4DLT(lvlType));
bool isSparse = !isDenseDLT(lvlType);
bool isSlice = isSparseSlices[t];
@@ -637,6 +638,7 @@ std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
Value hi, MutableArrayRef<Value> reduc, bool isParallel) {
bool isSparseCond = isCompressedDLT(lvlTypes[tid][lvl]) ||
isLooseCompressedDLT(lvlTypes[tid][lvl]) ||
+ is2OutOf4DLT(lvlTypes[tid][lvl]) ||
isSingletonDLT(lvlTypes[tid][lvl]);
// TODO: support dynamic slices.
// Uses the first dimension here to build the loop bound (which is also the
@@ -1240,6 +1242,7 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
const Value c0 = C_IDX(0);
const Value c1 = C_IDX(1);
+ const Value c2 = C_IDX(2);
// Either the first level, or the previous level has been set.
/// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
assert(lvl == 0 || posits[tid][lvl - 1]);
@@ -1248,7 +1251,7 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
if (isLooseCompressedDLT(lvlTp))
- pLo = builder.create<arith::MulIOp>(loc, pLo, C_IDX(2));
+ pLo = builder.create<arith::MulIOp>(loc, pLo, c2);
posits[tid][lvl] = genIndexLoad(builder, loc, mem, pLo);
const Value pHi = ADDI(pLo, c1);
@@ -1271,7 +1274,13 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
: ADDI(pLo, c1);
return;
}
-
+ if (is2OutOf4DLT(lvlTp)) {
+ const Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
+ // Each 2:4 block has exactly two specified elements.
+ posits[tid][lvl] = MULI(pLo, c2);
+ highs[tid][lvl] = ADDI(posits[tid][lvl], c2);
+ return;
+ }
llvm_unreachable("Unrecognized level-type!");
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 85d6a6ddabf9eb6..dd121cb05c2184d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -816,7 +816,7 @@ static bool computeIterationGraph(CodegenEnv &env, SortMask mask,
for (LoopId i = 0; i < numLoops; i++) {
const auto dltI = env.dlt(tid, i);
if (isCompressedDLT(dltI) || isLooseCompressedDLT(dltI) ||
- isSingletonDLT(dltI)) {
+ isSingletonDLT(dltI) || is2OutOf4DLT(dltI)) {
for (LoopId j = 0; j < numLoops; j++)
if (isUndefDLT(env.dlt(tid, j))) {
addIterOrdering(i, j, adjM, inDegree);
@@ -1508,7 +1508,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
assert(ldx == env.merger().loop(b));
Value clause;
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
- isLooseCompressedDLT(dlt)) {
+ isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt)) {
assert(lvl.has_value());
const Value crd = env.emitter().getCoords()[tid][*lvl];
const Value lvar = env.getLoopVar(ldx);
@@ -1593,7 +1593,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
needsUniv = true;
}
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
- isLooseCompressedDLT(dlt) || isIdxReduc) {
+ isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt) || isIdxReduc) {
// Only when this is a index reduction loop, can the dlt be undefined.
assert(!isUndefDLT(dlt) || isIdxReduc);
// sparse/singleton levels, or a dense/sparse index reduction loop.
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 18ebd608608bdcb..033b61fc872a312 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -490,7 +490,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
const auto dlt = getLvlType(b);
if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) &&
- !isLooseCompressedDLT(dlt)) {
+ !isLooseCompressedDLT(dlt) && !is2OutOf4DLT(dlt)) {
if (reset)
simple.reset(b);
reset = true;
@@ -671,7 +671,7 @@ bool Merger::hasAnySparse(const BitVector &bits) const {
for (TensorLoopId b : bits.set_bits()) {
const auto dlt = getLvlType(b);
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
- isLooseCompressedDLT(dlt))
+ isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt))
return true;
}
return hasSparseIdxReduction(bits);
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index 7e9c0ae71a7b7e6..c2087d626b6a9aa 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -47,19 +47,41 @@
#BSR = #sparse_tensor.encoding<{
map = ( i, j ) ->
( i floordiv 2 : dense,
- j floordiv 3 : compressed,
+ j floordiv 2 : compressed,
i mod 2 : dense,
- j mod 3 : dense
+ j mod 2 : dense
)
}>
+#NV_24 = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i : dense,
+ j floordiv 4 : dense,
+ j mod 4 : block2_4
+ ),
+}>
+
module {
-func.func @mul(%arg0: tensor<4x6xf64>,
- %arg1: tensor<4x6xf64, #BSR>) -> tensor<4x4xf64> {
+func.func @mul(%arg0: tensor<4x8xf64>,
+ %arg1: tensor<4x8xf64, #BSR>) -> tensor<4x4xf64> {
+ %out = arith.constant dense<0.0> : tensor<4x4xf64>
+ %0 = linalg.generic #trait_mul
+ ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64, #BSR>)
+ outs(%out: tensor<4x4xf64>) {
+ ^bb(%x: f64, %y : f64, %z : f64):
+ %1 = arith.mulf %x, %y : f64
+ %2 = arith.addf %1, %z : f64
+ linalg.yield %2 : f64
+ } -> tensor<4x4xf64>
+ return %0 : tensor<4x4xf64>
+}
+
+func.func @mul_24(%arg0: tensor<4x8xf64>,
+ %arg1: tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64> {
%out = arith.constant dense<0.0> : tensor<4x4xf64>
%0 = linalg.generic #trait_mul
- ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64, #BSR>)
+ ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64, #NV_24>)
outs(%out: tensor<4x4xf64>) {
^bb(%x: f64, %y : f64, %z : f64):
%1 = arith.mulf %x, %y : f64
@@ -69,11 +91,11 @@ func.func @mul(%arg0: tensor<4x6xf64>,
return %0 : tensor<4x4xf64>
}
-func.func @mul_dense(%arg0: tensor<4x6xf64>,
- %arg1: tensor<4x6xf64>) -> tensor<4x4xf64> {
+func.func @mul_dense(%arg0: tensor<4x8xf64>,
+ %arg1: tensor<4x8xf64>) -> tensor<4x4xf64> {
%out = arith.constant dense<0.0> : tensor<4x4xf64>
%0 = linalg.generic #trait_mul
- ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64>)
+ ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64>)
outs(%out: tensor<4x4xf64>) {
^bb(%x: f64, %y : f64, %z : f64):
%1 = arith.mulf %x, %y : f64
@@ -104,22 +126,26 @@ func.func @mul_dense(%arg0: tensor<4x6xf64>,
%c2 = arith.constant 2 : index
- %td = arith.constant dense<[[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
- [ 6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
- [12.0, 13.0, 14.0, 15.0, 16.0, 17.0],
- [18.0, 19.0, 20.0, 21.0, 22.0, 23.0]]> : tensor<4x6xf64>
+ %td = arith.constant dense<[[ 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0],
+ [ 6.0, 7.0, 0.0, 0.0, 0.0, 0.0, 10.0, 11.0],
+ [ 0.0, 0.0, 12.0, 13.0, 16.0, 17.0, 0.0, 0.0],
+ [ 0.0, 0.0, 18.0, 19.0, 22.0, 23.0, 0.0, 0.0]]> : tensor<4x8xf64>
- %2 = sparse_tensor.convert %td : tensor<4x6xf64> to tensor<4x6xf64, #BSR>
+ %2 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #BSR>
+ %3 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #NV_24>
%d = call @mul_dense(%td, %td)
- : (tensor<4x6xf64>, tensor<4x6xf64>) -> tensor<4x4xf64>
+ : (tensor<4x8xf64>, tensor<4x8xf64>) -> tensor<4x4xf64>
%s = call @mul(%td, %2)
- : (tensor<4x6xf64>, tensor<4x6xf64, #BSR>) -> tensor<4x4xf64>
+ : (tensor<4x8xf64>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
+ %s24 = call @mul_24(%td, %3)
+ : (tensor<4x8xf64>, tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64>
- // CHECK-COUNT-2: ( ( 55, 145, 235, 325 ), ( 145, 451, 757, 1063 ), ( 235, 757, 1279, 1801 ), ( 325, 1063, 1801, 2539 ) )
+ // CHECK-COUNT-3: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
call @dumpf64(%s) : (tensor<4x4xf64>) -> ()
+ call @dumpf64(%s24) : (tensor<4x4xf64>) -> ()
return
}
More information about the Mlir-commits
mailing list