[Mlir-commits] [mlir] c8bb235 - [mlir][sparse] Custom reduce with identity
Jim Kitchen
llvmlistbot at llvm.org
Wed Aug 17 09:22:31 PDT 2022
Author: Jim Kitchen
Date: 2022-08-17T11:21:46-05:00
New Revision: c8bb23547f2138beb5997caaaf1f4be46bfc30a3
URL: https://github.com/llvm/llvm-project/commit/c8bb23547f2138beb5997caaaf1f4be46bfc30a3
DIFF: https://github.com/llvm/llvm-project/commit/c8bb23547f2138beb5997caaaf1f4be46bfc30a3.diff
LOG: [mlir][sparse] Custom reduce with identity
Implement the new sparse_tensor.reduce operation which
accepts a starting identity value and a code block
describing how to perform the reduction.
Reviewed by: aartbik
Differential Revision: https://reviews.llvm.org/D130573
Added:
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 0e3f5d0c03080..ea00b958f1b95 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -84,6 +84,7 @@ enum Kind {
kShrU, // unsigned
kShlI,
kBinary, // semiring binary op
+ kReduce, // semiring reduction op
};
/// Children subexpressions of tensor operations.
@@ -115,8 +116,8 @@ struct TensorExp {
/// this field may be used to cache "hoisted" loop invariant tensor loads.
Value val;
- /// Code blocks used by semirings. For the case of kUnary and
- /// kBinary, this holds the original operation with all regions. For
+ /// Code blocks used by semirings. For the case of kUnary, kBinary, and
+ /// kReduce, this holds the original operation with all regions. For
/// kBinaryBranch, this holds the YieldOp for the left or right half
/// to be merged into a nested scf loop.
Operation *op;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index e5714c6a5cdf6..8c56e43433533 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -50,7 +50,7 @@ enum SortMask {
};
// Reduction kinds.
-enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor };
+enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor, kCustom };
// Code generation.
struct CodeGen {
@@ -87,6 +87,7 @@ struct CodeGen {
unsigned redExp = -1u;
Value redVal;
Reduction redKind = kNoReduc;
+ unsigned redCustom = -1u;
// Sparse tensor as output. Implemented either through direct injective
// insertion in lexicographic index order (where indices are updated
// in the temporary array `lexIdx`) or through access pattern expansion
@@ -373,6 +374,7 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
static vector::CombiningKind getCombiningKind(Reduction kind) {
switch (kind) {
case kNoReduc:
+ case kCustom:
break;
case kSum:
return vector::CombiningKind::ADD;
@@ -408,6 +410,8 @@ static Reduction getReduction(Kind kind) {
return kOr;
case Kind::kXorI:
return kXor;
+ case Kind::kReduce:
+ return kCustom;
default:
llvm_unreachable("unexpected reduction operator");
}
@@ -422,6 +426,7 @@ static Value genVectorReducInit(CodeGen &codegen, OpBuilder &builder,
Value r = codegen.redVal;
switch (codegen.redKind) {
case kNoReduc:
+ case kCustom:
break;
case kSum:
case kXor:
@@ -454,6 +459,11 @@ static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) {
codegen.redVal = merger.exp(codegen.redExp).val = reduc;
}
+/// Extracts identity from custom reduce.
+static Value getCustomRedId(Operation *op) {
+ return dyn_cast<sparse_tensor::ReduceOp>(op).getIdentity();
+}
+
//===----------------------------------------------------------------------===//
// Sparse compiler synthesis methods (statements and expressions).
//===----------------------------------------------------------------------===//
@@ -726,6 +736,25 @@ static Value genInsertionLoad(CodeGen &codegen, OpBuilder &builder,
return builder.create<memref::LoadOp>(loc, codegen.expValues, index);
}
+/// Generates insertion code to implement dynamic tensor load for reduction.
+static Value genInsertionLoadReduce(Merger &merger, CodeGen &codegen,
+ OpBuilder &builder, linalg::GenericOp op,
+ OpOperand *t) {
+ Location loc = op.getLoc();
+ Value identity = getCustomRedId(merger.exp(codegen.redCustom).op);
+ // Direct lexicographic index order, tensor loads as identity.
+ if (!codegen.expValues) {
+ return identity;
+ }
+ // Load from expanded access pattern if filled, identity otherwise.
+ Value index = genIndex(codegen, op, t);
+ Value isFilled =
+ builder.create<memref::LoadOp>(loc, codegen.expFilled, index);
+ Value valAtIndex =
+ builder.create<memref::LoadOp>(loc, codegen.expValues, index);
+ return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
+}
+
/// Generates insertion code to implement dynamic tensor store.
static void genInsertionStore(CodeGen &codegen, OpBuilder &builder,
linalg::GenericOp op, OpOperand *t, Value rhs) {
@@ -780,8 +809,11 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder,
}
// Load during insertion.
OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
- if (t == codegen.sparseOut)
+ if (t == codegen.sparseOut) {
+ if (codegen.redCustom != -1u)
+ return genInsertionLoadReduce(merger, codegen, builder, op, t);
return genInsertionLoad(codegen, builder, op, t);
+ }
// Actual load.
SmallVector<Value, 4> args;
Value ptr = genSubscript(codegen, builder, op, t, args);
@@ -953,6 +985,11 @@ static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
return genInvariantValue(merger, codegen, rewriter, exp);
if (merger.exp(exp).kind == Kind::kIndex)
return genIndexValue(codegen, rewriter, merger.exp(exp).index, ldx);
+ if (merger.exp(exp).kind == Kind::kReduce) {
+ // Make custom reduction identity accessible for expanded access pattern.
+ assert(codegen.redCustom == -1u);
+ codegen.redCustom = exp;
+ }
Value v0 =
genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx);
Value v1 =
@@ -960,8 +997,11 @@ static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
Value ee = merger.buildExp(rewriter, loc, exp, v0, v1);
if (ee && (merger.exp(exp).kind == Kind::kUnary ||
merger.exp(exp).kind == Kind::kBinary ||
- merger.exp(exp).kind == Kind::kBinaryBranch))
+ merger.exp(exp).kind == Kind::kBinaryBranch ||
+ merger.exp(exp).kind == Kind::kReduce))
ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx);
+ if (merger.exp(exp).kind == Kind::kReduce)
+ codegen.redCustom = -1u;
return ee;
}
@@ -989,7 +1029,7 @@ static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
/// Hoists loop invariant tensor loads for which indices have been exhausted.
static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
linalg::GenericOp op, unsigned exp, unsigned ldx,
- bool atStart, Kind last = Kind::kTensor) {
+ bool atStart, unsigned last = 0) {
if (exp == -1u)
return;
if (merger.exp(exp).kind == Kind::kTensor) {
@@ -1010,8 +1050,11 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
if (lhs == t) {
// Start or end a scalarized reduction
if (atStart) {
- Value load = genTensorLoad(merger, codegen, builder, op, exp);
- codegen.redKind = getReduction(last);
+ Kind kind = merger.exp(last).kind;
+ Value load = kind == Kind::kReduce
+ ? getCustomRedId(merger.exp(last).op)
+ : genTensorLoad(merger, codegen, builder, op, exp);
+ codegen.redKind = getReduction(kind);
codegen.redExp = exp;
updateReduc(merger, codegen, load);
} else {
@@ -1031,11 +1074,10 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
// Traverse into the binary operations. Note that we only hoist
// tensor loads, since subsequent MLIR/LLVM passes know how to
// deal with all other kinds of derived loop invariants.
- Kind last = merger.exp(exp).kind;
unsigned e0 = merger.exp(exp).children.e0;
unsigned e1 = merger.exp(exp).children.e1;
- genInvariants(merger, codegen, builder, op, e0, ldx, atStart, last);
- genInvariants(merger, codegen, builder, op, e1, ldx, atStart, last);
+ genInvariants(merger, codegen, builder, op, e0, ldx, atStart, exp);
+ genInvariants(merger, codegen, builder, op, e1, ldx, atStart, exp);
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index a5e7a37776c95..eeaaa2e8e2e9d 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -114,6 +114,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
children.e1 = y;
break;
case kBinary:
+ case kReduce:
assert(x != -1u && y != -1u && !v && o);
children.e0 = x;
children.e1 = y;
@@ -376,6 +377,7 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kOrI:
case kXorI:
case kBinary:
+ case kReduce:
return false;
}
llvm_unreachable("unexpected kind");
@@ -476,6 +478,8 @@ static const char *kindToOpSymbol(Kind kind) {
return "<<";
case kBinary:
return "binary";
+ case kReduce:
+ return "reduce";
}
llvm_unreachable("unexpected kind for symbol");
}
@@ -554,6 +558,7 @@ void Merger::dumpExp(unsigned e) const {
case kShrU:
case kShlI:
case kBinary:
+ case kReduce:
llvm::dbgs() << "(";
dumpExp(tensorExps[e].children.e0);
llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
@@ -794,6 +799,11 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
kBinaryBranch, leftYield, includeRight, kBinaryBranch,
rightYield);
}
+ case kReduce:
+ // A custom reduce operation.
+ return takeConj(kind, buildLattices(tensorExps[e].children.e0, i),
+ buildLattices(tensorExps[e].children.e1, i),
+ tensorExps[e].op);
}
llvm_unreachable("unexpected expression kind");
}
@@ -965,7 +975,7 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
}
// Construct binary operations if subexpressions can be built.
// See buildLattices() for an explanation of rejecting certain
- // division and shift operations
+ // division and shift operations.
if (def->getNumOperands() == 2) {
auto x = buildTensorExp(op, def->getOperand(0));
auto y = buildTensorExp(op, def->getOperand(1));
@@ -1020,6 +1030,21 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
}
}
}
+ // Construct ternary operations if subexpressions can be built.
+ if (def->getNumOperands() == 3) {
+ auto x = buildTensorExp(op, def->getOperand(0));
+ auto y = buildTensorExp(op, def->getOperand(1));
+ auto z = buildTensorExp(op, def->getOperand(2));
+ if (x.has_value() && y.has_value() && z.has_value()) {
+ unsigned e0 = x.value();
+ unsigned e1 = y.value();
+ // unsigned e2 = z.getValue();
+ if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
+ if (isAdmissableBranch(redop, redop.getRegion()))
+ return addExp(kReduce, e0, e1, Value(), def);
+ }
+ }
+ }
// Cannot build.
return None;
}
@@ -1199,6 +1224,10 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
case kBinary:
return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
+ case kReduce: {
+ ReduceOp redOp = cast<ReduceOp>(tensorExps[e].op);
+ return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
+ }
}
llvm_unreachable("unexpected expression kind in build");
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir
new file mode 100644
index 0000000000000..443e597292dd3
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir
@@ -0,0 +1,234 @@
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+#CSC = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "compressed" ],
+ dimOrdering = affine_map<(i,j) -> (j,i)>
+}>
+
+//
+// Traits for tensor operations.
+//
+#trait_matmul = {
+ indexing_maps = [
+ affine_map<(i,j,k) -> (i,k)>, // A
+ affine_map<(i,j,k) -> (k,j)>, // B
+ affine_map<(i,j,k) -> (i,j)> // C (out)
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ doc = "C(i,j) = SUM_k A(i,k) * B(k,j)"
+}
+
+#trait_mat_reduce_rowwise = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // A (in)
+ affine_map<(i,j) -> (i)> // X (out)
+ ],
+ iterator_types = ["parallel", "reduction"],
+ doc = "X(i) = PROD_j A(i,j)"
+}
+
+#trait_mat_reduce_colwise = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // A (in)
+ affine_map<(i,j) -> (j)> // X (out)
+ ],
+ iterator_types = ["reduction", "parallel"],
+ doc = "X(j) = PROD_i A(i,j)"
+}
+
+module {
+ func.func @redProdLex(%arga: tensor<?x?xf64, #CSR>) -> tensor<?xf64, #SparseVector> {
+ %c0 = arith.constant 0 : index
+ %cf1 = arith.constant 1.0 : f64
+ %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #CSR>
+ %xv = bufferization.alloc_tensor(%d0): tensor<?xf64, #SparseVector>
+ %0 = linalg.generic #trait_mat_reduce_rowwise
+ ins(%arga: tensor<?x?xf64, #CSR>)
+ outs(%xv: tensor<?xf64, #SparseVector>) {
+ ^bb(%a: f64, %b: f64):
+ %1 = sparse_tensor.reduce %a, %b, %cf1 : f64 {
+ ^bb0(%x: f64, %y: f64):
+ %2 = arith.mulf %x, %y : f64
+ sparse_tensor.yield %2 : f64
+ }
+ linalg.yield %1 : f64
+ } -> tensor<?xf64, #SparseVector>
+ return %0 : tensor<?xf64, #SparseVector>
+ }
+
+ func.func @redProdExpand(%arga: tensor<?x?xf64, #CSC>) -> tensor<?xf64, #SparseVector> {
+ %c0 = arith.constant 0 : index
+ %cf1 = arith.constant 1.0 : f64
+ %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #CSC>
+ %xv = bufferization.alloc_tensor(%d0): tensor<?xf64, #SparseVector>
+ %0 = linalg.generic #trait_mat_reduce_rowwise
+ ins(%arga: tensor<?x?xf64, #CSC>)
+ outs(%xv: tensor<?xf64, #SparseVector>) {
+ ^bb(%a: f64, %b: f64):
+ %1 = sparse_tensor.reduce %a, %b, %cf1 : f64 {
+ ^bb0(%x: f64, %y: f64):
+ %2 = arith.mulf %x, %y : f64
+ sparse_tensor.yield %2 : f64
+ }
+ linalg.yield %1 : f64
+ } -> tensor<?xf64, #SparseVector>
+ return %0 : tensor<?xf64, #SparseVector>
+ }
+
+ func.func @min_plus_csrcsr(%arga: tensor<?x?xf64, #CSR>,
+ %argb: tensor<?x?xf64, #CSR>) -> tensor<?x?xf64, #CSR> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %maxf = arith.constant 1.0e999 : f64
+ %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #CSR>
+ %d1 = tensor.dim %argb, %c1 : tensor<?x?xf64, #CSR>
+ %xm = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf64, #CSR>
+ %0 = linalg.generic #trait_matmul
+ ins(%arga, %argb: tensor<?x?xf64, #CSR>, tensor<?x?xf64, #CSR>)
+ outs(%xm: tensor<?x?xf64, #CSR>) {
+ ^bb(%a: f64, %b: f64, %output: f64):
+ %1 = sparse_tensor.binary %a, %b : f64, f64 to f64
+ overlap = {
+ ^bb0(%x: f64, %y: f64):
+ %3 = arith.addf %x, %y : f64
+ sparse_tensor.yield %3 : f64
+ }
+ left={}
+ right={}
+ %2 = sparse_tensor.reduce %1, %output, %maxf : f64 {
+ ^bb0(%x: f64, %y: f64):
+ %cmp = arith.cmpf "olt", %x, %y : f64
+ %3 = arith.select %cmp, %x, %y : f64
+ sparse_tensor.yield %3 : f64
+ }
+ linalg.yield %2 : f64
+ } -> tensor<?x?xf64, #CSR>
+ return %0 : tensor<?x?xf64, #CSR>
+ }
+
+ func.func @min_plus_csrcsc(%arga: tensor<?x?xf64, #CSR>,
+ %argb: tensor<?x?xf64, #CSC>) -> tensor<?x?xf64, #CSR> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %maxf = arith.constant 1.0e999 : f64
+ %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #CSR>
+ %d1 = tensor.dim %argb, %c1 : tensor<?x?xf64, #CSC>
+ %xm = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf64, #CSR>
+ %0 = linalg.generic #trait_matmul
+ ins(%arga, %argb: tensor<?x?xf64, #CSR>, tensor<?x?xf64, #CSC>)
+ outs(%xm: tensor<?x?xf64, #CSR>) {
+ ^bb(%a: f64, %b: f64, %output: f64):
+ %1 = sparse_tensor.binary %a, %b : f64, f64 to f64
+ overlap = {
+ ^bb0(%x: f64, %y: f64):
+ %3 = arith.addf %x, %y : f64
+ sparse_tensor.yield %3 : f64
+ }
+ left={}
+ right={}
+ %2 = sparse_tensor.reduce %1, %output, %maxf : f64 {
+ ^bb0(%x: f64, %y: f64):
+ %cmp = arith.cmpf "olt", %x, %y : f64
+ %3 = arith.select %cmp, %x, %y : f64
+ sparse_tensor.yield %3 : f64
+ }
+ linalg.yield %2 : f64
+ } -> tensor<?x?xf64, #CSR>
+ return %0 : tensor<?x?xf64, #CSR>
+ }
+
+ // Dumps a sparse vector of type f64.
+ func.func @dump_vec(%arg0: tensor<?xf64, #SparseVector>) {
+ // Dump the values array to verify only sparse contents are stored.
+ %c0 = arith.constant 0 : index
+ %d0 = arith.constant -1.0 : f64
+ %0 = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf64>
+ %1 = vector.transfer_read %0[%c0], %d0: memref<?xf64>, vector<8xf64>
+ vector.print %1 : vector<8xf64>
+ // Dump the dense vector to verify structure is correct.
+ %dv = sparse_tensor.convert %arg0 : tensor<?xf64, #SparseVector> to tensor<?xf64>
+ %2 = vector.transfer_read %dv[%c0], %d0: tensor<?xf64>, vector<16xf64>
+ vector.print %2 : vector<16xf64>
+ return
+ }
+
+ // Dump a sparse matrix.
+ func.func @dump_mat(%arg0: tensor<?x?xf64, #CSR>) {
+ // Dump the values array to verify only sparse contents are stored.
+ %c0 = arith.constant 0 : index
+ %d0 = arith.constant -1.0 : f64
+ %0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #CSR> to memref<?xf64>
+ %1 = vector.transfer_read %0[%c0], %d0: memref<?xf64>, vector<16xf64>
+ vector.print %1 : vector<16xf64>
+ %dm = sparse_tensor.convert %arg0 : tensor<?x?xf64, #CSR> to tensor<?x?xf64>
+ %2 = vector.transfer_read %dm[%c0, %c0], %d0: tensor<?x?xf64>, vector<5x5xf64>
+ vector.print %2 : vector<5x5xf64>
+ return
+ }
+
+ // Driver method to call and verify vector kernels.
+ func.func @entry() {
+ %c0 = arith.constant 0 : index
+
+ // Setup sparse matrices.
+ %m1 = arith.constant sparse<
+ [ [0,0], [0,1], [1,0], [2,2], [2,3], [2,4], [3,0], [3,2], [3,3] ],
+ [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ]
+ > : tensor<4x5xf64>
+ %m2 = arith.constant sparse<
+ [ [0,0], [1,3], [2,0], [2,3], [3,1], [4,1] ],
+ [6.0, 5.0, 4.0, 3.0, 2.0, 11.0 ]
+ > : tensor<5x4xf64>
+ %sm1 = sparse_tensor.convert %m1 : tensor<4x5xf64> to tensor<?x?xf64, #CSR>
+ %sm2r = sparse_tensor.convert %m2 : tensor<5x4xf64> to tensor<?x?xf64, #CSR>
+ %sm2c = sparse_tensor.convert %m2 : tensor<5x4xf64> to tensor<?x?xf64, #CSC>
+
+ // Call sparse matrix kernels.
+ %1 = call @redProdLex(%sm1) : (tensor<?x?xf64, #CSR>) -> tensor<?xf64, #SparseVector>
+ %2 = call @redProdExpand(%sm2c) : (tensor<?x?xf64, #CSC>) -> tensor<?xf64, #SparseVector>
+ %5 = call @min_plus_csrcsr(%sm1, %sm2r)
+ : (tensor<?x?xf64, #CSR>, tensor<?x?xf64, #CSR>) -> tensor<?x?xf64, #CSR>
+ %6 = call @min_plus_csrcsc(%sm1, %sm2c)
+ : (tensor<?x?xf64, #CSR>, tensor<?x?xf64, #CSC>) -> tensor<?x?xf64, #CSR>
+
+ //
+ // Verify the results.
+ //
+ // CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1 )
+ // CHECK-NEXT: ( ( 1, 2, 0, 0, 0 ), ( 3, 0, 0, 0, 0 ), ( 0, 0, 4, 5, 6 ), ( 7, 0, 8, 9, 0 ), ( -1, -1, -1, -1, -1 ) )
+ // CHECK-NEXT: ( 6, 5, 4, 3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+ // CHECK-NEXT: ( ( 6, 0, 0, 0, -1 ), ( 0, 0, 0, 5, -1 ), ( 4, 0, 0, 3, -1 ), ( 0, 2, 0, 0, -1 ), ( 0, 11, 0, 0, -1 ) )
+ // CHECK-NEXT: ( 2, 3, 120, 504, -1, -1, -1, -1 )
+ // CHECK-NEXT: ( 2, 3, 120, 504, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+ // CHECK-NEXT: ( 6, 5, 12, 2, 11, -1, -1, -1 )
+ // CHECK-NEXT: ( 6, 5, 12, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+ // CHECK-NEXT: ( 7, 7, 9, 8, 7, 7, 12, 11, 11, -1, -1, -1, -1, -1, -1, -1 )
+ // CHECK-NEXT: ( ( 7, 0, 0, 7, -1 ), ( 9, 0, 0, 0, -1 ), ( 8, 7, 0, 7, -1 ), ( 12, 11, 0, 11, -1 ), ( -1, -1, -1, -1, -1 ) )
+ // TODO: Update once identity values are no longer inserted for non-overlapping dot product
+ // CHECK-NEXT: ( 7, inf, inf, 7, 9, inf, inf, inf, 8, 7, inf, 7, 12, 11, inf, 11 )
+ // CHECK-NEXT: ( ( 7, inf, inf, 7, -1 ), ( 9, inf, inf, inf, -1 ), ( 8, 7, inf, 7, -1 ), ( 12, 11, inf, 11, -1 ), ( -1, -1, -1, -1, -1 ) )
+ //
+ call @dump_mat(%sm1) : (tensor<?x?xf64, #CSR>) -> ()
+ call @dump_mat(%sm2r) : (tensor<?x?xf64, #CSR>) -> ()
+ call @dump_vec(%1) : (tensor<?xf64, #SparseVector>) -> ()
+ call @dump_vec(%2) : (tensor<?xf64, #SparseVector>) -> ()
+ call @dump_mat(%5) : (tensor<?x?xf64, #CSR>) -> ()
+ call @dump_mat(%6) : (tensor<?x?xf64, #CSR>) -> ()
+
+ // Release the resources.
+ bufferization.dealloc_tensor %sm1 : tensor<?x?xf64, #CSR>
+ bufferization.dealloc_tensor %sm2r : tensor<?x?xf64, #CSR>
+ bufferization.dealloc_tensor %sm2c : tensor<?x?xf64, #CSC>
+ bufferization.dealloc_tensor %1 : tensor<?xf64, #SparseVector>
+ bufferization.dealloc_tensor %2 : tensor<?xf64, #SparseVector>
+ bufferization.dealloc_tensor %5 : tensor<?x?xf64, #CSR>
+ bufferization.dealloc_tensor %6 : tensor<?x?xf64, #CSR>
+ return
+ }
+}
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 91101c4d77db5..3bf424c40a30d 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -283,6 +283,7 @@ class MergerTestBase : public ::testing::Test {
case kShrU:
case kShlI:
case kBinary:
+ case kReduce:
return compareExpression(tensorExp.children.e0, pattern->e0) &&
compareExpression(tensorExp.children.e1, pattern->e1);
}
More information about the Mlir-commits
mailing list