[Mlir-commits] [mlir] [mlir][sparse] support non-id map for [Dis]assembleOp (PR #80355)

Peiming Liu llvmlistbot at llvm.org
Thu Feb 1 14:19:33 PST 2024


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/80355

None

>From fea98ccf7dd9492b736415aa0deb8ac753cd3908 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 1 Feb 2024 22:18:53 +0000
Subject: [PATCH] [mlir][sparse] support non-id map for [Dis]assembleOp

---
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  2 -
 .../Transforms/SparseReinterpretMap.cpp       | 37 +++++++++++++-
 .../SparseTensor/sparse_reinterpret_map.mlir  | 48 +++++++++++++++++++
 3 files changed, 84 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6033ebf6897ce..27125bc7ed45e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1016,8 +1016,6 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
     return op->emitError("the sparse-tensor must have static shape");
   if (!stt.hasEncoding())
     return op->emitError("the sparse-tensor must have an encoding attribute");
-  if (!stt.isIdentity())
-    return op->emitError("the sparse-tensor must have the identity mapping");
 
   // Verifies the trailing COO.
   Level cooStartLvl = stt.getCOOStart();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index a0f7b55ce4446..fbe2fc31ab8b1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -656,6 +656,40 @@ struct TensorInsertDemapper
   }
 };
 
+struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(AssembleOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!hasAnyNonIdentityOperandsOrResults(op))
+      return failure();
+
+    assert(hasAnySparseResult(op));
+    auto stt = getSparseTensorType(op.getResult());
+    rewriter.modifyOpInPlace(
+        op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
+    rewriter.setInsertionPointAfter(op);
+    Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
+    rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp());
+    return success();
+  }
+};
+
+struct SparseDisassembleDemapper
+    : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
+  using DemapInsRewriter::DemapInsRewriter;
+  LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
+                          PatternRewriter &rewriter) const {
+    if (!hasAnyNonIdentityOperandsOrResults(op))
+      return failure();
+
+    assert(hasAnySparseOperandOrResult(op));
+    rewriter.modifyOpInPlace(op, [&op, &adaptor]() {
+      op.getTensorMutable().assign(adaptor.getTensor());
+    });
+    return success();
+  }
+};
+
 struct ForeachOpDemapper
     : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
   using DemapInsRewriter::DemapInsRewriter;
@@ -758,7 +792,8 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
   if (scope == ReinterpretMapScope::kAll ||
       scope == ReinterpretMapScope::kExceptGeneric) {
     patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
-                 TensorAllocDemapper<tensor::EmptyOp>, TensorInsertDemapper,
+                 TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
+                 SparseDisassembleDemapper, TensorInsertDemapper,
                  ForeachOpDemapper>(patterns.getContext());
   }
 }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
index 46f04cca03ed7..54de1024323b5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -80,3 +80,51 @@ func.func @sparse_foreach_reinterpret_map(%6 : tensor<2x4xf64, #BSR>) -> tensor<
   %9 = sparse_tensor.load %8 hasInserts : tensor<2x4xf64, #BSR>
   return %9 : tensor<2x4xf64, #BSR>
 }
+
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+   map = ( i, j ) ->
+      ( i floordiv 2 : dense,
+        j floordiv 2 : compressed,
+        i mod 2      : dense,
+        j mod 2      : dense
+      )
+}>
+// CHECK-DAG: #[[$remap:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : dense, d1 floordiv 2 : compressed, d0 mod 2 : dense, d1 mod 2 : dense) }>
+// CHECK-DAG: #[[$demap:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : dense, d3 : dense) }>
+
+// CHECK-LABEL:   func.func @sparse_assemble_reinterpret_map(
+// CHECK-SAME:        %[[VAL_0:.*]]: tensor<?xf64>,
+// CHECK-SAME:        %[[VAL_1:.*]]: tensor<?xindex>,
+// CHECK-SAME:        %[[VAL_2:.*]]: tensor<?xindex>) -> tensor<2x4xf64, #[[$remap]]> {
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.assemble %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : tensor<?xf64>, tensor<?xindex>, tensor<?xindex> to tensor<1x2x2x2xf64, #[[$demap]]>
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.reinterpret_map %[[VAL_3]] : tensor<1x2x2x2xf64, #[[$demap]]> to tensor<2x4xf64, #[[$remap]]>
+// CHECK:           return %[[VAL_4]] : tensor<2x4xf64, #[[$remap]]>
+// CHECK:         }
+func.func @sparse_assemble_reinterpret_map(%val : tensor<?xf64>, %pos:tensor<?xindex>, %crd:tensor<?xindex>) -> tensor<2x4xf64, #BSR> {
+  %0 = sparse_tensor.assemble %val, %pos, %crd
+     : tensor<?xf64>, tensor<?xindex>, tensor<?xindex> to tensor<2x4xf64, #BSR>
+  return %0 : tensor<2x4xf64, #BSR>
+}
+
+// CHECK-LABEL:   func.func @sparse_disassemble_reinterpret_map(
+// CHECK-SAME:         %[[VAL_0:.*]]: tensor<2x4xf64, #[[$remap]]>,
+// CHECK-SAME:         %[[VAL_1:.*]]: tensor<?xf64>,
+// CHECK-SAME:         %[[VAL_2:.*]]: tensor<?xindex>,
+// CHECK-SAME:         %[[VAL_3:.*]]: tensor<?xindex>) -> (tensor<?xf64>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<2x4xf64, #[[$remap]]> to tensor<1x2x2x2xf64, #[[$demap]]>
+// CHECK:           %[[VAL_5:.*]], %[[VAL_6:.*]]:2, %[[VAL_7:.*]], %[[VAL_8:.*]]:2 = sparse_tensor.disassemble %[[VAL_4]] : tensor<1x2x2x2xf64, #[[$demap]]>
+// CHECK:           return
+// CHECK:         }
+func.func @sparse_disassemble_reinterpret_map(%sp : tensor<2x4xf64, #BSR>,
+                                              %od : tensor<?xf64>,
+                                              %op : tensor<?xindex>,
+                                              %oi : tensor<?xindex>)
+                                            -> (tensor<?xf64>, tensor<?xindex>, tensor<?xindex>) {
+  %rd, %rp, %ri, %dl, %pl, %il = sparse_tensor.disassemble %sp : tensor<2x4xf64, #BSR>
+                                 outs(%od, %op, %oi : tensor<?xf64>, tensor<?xindex>, tensor<?xindex>)
+                                 -> tensor<?xf64>, (tensor<?xindex>, tensor<?xindex>), index, (index, index)
+  return %rd, %rp, %ri : tensor<?xf64>, tensor<?xindex>, tensor<?xindex>
+}



More information about the Mlir-commits mailing list