[Mlir-commits] [mlir] [mlir][sparse] first proof-of-concept non-permutation rewriter (PR #70863)
Aart Bik
llvmlistbot at llvm.org
Tue Oct 31 14:52:42 PDT 2023
https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/70863
Rather than extending sparsifier codegen with higher order non-permutations, we follow the path of rewriting linalg geneneric ops into higher order operations. That way, codegeneration will simply work out of the box. This is a very first proof-of-concept rewriting of that idea.
>From 591c6f2feb7ce3fbb79dc0ef2d048a5712d1d7ed Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 31 Oct 2023 14:25:15 -0700
Subject: [PATCH] [mlir][sparse] first proof-of-concept non-permutation
rewriter
Rather than extending sparsifier codegen with higher order
non-permutations, we follow the path of rewriting linalg
geneneric ops into higher order operations. That way,
codegeneration will simply work out of the box. This is a very
first proof-of-concept rewriting of that idea.
---
.../Transforms/SparseReinterpretMap.cpp | 143 +++++++++++++++++-
.../SparseTensor/sparse_reinterpret_map.mlir | 49 +++++-
2 files changed, 183 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 5880f2158b8cd05..14aaa39f3183e47 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -18,10 +20,135 @@ using namespace mlir::sparse_tensor;
namespace {
-// TODO:
-// (1) insert the zero-cost sparse_tensor.reinterpret_map ops
-// (2) rewrite linalg.generic ops traits on level crds
-// (3) compute topsort, and resolve cyles with sparse_tensor.convert ops
+//===----------------------------------------------------------------------===//
+// Helper methods.
+//===----------------------------------------------------------------------===//
+
+// Translates a "simple" map according to an identify lvl-map.
+static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
+ AffineMap map) {
+ unsigned lvlRank = stt.getLvlRank();
+ AffineMap lvl2dim = stt.getLvlToDim();
+ assert(lvl2dim.getNumInputs() == lvlRank);
+ SmallVector<AffineExpr> exps;
+ for (unsigned i = 0, n = map.getNumResults(); i < n; i++) {
+ unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
+ exps.push_back(lvl2dim.getResult(pos));
+ }
+ return AffineMap::get(lvlRank, 0, exps, builder.getContext());
+}
+
+// Generates a "de"mapping reinterpretation of the map.
+static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
+ Value val) {
+ unsigned lvlRank = enc.getLvlTypes().size();
+ AffineMap idMap =
+ AffineMap::getMultiDimIdentityMap(lvlRank, builder.getContext());
+ auto newEnc = SparseTensorEncodingAttr::get(
+ builder.getContext(), enc.getLvlTypes(), idMap, idMap, enc.getPosWidth(),
+ enc.getCrdWidth());
+ return builder.create<ReinterpretMapOp>(val.getLoc(), newEnc, val);
+}
+
+// Generates a "re"mapping reinterpretation of the map.
+static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
+ Value val) {
+ return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
+}
+
+// Generates a clone of the given linalg generic operation, but with
+// remapped arguments, index maps, and iteration types.
+//
+// TODO: As decribed below, this is proof-of-concept code which makes a lot
+// of simplifying assumptions for now.
+//
+static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter,
+ linalg::GenericOp linalgOp,
+ SparseTensorType stt, Value out) {
+ unsigned dimRank = stt.getDimRank();
+ unsigned lvlRank = stt.getLvlRank();
+ SmallVector<Value> inputOps = linalgOp.getInputs();
+ SmallVector<Value> outputOps = {out};
+ SmallVector<AffineMap> indexMaps;
+ SmallVector<utils::IteratorType> iterTypes;
+ // Translate the index maps, except output map, which is lvl-identity.
+ auto maps = linalgOp.getIndexingMapsArray();
+ for (unsigned i = 0, n = maps.size() - 1; i < n; i++)
+ indexMaps.push_back(translateMap(rewriter, stt, maps[i]));
+ indexMaps.push_back(
+ AffineMap::getMultiDimIdentityMap(lvlRank, rewriter.getContext()));
+ // Add additional "parallel" iteration types at the top.
+ for (unsigned i = 0, diff = lvlRank = dimRank; i < diff; i++)
+ iterTypes.push_back(utils::IteratorType::parallel);
+ for (auto &i : linalgOp.getIteratorTypesArray())
+ iterTypes.push_back(i);
+ // Generate the new linalg generic operation and clone body.
+ auto newOp = rewriter.create<linalg::GenericOp>(
+ linalgOp.getLoc(), out.getType(), inputOps, outputOps, indexMaps,
+ iterTypes);
+ rewriter.cloneRegionBefore(linalgOp.getRegion(), newOp.getRegion(),
+ newOp.getRegion().begin());
+ return newOp;
+}
+
+//===----------------------------------------------------------------------===//
+// Rewriting rules for linalg generic ops.
+//===----------------------------------------------------------------------===//
+
+/// Sparse rewriting rule for the generic `linalg` operation.
+struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
+public:
+ GenericOpReinterpretMap(MLIRContext *context)
+ : OpRewritePattern<linalg::GenericOp>(context) {}
+
+ LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
+ PatternRewriter &rewriter) const override {
+ // Only rewrite single output operations with pure tensor semantics.
+ if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics())
+ return failure();
+ // Scan all operands, inspect sparse tensors.
+ //
+ // TODO: generalize this proof-of-concept algorithm, since the current
+ // implementation accepts only simple indexing maps, and one
+ // non-permutation sparse tensor, which must have an identify
+ // indexing map and be the output.
+ //
+ OpOperand *tx = nullptr;
+ for (OpOperand &t : linalgOp->getOpOperands()) {
+ // Ensure every index map is "simple".
+ const auto map = linalgOp.getMatchingIndexingMap(&t);
+ for (unsigned i = 0, n = map.getNumResults(); i < n; i++)
+ if (map.getResult(i).getKind() != AffineExprKind::DimId)
+ return failure();
+ // Inspect sparse operands.
+ auto stt = getSparseTensorType(t.get());
+ if (stt.hasEncoding()) {
+ if (stt.isPermutation())
+ continue;
+ assert(stt.getDimRank() < stt.getLvlRank()); // only allowed non-perm
+ if (tx)
+ return failure(); // more than one non-perm
+ if (!map.isIdentity())
+ return failure(); // no ID indexing map on the non-perm
+ tx = &t;
+ }
+ }
+ // Found a non-permutation, rewrite when this is the output.
+ if (tx && tx == linalgOp.getDpsInitOperand(0)) {
+ auto stt = getSparseTensorType(tx->get());
+ auto demap = genDemap(rewriter, stt.getEncoding(), tx->get());
+ auto newOp = genGenericLinalg(rewriter, linalgOp, stt, demap);
+ auto remap = genRemap(rewriter, stt.getEncoding(), newOp.getResult(0));
+ rewriter.replaceOp(linalgOp, remap);
+ return success();
+ }
+ return failure();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Rewriting rules for operations other than linalg generic ops.
+//===----------------------------------------------------------------------===//
// CRTP to help implementing a rewriter that demaps all its inputs and remaps
// all its outputs.
@@ -59,10 +186,6 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
}
};
-//===----------------------------------------------------------------------===//
-// Reinterpret Map Rewriters for operations other than linalg.generics
-//===----------------------------------------------------------------------===//
-
struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CrdTranslateOp op,
@@ -110,6 +233,10 @@ struct TensorInsertRewriter
void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
ReinterpretMapScope scope) {
+ if (scope == ReinterpretMapScope::kAll ||
+ scope == ReinterpretMapScope::kGenericOnly) {
+ patterns.add<GenericOpReinterpretMap>(patterns.getContext());
+ }
if (scope == ReinterpretMapScope::kAll ||
scope == ReinterpretMapScope::kExceptGeneric) {
patterns.add<CrdTranslateRewriter, TensorInsertRewriter>(
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
index 8517f2a27ae3fc8..e4f591f38cdbed7 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-reinterpret-map | FileCheck %s
+// RUN: mlir-opt %s -split-input-file --sparse-reinterpret-map | FileCheck %s
#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
@@ -8,3 +8,50 @@
func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
return %arg0 : tensor<?xf64, #SparseVector>
}
+
+// -----
+
+#trait_mul = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // A (in)
+ affine_map<(i,j) -> (j,i)>, // B (in, transposed)
+ affine_map<(i,j) -> (i,j)> // X (out)
+ ],
+ iterator_types = ["parallel", "parallel"],
+ doc = "X(i,j) += A(j,i)"
+}
+
+#BSR = #sparse_tensor.encoding<{ // 2x4 blocks
+ map = (i, j) ->
+ ( i floordiv 2 : dense
+ , j floordiv 4 : compressed
+ , i mod 2 : dense
+ , j mod 4 : dense
+ )
+}>
+
+// CHECK: #[[$map0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 4 + d3)>
+// CHECK: #[[$map1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1 * 4 + d3, d0 * 2 + d2)>
+// CHECK: #[[$map2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func @mul(
+// CHECK-SAME: %[[A0:.*0]]: tensor<32x32xf32>,
+// CHECK-SAME: %[[A1:.*1]]: tensor<32x32xf32>,
+// CHECK-SAME: %[[A2:.*2]]: tensor<32x32xf32, #sparse_tensor.encoding<{{{.*}}}>>)
+// CHECK: %[[T0:.*]] = sparse_tensor.reinterpret_map %[[A2]]
+// CHECK: %[[T1:.*]] = linalg.generic {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK: %[[T2:.*]] = sparse_tensor.reinterpret_map %[[T1]]
+// CHECK: return %[[T2]] : tensor<32x32xf32, #sparse_tensor.encoding<{{{.*}}}>>
+func.func @mul(%arg0: tensor<32x32xf32>,
+ %arg1: tensor<32x32xf32>,
+ %arg2: tensor<32x32xf32, #BSR>) -> tensor<32x32xf32, #BSR> {
+ %0 = linalg.generic #trait_mul
+ ins(%arg0, %arg1: tensor<32x32xf32>, tensor<32x32xf32>)
+ outs(%arg2: tensor<32x32xf32, #BSR>) {
+ ^bb(%x: f32, %y : f32, %z : f32):
+ %1 = arith.mulf %x, %y : f32
+ %2 = arith.mulf %1, %z : f32
+ linalg.yield %2 : f32
+ } -> tensor<32x32xf32, #BSR>
+ return %0 : tensor<32x32xf32, #BSR>
+}
+
More information about the Mlir-commits
mailing list