[Mlir-commits] [mlir] [mlir][sparse] use sparse_tensor.insert to update values in all-dense… (PR #76544)
Peiming Liu
llvmlistbot at llvm.org
Thu Dec 28 16:04:50 PST 2023
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/76544
… annotated output tensor.
>From 2440d5067f435630c4d7e1c2bf70f599a8a0c22f Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 28 Dec 2023 23:35:05 +0000
Subject: [PATCH] [mlir][sparse] use sparse_tensor.insert to update values in
all-dense annotated output tensor.
---
.../Transforms/Sparsification.cpp | 15 +---
.../Transforms/Utils/CodegenEnv.cpp | 10 ++-
.../Transforms/Utils/CodegenEnv.h | 5 +-
.../Transforms/Utils/LoopEmitter.cpp | 8 +-
mlir/test/Dialect/SparseTensor/dense.mlir | 85 +++++++++----------
mlir/test/Dialect/SparseTensor/one_trip.mlir | 18 ++--
.../Dialect/SparseTensor/sparse_affine.mlir | 28 +++---
.../Dialect/SparseTensor/sparse_index.mlir | 42 +++++----
8 files changed, 101 insertions(+), 110 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 934e1e559f44d6..6b8b0bed7521f5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -688,13 +688,13 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
} else {
env.startReduc(exp, genTensorLoad(env, builder, exp));
}
- if (env.hasSparseOutput())
+ if (env.hasTrulySparseOutput())
env.startValidLexInsert(
constantI1(builder, env.op().getLoc(), false));
} else {
if (!env.isCustomReduc() || env.isReduc())
genTensorStore(env, builder, exp, env.endReduc());
- if (env.hasSparseOutput())
+ if (env.hasTrulySparseOutput())
env.endValidLexInsert();
}
} else {
@@ -769,7 +769,7 @@ static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
/// converted to a parallel operation depends on the requested strategy.
static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
// Reject parallelization of sparse output.
- if (env.hasSparseOutput())
+ if (env.hasTrulySparseOutput())
return false;
// Parallel loops on tensor expansion can cause data races.
if (env.isExpand())
@@ -1038,8 +1038,6 @@ static bool translateBitsToTidLvlPairs(
SmallVectorImpl<TensorLevel> &tidLvls,
SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
const BitVector &simple = env.lat(li).simple;
- const TensorId outTid = env.merger().getOutTensorID();
- const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
unsigned numloopCond = 0;
bool hasNonUnique = false;
@@ -1116,13 +1114,6 @@ static bool translateBitsToTidLvlPairs(
}
});
- if (isDenseLT(env.lt(outTid, curr))) {
- // Note that we generate dense indices of the output tensor
- // unconditionally, since they may not appear in the lattice, but may be
- // needed for linearized env.
- tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
- }
-
if (numloopCond == 0) {
// Corner cases where the loop bound is defined by a *unused* operand, in
// this case, we just generate a dense "fake" loop by iterating over the
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
index d3de55e4d59bd8..3beed08a738c46 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
@@ -147,10 +147,18 @@ bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
OpOperand *lhs = linalgOp.getDpsInitOperand(0);
const TensorId tensor = makeTensorId(lhs->getOperandNumber());
+ auto outStt = getSparseTensorType(lhs->get());
// An non-annotated output tensor is assumed dense, and becomes a random
// access n-dim memref. Admissible since insertions cannot occur.
- if (getSparseTensorType(lhs->get()).isAllDense())
+ if (outStt.isAllDense()) {
+ // We treat "all dense" annotated tensor as a "sparse" tensor and handle it
+ // in a unified way as "truly sparse" tensor, which avoids extra code to
+ // handle corner cases introduced by the use of "all dense" annotated
+ // tensors.
+ if (outStt.hasEncoding())
+ sparseOut = lhs;
return true;
+ }
// A tensor expression with a sparse output tensor that changes its values
// but not its nonzero structure, an operation called "simply dynamic" in
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
index 728af841cc7b17..f1bd45c7a8f3de 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
@@ -125,7 +125,10 @@ class CodegenEnv {
// Sparse tensor output and expansion methods.
//
- bool hasSparseOutput() const { return sparseOut != nullptr; }
+ bool hasTrulySparseOutput() const {
+ return sparseOut != nullptr &&
+ !getSparseTensorType(sparseOut->get()).isAllDense();
+ }
bool isSparseOutput(OpOperand *o) const { return sparseOut == o; }
Value getInsertionChain() const { return insChain; }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 80dad064676220..ccbbacc2ac379a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -1339,13 +1339,7 @@ void LoopEmitter::enterTensorsAtDenseLvls(
assert(dependentLvlMap[tid][lvl].empty());
auto enc = getSparseTensorEncoding(tensors[tid].getType());
if (enc && !isSparseOutput(tid)) {
- bool validPos = lvl == 0 || posits[tid][lvl - 1];
- if (!validPos) {
- // We might not find the pos for the sparse output tensor as it is
- // unconditionally required by the sparsification.
- assert(isOutputTensor(tid));
- continue;
- }
+ assert(lvl == 0 || posits[tid][lvl - 1]);
posits[tid][lvl] = genAddress(builder, loc, tid, lvl, iv);
// NOTE: we can also prepare for next lvl here in advance
}
diff --git a/mlir/test/Dialect/SparseTensor/dense.mlir b/mlir/test/Dialect/SparseTensor/dense.mlir
index 2d8dcfea9adc19..a254161dd08984 100644
--- a/mlir/test/Dialect/SparseTensor/dense.mlir
+++ b/mlir/test/Dialect/SparseTensor/dense.mlir
@@ -71,27 +71,27 @@ func.func @dense1(%arga: tensor<32x16xf32, #DenseMatrix>,
// Test with a non-annotated dense matrix as input and
// an all-dense annotated "sparse" matrix as output.
//
-// CHECK-LABEL: func @dense2(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32, #sparse{{[0-9]*}}>) -> tensor<32x16xf32, #sparse{{[0-9]*}}> {
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 32 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 16 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
+
+// CHECK-LABEL: func.func @dense2(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32, #sparse{{[0-9]*}}>) -> tensor<32x16xf32, #sparse{{[0-9]*}}> {
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 32 : index
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_0]] : memref<32x16xf32>
-// CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
-// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index
-// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_10]] : index
-// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<32x16xf32>
-// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_13]], %[[VAL_2]] : f32
-// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xf32>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[VAL_15:.*]] = sparse_tensor.load %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
-// CHECK: return %[[VAL_15]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK: %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_5]] iter_args(%[[VAL_10:.*]] = %[[VAL_1]]) -> (tensor<32x16xf32, #sparse{{[0-9]*}}>) {
+// CHECK: %[[VAL_11:.*]] = scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_13:.*]] = %[[VAL_10]]) -> (tensor<32x16xf32, #sparse{{[0-9]*}}>) {
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_9]], %[[VAL_12]]] : memref<32x16xf32>
+// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_14]], %[[VAL_6]] : f32
+// CHECK: %[[VAL_16:.*]] = sparse_tensor.insert %[[VAL_15]] into %[[VAL_13]]{{\[}}%[[VAL_9]], %[[VAL_12]]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK: scf.yield %[[VAL_16]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: scf.yield %[[VAL_11]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: %[[VAL_17:.*]] = sparse_tensor.load %[[VAL_8]] hasInserts : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK: return %[[VAL_17]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
// CHECK: }
func.func @dense2(%arga: tensor<32x16xf32>,
%argx: tensor<32x16xf32, #DenseMatrix>)
@@ -114,31 +114,30 @@ func.func @dense2(%arga: tensor<32x16xf32>,
// The missing innermost "k" index (due to a reduction) is accounted
// for by scalarizing the reduction operation for the output tensor.
//
-// CHECK-LABEL: func @dense3(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32, #sparse{{[0-9]*}}>) -> tensor<32x16xf32, #sparse{{[0-9]*}}> {
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 32 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK-LABEL: func.func @dense3(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32, #sparse{{[0-9]*}}>) -> tensor<32x16xf32, #sparse{{[0-9]*}}> {
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 32 : index
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_0]] : memref<32x16x8xf32>
-// CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
-// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
-// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_9]], %[[VAL_4]] : index
-// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_10]] : index
-// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xf32>
-// CHECK: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_6]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f32) {
-// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_9]], %[[VAL_10]], %[[VAL_15]]] : memref<32x16x8xf32>
-// CHECK: %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_17]] : f32
-// CHECK: scf.yield %[[VAL_18]] : f32
-// CHECK: }
-// CHECK: memref.store %[[VAL_19:.*]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xf32>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[VAL_20:.*]] = sparse_tensor.load %[[VAL_1]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
-// CHECK: return %[[VAL_20]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<32x16x8xf32>
+// CHECK: %[[VAL_9:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_6]] iter_args(%[[VAL_11:.*]] = %[[VAL_1]]) -> (tensor<32x16xf32, #sparse{{[0-9]*}}>) {
+// CHECK: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (tensor<32x16xf32, #sparse{{[0-9]*}}>) {
+// CHECK: %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] iter_args(%[[VAL_17:.*]] = %[[VAL_7]]) -> (f32) {
+// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_10]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
+// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f32
+// CHECK: scf.yield %[[VAL_19]] : f32
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: %[[VAL_20:.*]] = sparse_tensor.insert %[[VAL_15]] into %[[VAL_14]]{{\[}}%[[VAL_10]], %[[VAL_13]]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK: scf.yield %[[VAL_20]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: scf.yield %[[VAL_12]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: %[[VAL_21:.*]] = sparse_tensor.load %[[VAL_9]] hasInserts : tensor<32x16xf32, #sparse{{[0-9]*}}>
+// CHECK: return %[[VAL_21]] : tensor<32x16xf32, #sparse{{[0-9]*}}>
// CHECK: }
func.func @dense3(%arga: tensor<32x16x8xf32>,
%argx: tensor<32x16xf32, #DenseMatrix>)
diff --git a/mlir/test/Dialect/SparseTensor/one_trip.mlir b/mlir/test/Dialect/SparseTensor/one_trip.mlir
index 9d2a125dbef171..ab10222d733726 100644
--- a/mlir/test/Dialect/SparseTensor/one_trip.mlir
+++ b/mlir/test/Dialect/SparseTensor/one_trip.mlir
@@ -12,16 +12,14 @@
doc = "X(i,j) = X(i,j) * 2.0"
}
-// CHECK-LABEL: func.func @sparse_scale(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1xf32, #sparse{{[0-9]*}}>)
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK: %[[VAL_3:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<1x1xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK: %[[VAL_4:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<?xf32>
-// CHECK: %[[VAL_5:.*]] = arith.mulf %[[VAL_4]], %[[VAL_2]] : f32
-// CHECK: memref.store %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<?xf32>
-// CHECK: %[[VAL_6:.*]] = sparse_tensor.load %[[VAL_0]] : tensor<1x1xf32, #sparse{{[0-9]*}}>
-// CHECK: return %[[VAL_6]] : tensor<1x1xf32, #sparse{{[0-9]*}}>
+// CHECK-LABEL: func.func @sparse_scale(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1xf32, #sparse{{[0-9]*}}>) -> tensor<1x1xf32, #sparse{{[0-9]*}}> {
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_3:.*]] = sparse_tensor.insert %[[VAL_2]] into %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : tensor<1x1xf32, #sparse{{[0-9]*}}>
+// CHECK: %[[VAL_4:.*]] = sparse_tensor.load %[[VAL_3]] hasInserts : tensor<1x1xf32, #sparse{{[0-9]*}}>
+// CHECK: return %[[VAL_4]] : tensor<1x1xf32, #sparse{{[0-9]*}}>
+// CHECK: }
func.func @sparse_scale(%argx: tensor<1x1xf32, #Dense>) -> tensor<1x1xf32, #Dense> {
%c = arith.constant 2.0 : f32
%0 = linalg.generic #trait_scale
diff --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
index aa75e20460f51f..05fa4ca22c1b83 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir
@@ -62,20 +62,20 @@ func.func @mul_inv_dense1d(%arga: tensor<32xf32, #SpVec>,
// CHECK: %[[VAL_3:.*]] = arith.constant 3 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<32xf32, #sparse{{[0-9]*}}>
-// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<4xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_6]] : tensor<32xf32, #sparse{{[0-9]*}}> to memref<?xf32>
-// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref<?xf32>
-// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_5]] {
-// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_11]]] : memref<?xf32>
-// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xf32>
-// CHECK: %[[VAL_14:.*]] = arith.mulf %[[VAL_13]], %[[VAL_10]] : f32
-// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_12]], %[[VAL_14]] : f32
-// CHECK: memref.store %[[VAL_15]], %[[VAL_9]]{{\[}}%[[VAL_11]]] : memref<?xf32>
-// CHECK: }
-// CHECK: %[[VAL_16:.*]] = sparse_tensor.load %[[VAL_6]] : tensor<32xf32, #sparse{{[0-9]*}}>
-// CHECK: return %[[VAL_16]] : tensor<32xf32, #sparse{{[0-9]*}}>
+// CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<32xf32, #sparse{{[0-9]*}}>
+// CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse{{[0-9]*}}> to memref<?xf32>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<4xf32, #sparse{{[0-9]*}}> to memref<?xf32>
+// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xf32>
+// CHECK: %[[VAL_11:.*]] = scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_5]] iter_args(%[[VAL_13:.*]] = %[[VAL_7]]) -> (tensor<32xf32, #sparse{{[0-9]*}}>) {
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xf32>
+// CHECK: %[[VAL_15:.*]] = arith.mulf %[[VAL_14]], %[[VAL_10]] : f32
+// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_15]], %[[VAL_6]] : f32
+// CHECK: %[[VAL_17:.*]] = sparse_tensor.insert %[[VAL_16]] into %[[VAL_13]]{{\[}}%[[VAL_12]]] : tensor<32xf32, #sparse{{[0-9]*}}>
+// CHECK: scf.yield %[[VAL_17]] : tensor<32xf32, #sparse{{[0-9]*}}>
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: %[[VAL_18:.*]] = sparse_tensor.load %[[VAL_11]] hasInserts : tensor<32xf32, #sparse{{[0-9]*}}>
+// CHECK: return %[[VAL_18]] : tensor<32xf32, #sparse{{[0-9]*}}>
// CHECK: }
func.func @mul_inv_enc_dense1d(%arga: tensor<32xf32, #EncDenseVec>,
%argb: tensor<4xf32, #EncDenseVec>) -> tensor<32xf32, #EncDenseVec> {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
index b09bd0a7400941..0d8e10a8541e20 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
@@ -18,33 +18,31 @@
}
// CHECK-LABEL: func.func @dense_index(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi64, #sparse{{[0-9]*}}>) -> tensor<?x?xi64, #sparse{{[0-9]*}}> {
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_3:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
// CHECK-DAG: %[[VAL_4:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
// CHECK-DAG: %[[VAL_5:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_4]]) : tensor<?x?xi64, #sparse{{[0-9]*}}>
-// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
-// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
-// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
-// CHECK-DAG: %[[VAL_24:.*]] = sparse_tensor.lvl %[[VAL_5]], %[[VAL_2]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
-// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_5]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
-// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_1]] to %[[VAL_7]] step %[[VAL_2]] {
-// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_8]] step %[[VAL_2]] {
-// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_8]], %[[VAL_10]] : index
-// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_11]] : index
-// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_24]], %[[VAL_10]] : index
-// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_11]] : index
-// CHECK: %[[VAL_16:.*]] = arith.index_cast %[[VAL_11]] : index to i64
-// CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_10]] : index to i64
-// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xi64>
-// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_18]] : i64
-// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_16]], %[[VAL_19]] : i64
-// CHECK: memref.store %[[VAL_20]], %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xi64>
-// CHECK: }
-// CHECK: }
-// CHECK: %[[VAL_21:.*]] = sparse_tensor.load %[[VAL_5]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
-// CHECK: return %[[VAL_21]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse{{[0-9]*}}> to memref<?xi64>
+// CHECK: %[[VAL_9:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_1]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_11:.*]] = %[[VAL_5]]) -> (tensor<?x?xi64, #sparse{{[0-9]*}}>) {
+// CHECK: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_1]] to %[[VAL_7]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (tensor<?x?xi64, #sparse{{[0-9]*}}>) {
+// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_7]], %[[VAL_10]] : index
+// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_15]], %[[VAL_13]] : index
+// CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_13]] : index to i64
+// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_10]] : index to i64
+// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xi64>
+// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_18]], %[[VAL_19]] : i64
+// CHECK: %[[VAL_21:.*]] = arith.muli %[[VAL_17]], %[[VAL_20]] : i64
+// CHECK: %[[VAL_22:.*]] = sparse_tensor.insert %[[VAL_21]] into %[[VAL_14]]{{\[}}%[[VAL_10]], %[[VAL_13]]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK: scf.yield %[[VAL_22]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: scf.yield %[[VAL_12]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: %[[VAL_23:.*]] = sparse_tensor.load %[[VAL_9]] hasInserts : tensor<?x?xi64, #sparse{{[0-9]*}}>
+// CHECK: return %[[VAL_23]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
// CHECK: }
func.func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
-> tensor<?x?xi64, #DenseMatrix> {
More information about the Mlir-commits
mailing list