[Mlir-commits] [mlir] e7d3ba1 - [mlir][sparse] accept sparse reshape (expand/collapse)

Aart Bik llvmlistbot at llvm.org
Wed Jun 22 09:40:45 PDT 2022


Author: Aart Bik
Date: 2022-06-22T09:40:38-07:00
New Revision: e7d3ba1066c825c07b71deae512a40731769a963

URL: https://github.com/llvm/llvm-project/commit/e7d3ba1066c825c07b71deae512a40731769a963
DIFF: https://github.com/llvm/llvm-project/commit/e7d3ba1066c825c07b71deae512a40731769a963.diff

LOG: [mlir][sparse] accept sparse reshape (expand/collapse)

This revision makes sure we accept sparse tensors as arguments
of the expand/collapse reshaping operations in the tensor dialect.
Note that the actual lowering to runnable IR is still TBD.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D128311

Added: 
    mlir/test/Dialect/SparseTensor/sparse_reshape.mlir

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 6352db7e31be..d9f2145d025c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -844,6 +844,18 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                       getReassociationIndicesAttribute(b, reassociation));
 }
 
+// Checks if types are the same, but ignoring encoding on ranked tensors.
+static bool isSameTypesWithoutEncoding(Type tp1, Type tp2) {
+  if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
+    if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
+      return rtp1.getShape() == rtp2.getShape() &&
+             rtp1.getElementType() == rtp2.getElementType();
+    return false;
+  }
+  // Default implementation.
+  return tp1 == tp2;
+}
+
 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
                                         TensorReshapeOp, ExpandShapeOp>::value>
 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
@@ -856,7 +868,7 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
   auto maps = op.getReassociationMaps();
   RankedTensorType expectedType =
       computeTensorReshapeCollapsedType(expandedType, maps);
-  if (collapsedType != expectedType)
+  if (!isSameTypesWithoutEncoding(collapsedType, expectedType))
     return op.emitOpError("expected collapsed type to be ")
            << expectedType << ", but got " << collapsedType;
   return success();

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
new file mode 100644
index 000000000000..c791536e1519
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+// TODO: check lowering to an actual implementation
+
+#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
+#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
+
+// CHECK-LABEL: func.func @sparse_expand(
+// CHECK-SAME:  %[[A:.*]]: tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//      CHECK:  %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//      CHECK:  return %[[E]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>
+func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> {
+  %0 = tensor.expand_shape %arg0 [[0, 1]] :
+    tensor<100xf64, #SparseVector> into tensor<10x10xf64, #SparseMatrix>
+  return %0 : tensor<10x10xf64, #SparseMatrix>
+}
+
+// CHECK-LABEL: func.func @sparse_collapse(
+// CHECK-SAME:  %[[A:.*]]: tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//      CHECK:  %[[C:.*]] = tensor.collapse_shape %[[A]] {{\[\[}}0, 1]] : tensor<10x10xf64, #sparse_tensor.encoding<{{{.*}}}>> into tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
+//      CHECK:  return %[[C]] : tensor<100xf64, #sparse_tensor.encoding<{{{.*}}}>>
+func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<100xf64, #SparseVector> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1]] :
+    tensor<10x10xf64, #SparseMatrix> into tensor<100xf64, #SparseVector>
+  return %0 : tensor<100xf64, #SparseVector>
+}


        


More information about the Mlir-commits mailing list