[Mlir-commits] [mlir] [mlir][sparse] code cleanup using the assumption that dim2lvl maps ar… (PR #72894)

Peiming Liu llvmlistbot at llvm.org
Mon Nov 20 09:46:19 PST 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/72894

>From 7a8cdc5a3e16f653c7b7a89063be5e6e34618c62 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 20 Nov 2023 17:33:34 +0000
Subject: [PATCH 1/2] [mlir][sparse] code cleanup using the assumption that
 dim2lvl maps are simplified.

---
 .../Transforms/SparseTensorCodegen.cpp        | 35 +++++++++----------
 mlir/test/Dialect/SparseTensor/codegen.mlir   |  2 +-
 2 files changed, 17 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 9f41db73a5091120..1549ada203b4ae32 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -745,8 +745,13 @@ class SparseTensorAllocConverter
     const auto resType = getSparseTensorType(op);
     if (!resType.hasEncoding())
       return failure();
-    Location loc = op.getLoc();
 
+    if (!resType.isIdentity()) {
+      return rewriter.notifyMatchFailure(
+          op, "try run --sparse-reinterpret-map before codegen");
+    }
+
+    Location loc = op.getLoc();
     // Deal with copy.
     if (op.getCopy()) {
       auto desc = getDescriptorFromTensorTuple(adaptor.getCopy());
@@ -768,16 +773,10 @@ class SparseTensorAllocConverter
       return success();
     }
 
-    // Construct the dim/lvl sizes and the (unused) dim2lvl/lvl2dim buffers.
-    SmallVector<Value> dimSizesValues;
+    // Level size equals to dimension size since lvl2dim map is an identity map.
     SmallVector<Value> lvlSizesValues;
-    Value dimSizesBuffer;
-    Value dim2lvlBuffer;
-    Value lvl2dimBuffer;
     createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
-                   dimSizesValues);
-    genMapBuffers(rewriter, loc, resType, dimSizesValues, dimSizesBuffer,
-                  lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
+                   lvlSizesValues);
 
     // Construct allocation for each field.
     Value sizeHint = op.getSizeHint();
@@ -809,19 +808,17 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
     const auto resType = getSparseTensorType(op);
     if (!resType.hasEncoding())
       return failure();
-    Location loc = op.getLoc();
 
-    // Construct the dim/lvl sizes and the (unused) dim2lvl/lvl2dim buffers.
-    SmallVector<Value> dimSizesValues;
+    if (!resType.isIdentity()) {
+      return rewriter.notifyMatchFailure(
+          op, "try run --sparse-reinterpret-map before codegen");
+    }
+
+    Location loc = op.getLoc();
+    // Level size equals to dimension size since lvl2dim map is an identity map.
     SmallVector<Value> lvlSizesValues;
-    Value dimSizesBuffer;
-    Value dim2lvlBuffer;
-    Value lvl2dimBuffer;
     createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
-                   dimSizesValues);
-    genMapBuffers(rewriter, loc, resType, dimSizesValues, dimSizesBuffer,
-                  lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
-
+                   lvlSizesValues);
     // Construct allocation for each field.
     Value sizeHint; // none
     SmallVector<Value> fields;
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index e63595bed53e5998..a3b26972d66ff5d8 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --lower-sparse-ops-to-foreach --lower-sparse-foreach-to-scf --sparse-tensor-codegen  --canonicalize -cse | FileCheck %s
+// RUN: mlir-opt %s --lower-sparse-ops-to-foreach --lower-sparse-foreach-to-scf --sparse-reinterpret-map --sparse-tensor-codegen  --canonicalize -cse | FileCheck %s
 
 #SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
 

>From a1b60bcadf49c03f78f05a3184643c836b240b86 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 20 Nov 2023 17:45:58 +0000
Subject: [PATCH 2/2] address comments

---
 .../SparseTensor/Transforms/SparseTensorCodegen.cpp | 13 ++++++-------
 1 file changed, 6 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 1549ada203b4ae32..1c0366c6476a3856 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -746,11 +746,6 @@ class SparseTensorAllocConverter
     if (!resType.hasEncoding())
       return failure();
 
-    if (!resType.isIdentity()) {
-      return rewriter.notifyMatchFailure(
-          op, "try run --sparse-reinterpret-map before codegen");
-    }
-
     Location loc = op.getLoc();
     // Deal with copy.
     if (op.getCopy()) {
@@ -773,10 +768,14 @@ class SparseTensorAllocConverter
       return success();
     }
 
+    if (!resType.isIdentity()) {
+      return rewriter.notifyMatchFailure(
+          op, "try run --sparse-reinterpret-map before codegen");
+    }
     // Level size equals to dimension size since lvl2dim map is an identity map.
     SmallVector<Value> lvlSizesValues;
     createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
-                   lvlSizesValues);
+                   /*dimSizesValues=*/lvlSizesValues);
 
     // Construct allocation for each field.
     Value sizeHint = op.getSizeHint();
@@ -818,7 +817,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
     // Level size equals to dimension size since lvl2dim map is an identity map.
     SmallVector<Value> lvlSizesValues;
     createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
-                   lvlSizesValues);
+                   /*dimSizesValues=*/lvlSizesValues);
     // Construct allocation for each field.
     Value sizeHint; // none
     SmallVector<Value> fields;



More information about the Mlir-commits mailing list