[Mlir-commits] [mlir] 349bceb - [mlir][sparse] Refactor the conversion of the tensor reshape operators.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 3 11:06:57 PDT 2022


Author: bixia1
Date: 2022-10-03T11:06:49-07:00
New Revision: 349bceba65adb224ddeebdd9e10dea9bc1f33a25

URL: https://github.com/llvm/llvm-project/commit/349bceba65adb224ddeebdd9e10dea9bc1f33a25
DIFF: https://github.com/llvm/llvm-project/commit/349bceba65adb224ddeebdd9e10dea9bc1f33a25.diff

LOG: [mlir][sparse] Refactor the conversion of the tensor reshape operators.

Move genReshapeDstShape to codegen utils to support the rewriting of the tensor
reshape operators for the codegen path.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D135074

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 89ae924cf8b27..c850044fa86f6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -425,6 +425,61 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
   llvm_unreachable("Non-numeric type");
 }
 
+void mlir::sparse_tensor::genReshapeDstShape(
+    Location loc, PatternRewriter &rewriter, SmallVector<Value, 4> &dstShape,
+    ArrayRef<Value> srcShape, ArrayRef<int64_t> staticDstShape,
+    ArrayRef<ReassociationIndices> reassociation) {
+  // Collapse shape.
+  if (reassociation.size() < srcShape.size()) {
+    unsigned start = 0;
+    for (const auto &map : llvm::enumerate(reassociation)) {
+      auto dstDim = constantIndex(rewriter, loc, 1);
+      for (unsigned i = start; i < start + map.value().size(); i++) {
+        dstDim = rewriter.create<arith::MulIOp>(loc, dstDim, srcShape[i]);
+      }
+      dstShape.push_back(dstDim);
+      start = start + map.value().size();
+    }
+    assert(start == srcShape.size());
+    return;
+  }
+
+  // Expand shape.
+  assert(reassociation.size() == srcShape.size());
+  unsigned start = 0;
+  // Expand the i-th dimension in srcShape.
+  for (unsigned i = 0, size = srcShape.size(); i < size; i++) {
+    auto map = reassociation[i];
+    auto srcDim = srcShape[i];
+    // Iterate through dimensions expanded from the i-th dimension.
+    for (unsigned j = start; j < start + map.size(); j++) {
+      // There can be only one dynamic sized dimension among dimensions expanded
+      // from the i-th dimension in srcShape. For example, if srcDim = 8, then
+      // the expanded shape could be <2x?x2>, but not <2x?x?>.
+      if (staticDstShape[j] == ShapedType::kDynamicSize) {
+        // The expanded dimension has dynamic size. We compute the dimension
+        // by dividing srcDim by the product of the static dimensions.
+        int64_t product = 1;
+        for (unsigned k = start; k < start + map.size(); k++) {
+          if (staticDstShape[k] != ShapedType::kDynamicSize) {
+            product *= staticDstShape[k];
+          }
+        }
+        // Compute the dynamic dimension size.
+        Value productVal = constantIndex(rewriter, loc, product);
+        Value dynamicSize =
+            rewriter.create<arith::DivUIOp>(loc, srcDim, productVal);
+        dstShape.push_back(dynamicSize);
+      } else {
+        // The expanded dimension is statically known.
+        dstShape.push_back(constantIndex(rewriter, loc, staticDstShape[j]));
+      }
+    }
+    start = start + map.size();
+  }
+  assert(start == staticDstShape.size());
+}
+
 void mlir::sparse_tensor::translateIndicesArray(
     OpBuilder &builder, Location loc,
     ArrayRef<ReassociationIndices> reassociation, ValueRange srcIndices,

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 63456a8bcc2cb..a9ea41771c0b1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -313,6 +313,15 @@ constantDimLevelTypeEncoding(OpBuilder &builder, Location loc,
                     static_cast<uint8_t>(dimLevelTypeEncoding(dlt)));
 }
 
+/// Computes the shape of destination tensor of a reshape operator. This is only
+/// used when operands have dynamic shape. The shape of the destination is
+/// stored into dstShape.
+void genReshapeDstShape(Location loc, PatternRewriter &rewriter,
+                        SmallVector<Value, 4> &dstShape,
+                        ArrayRef<Value> srcShape,
+                        ArrayRef<int64_t> staticDstShape,
+                        ArrayRef<ReassociationIndices> reassociation);
+
 /// Helper method to translate indices during a reshaping operation.
 void translateIndicesArray(OpBuilder &builder, Location loc,
                            ArrayRef<ReassociationIndices> reassociation,

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 4347db4e860ed..87304173d732c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -492,65 +492,6 @@ static void translateIndices(Location loc, ConversionPatternRewriter &rewriter,
                                      constantIndex(rewriter, loc, i));
 }
 
-/// Helper method to compute the shape of destination tensor of a reshape
-/// operator. This is only used when operands have dynamic shape. The shape of
-/// the destination is stored into dstShape.
-void genReshapeDstShape(Location loc, ConversionPatternRewriter &rewriter,
-                        SmallVector<Value, 4> &dstShape,
-                        ArrayRef<Value> srcShape,
-                        ArrayRef<int64_t> staticDstShape,
-                        ArrayRef<ReassociationIndices> reassociation) {
-  // Collapse shape.
-  if (reassociation.size() < srcShape.size()) {
-    unsigned start = 0;
-    for (const auto &map : llvm::enumerate(reassociation)) {
-      auto dstDim = constantIndex(rewriter, loc, 1);
-      for (unsigned i = start; i < start + map.value().size(); i++) {
-        dstDim = rewriter.create<arith::MulIOp>(loc, dstDim, srcShape[i]);
-      }
-      dstShape.push_back(dstDim);
-      start = start + map.value().size();
-    }
-    assert(start == srcShape.size());
-    return;
-  }
-
-  // Expand shape.
-  assert(reassociation.size() == srcShape.size());
-  unsigned start = 0;
-  // Expand the i-th dimension in srcShape.
-  for (unsigned i = 0, size = srcShape.size(); i < size; i++) {
-    auto map = reassociation[i];
-    auto srcDim = srcShape[i];
-    // Iterate through dimensions expanded from the i-th dimension.
-    for (unsigned j = start; j < start + map.size(); j++) {
-      // There can be only one dynamic sized dimension among dimensions expanded
-      // from the i-th dimension in srcShape. For example, if srcDim = 8, then
-      // the expanded shape could be <2x?x2>, but not <2x?x?>.
-      if (staticDstShape[j] == ShapedType::kDynamicSize) {
-        // The expanded dimension has dynamic size. We compute the dimension
-        // by dividing srcDim by the product of the static dimensions.
-        int64_t product = 1;
-        for (unsigned k = start; k < start + map.size(); k++) {
-          if (staticDstShape[k] != ShapedType::kDynamicSize) {
-            product *= staticDstShape[k];
-          }
-        }
-        // Compute the dynamic dimension size.
-        Value productVal = constantIndex(rewriter, loc, product);
-        Value dynamicSize =
-            rewriter.create<arith::DivUIOp>(loc, srcDim, productVal);
-        dstShape.push_back(dynamicSize);
-      } else {
-        // The expanded dimension is statically known.
-        dstShape.push_back(constantIndex(rewriter, loc, staticDstShape[j]));
-      }
-    }
-    start = start + map.size();
-  }
-  assert(start == staticDstShape.size());
-}
-
 /// Generate code for a general sparse to sparse reshaping operation.
 /// Note that unlike dense reshaping (which can be done with a "cheap"
 /// change of view), sparse reshaping is currently done with actual


        


More information about the Mlir-commits mailing list