[Mlir-commits] [mlir] [mlir][sparse] Implement rewriters to reinterpret maps on alloc_tenso… (PR #70993)

Peiming Liu llvmlistbot at llvm.org
Wed Nov 1 14:49:10 PDT 2023


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/70993

…r operation

>From cf7569dd10dbfc497ae254ee678326428d898688 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 1 Nov 2023 21:48:33 +0000
Subject: [PATCH] [mlir][sparse] Implement rewriters to reinterpret maps on
 alloc_tensor operation

---
 .../Transforms/SparseReinterpretMap.cpp       |  55 ++++++++-
 .../SparseTensor/sparse_reinterpret_map.mlir  |   5 +-
 .../CPU/sparse_conversion_block.mlir          | 105 ++++++++++++++++++
 3 files changed, 161 insertions(+), 4 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index d14df6db8ee6b3f..307a609fd1b7746 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -6,7 +6,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "CodegenUtils.h"
+
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -188,6 +191,56 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
   }
 };
 
+struct TensorAllocDemapper
+    : public OpRewritePattern<bufferization::AllocTensorOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!hasNonIdentityOperandsOrResults(op))
+      return failure();
+
+    Location loc = op.getLoc();
+    auto stt = getSparseTensorType(op.getResult());
+
+    SmallVector<Value> maxDimCrds;
+    maxDimCrds.reserve(stt.getDimRank());
+    ValueRange dynSz = op.getDynamicSizes();
+    for (int64_t dimSz : stt.getDimShape()) {
+      if (ShapedType::isDynamic(dimSz)) {
+        Value maxCrd = rewriter.create<arith::SubIOp>(
+            loc, dynSz.front(), constantIndex(rewriter, loc, 1));
+        maxDimCrds.push_back(maxCrd);
+        dynSz = dynSz.drop_front();
+      } else {
+        maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
+      }
+    }
+
+    ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
+                                              CrdTransDirectionKind::dim2lvl);
+    auto lvlShape = stt.getLvlShape();
+    SmallVector<Value> dynLvlSzs;
+    for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
+      if (ShapedType::isDynamic(lvlShape[i])) {
+        Value sz = rewriter.create<arith::AddIOp>(
+            loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
+        dynLvlSzs.push_back(sz);
+      }
+    }
+
+    assert(dynSz.empty()); // should have consumed all.
+    rewriter.startRootUpdate(op);
+    op->setOperands(dynLvlSzs);
+    op.getResult().setType(stt.getDemappedType());
+    rewriter.finalizeRootUpdate(op);
+    rewriter.setInsertionPointAfter(op);
+
+    Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
+    rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
+    return success();
+  }
+};
+
 struct TensorInsertDemapper
     : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
   using DemapInsRewriter::DemapInsRewriter;
@@ -309,7 +362,7 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
   }
   if (scope == ReinterpretMapScope::kAll ||
       scope == ReinterpretMapScope::kExceptGeneric) {
-    patterns.add<TensorInsertDemapper, ForeachOpDemapper>(
+    patterns.add<TensorAllocDemapper, TensorInsertDemapper, ForeachOpDemapper>(
         patterns.getContext());
   }
 }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
index be3ab37e9cbd182..972364289ac2e2a 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -57,10 +57,9 @@ func.func @mul(%arg0: tensor<32x32xf32>,
 
 // CHECK-LABEL:   func.func @sparse_foreach_reinterpret_map(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<2x4xf64
-// CHECK:           %[[VAL_1:.*]] = bufferization.alloc_tensor() : tensor<2x4xf64
+// CHECK:           %[[VAL_1:.*]] = bufferization.alloc_tensor() : tensor<1x2x2x2xf64
 // CHECK:           %[[VAL_2:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<2x4xf64
-// CHECK:           %[[VAL_3:.*]] = sparse_tensor.reinterpret_map %[[VAL_1]] : tensor<2x4xf64
-// CHECK:           %[[VAL_4:.*]] = sparse_tensor.foreach in %[[VAL_2]] init(%[[VAL_3]])
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.foreach in %[[VAL_2]] init(%[[VAL_1]])
 // CHECK:           ^bb0(%[[VAL_5:.*]]: index, %[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: f64, %[[VAL_10:.*]]: tensor<1x2x2x2xf64
 // CHECK:             %[[VAL_11:.*]] = sparse_tensor.insert %[[VAL_9]] into %[[VAL_10]]{{\[}}%[[VAL_5]], %[[VAL_6]], %[[VAL_7]], %[[VAL_8]]] : tensor<1x2x2x2xf64
 // CHECK:             sparse_tensor.yield %[[VAL_11]] : tensor<1x2x2x2xf64
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
new file mode 100644
index 000000000000000..34a11e748ebd68b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir
@@ -0,0 +1,105 @@
+//--------------------------------------------------------------------------------------------------
+// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
+//
+// Set-up that's shared across all tests in this directory. In principle, this
+// config could be moved to lit.local.cfg. However, there are downstream users that
+//  do not use these LIT config files. Hence why this is kept inline.
+//
+// DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true
+// DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts}
+// DEFINE: %{compile} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts}"
+// DEFINE: %{compile_sve} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts_sve}"
+// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
+// DEFINE: %{run_opts} = -e entry -entry-point-result=void
+// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
+// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
+//
+// DEFINE: %{env} =
+//--------------------------------------------------------------------------------------------------
+
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false enable-buffer-initialization=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation and VLA vectorization.
+// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
+
+#CSR  = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed)
+}>
+
+#CSC  = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d1 : dense, d0 : compressed)
+}>
+
+#BSR = #sparse_tensor.encoding<{
+   map = ( i, j ) ->
+      ( i floordiv 2 : dense,
+        j floordiv 2 : compressed,
+        i mod 2      : dense,
+        j mod 2      : dense
+      )
+}>
+
+
+//
+// Integration test that tests conversions between sparse tensors.
+//
+module {
+  //
+  // Output utilities.
+  //
+  func.func @dumpf64(%arg0: memref<?xf64>) {
+    %c0 = arith.constant 0 : index
+    %d0 = arith.constant -1.0 : f64
+    %0 = vector.transfer_read %arg0[%c0], %d0: memref<?xf64>, vector<8xf64>
+    vector.print %0 : vector<8xf64>
+    return
+  }
+
+  //
+  // Main driver.
+  //
+  func.func @entry() {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+
+    //
+    // Initialize a 3-dim dense tensor.
+    //
+    %t = arith.constant dense<[
+       [  1.0,  2.0,  3.0,  4.0 ],
+       [  5.0,  6.0,  7.0,  8.0 ]
+    ]> : tensor<2x4xf64>
+
+    //
+    // Convert dense tensor directly to various sparse tensors.
+    //    tensor1: stored as 2x3x4
+    //    tensor2: stored as 3x4x2
+    //    tensor3: stored as 4x2x3
+    //
+    %1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #CSR>
+    %2 = sparse_tensor.convert %1 : tensor<2x4xf64, #CSR> to tensor<2x4xf64, #BSR>
+    %3 = sparse_tensor.convert %2 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
+
+    %v1 = sparse_tensor.values %1 : tensor<2x4xf64, #CSR> to memref<?xf64>
+    %v2 = sparse_tensor.values %2 : tensor<2x4xf64, #BSR> to memref<?xf64>
+    %v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSC> to memref<?xf64>
+
+    // CHECK:      ( 1, 2, 3, 4, 5, 6, 7, 8 )
+    // CHECK-NEXT: ( 1, 2, 5, 6, 3, 4, 7, 8 )
+    // CHECK-NEXT: ( 1, 5, 2, 6, 3, 7, 4, 8 )
+    call @dumpf64(%v1) : (memref<?xf64>) -> ()
+    call @dumpf64(%v2) : (memref<?xf64>) -> ()
+    call @dumpf64(%v3) : (memref<?xf64>) -> ()
+
+    return
+  }
+}



More information about the Mlir-commits mailing list