[Mlir-commits] [mlir] e599978 - [mlir][sparse] first proof-of-concept non-permutation rewriter (#70863)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 31 16:19:31 PDT 2023


Author: Aart Bik
Date: 2023-10-31T16:19:27-07:00
New Revision: e599978760e88231b834772eca080c42b71716c3

URL: https://github.com/llvm/llvm-project/commit/e599978760e88231b834772eca080c42b71716c3
DIFF: https://github.com/llvm/llvm-project/commit/e599978760e88231b834772eca080c42b71716c3.diff

LOG:  [mlir][sparse] first proof-of-concept non-permutation rewriter (#70863)

Rather than extending sparsifier codegen with higher order
non-permutations, we follow the path of rewriting linalg generic ops
into higher order operations. That way, code generation will simply work
out of the box. This is a very first proof-of-concept rewriting of that
idea.

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
    mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 5880f2158b8cd05..31cc8525725d43d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -18,10 +20,130 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
-// TODO:
-//   (1) insert the zero-cost sparse_tensor.reinterpret_map ops
-//   (2) rewrite linalg.generic ops traits on level crds
-//   (3) compute topsort, and resolve cyles with sparse_tensor.convert ops
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
+// Translates a "simple" map according to an identity lvl-map.
+static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
+                              AffineMap map) {
+  unsigned lvlRank = stt.getLvlRank();
+  AffineMap lvl2dim = stt.getLvlToDim();
+  assert(lvl2dim.getNumInputs() == lvlRank);
+  SmallVector<AffineExpr> exps;
+  for (unsigned i = 0, n = map.getNumResults(); i < n; i++) {
+    unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
+    exps.push_back(lvl2dim.getResult(pos));
+  }
+  return AffineMap::get(lvlRank, 0, exps, builder.getContext());
+}
+
+// Generates a "de"mapping reinterpretation of the map.
+static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
+                      Value val) {
+  return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
+                                          val);
+}
+
+// Generates a "re"mapping reinterpretation of the map.
+static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
+                      Value val) {
+  return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
+}
+
+// Generates a clone of the given linalg generic operation, but with
+// remapped arguments, index maps, and iteration types.
+//
+// TODO: As decribed below, this is proof-of-concept code which makes a lot
+//       of simplifying assumptions for now.
+//
+static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter,
+                                          linalg::GenericOp linalgOp,
+                                          SparseTensorType stt, Value out) {
+  unsigned dimRank = stt.getDimRank();
+  unsigned lvlRank = stt.getLvlRank();
+  SmallVector<Value> inputOps = linalgOp.getInputs();
+  SmallVector<Value> outputOps = {out};
+  SmallVector<AffineMap> indexMaps;
+  SmallVector<utils::IteratorType> iterTypes;
+  // Translate the index maps, except output map, which is lvl-identity.
+  auto maps = linalgOp.getIndexingMapsArray();
+  for (unsigned i = 0, n = maps.size() - 1; i < n; i++)
+    indexMaps.push_back(translateMap(rewriter, stt, maps[i]));
+  indexMaps.push_back(
+      AffineMap::getMultiDimIdentityMap(lvlRank, rewriter.getContext()));
+  // Add additional "parallel" iteration types at the top.
+  for (unsigned i = 0, 
diff  = lvlRank = dimRank; i < 
diff ; i++)
+    iterTypes.push_back(utils::IteratorType::parallel);
+  for (auto &i : linalgOp.getIteratorTypesArray())
+    iterTypes.push_back(i);
+  // Generate the new linalg generic operation and clone body.
+  auto newOp = rewriter.create<linalg::GenericOp>(
+      linalgOp.getLoc(), out.getType(), inputOps, outputOps, indexMaps,
+      iterTypes);
+  rewriter.cloneRegionBefore(linalgOp.getRegion(), newOp.getRegion(),
+                             newOp.getRegion().begin());
+  return newOp;
+}
+
+//===----------------------------------------------------------------------===//
+// Rewriting rules for linalg generic ops.
+//===----------------------------------------------------------------------===//
+
+/// Sparse rewriting rule for the generic `linalg` operation.
+struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
+public:
+  GenericOpReinterpretMap(MLIRContext *context)
+      : OpRewritePattern<linalg::GenericOp>(context) {}
+
+  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
+                                PatternRewriter &rewriter) const override {
+    // Only rewrite single output operations with pure tensor semantics.
+    if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics())
+      return failure();
+    // Scan all operands, inspect sparse tensors.
+    //
+    // TODO: generalize this proof-of-concept algorithm, since the current
+    //       implementation accepts only simple indexing maps, and one
+    //       non-permutation sparse tensor, which must have an identity
+    //       indexing map and be the output.
+    //
+    OpOperand *tx = nullptr;
+    for (OpOperand &t : linalgOp->getOpOperands()) {
+      // Ensure every index map is "simple".
+      const auto map = linalgOp.getMatchingIndexingMap(&t);
+      for (unsigned i = 0, n = map.getNumResults(); i < n; i++)
+        if (map.getResult(i).getKind() != AffineExprKind::DimId)
+          return failure();
+      // Inspect sparse operands.
+      auto stt = getSparseTensorType(t.get());
+      if (stt.hasEncoding()) {
+        if (stt.isPermutation())
+          continue;
+        assert(stt.getDimRank() < stt.getLvlRank()); // only allowed non-perm
+        if (tx)
+          return failure(); // more than one non-perm
+        if (!map.isIdentity())
+          return failure(); // no ID indexing map on the non-perm
+        tx = &t;
+      }
+    }
+    // Found a non-permutation, rewrite when this is the output.
+    if (tx && tx == linalgOp.getDpsInitOperand(0)) {
+      auto stt = getSparseTensorType(tx->get());
+      auto demap = genDemap(rewriter, stt.getEncoding(), tx->get());
+      auto newOp = genGenericLinalg(rewriter, linalgOp, stt, demap);
+      auto remap = genRemap(rewriter, stt.getEncoding(), newOp.getResult(0));
+      rewriter.replaceOp(linalgOp, remap);
+      return success();
+    }
+    return failure();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Rewriting rules for operations other than linalg generic ops.
+//===----------------------------------------------------------------------===//
 
 // CRTP to help implementing a rewriter that demaps all its inputs and remaps
 // all its outputs.
@@ -59,10 +181,6 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
   }
 };
 
-//===----------------------------------------------------------------------===//
-// Reinterpret Map Rewriters for operations other than linalg.generics
-//===----------------------------------------------------------------------===//
-
 struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(CrdTranslateOp op,
@@ -110,6 +228,10 @@ struct TensorInsertRewriter
 
 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
                                         ReinterpretMapScope scope) {
+  if (scope == ReinterpretMapScope::kAll ||
+      scope == ReinterpretMapScope::kGenericOnly) {
+    patterns.add<GenericOpReinterpretMap>(patterns.getContext());
+  }
   if (scope == ReinterpretMapScope::kAll ||
       scope == ReinterpretMapScope::kExceptGeneric) {
     patterns.add<CrdTranslateRewriter, TensorInsertRewriter>(

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
index 8517f2a27ae3fc8..149c0bc46e25118 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-reinterpret-map | FileCheck %s
+// RUN: mlir-opt %s -split-input-file  --sparse-reinterpret-map | FileCheck %s
 
 #SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
 
@@ -8,3 +8,50 @@
 func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
   return %arg0 : tensor<?xf64, #SparseVector>
 }
+
+// -----
+
+#trait_mul = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A (in)
+    affine_map<(i,j) -> (j,i)>,  // B (in, transposed)
+    affine_map<(i,j) -> (i,j)>   // X (out)
+  ],
+  iterator_types = ["parallel", "parallel"],
+  doc = "X(i,j) *= A(i,j) * B(j,i)"
+}
+
+#BSR = #sparse_tensor.encoding<{   // 2x4 blocks
+  map = (i, j) ->
+    ( i floordiv 2 : dense
+    , j floordiv 4 : compressed
+    , i mod 2 : dense
+    , j mod 4 : dense
+    )
+}>
+
+// CHECK: #[[$map0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 4 + d3)>
+// CHECK: #[[$map1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1 * 4 + d3, d0 * 2 + d2)>
+// CHECK: #[[$map2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func @mul(
+// CHECK-SAME:  %[[A0:.*0]]: tensor<32x32xf32>,
+// CHECK-SAME:  %[[A1:.*1]]: tensor<32x32xf32>,
+// CHECK-SAME:  %[[A2:.*2]]: tensor<32x32xf32, #sparse_tensor.encoding<{{{.*}}}>>)
+// CHECK:       %[[T0:.*]] = sparse_tensor.reinterpret_map %[[A2]]
+// CHECK:       %[[T1:.*]] = linalg.generic {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK:       %[[T2:.*]] = sparse_tensor.reinterpret_map %[[T1]]
+// CHECK:       return %[[T2]] : tensor<32x32xf32, #sparse_tensor.encoding<{{{.*}}}>>
+func.func @mul(%arg0: tensor<32x32xf32>,
+               %arg1: tensor<32x32xf32>,
+               %arg2: tensor<32x32xf32, #BSR>) -> tensor<32x32xf32, #BSR> {
+  %0 = linalg.generic #trait_mul
+    ins(%arg0, %arg1: tensor<32x32xf32>, tensor<32x32xf32>)
+    outs(%arg2: tensor<32x32xf32, #BSR>) {
+      ^bb(%x: f32, %y : f32, %z : f32):
+        %1 = arith.mulf %x, %y : f32
+        %2 = arith.mulf %1, %z : f32
+        linalg.yield %2 : f32
+  } -> tensor<32x32xf32, #BSR>
+  return %0 : tensor<32x32xf32, #BSR>
+}
+


        


More information about the Mlir-commits mailing list