[Mlir-commits] [mlir] 11069cb - [mlir][sparse] refactoring: split translateIndices.
Peiming Liu
llvmlistbot at llvm.org
Thu Sep 29 16:59:48 PDT 2022
Author: Peiming Liu
Date: 2022-09-29T23:59:39Z
New Revision: 11069cbcb47845074d526490fff8daff8afda11d
URL: https://github.com/llvm/llvm-project/commit/11069cbcb47845074d526490fff8daff8afda11d
DIFF: https://github.com/llvm/llvm-project/commit/11069cbcb47845074d526490fff8daff8afda11d.diff
LOG: [mlir][sparse] refactoring: split translateIndices.
TranslateIndicesArray take an array of SSA value and convert them into another array of SSA values based on reassociation. Which makes it easier to be reused by `foreach` operator (as the indices array are given as an array of SSA values).
Reviewed By: aartbik, bixia
Differential Revision: https://reviews.llvm.org/D134918
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 9f9bd918c9c8..62c73998d136 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -199,3 +199,52 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
return builder.create<complex::NotEqualOp>(loc, v, zero);
llvm_unreachable("Non-numeric type");
}
+
+void mlir::sparse_tensor::translateIndicesArray(
+ OpBuilder &builder, Location loc,
+ ArrayRef<ReassociationIndices> reassociation, ValueRange srcIndices,
+ ArrayRef<Value> srcShape, ArrayRef<Value> dstShape,
+ SmallVectorImpl<Value> &dstIndices) {
+ unsigned i = 0;
+ unsigned start = 0;
+ unsigned dstRank = dstShape.size();
+ unsigned srcRank = srcShape.size();
+ assert(srcRank == srcIndices.size());
+ bool isCollapse = srcRank > dstRank;
+ ArrayRef<Value> shape = isCollapse ? srcShape : dstShape;
+ // Iterate over reassociation map.
+ for (const auto &map : llvm::enumerate(reassociation)) {
+ // Prepare strides information in dimension slice.
+ Value linear = constantIndex(builder, loc, 1);
+ for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
+ linear = builder.create<arith::MulIOp>(loc, linear, shape[j]);
+ }
+ // Start expansion.
+ Value val;
+ if (!isCollapse)
+ val = srcIndices[i];
+ // Iterate over dimension slice.
+ for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
+ linear = builder.create<arith::DivUIOp>(loc, linear, shape[j]);
+ if (isCollapse) {
+ Value old = srcIndices[j];
+ Value mul = builder.create<arith::MulIOp>(loc, old, linear);
+ val = val ? builder.create<arith::AddIOp>(loc, val, mul) : mul;
+ } else {
+ Value old = val;
+ val = builder.create<arith::DivUIOp>(loc, val, linear);
+ assert(dstIndices.size() == j);
+ dstIndices.push_back(val);
+ val = builder.create<arith::RemUIOp>(loc, old, linear);
+ }
+ }
+ // Finalize collapse.
+ if (isCollapse) {
+ assert(dstIndices.size() == i);
+ dstIndices.push_back(val);
+ }
+ start += map.value().size();
+ i++;
+ }
+ assert(dstIndices.size() == dstRank);
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 7257ca5af078..d074f43e737f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -16,7 +16,9 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/ExecutionEngine/SparseTensor/Enums.h"
+#include "mlir/ExecutionEngine/SparseTensorUtils.h"
#include "mlir/IR/Builders.h"
namespace mlir {
@@ -193,6 +195,12 @@ constantDimLevelTypeEncoding(OpBuilder &builder, Location loc,
static_cast<uint8_t>(dimLevelTypeEncoding(dlt)));
}
+/// Helper method to translate indices during a reshaping operation.
+void translateIndicesArray(OpBuilder &builder, Location loc,
+ ArrayRef<ReassociationIndices> reassociation,
+ ValueRange srcIndices, ArrayRef<Value> srcShape,
+ ArrayRef<Value> dstShape,
+ SmallVectorImpl<Value> &dstIndices);
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 6b2c4611d454..4347db4e860e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -475,44 +475,21 @@ static void translateIndices(Location loc, ConversionPatternRewriter &rewriter,
ArrayRef<Value> srcShape) {
unsigned dstRank = dstTp.getRank();
unsigned srcRank = srcTp.getRank();
- unsigned start = 0;
- unsigned i = 0;
- bool isExpand = srcRank > dstRank;
- ArrayRef<Value> shape = isExpand ? srcShape : dstShape;
- // Iterate over reassociation map.
- for (const auto &map : llvm::enumerate(reassociation)) {
- // Prepare strides information in dimension slice.
- Value linear = constantIndex(rewriter, loc, 1);
- for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
- linear = rewriter.create<arith::MulIOp>(loc, linear, shape[j]);
- }
- // Start collapse.
- Value idx = constantIndex(rewriter, loc, i++);
- Value val;
- if (!isExpand)
- val = rewriter.create<memref::LoadOp>(loc, srcIdx, idx);
- // Iterate over dimension slice.
- for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
- linear = rewriter.create<arith::DivUIOp>(loc, linear, shape[j]);
- Value jdx = constantIndex(rewriter, loc, j);
- if (isExpand) {
- Value old = rewriter.create<memref::LoadOp>(loc, srcIdx, jdx);
- Value mul = rewriter.create<arith::MulIOp>(loc, old, linear);
- val = val ? rewriter.create<arith::AddIOp>(loc, val, mul) : mul;
- } else {
- Value old = val;
- val = rewriter.create<arith::DivUIOp>(loc, val, linear);
- rewriter.create<memref::StoreOp>(loc, val, dstIdx, jdx);
- val = rewriter.create<arith::RemUIOp>(loc, old, linear);
- }
- }
- // Finalize expansion.
- if (isExpand)
- rewriter.create<memref::StoreOp>(loc, val, dstIdx, idx);
- start += map.value().size();
+
+ SmallVector<Value, 4> srcIndices;
+ for (unsigned i = 0; i < srcRank; i++) {
+ Value idx = rewriter.create<memref::LoadOp>(
+ loc, srcIdx, constantIndex(rewriter, loc, i));
+ srcIndices.push_back(idx);
}
- // Sanity.
- assert((isExpand && i == dstRank) || (!isExpand && i == srcRank));
+
+ SmallVector<Value, 4> dstIndices;
+ translateIndicesArray(rewriter, loc, reassociation, srcIndices, srcShape,
+ dstShape, dstIndices);
+
+ for (unsigned i = 0; i < dstRank; i++)
+ rewriter.create<memref::StoreOp>(loc, dstIndices[i], dstIdx,
+ constantIndex(rewriter, loc, i));
}
/// Helper method to compute the shape of destination tensor of a reshape
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index c58e34be6583..5848cbb44a01 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -26,8 +26,8 @@
// CHECK-CONV: } do {
// CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<1xindex>
// CHECK-CONV: %[[D:.*]] = arith.divui %[[X]], %[[C10]] : index
-// CHECK-CONV: memref.store %[[D]], %{{.*}}[%[[C0]]] : memref<2xindex>
// CHECK-CONV: %[[R:.*]] = arith.remui %[[X]], %[[C10]] : index
+// CHECK-CONV: memref.store %[[D]], %{{.*}}[%[[C0]]] : memref<2xindex>
// CHECK-CONV: memref.store %[[R]], %{{.*}}[%[[C1]]] : memref<2xindex>
// CHECK-CONV: call @addEltF64
// CHECK-CONV: scf.yield
@@ -64,8 +64,8 @@ func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10x
// CHECK-CONV: scf.condition
// CHECK-CONV: } do {
// CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<2xindex>
-// CHECK-CONV: %[[M:.*]] = arith.muli %[[X]], %[[C10]] : index
// CHECK-CONV: %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<2xindex>
+// CHECK-CONV: %[[M:.*]] = arith.muli %[[X]], %[[C10]] : index
// CHECK-CONV: %[[A:.*]] = arith.addi %[[M]], %[[Y]] : index
// CHECK-CONV: memref.store %[[A]], %{{.*}}[%[[C0]]] : memref<1xindex>
// CHECK-CONV: call @addEltF64
@@ -103,14 +103,14 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
// CHECK-CONV: call @getNextF64
// CHECK-CONV: scf.condition
// CHECK-CONV: } do {
-// CHECK-CONV: %[[M:.*]] = arith.muli %[[D1]], %[[C10]] : index
// CHECK-CONV: %[[L:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<1xindex>
+// CHECK-CONV: %[[M:.*]] = arith.muli %[[D1]], %[[C10]] : index
// CHECK-CONV: %[[D2:.*]] = arith.divui %[[M]], %[[D1]] : index
// CHECK-CONV: %[[D3:.*]] = arith.divui %[[L]], %[[D2]] : index
-// CHECK-CONV: memref.store %[[D3]], %{{.*}}[%[[C0]]] : memref<2xindex>
// CHECK-CONV: %[[R:.*]] = arith.remui %[[L]], %[[D2]] : index
// CHECK-CONV: %[[D4:.*]] = arith.divui %[[D2]], %[[C10]] : index
// CHECK-CONV: %[[D5:.*]] = arith.divui %[[R]], %[[D4]] : index
+// CHECK-CONV: memref.store %[[D3]], %{{.*}}[%[[C0]]] : memref<2xindex>
// CHECK-CONV: memref.store %[[D5]], %{{.*}}[%[[C1]]] : memref<2xindex>
// CHECK-CONV: call @addEltF64
// CHECK-CONV: scf.yield
@@ -147,11 +147,11 @@ func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<
// CHECK-CONV: call @getNextF64
// CHECK-CONV: scf.condition
// CHECK-CONV: } do {
-// CHECK-CONV: %[[D1:.*]] = arith.divui %[[M1]], %[[C10]] : index
// CHECK-CONV: %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<2xindex>
+// CHECK-CONV: %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<2xindex>
+// CHECK-CONV: %[[D1:.*]] = arith.divui %[[M1]], %[[C10]] : index
// CHECK-CONV: %[[M2:.*]] = arith.muli %[[X]], %[[D1]] : index
// CHECK-CONV: %[[D2:.*]] = arith.divui %[[D1]], %{{.*}} : index
-// CHECK-CONV: %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<2xindex>
// CHECK-CONV: %[[M3:.*]] = arith.muli %[[Y]], %[[D2]] : index
// CHECK-CONV: %[[A:.*]] = arith.addi %[[M2]], %[[M3]] : index
// CHECK-CONV: memref.store %[[A]], %{{.*}}[%[[C0]]] : memref<1xindex>
More information about the Mlir-commits
mailing list