[Mlir-commits] [mlir] 7919350 - [mlir][sparse] Implement sparse_tensor.select

Jim Kitchen llvmlistbot at llvm.org
Mon Oct 3 12:39:41 PDT 2022


Author: Jim Kitchen
Date: 2022-10-03T14:39:26-05:00
New Revision: 791935037b0b3b211bee54fae694aeb5b7b75125

URL: https://github.com/llvm/llvm-project/commit/791935037b0b3b211bee54fae694aeb5b7b75125
DIFF: https://github.com/llvm/llvm-project/commit/791935037b0b3b211bee54fae694aeb5b7b75125.diff

LOG: [mlir][sparse] Implement sparse_tensor.select

The region within sparse_tensor.select is used as the runtime criteria
for whether to keep the existing value in the sparse tensor.

While the sparse element is provided to the comparison, indices may also
be used to decide on whether to keep the original value. This allows, for
example, to only keep the upper triangle of a matrix.

Reviewed by: aartbik

Differential Revision: https://reviews.llvm.org/D134761

Added: 
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_select.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index f9be4762ba640..a376d9a6f479f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -76,6 +76,7 @@ enum Kind {
   kBitCast,
   kBinaryBranch, // semiring unary branch created from a binary op
   kUnary,        // semiring unary op
+  kSelect,       // custom selection criteria
   // Binary operations.
   kMulF,
   kMulC,
@@ -129,8 +130,8 @@ struct TensorExp {
   /// this field may be used to cache "hoisted" loop invariant tensor loads.
   Value val;
 
-  /// Code blocks used by semirings. For the case of kUnary, kBinary, and
-  /// kReduce, this holds the original operation with all regions. For
+  /// Code blocks used by semirings. For the case of kUnary, kBinary, kReduce,
+  /// and kSelect, this holds the original operation with all regions. For
   /// kBinaryBranch, this holds the YieldOp for the left or right half
   /// to be merged into a nested scf loop.
   Operation *op;

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 88bd885393b7b..7cd9f7f31e425 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -878,6 +878,15 @@ static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder,
       // Only unary and binary are allowed to return uninitialized rhs
       // to indicate missing output.
       assert(merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary);
+    } else if (merger.exp(exp).kind == kSelect) {
+      scf::IfOp ifOp = builder.create<scf::IfOp>(loc, rhs);
+      builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+      // Existing value was preserved to be used here.
+      assert(merger.exp(exp).val);
+      Value v0 = merger.exp(exp).val;
+      genInsertionStore(codegen, builder, op, t, v0);
+      merger.exp(exp).val = Value();
+      builder.setInsertionPointAfter(ifOp);
     } else {
       genInsertionStore(codegen, builder, op, t, rhs);
     }
@@ -1037,9 +1046,15 @@ static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
   if (ee && (merger.exp(exp).kind == Kind::kUnary ||
              merger.exp(exp).kind == Kind::kBinary ||
              merger.exp(exp).kind == Kind::kBinaryBranch ||
-             merger.exp(exp).kind == Kind::kReduce))
+             merger.exp(exp).kind == Kind::kReduce ||
+             merger.exp(exp).kind == Kind::kSelect))
     ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx);
 
+  if (merger.exp(exp).kind == kSelect) {
+    assert(!merger.exp(exp).val);
+    merger.exp(exp).val = v0; // Preserve value for later use.
+  }
+
   if (merger.exp(exp).kind == Kind::kReduce) {
     assert(codegen.redCustom != -1u);
     codegen.redCustom = -1u;

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 7f132f90e81bb..3791971501b0a 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -78,6 +78,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
     children.e1 = y;
     break;
   case kBinaryBranch:
+  case kSelect:
     assert(x != -1u && y == -1u && !v && o);
     children.e0 = x;
     children.e1 = y;
@@ -212,7 +213,7 @@ unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
 }
 
 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) {
-  assert(kAbsF <= kind && kind <= kUnary);
+  assert(kAbsF <= kind && kind <= kSelect);
   unsigned s = addSet();
   for (unsigned p : latSets[s0]) {
     unsigned e = addExp(kind, latPoints[p].exp, v, op);
@@ -265,9 +266,8 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
   BitVector simple = latPoints[p0].bits;
   bool reset = isSingleton && hasAnySparse(simple);
   for (unsigned b = 0, be = simple.size(); b < be; b++) {
-    if (simple[b] &&
-        (!isDimLevelType(b, DimLvlType::kCompressed) &&
-         !isDimLevelType(b, DimLvlType::kSingleton))) {
+    if (simple[b] && (!isDimLevelType(b, DimLvlType::kCompressed) &&
+                      !isDimLevelType(b, DimLvlType::kSingleton))) {
       if (reset)
         simple.reset(b);
       reset = true;
@@ -338,6 +338,7 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
     return isSingleCondition(t, tensorExps[e].children.e0);
   case kBinaryBranch:
   case kUnary:
+  case kSelect:
     return false;
   // Binary operations.
   case kDivF: // note: x / c only
@@ -449,6 +450,8 @@ static const char *kindToOpSymbol(Kind kind) {
     return "binary_branch";
   case kUnary:
     return "unary";
+  case kSelect:
+    return "select";
   // Binary operations.
   case kMulF:
   case kMulC:
@@ -537,6 +540,7 @@ void Merger::dumpExp(unsigned e) const {
   case kBitCast:
   case kBinaryBranch:
   case kUnary:
+  case kSelect:
     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
     dumpExp(tensorExps[e].children.e0);
     break;
@@ -684,6 +688,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
                   tensorExps[e].val);
   case kBinaryBranch:
+  case kSelect:
     // The left or right half of a binary operation which has already
     // been split into separate operations for each region.
     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
@@ -978,6 +983,10 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
             isAdmissableBranch(unop, unop.getAbsentRegion()))
           return addExp(kUnary, e, Value(), def);
       }
+      if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
+        if (isAdmissableBranch(selop, selop.getRegion()))
+          return addExp(kSelect, e, Value(), def);
+      }
     }
   }
   // Construct binary operations if subexpressions can be built.
@@ -1228,6 +1237,9 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
                          *tensorExps[e].op->getBlock()->getParent(), {v0});
   case kUnary:
     return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
+  case kSelect:
+    return insertYieldOp(rewriter, loc,
+                         cast<SelectOp>(tensorExps[e].op).getRegion(), {v0});
   case kBinary:
     return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
   case kReduce: {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_select.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_select.mlir
new file mode 100644
index 0000000000000..bbe94e9c71d14
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_select.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+#CSC = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  dimOrdering = affine_map<(i,j) -> (j,i)>
+}>
+
+//
+// Traits for tensor operations.
+//
+#trait_vec_select = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>, // A
+    affine_map<(i) -> (i)>  // C (out)
+  ],
+  iterator_types = ["parallel"]
+}
+
+#trait_mat_select = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A (in)
+    affine_map<(i,j) -> (i,j)>   // X (out)
+  ],
+  iterator_types = ["parallel", "parallel"]
+}
+
+module {
+  func.func @vecSelect(%arga: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
+    %c0 = arith.constant 0 : index
+    %cf1 = arith.constant 1.0 : f64
+    %d0 = tensor.dim %arga, %c0 : tensor<?xf64, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d0): tensor<?xf64, #SparseVector>
+    %0 = linalg.generic #trait_vec_select
+      ins(%arga: tensor<?xf64, #SparseVector>)
+      outs(%xv: tensor<?xf64, #SparseVector>) {
+        ^bb(%a: f64, %b: f64):
+          %1 = sparse_tensor.select %a : f64 {
+              ^bb0(%x: f64):
+                %keep = arith.cmpf "oge", %x, %cf1 : f64
+                sparse_tensor.yield %keep : i1
+            }
+          linalg.yield %1 : f64
+    } -> tensor<?xf64, #SparseVector>
+    return %0 : tensor<?xf64, #SparseVector>
+  }
+
+  func.func @matUpperTriangle(%arga: tensor<?x?xf64, #CSR>) -> tensor<?x?xf64, #CSR> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #CSR>
+    %d1 = tensor.dim %arga, %c1 : tensor<?x?xf64, #CSR>
+    %xv = bufferization.alloc_tensor(%d0, %d1): tensor<?x?xf64, #CSR>
+    %0 = linalg.generic #trait_mat_select
+      ins(%arga: tensor<?x?xf64, #CSR>)
+      outs(%xv: tensor<?x?xf64, #CSR>) {
+        ^bb(%a: f64, %b: f64):
+          %row = linalg.index 0 : index
+          %col = linalg.index 1 : index
+          %1 = sparse_tensor.select %a : f64 {
+              ^bb0(%x: f64):
+                %keep = arith.cmpi "ugt", %col, %row : index
+                sparse_tensor.yield %keep : i1
+            }
+          linalg.yield %1 : f64
+    } -> tensor<?x?xf64, #CSR>
+    return %0 : tensor<?x?xf64, #CSR>
+  }
+
+  // Dumps a sparse vector of type f64.
+  func.func @dump_vec(%arg0: tensor<?xf64, #SparseVector>) {
+    // Dump the values array to verify only sparse contents are stored.
+    %c0 = arith.constant 0 : index
+    %d0 = arith.constant -1.0 : f64
+    %0 = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf64>
+    %1 = vector.transfer_read %0[%c0], %d0: memref<?xf64>, vector<8xf64>
+    vector.print %1 : vector<8xf64>
+    // Dump the dense vector to verify structure is correct.
+    %dv = sparse_tensor.convert %arg0 : tensor<?xf64, #SparseVector> to tensor<?xf64>
+    %2 = vector.transfer_read %dv[%c0], %d0: tensor<?xf64>, vector<16xf64>
+    vector.print %2 : vector<16xf64>
+    return
+  }
+
+  // Dump a sparse matrix.
+  func.func @dump_mat(%arg0: tensor<?x?xf64, #CSR>) {
+    // Dump the values array to verify only sparse contents are stored.
+    %c0 = arith.constant 0 : index
+    %d0 = arith.constant -1.0 : f64
+    %0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #CSR> to memref<?xf64>
+    %1 = vector.transfer_read %0[%c0], %d0: memref<?xf64>, vector<16xf64>
+    vector.print %1 : vector<16xf64>
+    %dm = sparse_tensor.convert %arg0 : tensor<?x?xf64, #CSR> to tensor<?x?xf64>
+    %2 = vector.transfer_read %dm[%c0, %c0], %d0: tensor<?x?xf64>, vector<5x5xf64>
+    vector.print %2 : vector<5x5xf64>
+    return
+  }
+
+  // Driver method to call and verify vector kernels.
+  func.func @entry() {
+    %c0 = arith.constant 0 : index
+
+    // Setup sparse matrices.
+    %v1 = arith.constant sparse<
+        [ [1], [3], [5], [7], [9] ],
+        [ 1.0, 2.0, -4.0, 0.0, 5.0 ]
+    > : tensor<10xf64>
+    %m1 = arith.constant sparse<
+        [ [0, 3], [1, 4], [2, 1], [2, 3], [3, 3], [3, 4], [4, 2] ],
+        [ 1., 2., 3., 4., 5., 6., 7.]
+    > : tensor<5x5xf64>
+    %sv1 = sparse_tensor.convert %v1 : tensor<10xf64> to tensor<?xf64, #SparseVector>
+    %sm1 = sparse_tensor.convert %m1 : tensor<5x5xf64> to tensor<?x?xf64, #CSR>
+
+    // Call sparse matrix kernels.
+    %1 = call @vecSelect(%sv1) : (tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
+    %2 = call @matUpperTriangle(%sm1) : (tensor<?x?xf64, #CSR>) -> tensor<?x?xf64, #CSR>
+
+    //
+    // Verify the results.
+    //
+    // CHECK:      ( 1, 2, -4, 0, 5, -1, -1, -1 )
+    // CHECK-NEXT: ( 0, 1, 0, 2, 0, -4, 0, 0, 0, 5, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( ( 0, 0, 0, 1, 0 ), ( 0, 0, 0, 0, 2 ), ( 0, 3, 0, 4, 0 ), ( 0, 0, 0, 5, 6 ), ( 0, 0, 7, 0, 0 ) )
+    // CHECK-NEXT: ( 1, 2, 5, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 0, 1, 0, 2, 0, 0, 0, 0, 0, 5, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 1, 2, 4, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( ( 0, 0, 0, 1, 0 ), ( 0, 0, 0, 0, 2 ), ( 0, 0, 0, 4, 0 ), ( 0, 0, 0, 0, 6 ), ( 0, 0, 0, 0, 0 ) )
+    //
+    call @dump_vec(%sv1) : (tensor<?xf64, #SparseVector>) -> ()
+    call @dump_mat(%sm1) : (tensor<?x?xf64, #CSR>) -> ()
+    call @dump_vec(%1) : (tensor<?xf64, #SparseVector>) -> ()
+    call @dump_mat(%2) : (tensor<?x?xf64, #CSR>) -> ()
+
+    // Release the resources.
+    bufferization.dealloc_tensor %sv1 : tensor<?xf64, #SparseVector>
+    bufferization.dealloc_tensor %sm1 : tensor<?x?xf64, #CSR>
+    bufferization.dealloc_tensor %1 : tensor<?xf64, #SparseVector>
+    bufferization.dealloc_tensor %2 : tensor<?x?xf64, #CSR>
+    return
+  }
+}

diff  --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 8d41558d31e98..c0e75dc3f0e78 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -259,6 +259,7 @@ class MergerTestBase : public ::testing::Test {
     case kCIm:
     case kCRe:
     case kBitCast:
+    case kSelect:
     case kBinaryBranch:
     case kUnary:
       return compareExpression(tensorExp.children.e0, pattern->e0);


        


More information about the Mlir-commits mailing list