[Mlir-commits] [mlir] 53cc3a0 - [mlir][sparse] index support in sparse compiler codegen

Aart Bik llvmlistbot at llvm.org
Tue Mar 8 17:25:43 PST 2022


Author: Aart Bik
Date: 2022-03-08T17:25:36-08:00
New Revision: 53cc3a06378229f5b4713f0db39135e846609d0a

URL: https://github.com/llvm/llvm-project/commit/53cc3a06378229f5b4713f0db39135e846609d0a
DIFF: https://github.com/llvm/llvm-project/commit/53cc3a06378229f5b4713f0db39135e846609d0a.diff

LOG: [mlir][sparse] index support in sparse compiler codegen

This revision adds support for the linalg.index to the sparse compiler
pipeline. In essence, this adds the ability to refer to indices in
the tensor index expression, as illustrated below:

 Y[i, j, k, l, m] = T[i, j, k, l, m]  * i * j

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D121251

Added: 
    mlir/test/Dialect/SparseTensor/sparse_index.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 304ba93737b5b..3ecd584c8fe83 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -28,6 +28,7 @@ enum Kind {
   // Leaf.
   kTensor = 0,
   kInvariant,
+  kIndex,
   // Unary operations.
   kAbsF,
   kCeilF,
@@ -42,6 +43,7 @@ enum Kind {
   kCastUF, // unsigned
   kCastS,  // signed
   kCastU,  // unsigned
+  kCastIdx,
   kTruncI,
   kBitCast,
   // Binary operations.
@@ -79,6 +81,9 @@ struct TensorExp {
     /// Expressions representing tensors simply have a tensor number.
     unsigned tensor;
 
+    /// Indices hold the index number.
+    unsigned index;
+
     /// Tensor operations hold the indices of their children.
     Children children;
   };

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 6945e54a73a05..e4faddd9b7e8a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "CodegenUtils.h"
+
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
@@ -870,6 +871,13 @@ static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
   return rewriter.create<arith::AddIOp>(loc, mul, i);
 }
 
+/// Generates an index value.
+static Value genIndexValue(Merger &merger, CodeGen &codegen, unsigned exp) {
+  assert(codegen.curVecLength == 1); // TODO: implement vectorization!
+  unsigned idx = merger.exp(exp).index;
+  return codegen.loops[idx];
+}
+
 /// Recursively generates tensor expression.
 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
                     linalg::GenericOp op, unsigned exp) {
@@ -880,6 +888,8 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
     return genTensorLoad(merger, codegen, rewriter, op, exp);
   if (merger.exp(exp).kind == Kind::kInvariant)
     return genInvariantValue(merger, codegen, rewriter, exp);
+  if (merger.exp(exp).kind == Kind::kIndex)
+    return genIndexValue(merger, codegen, exp);
   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0);
   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1);
   return merger.buildExp(rewriter, loc, exp, v0, v1);
@@ -947,7 +957,8 @@ static void genInvariants(Merger &merger, CodeGen &codegen,
       merger.exp(exp).val =
           atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value();
     }
-  } else if (merger.exp(exp).kind != Kind::kInvariant) {
+  } else if (merger.exp(exp).kind != Kind::kInvariant &&
+             merger.exp(exp).kind != Kind::kIndex) {
     // Traverse into the binary operations. Note that we only hoist
     // tensor loads, since subsequent MLIR/LLVM passes know how to
     // deal with all other kinds of derived loop invariants.
@@ -1039,7 +1050,12 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
 /// Returns vectorization strategy. Any implicit inner loop in the Linalg
 /// operation is a candidate. Whether it is actually converted to SIMD code
 /// depends on the requested strategy.
-static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
+static bool isVectorFor(CodeGen &codegen, bool isInner, bool isReduction,
+                        bool isSparse) {
+  // Reject vectorization of sparse output, unless innermost is reduction.
+  if (codegen.sparseOut && !isReduction)
+    return false;
+  // Inspect strategy.
   switch (codegen.options.vectorizationStrategy) {
   case SparseVectorizationStrategy::kNone:
     return false;
@@ -1056,6 +1072,10 @@ static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) {
 /// to a parallel operation depends on the requested strategy.
 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
                           bool isSparse, bool isVector) {
+  // Reject parallelization of sparse output.
+  if (codegen.sparseOut)
+    return false;
+  // Inspect strategy.
   switch (codegen.options.parallelizationStrategy) {
   case SparseParallelizationStrategy::kNone:
     return false;
@@ -1107,11 +1127,9 @@ static Operation *genFor(Merger &merger, CodeGen &codegen,
   auto iteratorTypes = op.iterator_types().getValue();
   bool isReduction = isReductionIterator(iteratorTypes[idx]);
   bool isSparse = merger.isDim(fb, Dim::kSparse);
-  bool isVector = !codegen.sparseOut &&
-                  isVectorFor(codegen, isInner, isSparse) &&
+  bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) &&
                   denseUnitStrides(merger, op, idx);
   bool isParallel =
-      !codegen.sparseOut &&
       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
 
   // Prepare vector length.
@@ -1626,6 +1644,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
 
   LogicalResult matchAndRewrite(linalg::GenericOp op,
                                 PatternRewriter &rewriter) const override {
+
     // Detects sparse annotations and translate the per-dimension sparsity
     // information for all tensors to loop indices in the kernel.
     assert(op.getNumOutputs() == 1);

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 37e077acf06aa..005278ca70d22 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -29,6 +29,10 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
   case kInvariant:
     assert(x == -1u && y == -1u && v);
     break;
+  case kIndex:
+    assert(x != -1u && y == -1u && !v);
+    index = x;
+    break;
   case kAbsF:
   case kCeilF:
   case kFloorF:
@@ -46,6 +50,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
   case kCastUF:
   case kCastS:
   case kCastU:
+  case kCastIdx:
   case kTruncI:
   case kBitCast:
     assert(x != -1u && y == -1u && v);
@@ -230,6 +235,7 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
   case kCastUF:
   case kCastS:
   case kCastU:
+  case kCastIdx:
   case kTruncI:
   case kBitCast:
     return isSingleCondition(t, tensorExps[e].children.e0);
@@ -273,6 +279,8 @@ static const char *kindToOpSymbol(Kind kind) {
     return "tensor";
   case kInvariant:
     return "invariant";
+  case kIndex:
+    return "index";
   case kAbsF:
     return "abs";
   case kCeilF:
@@ -291,6 +299,7 @@ static const char *kindToOpSymbol(Kind kind) {
   case kCastUF:
   case kCastS:
   case kCastU:
+  case kCastIdx:
   case kTruncI:
   case kBitCast:
     return "cast";
@@ -340,6 +349,9 @@ void Merger::dumpExp(unsigned e) const {
   case kInvariant:
     llvm::dbgs() << "invariant";
     break;
+  case kIndex:
+    llvm::dbgs() << "index_" << tensorExps[e].index;
+    break;
   case kAbsF:
   case kCeilF:
   case kFloorF:
@@ -353,6 +365,7 @@ void Merger::dumpExp(unsigned e) const {
   case kCastUF:
   case kCastS:
   case kCastU:
+  case kCastIdx:
   case kTruncI:
   case kBitCast:
     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
@@ -420,16 +433,20 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
   Kind kind = tensorExps[e].kind;
   switch (kind) {
   case kTensor:
-  case kInvariant: {
+  case kInvariant:
+  case kIndex: {
     // Either the index is really used in the tensor expression, or it is
-    // set to the undefined index in that dimension. An invariant expression
-    // and a truly dynamic sparse output tensor are set to a synthetic tensor
-    // with undefined indices only to ensure the iteration space is not
-    // skipped as a result of their contents.
+    // set to the undefined index in that dimension. An invariant expression,
+    // a proper index value, and a truly dynamic sparse output tensor are set
+    // to a synthetic tensor with undefined indices only to ensure the
+    // iteration space is not skipped as a result of their contents.
     unsigned s = addSet();
-    unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor;
-    if (hasSparseOut && t == outTensor)
-      t = syntheticTensor;
+    unsigned t = syntheticTensor;
+    if (kind == kTensor) {
+      t = tensorExps[e].tensor;
+      if (hasSparseOut && t == outTensor)
+        t = syntheticTensor;
+    }
     latSets[s].push_back(addLat(t, i, e));
     return s;
   }
@@ -446,6 +463,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
   case kCastUF:
   case kCastS:
   case kCastU:
+  case kCastIdx:
   case kTruncI:
   case kBitCast:
     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
@@ -569,6 +587,11 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
   Operation *def = v.getDefiningOp();
   if (def->getBlock() != &op.region().front())
     return addExp(kInvariant, v);
+  // Construct index operations.
+  if (def->getNumOperands() == 0) {
+    if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
+      return addExp(kIndex, indexOp.dim());
+  }
   // Construct unary operations if subexpression can be built.
   if (def->getNumOperands() == 1) {
     auto x = buildTensorExp(op, def->getOperand(0));
@@ -598,6 +621,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
         return addExp(kCastS, e, v);
       if (isa<arith::ExtUIOp>(def))
         return addExp(kCastU, e, v);
+      if (isa<arith::IndexCastOp>(def))
+        return addExp(kCastIdx, e, v);
       if (isa<arith::TruncIOp>(def))
         return addExp(kTruncI, e, v);
       if (isa<arith::BitcastOp>(def))
@@ -654,6 +679,7 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
   switch (tensorExps[e].kind) {
   case kTensor:
   case kInvariant:
+  case kIndex:
     llvm_unreachable("unexpected non-op");
   // Unary ops.
   case kAbsF:
@@ -686,6 +712,8 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
   case kCastU:
     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
+  case kCastIdx:
+    return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
   case kTruncI:
     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
   case kBitCast:

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
new file mode 100644
index 0000000000000..f41c765376bbe
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
@@ -0,0 +1,128 @@
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+#DenseMatrix = #sparse_tensor.encoding<{
+  dimLevelType = ["dense", "dense"]
+}>
+
+#SparseMatrix = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed", "compressed"]
+}>
+
+#trait = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A
+    affine_map<(i,j) -> (i,j)>   // X (out)
+  ],
+  iterator_types = ["parallel", "parallel"],
+  doc = "X(i,j) = A(i,j) * i * j"
+}
+
+// CHECK-LABEL:   func @dense_index(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.init{{\[}}%[[VAL_3]], %[[VAL_4]]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_7:.*]] = tensor.dim %[[VAL_5]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_8:.*]] = tensor.dim %[[VAL_5]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_5]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_1]] to %[[VAL_7]] step %[[VAL_2]] {
+// CHECK:             scf.for %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_8]] step %[[VAL_2]] {
+// CHECK:               %[[VAL_12:.*]] = arith.muli %[[VAL_8]], %[[VAL_10]] : index
+// CHECK:               %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_11]] : index
+// CHECK:               %[[VAL_14:.*]] = arith.muli %[[VAL_8]], %[[VAL_10]] : index
+// CHECK:               %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_11]] : index
+// CHECK:               %[[VAL_16:.*]] = arith.index_cast %[[VAL_11]] : index to i64
+// CHECK:               %[[VAL_17:.*]] = arith.index_cast %[[VAL_10]] : index to i64
+// CHECK:               %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xi64>
+// CHECK:               %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_18]] : i64
+// CHECK:               %[[VAL_20:.*]] = arith.muli %[[VAL_16]], %[[VAL_19]] : i64
+// CHECK:               memref.store %[[VAL_20]], %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xi64>
+// CHECK:             }
+// CHECK:           }
+// CHECK:           %[[VAL_21:.*]] = sparse_tensor.load %[[VAL_5]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK:           return %[[VAL_21]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK:         }
+func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
+                      -> tensor<?x?xi64, #DenseMatrix> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 0 : index
+  %0 = tensor.dim %arga, %c0 : tensor<?x?xi64, #DenseMatrix>
+  %1 = tensor.dim %arga, %c1 : tensor<?x?xi64, #DenseMatrix>
+  %init = sparse_tensor.init [%0, %1] : tensor<?x?xi64, #DenseMatrix>
+  %r = linalg.generic #trait
+      ins(%arga: tensor<?x?xi64, #DenseMatrix>)
+     outs(%init: tensor<?x?xi64, #DenseMatrix>) {
+      ^bb(%a: i64, %x: i64):
+        %i = linalg.index 0 : index
+        %j = linalg.index 1 : index
+        %ii = arith.index_cast %i : index to i64
+        %jj = arith.index_cast %j : index to i64
+        %m1 = arith.muli %ii, %a : i64
+        %m2 = arith.muli %jj, %m1 : i64
+        linalg.yield %m2 : i64
+  } -> tensor<?x?xi64, #DenseMatrix>
+  return %r : tensor<?x?xi64, #DenseMatrix>
+}
+
+// CHECK-LABEL:   func @sparse_index(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.init{{\[}}%[[VAL_4]], %[[VAL_5]]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK:           %[[VAL_12:.*]] = memref.alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_1]]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_2]] {
+// CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK:             memref.store %[[VAL_16]], %[[VAL_12]]{{\[}}%[[VAL_1]]] : memref<?xindex>
+// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xindex>
+// CHECK:             %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_2]] : index
+// CHECK:             %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xindex>
+// CHECK:             scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_2]] {
+// CHECK:               %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref<?xindex>
+// CHECK:               memref.store %[[VAL_21]], %[[VAL_12]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:               %[[VAL_22:.*]] = arith.index_cast %[[VAL_21]] : index to i64
+// CHECK:               %[[VAL_23:.*]] = arith.index_cast %[[VAL_16]] : index to i64
+// CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<?xi64>
+// CHECK:               %[[VAL_25:.*]] = arith.muli %[[VAL_23]], %[[VAL_24]] : i64
+// CHECK:               %[[VAL_26:.*]] = arith.muli %[[VAL_22]], %[[VAL_25]] : i64
+// CHECK:               sparse_tensor.lex_insert %[[VAL_6]], %[[VAL_12]], %[[VAL_26]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK:             }
+// CHECK:           }
+// CHECK:           %[[VAL_27:.*]] = sparse_tensor.load %[[VAL_6]] hasInserts : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK:           return %[[VAL_27]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK:         }
+func @sparse_index(%arga: tensor<?x?xi64, #SparseMatrix>)
+                       -> tensor<?x?xi64, #SparseMatrix> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 0 : index
+  %0 = tensor.dim %arga, %c0 : tensor<?x?xi64, #SparseMatrix>
+  %1 = tensor.dim %arga, %c1 : tensor<?x?xi64, #SparseMatrix>
+  %init = sparse_tensor.init [%0, %1] : tensor<?x?xi64, #SparseMatrix>
+  %r = linalg.generic #trait
+      ins(%arga: tensor<?x?xi64, #SparseMatrix>)
+     outs(%init: tensor<?x?xi64, #SparseMatrix>) {
+      ^bb(%a: i64, %x: i64):
+        %i = linalg.index 0 : index
+        %j = linalg.index 1 : index
+        %ii = arith.index_cast %i : index to i64
+        %jj = arith.index_cast %j : index to i64
+        %m1 = arith.muli %ii, %a : i64
+        %m2 = arith.muli %jj, %m1 : i64
+        linalg.yield %m2 : i64
+  } -> tensor<?x?xi64, #SparseMatrix>
+  return %r : tensor<?x?xi64, #SparseMatrix>
+}
+

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir
new file mode 100644
index 0000000000000..36a052155a591
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#SparseMatrix = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed", "compressed"]
+}>
+
+#trait = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A
+    affine_map<(i,j) -> (i,j)>   // X (out)
+  ],
+  iterator_types = ["parallel", "parallel"],
+  doc = "X(i,j) = A(i,j) * i * j"
+}
+
+module {
+
+  //
+  // Kernel that uses indices in the index notation.
+  //
+  func @sparse_index(%arga: tensor<3x4xi64, #SparseMatrix>)
+                         -> tensor<3x4xi64, #SparseMatrix> {
+    %d0 = arith.constant 3 : index
+    %d1 = arith.constant 4 : index
+    %init = sparse_tensor.init [%d0, %d1] : tensor<3x4xi64, #SparseMatrix>
+    %r = linalg.generic #trait
+        ins(%arga: tensor<3x4xi64, #SparseMatrix>)
+       outs(%init: tensor<3x4xi64, #SparseMatrix>) {
+        ^bb(%a: i64, %x: i64):
+          %i = linalg.index 0 : index
+          %j = linalg.index 1 : index
+          %ii = arith.index_cast %i : index to i64
+          %jj = arith.index_cast %j : index to i64
+          %m1 = arith.muli %ii, %a : i64
+          %m2 = arith.muli %jj, %m1 : i64
+          linalg.yield %m2 : i64
+    } -> tensor<3x4xi64, #SparseMatrix>
+    return %r : tensor<3x4xi64, #SparseMatrix>
+  }
+
+  //
+  // Main driver.
+  //
+  func @entry() {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+    %du = arith.constant -1 : i64
+
+    // Setup input "sparse" matrix.
+    %d = arith.constant dense <[
+       [ 1,  1,  1,  1 ],
+       [ 1,  1,  1,  1 ],
+       [ 1,  1,  1,  1 ]
+    ]> : tensor<3x4xi64>
+    %a = sparse_tensor.convert %d : tensor<3x4xi64> to tensor<3x4xi64, #SparseMatrix>
+
+    // Call the kernel.
+    %0 = call @sparse_index(%a) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64, #SparseMatrix>
+
+    //
+    // Verify result.
+    //
+    // CHECK: ( ( 0, 0, 0, 0 ), ( 0, 1, 2, 3 ), ( 0, 2, 4, 6 ) )
+    //
+    %x = sparse_tensor.convert %0 : tensor<3x4xi64, #SparseMatrix> to tensor<3x4xi64>
+    %m = bufferization.to_memref %x : memref<3x4xi64>
+    %v = vector.transfer_read %m[%c0, %c0], %du: memref<3x4xi64>, vector<3x4xi64>
+    vector.print %v : vector<3x4xi64>
+
+    // Release resources.
+    sparse_tensor.release %a : tensor<3x4xi64, #SparseMatrix>
+    sparse_tensor.release %0 : tensor<3x4xi64, #SparseMatrix>
+    memref.dealloc %m : memref<3x4xi64>
+
+    return
+  }
+}


        


More information about the Mlir-commits mailing list