[Mlir-commits] [mlir] [mlir][sparse] first proof-of-concept non-permutation rewriter (PR #70863)

Aart Bik llvmlistbot at llvm.org
Tue Oct 31 15:16:34 PDT 2023


https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/70863

>From 591c6f2feb7ce3fbb79dc0ef2d048a5712d1d7ed Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 31 Oct 2023 14:25:15 -0700
Subject: [PATCH 1/4]  [mlir][sparse] first proof-of-concept non-permutation
 rewriter

Rather than extending sparsifier codegen with higher order
non-permutations, we follow the path of rewriting linalg
geneneric ops into higher order operations. That way,
codegeneration will simply work out of the box. This is a very
first proof-of-concept rewriting of that idea.
---
 .../Transforms/SparseReinterpretMap.cpp       | 143 +++++++++++++++++-
 .../SparseTensor/sparse_reinterpret_map.mlir  |  49 +++++-
 2 files changed, 183 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 5880f2158b8cd05..14aaa39f3183e47 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -18,10 +20,135 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
-// TODO:
-//   (1) insert the zero-cost sparse_tensor.reinterpret_map ops
-//   (2) rewrite linalg.generic ops traits on level crds
-//   (3) compute topsort, and resolve cyles with sparse_tensor.convert ops
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
+// Translates a "simple" map according to an identify lvl-map.
+static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
+                              AffineMap map) {
+  unsigned lvlRank = stt.getLvlRank();
+  AffineMap lvl2dim = stt.getLvlToDim();
+  assert(lvl2dim.getNumInputs() == lvlRank);
+  SmallVector<AffineExpr> exps;
+  for (unsigned i = 0, n = map.getNumResults(); i < n; i++) {
+    unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
+    exps.push_back(lvl2dim.getResult(pos));
+  }
+  return AffineMap::get(lvlRank, 0, exps, builder.getContext());
+}
+
+// Generates a "de"mapping reinterpretation of the map.
+static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
+                      Value val) {
+  unsigned lvlRank = enc.getLvlTypes().size();
+  AffineMap idMap =
+      AffineMap::getMultiDimIdentityMap(lvlRank, builder.getContext());
+  auto newEnc = SparseTensorEncodingAttr::get(
+      builder.getContext(), enc.getLvlTypes(), idMap, idMap, enc.getPosWidth(),
+      enc.getCrdWidth());
+  return builder.create<ReinterpretMapOp>(val.getLoc(), newEnc, val);
+}
+
+// Generates a "re"mapping reinterpretation of the map.
+static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
+                      Value val) {
+  return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
+}
+
+// Generates a clone of the given linalg generic operation, but with
+// remapped arguments, index maps, and iteration types.
+//
+// TODO: As decribed below, this is proof-of-concept code which makes a lot
+//       of simplifying assumptions for now.
+//
+static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter,
+                                          linalg::GenericOp linalgOp,
+                                          SparseTensorType stt, Value out) {
+  unsigned dimRank = stt.getDimRank();
+  unsigned lvlRank = stt.getLvlRank();
+  SmallVector<Value> inputOps = linalgOp.getInputs();
+  SmallVector<Value> outputOps = {out};
+  SmallVector<AffineMap> indexMaps;
+  SmallVector<utils::IteratorType> iterTypes;
+  // Translate the index maps, except output map, which is lvl-identity.
+  auto maps = linalgOp.getIndexingMapsArray();
+  for (unsigned i = 0, n = maps.size() - 1; i < n; i++)
+    indexMaps.push_back(translateMap(rewriter, stt, maps[i]));
+  indexMaps.push_back(
+      AffineMap::getMultiDimIdentityMap(lvlRank, rewriter.getContext()));
+  // Add additional "parallel" iteration types at the top.
+  for (unsigned i = 0, diff = lvlRank = dimRank; i < diff; i++)
+    iterTypes.push_back(utils::IteratorType::parallel);
+  for (auto &i : linalgOp.getIteratorTypesArray())
+    iterTypes.push_back(i);
+  // Generate the new linalg generic operation and clone body.
+  auto newOp = rewriter.create<linalg::GenericOp>(
+      linalgOp.getLoc(), out.getType(), inputOps, outputOps, indexMaps,
+      iterTypes);
+  rewriter.cloneRegionBefore(linalgOp.getRegion(), newOp.getRegion(),
+                             newOp.getRegion().begin());
+  return newOp;
+}
+
+//===----------------------------------------------------------------------===//
+// Rewriting rules for linalg generic ops.
+//===----------------------------------------------------------------------===//
+
+/// Sparse rewriting rule for the generic `linalg` operation.
+struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
+public:
+  GenericOpReinterpretMap(MLIRContext *context)
+      : OpRewritePattern<linalg::GenericOp>(context) {}
+
+  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
+                                PatternRewriter &rewriter) const override {
+    // Only rewrite single output operations with pure tensor semantics.
+    if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics())
+      return failure();
+    // Scan all operands, inspect sparse tensors.
+    //
+    // TODO: generalize this proof-of-concept algorithm, since the current
+    //       implementation accepts only simple indexing maps, and one
+    //       non-permutation sparse tensor, which must have an identify
+    //       indexing map and be the output.
+    //
+    OpOperand *tx = nullptr;
+    for (OpOperand &t : linalgOp->getOpOperands()) {
+      // Ensure every index map is "simple".
+      const auto map = linalgOp.getMatchingIndexingMap(&t);
+      for (unsigned i = 0, n = map.getNumResults(); i < n; i++)
+        if (map.getResult(i).getKind() != AffineExprKind::DimId)
+          return failure();
+      // Inspect sparse operands.
+      auto stt = getSparseTensorType(t.get());
+      if (stt.hasEncoding()) {
+        if (stt.isPermutation())
+          continue;
+        assert(stt.getDimRank() < stt.getLvlRank()); // only allowed non-perm
+        if (tx)
+          return failure(); // more than one non-perm
+        if (!map.isIdentity())
+          return failure(); // no ID indexing map on the non-perm
+        tx = &t;
+      }
+    }
+    // Found a non-permutation, rewrite when this is the output.
+    if (tx && tx == linalgOp.getDpsInitOperand(0)) {
+      auto stt = getSparseTensorType(tx->get());
+      auto demap = genDemap(rewriter, stt.getEncoding(), tx->get());
+      auto newOp = genGenericLinalg(rewriter, linalgOp, stt, demap);
+      auto remap = genRemap(rewriter, stt.getEncoding(), newOp.getResult(0));
+      rewriter.replaceOp(linalgOp, remap);
+      return success();
+    }
+    return failure();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Rewriting rules for operations other than linalg generic ops.
+//===----------------------------------------------------------------------===//
 
 // CRTP to help implementing a rewriter that demaps all its inputs and remaps
 // all its outputs.
@@ -59,10 +186,6 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
   }
 };
 
-//===----------------------------------------------------------------------===//
-// Reinterpret Map Rewriters for operations other than linalg.generics
-//===----------------------------------------------------------------------===//
-
 struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(CrdTranslateOp op,
@@ -110,6 +233,10 @@ struct TensorInsertRewriter
 
 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
                                         ReinterpretMapScope scope) {
+  if (scope == ReinterpretMapScope::kAll ||
+      scope == ReinterpretMapScope::kGenericOnly) {
+    patterns.add<GenericOpReinterpretMap>(patterns.getContext());
+  }
   if (scope == ReinterpretMapScope::kAll ||
       scope == ReinterpretMapScope::kExceptGeneric) {
     patterns.add<CrdTranslateRewriter, TensorInsertRewriter>(
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
index 8517f2a27ae3fc8..e4f591f38cdbed7 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-reinterpret-map | FileCheck %s
+// RUN: mlir-opt %s -split-input-file  --sparse-reinterpret-map | FileCheck %s
 
 #SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
 
@@ -8,3 +8,50 @@
 func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
   return %arg0 : tensor<?xf64, #SparseVector>
 }
+
+// -----
+
+#trait_mul = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A (in)
+    affine_map<(i,j) -> (j,i)>,  // B (in, transposed)
+    affine_map<(i,j) -> (i,j)>   // X (out)
+  ],
+  iterator_types = ["parallel", "parallel"],
+  doc = "X(i,j) += A(j,i)"
+}
+
+#BSR = #sparse_tensor.encoding<{   // 2x4 blocks
+  map = (i, j) ->
+    ( i floordiv 2 : dense
+    , j floordiv 4 : compressed
+    , i mod 2 : dense
+    , j mod 4 : dense
+    )
+}>
+
+// CHECK: #[[$map0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 4 + d3)>
+// CHECK: #[[$map1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1 * 4 + d3, d0 * 2 + d2)>
+// CHECK: #[[$map2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func @mul(
+// CHECK-SAME:  %[[A0:.*0]]: tensor<32x32xf32>,
+// CHECK-SAME:  %[[A1:.*1]]: tensor<32x32xf32>,
+// CHECK-SAME:  %[[A2:.*2]]: tensor<32x32xf32, #sparse_tensor.encoding<{{{.*}}}>>)
+// CHECK:       %[[T0:.*]] = sparse_tensor.reinterpret_map %[[A2]]
+// CHECK:       %[[T1:.*]] = linalg.generic {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK:       %[[T2:.*]] = sparse_tensor.reinterpret_map %[[T1]]
+// CHECK:       return %[[T2]] : tensor<32x32xf32, #sparse_tensor.encoding<{{{.*}}}>>
+func.func @mul(%arg0: tensor<32x32xf32>,
+               %arg1: tensor<32x32xf32>,
+               %arg2: tensor<32x32xf32, #BSR>) -> tensor<32x32xf32, #BSR> {
+  %0 = linalg.generic #trait_mul
+    ins(%arg0, %arg1: tensor<32x32xf32>, tensor<32x32xf32>)
+    outs(%arg2: tensor<32x32xf32, #BSR>) {
+      ^bb(%x: f32, %y : f32, %z : f32):
+        %1 = arith.mulf %x, %y : f32
+        %2 = arith.mulf %1, %z : f32
+        linalg.yield %2 : f32
+  } -> tensor<32x32xf32, #BSR>
+  return %0 : tensor<32x32xf32, #BSR>
+}
+

>From ebe56ca08c2b8610050486114b9c92031101e1f6 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 31 Oct 2023 14:56:13 -0700
Subject: [PATCH 2/4] typo

---
 .../Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp  | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 14aaa39f3183e47..c04eab18573e442 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -24,7 +24,7 @@ namespace {
 // Helper methods.
 //===----------------------------------------------------------------------===//
 
-// Translates a "simple" map according to an identify lvl-map.
+// Translates a "simple" map according to an identity lvl-map.
 static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
                               AffineMap map) {
   unsigned lvlRank = stt.getLvlRank();
@@ -110,7 +110,7 @@ struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
     //
     // TODO: generalize this proof-of-concept algorithm, since the current
     //       implementation accepts only simple indexing maps, and one
-    //       non-permutation sparse tensor, which must have an identify
+    //       non-permutation sparse tensor, which must have an identity
     //       indexing map and be the output.
     //
     OpOperand *tx = nullptr;

>From 6f71e0770777edc1347d3ecd3f581b9d52dcede6 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 31 Oct 2023 15:08:17 -0700
Subject: [PATCH 3/4] use util

---
 .../SparseTensor/Transforms/SparseReinterpretMap.cpp     | 9 ++-------
 1 file changed, 2 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index c04eab18573e442..9018e56a34e92a2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -41,13 +41,8 @@ static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
 // Generates a "de"mapping reinterpretation of the map.
 static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
                       Value val) {
-  unsigned lvlRank = enc.getLvlTypes().size();
-  AffineMap idMap =
-      AffineMap::getMultiDimIdentityMap(lvlRank, builder.getContext());
-  auto newEnc = SparseTensorEncodingAttr::get(
-      builder.getContext(), enc.getLvlTypes(), idMap, idMap, enc.getPosWidth(),
-      enc.getCrdWidth());
-  return builder.create<ReinterpretMapOp>(val.getLoc(), newEnc, val);
+   return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
+                                          val);
 }
 
 // Generates a "re"mapping reinterpretation of the map.

>From 1a6ba6cc23c60f67237082c79697d52385e4b3b6 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 31 Oct 2023 15:15:50 -0700
Subject: [PATCH 4/4] update DOC string

---
 mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
index e4f591f38cdbed7..149c0bc46e25118 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -18,7 +18,7 @@ func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #Spa
     affine_map<(i,j) -> (i,j)>   // X (out)
   ],
   iterator_types = ["parallel", "parallel"],
-  doc = "X(i,j) += A(j,i)"
+  doc = "X(i,j) *= A(i,j) * B(j,i)"
 }
 
 #BSR = #sparse_tensor.encoding<{   // 2x4 blocks



More information about the Mlir-commits mailing list