[Mlir-commits] [mlir] b1d44e5 - [mlir][sparse] add affine subscripts to sparse compilation pass
Aart Bik
llvmlistbot at llvm.org
Wed Sep 15 20:28:17 PDT 2021
Author: Aart Bik
Date: 2021-09-15T20:28:04-07:00
New Revision: b1d44e59020a2a7adfc81bffb07577fd091d0778
URL: https://github.com/llvm/llvm-project/commit/b1d44e59020a2a7adfc81bffb07577fd091d0778
DIFF: https://github.com/llvm/llvm-project/commit/b1d44e59020a2a7adfc81bffb07577fd091d0778.diff
LOG: [mlir][sparse] add affine subscripts to sparse compilation pass
This enables the sparsification of more kernels, such as convolutions
where there is a x(i+j) subscript. It also enables more tensor invariants
such as x(1) or other affine subscripts such as x(i+1). Currently, we
reject sparsity altogether for such tensors. Despite this restriction,
however, we can already handle a lot more kernels with compound subscripts
for dense access (viz. convolution with dense input and sparse filter).
Some unit tests and an integration test demonstrate new capability.
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D109783
Added:
mlir/test/Dialect/SparseTensor/sparse_affine.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 6fb09cabda80b..f733856443a9a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -99,21 +99,53 @@ static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) {
return Dim::kDense;
}
+/// Helper method to inspect affine expressions. Rejects cases where the
+/// same index is used in more than one dimension of a tensor. Also rejects
+/// affine expressions that are not a direct index for annotated tensors.
+/// TODO: accept more affine cases for sparse tensors
+static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim,
+ bool isDense) {
+ switch (a.getKind()) {
+ case AffineExprKind::DimId: {
+ unsigned idx = a.cast<AffineDimExpr>().getPosition();
+ if (!merger.isDim(tensor, idx, Dim::kUndef))
+ return false; // used more than once
+ merger.setDim(tensor, idx, dim);
+ return true;
+ }
+ case AffineExprKind::Add:
+ case AffineExprKind::Mul: {
+ if (!isDense)
+ return false;
+ auto binOp = a.cast<AffineBinaryOpExpr>();
+ return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) &&
+ findAffine(merger, tensor, binOp.getRHS(), dim, isDense);
+ }
+ case AffineExprKind::Constant:
+ return isDense;
+ default:
+ return false;
+ }
+}
+
/// Helper method to inspect sparse encodings in the tensor types.
/// Fills the per-dimension sparsity information for all tensors.
+/// Returns true if the sparse annotations and affine subscript
+/// expressions of all tensors are admissable. Returns false if
+/// no annotations are found or inadmissable constructs occur.
static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
bool annotated = false;
for (OpOperand *t : op.getInputAndOutputOperands()) {
auto map = op.getTiedIndexingMap(t);
- if (!map.isProjectedPermutation())
- return false;
auto enc = getSparseTensorEncoding(t->get().getType());
if (enc)
annotated = true;
assert(map.getNumResults() == op.getRank(t));
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
- unsigned idx = map.getDimPosition(perm(enc, d));
- merger.setDim(t->getOperandNumber(), idx, toDim(enc, d));
+ unsigned tensor = t->getOperandNumber();
+ AffineExpr a = map.getResult(perm(enc, d));
+ if (!findAffine(merger, tensor, a, toDim(enc, d), !enc))
+ return false; // inadmissable affine expression
}
}
return annotated;
@@ -137,6 +169,32 @@ static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
return true;
}
+/// Helper method to add all constraints from the indices in one affine
+/// expression before all indices in the other affine expression. For
+/// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3.
+static void addAffineOrderings(std::vector<std::vector<bool>> &adjM,
+ AffineExpr a, AffineExpr b, unsigned fidx) {
+ switch (a.getKind()) {
+ case AffineExprKind::DimId: {
+ unsigned idx = a.cast<AffineDimExpr>().getPosition();
+ if (b)
+ addAffineOrderings(adjM, b, AffineExpr(), idx);
+ else
+ adjM[fidx][idx] = true;
+ break;
+ }
+ case AffineExprKind::Add:
+ case AffineExprKind::Mul: {
+ auto binOp = a.cast<AffineBinaryOpExpr>();
+ addAffineOrderings(adjM, binOp.getLHS(), b, fidx);
+ addAffineOrderings(adjM, binOp.getRHS(), b, fidx);
+ break;
+ }
+ default:
+ break;
+ }
+}
+
/// Computes a topologically sorted iteration graph for the linalg operation.
/// Ensures all tensors are visited in natural index order. This is essential
/// for sparse storage formats since these only support access along fixed
@@ -163,9 +221,9 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
// example, the tensor expresion A_ijk forces the ordering i < j < k
// on the loop indices if no explicit dimension ordering is given.
for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) {
- unsigned f = map.getDimPosition(perm(enc, d - 1));
- unsigned t = map.getDimPosition(perm(enc, d));
- adjM[f][t] = true;
+ AffineExpr f = map.getResult(perm(enc, d - 1));
+ AffineExpr t = map.getResult(perm(enc, d));
+ addAffineOrderings(adjM, f, t, 0);
}
// Push unrelated loops into sparse iteration space, so these
// will be skipped more often.
@@ -201,7 +259,7 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
unsigned tensor = lhs->getOperandNumber();
auto enc = getSparseTensorEncoding(lhs->get().getType());
// An non-annotated output tensor is assumed dense, and becomes a random
- // access n-dim memref. Admissable since inserstions cannot occur.
+ // access n-dim memref. Admissable since insertions cannot occur.
if (!enc)
return true;
// An all-dense annotated "sparse" output tensor becomes a linearized random
@@ -282,7 +340,10 @@ static bool genBuffers(Merger &merger, CodeGen &codegen,
// Scan all dimensions of current tensor.
args.clear();
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
- unsigned idx = map.getDimPosition(perm(enc, d));
+ AffineExpr a = map.getResult(perm(enc, d));
+ if (a.getKind() != AffineExprKind::DimId)
+ continue; // compound
+ unsigned idx = a.cast<AffineDimExpr>().getPosition();
// Handle sparse storage schemes.
if (merger.isDim(tensor, idx, Dim::kSparse)) {
auto dynShape = {ShapedType::kDynamicSize};
@@ -414,6 +475,61 @@ static Value genVectorInvariantValue(CodeGen &codegen,
return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
}
+/// Generates an affine expression.
+//
+// TODO: generalize for sparse tensor subscripts
+//
+static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter,
+ AffineExpr a, Location loc) {
+ switch (a.getKind()) {
+ case AffineExprKind::DimId: {
+ unsigned idx = a.cast<AffineDimExpr>().getPosition();
+ return codegen.loops[idx]; // universal dense index
+ }
+ case AffineExprKind::Add: {
+ auto binOp = a.cast<AffineBinaryOpExpr>();
+ return rewriter.create<AddIOp>(
+ loc, genAffine(codegen, rewriter, binOp.getLHS(), loc),
+ genAffine(codegen, rewriter, binOp.getRHS(), loc));
+ }
+ case AffineExprKind::Mul: {
+ auto binOp = a.cast<AffineBinaryOpExpr>();
+ return rewriter.create<MulIOp>(
+ loc, genAffine(codegen, rewriter, binOp.getLHS(), loc),
+ genAffine(codegen, rewriter, binOp.getRHS(), loc));
+ }
+ case AffineExprKind::Constant: {
+ int64_t c = a.cast<AffineConstantExpr>().getValue();
+ return rewriter.create<ConstantIndexOp>(loc, c);
+ }
+ default:
+ llvm_unreachable("unexpected affine subscript");
+ }
+}
+
+/// Generates subscript for load/store on a dense or sparse tensor.
+static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter,
+ linalg::GenericOp op, OpOperand *t,
+ SmallVector<Value, 4> &args) {
+ unsigned tensor = t->getOperandNumber();
+ auto map = op.getTiedIndexingMap(t);
+ auto enc = getSparseTensorEncoding(t->get().getType());
+ unsigned rank = map.getNumResults();
+ if (enc) {
+ // Note that currently, all sparse subscripts are simple.
+ // TODO: accept affine too?
+ unsigned idx = map.getDimPosition(perm(enc, rank - 1));
+ assert(codegen.pidxs[tensor][idx] != nullptr);
+ args.push_back(codegen.pidxs[tensor][idx]); // position index
+ } else {
+ for (unsigned d = 0; d < rank; d++) {
+ AffineExpr a = map.getResult(perm(enc, d));
+ args.push_back(genAffine(codegen, rewriter, a, op.getLoc()));
+ }
+ }
+ return codegen.buffers[tensor];
+}
+
/// Generates a load on a dense or sparse tensor.
static Value genTensorLoad(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
@@ -428,62 +544,32 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen,
// Actual load.
SmallVector<Value, 4> args;
OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
- unsigned tensor = t->getOperandNumber();
- auto map = op.getTiedIndexingMap(t);
- auto enc = getSparseTensorEncoding(t->get().getType());
- unsigned rank = map.getNumResults();
- if (enc) {
- unsigned idx = map.getDimPosition(perm(enc, rank - 1));
- assert(codegen.pidxs[tensor][idx] != nullptr);
- args.push_back(codegen.pidxs[tensor][idx]); // position index
- } else {
- for (unsigned d = 0; d < rank; d++) {
- unsigned idx = map.getDimPosition(d);
- args.push_back(codegen.loops[idx]); // universal dense index
- }
- }
- Location loc = op.getLoc();
- Value ptr = codegen.buffers[tensor];
+ Value ptr = genSubscript(codegen, rewriter, op, t, args);
if (codegen.curVecLength > 1)
return genVectorLoad(codegen, rewriter, ptr, args);
- return rewriter.create<memref::LoadOp>(loc, ptr, args);
+ return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
}
/// Generates a store on a dense or sparse tensor.
static void genTensorStore(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
- OpOperand *t, Value rhs) {
- Location loc = op.getLoc();
+ Value rhs) {
// Test if this is a scalarized reduction.
- OpOperand *lhs = op.getOutputOperand(0);
- if (lhs == t && codegen.redVal) {
+ if (codegen.redVal) {
if (codegen.curVecLength > 1)
- rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs,
+ rhs = rewriter.create<SelectOp>(op.getLoc(), codegen.curVecMask, rhs,
codegen.redVal);
codegen.redVal = rhs;
return;
}
// Actual store.
SmallVector<Value, 4> args;
- unsigned tensor = t->getOperandNumber();
- auto map = op.getTiedIndexingMap(t);
- auto enc = getSparseTensorEncoding(t->get().getType());
- unsigned rank = map.getNumResults();
- if (enc) {
- unsigned idx = map.getDimPosition(perm(enc, rank - 1));
- assert(codegen.pidxs[tensor][idx] != nullptr);
- args.push_back(codegen.pidxs[tensor][idx]); // position index
- } else {
- for (unsigned d = 0; d < rank; d++) {
- unsigned idx = map.getDimPosition(d);
- args.push_back(codegen.loops[idx]); // universal dense index
- }
- }
- Value ptr = codegen.buffers[tensor];
+ OpOperand *t = op.getOutputOperand(0);
+ Value ptr = genSubscript(codegen, rewriter, op, t, args);
if (codegen.curVecLength > 1)
genVectorStore(codegen, rewriter, rhs, ptr, args);
else
- rewriter.create<memref::StoreOp>(loc, rhs, ptr, args);
+ rewriter.create<memref::StoreOp>(op.getLoc(), rhs, ptr, args);
}
/// Generates a pointer/index load from the sparse storage scheme. Narrower
@@ -575,7 +661,6 @@ static void genReductionEnd(Merger &merger, CodeGen &codegen,
return;
assert(codegen.curVecLength == 1);
codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain
- OpOperand *lhs = op.getOutputOperand(0);
if (auto vtp = red.getType().dyn_cast<VectorType>()) {
// TODO: assumes + reductions for now
StringAttr kind = rewriter.getStringAttr("add");
@@ -590,7 +675,7 @@ static void genReductionEnd(Merger &merger, CodeGen &codegen,
kind, red, ld);
}
}
- genTensorStore(merger, codegen, rewriter, op, lhs, red);
+ genTensorStore(merger, codegen, rewriter, op, red);
}
/// Recursively generates tensor expression.
@@ -616,6 +701,27 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
return merger.buildExp(rewriter, loc, exp, v0, v1);
}
+/// Determines if affine expression is invariant.
+static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
+ unsigned ldx, bool &atLevel) {
+ switch (a.getKind()) {
+ case AffineExprKind::DimId: {
+ unsigned idx = a.cast<AffineDimExpr>().getPosition();
+ if (idx == ldx)
+ atLevel = true;
+ return codegen.loops[idx] != nullptr; // no longer in play?
+ }
+ case AffineExprKind::Add:
+ case AffineExprKind::Mul: {
+ auto binOp = a.cast<AffineBinaryOpExpr>();
+ return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) &&
+ isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel);
+ }
+ default:
+ return true;
+ }
+}
+
/// Hoists loop invariant tensor loads for which indices have been exhausted.
static void genInvariants(Merger &merger, CodeGen &codegen,
PatternRewriter &rewriter, linalg::GenericOp op,
@@ -629,11 +735,9 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
auto map = op.getTiedIndexingMap(t);
auto enc = getSparseTensorEncoding(t->get().getType());
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
- unsigned idx = map.getDimPosition(perm(enc, d));
- if (!codegen.loops[idx])
+ AffineExpr a = map.getResult(perm(enc, d));
+ if (!isInvariantAffine(codegen, a, ldx, atLevel))
return; // still in play
- else if (idx == ldx)
- atLevel = true;
}
// All exhausted at this level (atLevel denotes exactly at this level).
OpOperand *lhs = op.getOutputOperand(0);
@@ -736,12 +840,16 @@ static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
/// For now, we reject vectorization of such cases.
/// TODO: implement strided load/stores on dense arrays
static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
- unsigned idx) {
+ unsigned ldx) {
for (OpOperand *t : op.getInputAndOutputOperands()) {
if (!getSparseTensorEncoding(t->get().getType())) {
auto map = op.getTiedIndexingMap(t);
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
- if (map.getDimPosition(d) == idx && d != rank - 1)
+ AffineExpr a = map.getResult(d);
+ if (a.getKind() != AffineExprKind::DimId)
+ return false; // very conservative
+ unsigned idx = a.cast<AffineDimExpr>().getPosition();
+ if (idx == ldx && d != rank - 1)
return false;
}
}
@@ -1004,9 +1112,8 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
unsigned exp, unsigned at) {
// At each leaf, assign remaining tensor (sub)expression to output tensor.
if (at == topSort.size()) {
- OpOperand *lhs = op.getOutputOperand(0);
Value rhs = genExp(merger, codegen, rewriter, op, exp);
- genTensorStore(merger, codegen, rewriter, op, lhs, rhs);
+ genTensorStore(merger, codegen, rewriter, op, rhs);
return;
}
assert(codegen.curVecLength == 1);
@@ -1107,7 +1214,10 @@ static void genResult(Merger &merger, CodeGen &codegen,
// (even though lowering should never need this eventually).
SmallVector<Value, 4> args;
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
- unsigned idx = map.getDimPosition(perm(enc, d));
+ AffineExpr a = map.getResult(perm(enc, d));
+ if (a.getKind() != AffineExprKind::DimId)
+ continue; // compound
+ unsigned idx = a.cast<AffineDimExpr>().getPosition();
if (merger.isDim(tensor, idx, Dim::kSparse)) {
args.push_back(codegen.pointers[tensor][idx]);
args.push_back(codegen.indices[tensor][idx]);
diff --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
new file mode 100644
index 0000000000000..097738d581a9f
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
@@ -0,0 +1,166 @@
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+#SpVec = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
+#CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>
+
+#trait1 = {
+ indexing_maps = [
+ affine_map<(i) -> (i)>, // a
+ affine_map<(i) -> (3)>, // b
+ affine_map<(i) -> (i)> // x (out)
+ ],
+ iterator_types = ["parallel"],
+ doc = "x(i) += a(i) * b(3)"
+}
+
+// CHECK-LABEL: func @mul_inv_dense1d(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xf32>) -> tensor<32xf32> {
+// CHECK-DAG: %[[VAL_3:.*]] = constant 0 : index
+// CHECK-DAG: %[[VAL_4:.*]] = constant 3 : index
+// CHECK-DAG: %[[VAL_5:.*]] = constant 1 : index
+// CHECK: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<4xf32>
+// CHECK: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf32>
+// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<32xf32>
+// CHECK: memref.copy %[[VAL_10]], %[[VAL_11]] : memref<32xf32> to memref<32xf32>
+// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref<4xf32>
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_5]] {
+// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_16]]] : memref<32xf32>
+// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xf32>
+// CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_12]] : f32
+// CHECK: %[[VAL_20:.*]] = addf %[[VAL_17]], %[[VAL_19]] : f32
+// CHECK: memref.store %[[VAL_20]], %[[VAL_11]]{{\[}}%[[VAL_16]]] : memref<32xf32>
+// CHECK: }
+// CHECK: %[[VAL_21:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf32>
+// CHECK: return %[[VAL_21]] : tensor<32xf32>
+// CHECK: }
+func @mul_inv_dense1d(%arga: tensor<32xf32, #SpVec>,
+ %argb: tensor<4xf32>,
+ %argx: tensor<32xf32>) -> tensor<32xf32> {
+ %0 = linalg.generic #trait1
+ ins(%arga, %argb: tensor<32xf32, #SpVec>, tensor<4xf32>)
+ outs(%argx: tensor<32xf32>) {
+ ^bb(%a: f32, %b: f32, %x: f32):
+ %0 = mulf %a, %b : f32
+ %1 = addf %x, %0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<32xf32>
+ return %0 : tensor<32xf32>
+}
+
+#trait2 = {
+ indexing_maps = [
+ affine_map<(i) -> (i)>, // a
+ affine_map<(i) -> (i+2)>, // b
+ affine_map<(i) -> (i)> // x (out)
+ ],
+ iterator_types = ["parallel"],
+ doc = "x(i) = a(i) & b(i+2)"
+}
+
+// CHECK-LABEL: func @and_affine_dense1d(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<34xi32>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xi32>) -> tensor<32xi32> {
+// CHECK-DAG: %[[VAL_3:.*]] = constant 0 : index
+// CHECK-DAG: %[[VAL_4:.*]] = constant 1 : index
+// CHECK-DAG: %[[VAL_5:.*]] = constant 2 : index
+// CHECK: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<34xi32>
+// CHECK: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xi32>
+// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<32xi32>
+// CHECK: memref.copy %[[VAL_10]], %[[VAL_11]] : memref<32xi32> to memref<32xi32>
+// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_4]] {
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref<?xi32>
+// CHECK: %[[VAL_17:.*]] = addi %[[VAL_15]], %[[VAL_5]] : index
+// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_17]]] : memref<34xi32>
+// CHECK: %[[VAL_19:.*]] = and %[[VAL_16]], %[[VAL_18]] : i32
+// CHECK: memref.store %[[VAL_19]], %[[VAL_11]]{{\[}}%[[VAL_15]]] : memref<32xi32>
+// CHECK: }
+// CHECK: %[[VAL_20:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xi32>
+// CHECK: return %[[VAL_20]] : tensor<32xi32>
+// CHECK: }
+func @and_affine_dense1d(%arga: tensor<32xi32, #SpVec>,
+ %argb: tensor<34xi32>,
+ %argx: tensor<32xi32>) -> tensor<32xi32> {
+ %0 = linalg.generic #trait2
+ ins(%arga, %argb: tensor<32xi32, #SpVec>, tensor<34xi32>)
+ outs(%argx: tensor<32xi32>) {
+ ^bb(%a: i32, %b: i32, %x: i32):
+ %0 = and %a, %b : i32
+ linalg.yield %0 : i32
+ } -> tensor<32xi32>
+ return %0 : tensor<32xi32>
+}
+
+#trait3 = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // a
+ affine_map<(i,j) -> (i+2,j+3)>, // b
+ affine_map<(i,j) -> (i,j)> // x (out)
+ ],
+ iterator_types = ["parallel","parallel"],
+ doc = "x(i,j) += a(i,j) * b(i+2,j+3)"
+}
+
+// CHECK-LABEL: func @mul_affine_dense2d(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<34x19xf64>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> {
+// CHECK-DAG: %[[VAL_3:.*]] = constant 1 : index
+// CHECK-DAG: %[[VAL_4:.*]] = constant 32 : index
+// CHECK-DAG: %[[VAL_5:.*]] = constant 0 : index
+// CHECK-DAG: %[[VAL_6:.*]] = constant 2 : index
+// CHECK-DAG: %[[VAL_7:.*]] = constant 3 : index
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_1]] : memref<34x19xf64>
+// CHECK: %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32x16xf64>
+// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<32x16xf64>
+// CHECK: memref.copy %[[VAL_12]], %[[VAL_13]] : memref<32x16xf64> to memref<32x16xf64>
+// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_3]] {
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK: %[[VAL_16:.*]] = addi %[[VAL_14]], %[[VAL_3]] : index
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_18:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_3]] {
+// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_19]]] : memref<32x16xf64>
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xf64>
+// CHECK: %[[VAL_22:.*]] = addi %[[VAL_14]], %[[VAL_6]] : index
+// CHECK: %[[VAL_23:.*]] = addi %[[VAL_19]], %[[VAL_7]] : index
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]], %[[VAL_23]]] : memref<34x19xf64>
+// CHECK: %[[VAL_25:.*]] = mulf %[[VAL_21]], %[[VAL_24]] : f64
+// CHECK: %[[VAL_26:.*]] = addf %[[VAL_20]], %[[VAL_25]] : f64
+// CHECK: memref.store %[[VAL_26]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_19]]] : memref<32x16xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_27:.*]] = memref.tensor_load %[[VAL_13]] : memref<32x16xf64>
+// CHECK: return %[[VAL_27]] : tensor<32x16xf64>
+// CHECK: }
+func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>,
+ %argb: tensor<34x19xf64>,
+ %argx: tensor<32x16xf64>) -> tensor<32x16xf64> {
+ %0 = linalg.generic #trait3
+ ins(%arga, %argb: tensor<32x16xf64, #CSR>, tensor<34x19xf64>)
+ outs(%argx: tensor<32x16xf64>) {
+ ^bb(%a: f64, %b: f64, %x: f64):
+ %0 = mulf %a, %b : f64
+ %1 = addf %x, %0 : f64
+ linalg.yield %1 : f64
+ } -> tensor<32x16xf64>
+ return %0 : tensor<32x16xf64>
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir
new file mode 100644
index 0000000000000..42a6068644d9f
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_filter_conv2d.mlir
@@ -0,0 +1,89 @@
+// RUN: mlir-opt %s \
+// RUN: --linalg-generalize-named-ops \
+// RUN: --sparsification --sparse-tensor-conversion \
+// RUN: --convert-vector-to-scf --convert-scf-to-std \
+// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
+// RUN: --std-bufferize --finalizing-bufferize --lower-affine \
+// RUN: --convert-vector-to-llvm --convert-memref-to-llvm \
+// RUN: --convert-std-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+//
+// Do the same run, but now with SIMDization as well. This should not change the outcome.
+//
+// RUN: mlir-opt %s \
+// RUN: --linalg-generalize-named-ops \
+// RUN: --sparsification="vectorization-strategy=2 vl=2" --sparse-tensor-conversion \
+// RUN: --convert-vector-to-scf --convert-scf-to-std \
+// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
+// RUN: --std-bufferize --finalizing-bufferize --lower-affine \
+// RUN: --convert-vector-to-llvm --convert-memref-to-llvm \
+// RUN: --convert-std-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
+
+// An example of a 2D convolution with a sparse filter.
+module {
+
+ func @conv2d(%input: tensor<8x8xi32>,
+ %filter: tensor<3x3xi32, #DCSR>,
+ %output: tensor<6x6xi32>) -> tensor<6x6xi32> {
+ %0 = linalg.conv_2d
+ ins (%input, %filter: tensor<8x8xi32>, tensor<3x3xi32, #DCSR>)
+ outs (%output: tensor<6x6xi32>) -> tensor<6x6xi32>
+ return %0 : tensor<6x6xi32>
+ }
+
+ func @entry() {
+ %c0 = constant 0 : index
+ %i0 = constant 0 : i32
+
+ // A typical edge detection filter.
+ %filter = constant dense<[
+ [ 1, 0, -1 ],
+ [ 0, 0, 0 ],
+ [ -1, 0, 1 ]
+ ]> : tensor<3x3xi32>
+ %sparse_filter = sparse_tensor.convert %filter
+ : tensor<3x3xi32> to tensor<3x3xi32, #DCSR>
+
+ %input = constant dense<[
+ [ 1, 2, 3, 4, 0, 6, 7, 8 ],
+ [ 2, 2, 4, 4, 0, 0, 6, 8 ],
+ [ 2, 2, 4, 4, 0, 0, 6, 8 ],
+ [ 2, 2, 3, 4, 0, 0, 7, 8 ],
+ [ 1, 3, 3, 4, 0, 0, 6, 8 ],
+ [ 3, 2, 3, 4, 0, 0, 7, 8 ],
+ [ 1, 3, 3, 4, 3, 6, 6, 8 ],
+ [ 1, 3, 3, 4, 3, 0, 7, 8 ]
+ ]> : tensor<8x8xi32>
+
+ // Call the kernel.
+ %output = constant dense<0> : tensor<6x6xi32>
+ %0 = call @conv2d(%input, %sparse_filter, %output)
+ : (tensor<8x8xi32>,
+ tensor<3x3xi32, #DCSR>, tensor<6x6xi32>) -> tensor<6x6xi32>
+
+ // Verify the output.
+ //
+ // CHECK: ( ( 0, 0, -1, -6, -1, 6 ),
+ // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
+ // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
+ // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
+ // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
+ // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
+ //
+ %m = memref.buffer_cast %0 : memref<6x6xi32>
+ %v = vector.transfer_read %m[%c0, %c0], %i0
+ : memref<6x6xi32>, vector<6x6xi32>
+ vector.print %v : vector<6x6xi32>
+
+ return
+ }
+}
More information about the Mlir-commits
mailing list