[Mlir-commits] [mlir] 11069cb - [mlir][sparse] refactoring: split translateIndices.

Peiming Liu llvmlistbot at llvm.org
Thu Sep 29 16:59:48 PDT 2022

Author: Peiming Liu
Date: 2022-09-29T23:59:39Z
New Revision: 11069cbcb47845074d526490fff8daff8afda11d

URL: https://github.com/llvm/llvm-project/commit/11069cbcb47845074d526490fff8daff8afda11d
DIFF: https://github.com/llvm/llvm-project/commit/11069cbcb47845074d526490fff8daff8afda11d.diff

LOG: [mlir][sparse] refactoring: split translateIndices.

TranslateIndicesArray take an array of SSA value and convert them into another array of SSA values based on reassociation. Which makes it easier to be reused by `foreach` operator (as the indices array are given as an array of SSA values).

Reviewed By: aartbik, bixia

Differential Revision: https://reviews.llvm.org/D134918




diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 9f9bd918c9c8..62c73998d136 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -199,3 +199,52 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
     return builder.create<complex::NotEqualOp>(loc, v, zero);
   llvm_unreachable("Non-numeric type");
+void mlir::sparse_tensor::translateIndicesArray(
+    OpBuilder &builder, Location loc,
+    ArrayRef<ReassociationIndices> reassociation, ValueRange srcIndices,
+    ArrayRef<Value> srcShape, ArrayRef<Value> dstShape,
+    SmallVectorImpl<Value> &dstIndices) {
+  unsigned i = 0;
+  unsigned start = 0;
+  unsigned dstRank = dstShape.size();
+  unsigned srcRank = srcShape.size();
+  assert(srcRank == srcIndices.size());
+  bool isCollapse = srcRank > dstRank;
+  ArrayRef<Value> shape = isCollapse ? srcShape : dstShape;
+  // Iterate over reassociation map.
+  for (const auto &map : llvm::enumerate(reassociation)) {
+    // Prepare strides information in dimension slice.
+    Value linear = constantIndex(builder, loc, 1);
+    for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
+      linear = builder.create<arith::MulIOp>(loc, linear, shape[j]);
+    }
+    // Start expansion.
+    Value val;
+    if (!isCollapse)
+      val = srcIndices[i];
+    // Iterate over dimension slice.
+    for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
+      linear = builder.create<arith::DivUIOp>(loc, linear, shape[j]);
+      if (isCollapse) {
+        Value old = srcIndices[j];
+        Value mul = builder.create<arith::MulIOp>(loc, old, linear);
+        val = val ? builder.create<arith::AddIOp>(loc, val, mul) : mul;
+      } else {
+        Value old = val;
+        val = builder.create<arith::DivUIOp>(loc, val, linear);
+        assert(dstIndices.size() == j);
+        dstIndices.push_back(val);
+        val = builder.create<arith::RemUIOp>(loc, old, linear);
+      }
+    }
+    // Finalize collapse.
+    if (isCollapse) {
+      assert(dstIndices.size() == i);
+      dstIndices.push_back(val);
+    }
+    start += map.value().size();
+    i++;
+  }
+  assert(dstIndices.size() == dstRank);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 7257ca5af078..d074f43e737f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -16,7 +16,9 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/ExecutionEngine/SparseTensor/Enums.h"
+#include "mlir/ExecutionEngine/SparseTensorUtils.h"
 #include "mlir/IR/Builders.h"
 namespace mlir {
@@ -193,6 +195,12 @@ constantDimLevelTypeEncoding(OpBuilder &builder, Location loc,
+/// Helper method to translate indices during a reshaping operation.
+void translateIndicesArray(OpBuilder &builder, Location loc,
+                           ArrayRef<ReassociationIndices> reassociation,
+                           ValueRange srcIndices, ArrayRef<Value> srcShape,
+                           ArrayRef<Value> dstShape,
+                           SmallVectorImpl<Value> &dstIndices);
 } // namespace sparse_tensor
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 6b2c4611d454..4347db4e860e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -475,44 +475,21 @@ static void translateIndices(Location loc, ConversionPatternRewriter &rewriter,
                              ArrayRef<Value> srcShape) {
   unsigned dstRank = dstTp.getRank();
   unsigned srcRank = srcTp.getRank();
-  unsigned start = 0;
-  unsigned i = 0;
-  bool isExpand = srcRank > dstRank;
-  ArrayRef<Value> shape = isExpand ? srcShape : dstShape;
-  // Iterate over reassociation map.
-  for (const auto &map : llvm::enumerate(reassociation)) {
-    // Prepare strides information in dimension slice.
-    Value linear = constantIndex(rewriter, loc, 1);
-    for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
-      linear = rewriter.create<arith::MulIOp>(loc, linear, shape[j]);
-    }
-    // Start collapse.
-    Value idx = constantIndex(rewriter, loc, i++);
-    Value val;
-    if (!isExpand)
-      val = rewriter.create<memref::LoadOp>(loc, srcIdx, idx);
-    // Iterate over dimension slice.
-    for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
-      linear = rewriter.create<arith::DivUIOp>(loc, linear, shape[j]);
-      Value jdx = constantIndex(rewriter, loc, j);
-      if (isExpand) {
-        Value old = rewriter.create<memref::LoadOp>(loc, srcIdx, jdx);
-        Value mul = rewriter.create<arith::MulIOp>(loc, old, linear);
-        val = val ? rewriter.create<arith::AddIOp>(loc, val, mul) : mul;
-      } else {
-        Value old = val;
-        val = rewriter.create<arith::DivUIOp>(loc, val, linear);
-        rewriter.create<memref::StoreOp>(loc, val, dstIdx, jdx);
-        val = rewriter.create<arith::RemUIOp>(loc, old, linear);
-      }
-    }
-    // Finalize expansion.
-    if (isExpand)
-      rewriter.create<memref::StoreOp>(loc, val, dstIdx, idx);
-    start += map.value().size();
+  SmallVector<Value, 4> srcIndices;
+  for (unsigned i = 0; i < srcRank; i++) {
+    Value idx = rewriter.create<memref::LoadOp>(
+        loc, srcIdx, constantIndex(rewriter, loc, i));
+    srcIndices.push_back(idx);
-  // Sanity.
-  assert((isExpand && i == dstRank) || (!isExpand && i == srcRank));
+  SmallVector<Value, 4> dstIndices;
+  translateIndicesArray(rewriter, loc, reassociation, srcIndices, srcShape,
+                        dstShape, dstIndices);
+  for (unsigned i = 0; i < dstRank; i++)
+    rewriter.create<memref::StoreOp>(loc, dstIndices[i], dstIdx,
+                                     constantIndex(rewriter, loc, i));
 /// Helper method to compute the shape of destination tensor of a reshape

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index c58e34be6583..5848cbb44a01 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -26,8 +26,8 @@
 // CHECK-CONV:      } do {
 // CHECK-CONV:        %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<1xindex>
 // CHECK-CONV:        %[[D:.*]] = arith.divui %[[X]], %[[C10]] : index
-// CHECK-CONV:        memref.store %[[D]], %{{.*}}[%[[C0]]] : memref<2xindex>
 // CHECK-CONV:        %[[R:.*]] = arith.remui %[[X]], %[[C10]] : index
+// CHECK-CONV:        memref.store %[[D]], %{{.*}}[%[[C0]]] : memref<2xindex>
 // CHECK-CONV:        memref.store %[[R]], %{{.*}}[%[[C1]]] : memref<2xindex>
 // CHECK-CONV:        call @addEltF64
 // CHECK-CONV:        scf.yield
@@ -64,8 +64,8 @@ func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10x
 // CHECK-CONV:        scf.condition
 // CHECK-CONV:      } do {
 // CHECK-CONV:        %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<2xindex>
-// CHECK-CONV:        %[[M:.*]] = arith.muli %[[X]], %[[C10]] : index
 // CHECK-CONV:        %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<2xindex>
+// CHECK-CONV:        %[[M:.*]] = arith.muli %[[X]], %[[C10]] : index
 // CHECK-CONV:        %[[A:.*]] = arith.addi %[[M]], %[[Y]] : index
 // CHECK-CONV:        memref.store %[[A]], %{{.*}}[%[[C0]]] : memref<1xindex>
 // CHECK-CONV:        call @addEltF64
@@ -103,14 +103,14 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
 // CHECK-CONV:        call @getNextF64
 // CHECK-CONV:        scf.condition
 // CHECK-CONV:      } do {
-// CHECK-CONV:        %[[M:.*]] = arith.muli %[[D1]], %[[C10]] : index
 // CHECK-CONV:        %[[L:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<1xindex>
+// CHECK-CONV:        %[[M:.*]] = arith.muli %[[D1]], %[[C10]] : index
 // CHECK-CONV:        %[[D2:.*]] = arith.divui %[[M]], %[[D1]] : index
 // CHECK-CONV:        %[[D3:.*]] = arith.divui %[[L]], %[[D2]] : index
-// CHECK-CONV:        memref.store %[[D3]], %{{.*}}[%[[C0]]] : memref<2xindex>
 // CHECK-CONV:        %[[R:.*]] = arith.remui %[[L]], %[[D2]] : index
 // CHECK-CONV:        %[[D4:.*]] = arith.divui %[[D2]], %[[C10]] : index
 // CHECK-CONV:        %[[D5:.*]] = arith.divui %[[R]], %[[D4]] : index
+// CHECK-CONV:        memref.store %[[D3]], %{{.*}}[%[[C0]]] : memref<2xindex>
 // CHECK-CONV:        memref.store %[[D5]], %{{.*}}[%[[C1]]] : memref<2xindex>
 // CHECK-CONV:        call @addEltF64
 // CHECK-CONV:        scf.yield
@@ -147,11 +147,11 @@ func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<
 // CHECK-CONV:        call @getNextF64
 // CHECK-CONV:        scf.condition
 // CHECK-CONV:      } do {
-// CHECK-CONV:        %[[D1:.*]] = arith.divui %[[M1]], %[[C10]] : index
 // CHECK-CONV:        %[[X:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<2xindex>
+// CHECK-CONV:        %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<2xindex>
+// CHECK-CONV:        %[[D1:.*]] = arith.divui %[[M1]], %[[C10]] : index
 // CHECK-CONV:        %[[M2:.*]] = arith.muli %[[X]], %[[D1]] : index
 // CHECK-CONV:        %[[D2:.*]] = arith.divui %[[D1]], %{{.*}} : index
-// CHECK-CONV:        %[[Y:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<2xindex>
 // CHECK-CONV:        %[[M3:.*]] = arith.muli %[[Y]], %[[D2]] : index
 // CHECK-CONV:        %[[A:.*]] = arith.addi %[[M2]], %[[M3]] : index
 // CHECK-CONV:        memref.store %[[A]], %{{.*}}[%[[C0]]] : memref<1xindex>


More information about the Mlir-commits mailing list