[Mlir-commits] [mlir] [mlir][sparse] Implement rewriters to reinterpret maps on alloc_tenso… (PR #70993)
Peiming Liu
llvmlistbot at llvm.org
Wed Nov 1 14:49:10 PDT 2023
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/70993
…r operation
>From cf7569dd10dbfc497ae254ee678326428d898688 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 1 Nov 2023 21:48:33 +0000
Subject: [PATCH] [mlir][sparse] Implement rewriters to reinterpret maps on
alloc_tensor operation
---
.../Transforms/SparseReinterpretMap.cpp | 55 ++++++++-
.../SparseTensor/sparse_reinterpret_map.mlir | 5 +-
.../CPU/sparse_conversion_block.mlir | 105 ++++++++++++++++++
3 files changed, 161 insertions(+), 4 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index d14df6db8ee6b3f..307a609fd1b7746 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -6,7 +6,10 @@
//
//===----------------------------------------------------------------------===//
+#include "CodegenUtils.h"
+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -188,6 +191,56 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
}
};
+struct TensorAllocDemapper
+ : public OpRewritePattern<bufferization::AllocTensorOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
+ PatternRewriter &rewriter) const override {
+ if (!hasNonIdentityOperandsOrResults(op))
+ return failure();
+
+ Location loc = op.getLoc();
+ auto stt = getSparseTensorType(op.getResult());
+
+ SmallVector<Value> maxDimCrds;
+ maxDimCrds.reserve(stt.getDimRank());
+ ValueRange dynSz = op.getDynamicSizes();
+ for (int64_t dimSz : stt.getDimShape()) {
+ if (ShapedType::isDynamic(dimSz)) {
+ Value maxCrd = rewriter.create<arith::SubIOp>(
+ loc, dynSz.front(), constantIndex(rewriter, loc, 1));
+ maxDimCrds.push_back(maxCrd);
+ dynSz = dynSz.drop_front();
+ } else {
+ maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
+ }
+ }
+
+ ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
+ CrdTransDirectionKind::dim2lvl);
+ auto lvlShape = stt.getLvlShape();
+ SmallVector<Value> dynLvlSzs;
+ for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
+ if (ShapedType::isDynamic(lvlShape[i])) {
+ Value sz = rewriter.create<arith::AddIOp>(
+ loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
+ dynLvlSzs.push_back(sz);
+ }
+ }
+
+ assert(dynSz.empty()); // should have consumed all.
+ rewriter.startRootUpdate(op);
+ op->setOperands(dynLvlSzs);
+ op.getResult().setType(stt.getDemappedType());
+ rewriter.finalizeRootUpdate(op);
+ rewriter.setInsertionPointAfter(op);
+
+ Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
+ rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
+ return success();
+ }
+};
+
struct TensorInsertDemapper
: public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
using DemapInsRewriter::DemapInsRewriter;
@@ -309,7 +362,7 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
}
if (scope == ReinterpretMapScope::kAll ||
scope == ReinterpretMapScope::kExceptGeneric) {
- patterns.add<TensorInsertDemapper, ForeachOpDemapper>(
+ patterns.add<TensorAllocDemapper, TensorInsertDemapper, ForeachOpDemapper>(
patterns.getContext());
}
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
index be3ab37e9cbd182..972364289ac2e2a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -57,10 +57,9 @@ func.func @mul(%arg0: tensor<32x32xf32>,
// CHECK-LABEL: func.func @sparse_foreach_reinterpret_map(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x4xf64
-// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() : tensor<2x4xf64
+// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() : tensor<1x2x2x2xf64
// CHECK: %[[VAL_2:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<2x4xf64
-// CHECK: %[[VAL_3:.*]] = sparse_tensor.reinterpret_map %[[VAL_1]] : tensor<2x4xf64
-// CHECK: %[[VAL_4:.*]] = sparse_tensor.foreach in %[[VAL_2]] init(%[[VAL_3]])
+// CHECK: %[[VAL_4:.*]] = sparse_tensor.foreach in %[[VAL_2]] init(%[[VAL_1]])
// CHECK: ^bb0(%[[VAL_5:.*]]: index, %[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: f64, %[[VAL_10:.*]]: tensor<1x2x2x2xf64
// CHECK: %[[VAL_11:.*]] = sparse_tensor.insert %[[VAL_9]] into %[[VAL_10]]{{\[}}%[[VAL_5]], %[[VAL_6]], %[[VAL_7]], %[[VAL_8]]] : tensor<1x2x2x2xf64
// CHECK: sparse_tensor.yield %[[VAL_11]] : tensor<1x2x2x2xf64
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
new file mode 100644
index 000000000000000..34a11e748ebd68b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
@@ -0,0 +1,105 @@
+//--------------------------------------------------------------------------------------------------
+// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
+//
+// Set-up that's shared across all tests in this directory. In principle, this
+// config could be moved to lit.local.cfg. However, there are downstream users that
+// do not use these LIT config files. Hence why this is kept inline.
+//
+// DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true
+// DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts}
+// DEFINE: %{compile} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts}"
+// DEFINE: %{compile_sve} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts_sve}"
+// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
+// DEFINE: %{run_opts} = -e entry -entry-point-result=void
+// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
+// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
+//
+// DEFINE: %{env} =
+//--------------------------------------------------------------------------------------------------
+
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false enable-buffer-initialization=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation and VLA vectorization.
+// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
+
+#CSR = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed)
+}>
+
+#CSC = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d1 : dense, d0 : compressed)
+}>
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i floordiv 2 : dense,
+ j floordiv 2 : compressed,
+ i mod 2 : dense,
+ j mod 2 : dense
+ )
+}>
+
+
+//
+// Integration test that tests conversions between sparse tensors.
+//
+module {
+ //
+ // Output utilities.
+ //
+ func.func @dumpf64(%arg0: memref<?xf64>) {
+ %c0 = arith.constant 0 : index
+ %d0 = arith.constant -1.0 : f64
+ %0 = vector.transfer_read %arg0[%c0], %d0: memref<?xf64>, vector<8xf64>
+ vector.print %0 : vector<8xf64>
+ return
+ }
+
+ //
+ // Main driver.
+ //
+ func.func @entry() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+
+ //
+ // Initialize a 3-dim dense tensor.
+ //
+ %t = arith.constant dense<[
+ [ 1.0, 2.0, 3.0, 4.0 ],
+ [ 5.0, 6.0, 7.0, 8.0 ]
+ ]> : tensor<2x4xf64>
+
+ //
+ // Convert dense tensor directly to various sparse tensors.
+ // tensor1: stored as 2x3x4
+ // tensor2: stored as 3x4x2
+ // tensor3: stored as 4x2x3
+ //
+ %1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #CSR>
+ %2 = sparse_tensor.convert %1 : tensor<2x4xf64, #CSR> to tensor<2x4xf64, #BSR>
+ %3 = sparse_tensor.convert %2 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
+
+ %v1 = sparse_tensor.values %1 : tensor<2x4xf64, #CSR> to memref<?xf64>
+ %v2 = sparse_tensor.values %2 : tensor<2x4xf64, #BSR> to memref<?xf64>
+ %v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSC> to memref<?xf64>
+
+ // CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8 )
+ // CHECK-NEXT: ( 1, 2, 5, 6, 3, 4, 7, 8 )
+ // CHECK-NEXT: ( 1, 5, 2, 6, 3, 7, 4, 8 )
+ call @dumpf64(%v1) : (memref<?xf64>) -> ()
+ call @dumpf64(%v2) : (memref<?xf64>) -> ()
+ call @dumpf64(%v3) : (memref<?xf64>) -> ()
+
+ return
+ }
+}
More information about the Mlir-commits
mailing list