[Mlir-commits] [mlir] c8bb235 - [mlir][sparse] Custom reduce with identity

Jim Kitchen llvmlistbot at llvm.org
Wed Aug 17 09:22:31 PDT 2022


Author: Jim Kitchen
Date: 2022-08-17T11:21:46-05:00
New Revision: c8bb23547f2138beb5997caaaf1f4be46bfc30a3

URL: https://github.com/llvm/llvm-project/commit/c8bb23547f2138beb5997caaaf1f4be46bfc30a3
DIFF: https://github.com/llvm/llvm-project/commit/c8bb23547f2138beb5997caaaf1f4be46bfc30a3.diff

LOG: [mlir][sparse] Custom reduce with identity

Implement the new sparse_tensor.reduce operation which
accepts a starting identity value and a code block
describing how to perform the reduction.

Reviewed by: aartbik

Differential Revision: https://reviews.llvm.org/D130573

Added: 
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 0e3f5d0c03080..ea00b958f1b95 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -84,6 +84,7 @@ enum Kind {
   kShrU, // unsigned
   kShlI,
   kBinary, // semiring binary op
+  kReduce, // semiring reduction op
 };
 
 /// Children subexpressions of tensor operations.
@@ -115,8 +116,8 @@ struct TensorExp {
   /// this field may be used to cache "hoisted" loop invariant tensor loads.
   Value val;
 
-  /// Code blocks used by semirings. For the case of kUnary and
-  /// kBinary, this holds the original operation with all regions. For
+  /// Code blocks used by semirings. For the case of kUnary, kBinary, and
+  /// kReduce, this holds the original operation with all regions. For
   /// kBinaryBranch, this holds the YieldOp for the left or right half
   /// to be merged into a nested scf loop.
   Operation *op;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index e5714c6a5cdf6..8c56e43433533 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -50,7 +50,7 @@ enum SortMask {
 };
 
 // Reduction kinds.
-enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor };
+enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor, kCustom };
 
 // Code generation.
 struct CodeGen {
@@ -87,6 +87,7 @@ struct CodeGen {
   unsigned redExp = -1u;
   Value redVal;
   Reduction redKind = kNoReduc;
+  unsigned redCustom = -1u;
   // Sparse tensor as output. Implemented either through direct injective
   // insertion in lexicographic index order (where indices are updated
   // in the temporary array `lexIdx`) or through access pattern expansion
@@ -373,6 +374,7 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
 static vector::CombiningKind getCombiningKind(Reduction kind) {
   switch (kind) {
   case kNoReduc:
+  case kCustom:
     break;
   case kSum:
     return vector::CombiningKind::ADD;
@@ -408,6 +410,8 @@ static Reduction getReduction(Kind kind) {
     return kOr;
   case Kind::kXorI:
     return kXor;
+  case Kind::kReduce:
+    return kCustom;
   default:
     llvm_unreachable("unexpected reduction operator");
   }
@@ -422,6 +426,7 @@ static Value genVectorReducInit(CodeGen &codegen, OpBuilder &builder,
   Value r = codegen.redVal;
   switch (codegen.redKind) {
   case kNoReduc:
+  case kCustom:
     break;
   case kSum:
   case kXor:
@@ -454,6 +459,11 @@ static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) {
   codegen.redVal = merger.exp(codegen.redExp).val = reduc;
 }
 
+/// Extracts identity from custom reduce.
+static Value getCustomRedId(Operation *op) {
+  return dyn_cast<sparse_tensor::ReduceOp>(op).getIdentity();
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse compiler synthesis methods (statements and expressions).
 //===----------------------------------------------------------------------===//
@@ -726,6 +736,25 @@ static Value genInsertionLoad(CodeGen &codegen, OpBuilder &builder,
   return builder.create<memref::LoadOp>(loc, codegen.expValues, index);
 }
 
+/// Generates insertion code to implement dynamic tensor load for reduction.
+static Value genInsertionLoadReduce(Merger &merger, CodeGen &codegen,
+                                    OpBuilder &builder, linalg::GenericOp op,
+                                    OpOperand *t) {
+  Location loc = op.getLoc();
+  Value identity = getCustomRedId(merger.exp(codegen.redCustom).op);
+  // Direct lexicographic index order, tensor loads as identity.
+  if (!codegen.expValues) {
+    return identity;
+  }
+  // Load from expanded access pattern if filled, identity otherwise.
+  Value index = genIndex(codegen, op, t);
+  Value isFilled =
+      builder.create<memref::LoadOp>(loc, codegen.expFilled, index);
+  Value valAtIndex =
+      builder.create<memref::LoadOp>(loc, codegen.expValues, index);
+  return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
+}
+
 /// Generates insertion code to implement dynamic tensor store.
 static void genInsertionStore(CodeGen &codegen, OpBuilder &builder,
                               linalg::GenericOp op, OpOperand *t, Value rhs) {
@@ -780,8 +809,11 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   }
   // Load during insertion.
   OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
-  if (t == codegen.sparseOut)
+  if (t == codegen.sparseOut) {
+    if (codegen.redCustom != -1u)
+      return genInsertionLoadReduce(merger, codegen, builder, op, t);
     return genInsertionLoad(codegen, builder, op, t);
+  }
   // Actual load.
   SmallVector<Value, 4> args;
   Value ptr = genSubscript(codegen, builder, op, t, args);
@@ -953,6 +985,11 @@ static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
     return genInvariantValue(merger, codegen, rewriter, exp);
   if (merger.exp(exp).kind == Kind::kIndex)
     return genIndexValue(codegen, rewriter, merger.exp(exp).index, ldx);
+  if (merger.exp(exp).kind == Kind::kReduce) {
+    // Make custom reduction identity accessible for expanded access pattern.
+    assert(codegen.redCustom == -1u);
+    codegen.redCustom = exp;
+  }
   Value v0 =
       genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx);
   Value v1 =
@@ -960,8 +997,11 @@ static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
   Value ee = merger.buildExp(rewriter, loc, exp, v0, v1);
   if (ee && (merger.exp(exp).kind == Kind::kUnary ||
              merger.exp(exp).kind == Kind::kBinary ||
-             merger.exp(exp).kind == Kind::kBinaryBranch))
+             merger.exp(exp).kind == Kind::kBinaryBranch ||
+             merger.exp(exp).kind == Kind::kReduce))
     ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx);
+  if (merger.exp(exp).kind == Kind::kReduce)
+    codegen.redCustom = -1u;
   return ee;
 }
 
@@ -989,7 +1029,7 @@ static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
 /// Hoists loop invariant tensor loads for which indices have been exhausted.
 static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
                           linalg::GenericOp op, unsigned exp, unsigned ldx,
-                          bool atStart, Kind last = Kind::kTensor) {
+                          bool atStart, unsigned last = 0) {
   if (exp == -1u)
     return;
   if (merger.exp(exp).kind == Kind::kTensor) {
@@ -1010,8 +1050,11 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     if (lhs == t) {
       // Start or end a scalarized reduction
       if (atStart) {
-        Value load = genTensorLoad(merger, codegen, builder, op, exp);
-        codegen.redKind = getReduction(last);
+        Kind kind = merger.exp(last).kind;
+        Value load = kind == Kind::kReduce
+                         ? getCustomRedId(merger.exp(last).op)
+                         : genTensorLoad(merger, codegen, builder, op, exp);
+        codegen.redKind = getReduction(kind);
         codegen.redExp = exp;
         updateReduc(merger, codegen, load);
       } else {
@@ -1031,11 +1074,10 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     // Traverse into the binary operations. Note that we only hoist
     // tensor loads, since subsequent MLIR/LLVM passes know how to
     // deal with all other kinds of derived loop invariants.
-    Kind last = merger.exp(exp).kind;
     unsigned e0 = merger.exp(exp).children.e0;
     unsigned e1 = merger.exp(exp).children.e1;
-    genInvariants(merger, codegen, builder, op, e0, ldx, atStart, last);
-    genInvariants(merger, codegen, builder, op, e1, ldx, atStart, last);
+    genInvariants(merger, codegen, builder, op, e0, ldx, atStart, exp);
+    genInvariants(merger, codegen, builder, op, e1, ldx, atStart, exp);
   }
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index a5e7a37776c95..eeaaa2e8e2e9d 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -114,6 +114,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
     children.e1 = y;
     break;
   case kBinary:
+  case kReduce:
     assert(x != -1u && y != -1u && !v && o);
     children.e0 = x;
     children.e1 = y;
@@ -376,6 +377,7 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
   case kOrI:
   case kXorI:
   case kBinary:
+  case kReduce:
     return false;
   }
   llvm_unreachable("unexpected kind");
@@ -476,6 +478,8 @@ static const char *kindToOpSymbol(Kind kind) {
     return "<<";
   case kBinary:
     return "binary";
+  case kReduce:
+    return "reduce";
   }
   llvm_unreachable("unexpected kind for symbol");
 }
@@ -554,6 +558,7 @@ void Merger::dumpExp(unsigned e) const {
   case kShrU:
   case kShlI:
   case kBinary:
+  case kReduce:
     llvm::dbgs() << "(";
     dumpExp(tensorExps[e].children.e0);
     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
@@ -794,6 +799,11 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
                        kBinaryBranch, leftYield, includeRight, kBinaryBranch,
                        rightYield);
     }
+  case kReduce:
+    // A custom reduce operation.
+    return takeConj(kind, buildLattices(tensorExps[e].children.e0, i),
+                    buildLattices(tensorExps[e].children.e1, i),
+                    tensorExps[e].op);
   }
   llvm_unreachable("unexpected expression kind");
 }
@@ -965,7 +975,7 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
   }
   // Construct binary operations if subexpressions can be built.
   // See buildLattices() for an explanation of rejecting certain
-  // division and shift operations
+  // division and shift operations.
   if (def->getNumOperands() == 2) {
     auto x = buildTensorExp(op, def->getOperand(0));
     auto y = buildTensorExp(op, def->getOperand(1));
@@ -1020,6 +1030,21 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
       }
     }
   }
+  // Construct ternary operations if subexpressions can be built.
+  if (def->getNumOperands() == 3) {
+    auto x = buildTensorExp(op, def->getOperand(0));
+    auto y = buildTensorExp(op, def->getOperand(1));
+    auto z = buildTensorExp(op, def->getOperand(2));
+    if (x.has_value() && y.has_value() && z.has_value()) {
+      unsigned e0 = x.value();
+      unsigned e1 = y.value();
+      // unsigned e2 = z.getValue();
+      if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
+        if (isAdmissableBranch(redop, redop.getRegion()))
+          return addExp(kReduce, e0, e1, Value(), def);
+      }
+    }
+  }
   // Cannot build.
   return None;
 }
@@ -1199,6 +1224,10 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
     return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
   case kBinary:
     return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
+  case kReduce: {
+    ReduceOp redOp = cast<ReduceOp>(tensorExps[e].op);
+    return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
+  }
   }
   llvm_unreachable("unexpected expression kind in build");
 }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir
new file mode 100644
index 0000000000000..443e597292dd3
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir
@@ -0,0 +1,234 @@
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+#CSC = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  dimOrdering = affine_map<(i,j) -> (j,i)>
+}>
+
+//
+// Traits for tensor operations.
+//
+#trait_matmul = {
+  indexing_maps = [
+    affine_map<(i,j,k) -> (i,k)>, // A
+    affine_map<(i,j,k) -> (k,j)>, // B
+    affine_map<(i,j,k) -> (i,j)>  // C (out)
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"],
+  doc = "C(i,j) = SUM_k A(i,k) * B(k,j)"
+}
+
+#trait_mat_reduce_rowwise = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A (in)
+    affine_map<(i,j) -> (i)>   // X (out)
+  ],
+  iterator_types = ["parallel", "reduction"],
+  doc = "X(i) = PROD_j A(i,j)"
+}
+
+#trait_mat_reduce_colwise = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A (in)
+    affine_map<(i,j) -> (j)>   // X (out)
+  ],
+  iterator_types = ["reduction", "parallel"],
+  doc = "X(j) = PROD_i A(i,j)"
+}
+
+module {
+  func.func @redProdLex(%arga: tensor<?x?xf64, #CSR>) -> tensor<?xf64, #SparseVector> {
+    %c0 = arith.constant 0 : index
+    %cf1 = arith.constant 1.0 : f64
+    %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #CSR>
+    %xv = bufferization.alloc_tensor(%d0): tensor<?xf64, #SparseVector>
+    %0 = linalg.generic #trait_mat_reduce_rowwise
+      ins(%arga: tensor<?x?xf64, #CSR>)
+      outs(%xv: tensor<?xf64, #SparseVector>) {
+        ^bb(%a: f64, %b: f64):
+          %1 = sparse_tensor.reduce %a, %b, %cf1 : f64 {
+              ^bb0(%x: f64, %y: f64):
+                %2 = arith.mulf %x, %y : f64
+                sparse_tensor.yield %2 : f64
+            }
+          linalg.yield %1 : f64
+    } -> tensor<?xf64, #SparseVector>
+    return %0 : tensor<?xf64, #SparseVector>
+  }
+
+  func.func @redProdExpand(%arga: tensor<?x?xf64, #CSC>) -> tensor<?xf64, #SparseVector> {
+    %c0 = arith.constant 0 : index
+    %cf1 = arith.constant 1.0 : f64
+    %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #CSC>
+    %xv = bufferization.alloc_tensor(%d0): tensor<?xf64, #SparseVector>
+    %0 = linalg.generic #trait_mat_reduce_rowwise
+      ins(%arga: tensor<?x?xf64, #CSC>)
+      outs(%xv: tensor<?xf64, #SparseVector>) {
+        ^bb(%a: f64, %b: f64):
+          %1 = sparse_tensor.reduce %a, %b, %cf1 : f64 {
+              ^bb0(%x: f64, %y: f64):
+                %2 = arith.mulf %x, %y : f64
+                sparse_tensor.yield %2 : f64
+            }
+          linalg.yield %1 : f64
+    } -> tensor<?xf64, #SparseVector>
+    return %0 : tensor<?xf64, #SparseVector>
+  }
+
+  func.func @min_plus_csrcsr(%arga: tensor<?x?xf64, #CSR>,
+                             %argb: tensor<?x?xf64, #CSR>) -> tensor<?x?xf64, #CSR> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %maxf = arith.constant 1.0e999 : f64
+    %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #CSR>
+    %d1 = tensor.dim %argb, %c1 : tensor<?x?xf64, #CSR>
+    %xm = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf64, #CSR>
+    %0 = linalg.generic #trait_matmul
+       ins(%arga, %argb: tensor<?x?xf64, #CSR>, tensor<?x?xf64, #CSR>)
+        outs(%xm: tensor<?x?xf64, #CSR>) {
+        ^bb(%a: f64, %b: f64, %output: f64):
+          %1 = sparse_tensor.binary %a, %b : f64, f64 to f64
+            overlap = {
+              ^bb0(%x: f64, %y: f64):
+                %3 = arith.addf %x, %y : f64
+                sparse_tensor.yield %3 : f64
+            }
+            left={}
+            right={}
+          %2 = sparse_tensor.reduce %1, %output, %maxf : f64 {
+              ^bb0(%x: f64, %y: f64):
+                %cmp = arith.cmpf "olt", %x, %y : f64
+                %3 = arith.select %cmp, %x, %y : f64
+                sparse_tensor.yield %3 : f64
+            }
+          linalg.yield %2 : f64
+    } -> tensor<?x?xf64, #CSR>
+    return %0 : tensor<?x?xf64, #CSR>
+  }
+
+  func.func @min_plus_csrcsc(%arga: tensor<?x?xf64, #CSR>,
+                             %argb: tensor<?x?xf64, #CSC>) -> tensor<?x?xf64, #CSR> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %maxf = arith.constant 1.0e999 : f64
+    %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #CSR>
+    %d1 = tensor.dim %argb, %c1 : tensor<?x?xf64, #CSC>
+    %xm = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf64, #CSR>
+    %0 = linalg.generic #trait_matmul
+       ins(%arga, %argb: tensor<?x?xf64, #CSR>, tensor<?x?xf64, #CSC>)
+        outs(%xm: tensor<?x?xf64, #CSR>) {
+        ^bb(%a: f64, %b: f64, %output: f64):
+          %1 = sparse_tensor.binary %a, %b : f64, f64 to f64
+            overlap = {
+              ^bb0(%x: f64, %y: f64):
+                %3 = arith.addf %x, %y : f64
+                sparse_tensor.yield %3 : f64
+            }
+            left={}
+            right={}
+          %2 = sparse_tensor.reduce %1, %output, %maxf : f64 {
+              ^bb0(%x: f64, %y: f64):
+                %cmp = arith.cmpf "olt", %x, %y : f64
+                %3 = arith.select %cmp, %x, %y : f64
+                sparse_tensor.yield %3 : f64
+            }
+          linalg.yield %2 : f64
+    } -> tensor<?x?xf64, #CSR>
+    return %0 : tensor<?x?xf64, #CSR>
+  }
+
+  // Dumps a sparse vector of type f64.
+  func.func @dump_vec(%arg0: tensor<?xf64, #SparseVector>) {
+    // Dump the values array to verify only sparse contents are stored.
+    %c0 = arith.constant 0 : index
+    %d0 = arith.constant -1.0 : f64
+    %0 = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf64>
+    %1 = vector.transfer_read %0[%c0], %d0: memref<?xf64>, vector<8xf64>
+    vector.print %1 : vector<8xf64>
+    // Dump the dense vector to verify structure is correct.
+    %dv = sparse_tensor.convert %arg0 : tensor<?xf64, #SparseVector> to tensor<?xf64>
+    %2 = vector.transfer_read %dv[%c0], %d0: tensor<?xf64>, vector<16xf64>
+    vector.print %2 : vector<16xf64>
+    return
+  }
+
+  // Dump a sparse matrix.
+  func.func @dump_mat(%arg0: tensor<?x?xf64, #CSR>) {
+    // Dump the values array to verify only sparse contents are stored.
+    %c0 = arith.constant 0 : index
+    %d0 = arith.constant -1.0 : f64
+    %0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #CSR> to memref<?xf64>
+    %1 = vector.transfer_read %0[%c0], %d0: memref<?xf64>, vector<16xf64>
+    vector.print %1 : vector<16xf64>
+    %dm = sparse_tensor.convert %arg0 : tensor<?x?xf64, #CSR> to tensor<?x?xf64>
+    %2 = vector.transfer_read %dm[%c0, %c0], %d0: tensor<?x?xf64>, vector<5x5xf64>
+    vector.print %2 : vector<5x5xf64>
+    return
+  }
+
+  // Driver method to call and verify vector kernels.
+  func.func @entry() {
+    %c0 = arith.constant 0 : index
+
+    // Setup sparse matrices.
+    %m1 = arith.constant sparse<
+       [ [0,0], [0,1], [1,0], [2,2], [2,3], [2,4], [3,0], [3,2], [3,3] ],
+         [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ]
+    > : tensor<4x5xf64>
+    %m2 = arith.constant sparse<
+       [ [0,0], [1,3], [2,0], [2,3], [3,1], [4,1] ],
+         [6.0, 5.0, 4.0, 3.0, 2.0, 11.0 ]
+    > : tensor<5x4xf64>
+    %sm1 = sparse_tensor.convert %m1 : tensor<4x5xf64> to tensor<?x?xf64, #CSR>
+    %sm2r = sparse_tensor.convert %m2 : tensor<5x4xf64> to tensor<?x?xf64, #CSR>
+    %sm2c = sparse_tensor.convert %m2 : tensor<5x4xf64> to tensor<?x?xf64, #CSC>
+
+    // Call sparse matrix kernels.
+    %1 = call @redProdLex(%sm1) : (tensor<?x?xf64, #CSR>) -> tensor<?xf64, #SparseVector>
+    %2 = call @redProdExpand(%sm2c) : (tensor<?x?xf64, #CSC>) -> tensor<?xf64, #SparseVector>
+    %5 = call @min_plus_csrcsr(%sm1, %sm2r)
+      : (tensor<?x?xf64, #CSR>, tensor<?x?xf64, #CSR>) -> tensor<?x?xf64, #CSR>
+    %6 = call @min_plus_csrcsc(%sm1, %sm2c)
+      : (tensor<?x?xf64, #CSR>, tensor<?x?xf64, #CSC>) -> tensor<?x?xf64, #CSR>
+
+    //
+    // Verify the results.
+    //
+    // CHECK:      ( 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( ( 1, 2, 0, 0, 0 ), ( 3, 0, 0, 0, 0 ), ( 0, 0, 4, 5, 6 ), ( 7, 0, 8, 9, 0 ), ( -1, -1, -1, -1, -1 ) )
+    // CHECK-NEXT: ( 6, 5, 4, 3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( ( 6, 0, 0, 0, -1 ), ( 0, 0, 0, 5, -1 ), ( 4, 0, 0, 3, -1 ), ( 0, 2, 0, 0, -1 ), ( 0, 11, 0, 0, -1 ) )
+    // CHECK-NEXT: ( 2, 3, 120, 504, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 2, 3, 120, 504, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 6, 5, 12, 2, 11, -1, -1, -1 )
+    // CHECK-NEXT: ( 6, 5, 12, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 7, 7, 9, 8, 7, 7, 12, 11, 11, -1, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( ( 7, 0, 0, 7, -1 ), ( 9, 0, 0, 0, -1 ), ( 8, 7, 0, 7, -1 ), ( 12, 11, 0, 11, -1 ), ( -1, -1, -1, -1, -1 ) )
+    // TODO: Update once identity values are no longer inserted for non-overlapping dot product
+    // CHECK-NEXT: ( 7, inf, inf, 7, 9, inf, inf, inf, 8, 7, inf, 7, 12, 11, inf, 11 )
+    // CHECK-NEXT: ( ( 7, inf, inf, 7, -1 ), ( 9, inf, inf, inf, -1 ), ( 8, 7, inf, 7, -1 ), ( 12, 11, inf, 11, -1 ), ( -1, -1, -1, -1, -1 ) )
+    //
+    call @dump_mat(%sm1) : (tensor<?x?xf64, #CSR>) -> ()
+    call @dump_mat(%sm2r) : (tensor<?x?xf64, #CSR>) -> ()
+    call @dump_vec(%1) : (tensor<?xf64, #SparseVector>) -> ()
+    call @dump_vec(%2) : (tensor<?xf64, #SparseVector>) -> ()
+    call @dump_mat(%5) : (tensor<?x?xf64, #CSR>) -> ()
+    call @dump_mat(%6) : (tensor<?x?xf64, #CSR>) -> ()
+
+    // Release the resources.
+    bufferization.dealloc_tensor %sm1 : tensor<?x?xf64, #CSR>
+    bufferization.dealloc_tensor %sm2r : tensor<?x?xf64, #CSR>
+    bufferization.dealloc_tensor %sm2c : tensor<?x?xf64, #CSC>
+    bufferization.dealloc_tensor %1 : tensor<?xf64, #SparseVector>
+    bufferization.dealloc_tensor %2 : tensor<?xf64, #SparseVector>
+    bufferization.dealloc_tensor %5 : tensor<?x?xf64, #CSR>
+    bufferization.dealloc_tensor %6 : tensor<?x?xf64, #CSR>
+    return
+  }
+}

diff  --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 91101c4d77db5..3bf424c40a30d 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -283,6 +283,7 @@ class MergerTestBase : public ::testing::Test {
     case kShrU:
     case kShlI:
     case kBinary:
+    case kReduce:
       return compareExpression(tensorExp.children.e0, pattern->e0) &&
              compareExpression(tensorExp.children.e1, pattern->e1);
     }


        


More information about the Mlir-commits mailing list