[llvm-branch-commits] [mlir] b5c542d - [mlir][sparse] add narrower choices for pointers/indices

Aart Bik via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jan 19 20:24:56 PST 2021


Author: Aart Bik
Date: 2021-01-19T20:20:38-08:00
New Revision: b5c542d64b98b5a74d35dedad41051a0b00d7946

URL: https://github.com/llvm/llvm-project/commit/b5c542d64b98b5a74d35dedad41051a0b00d7946
DIFF: https://github.com/llvm/llvm-project/commit/b5c542d64b98b5a74d35dedad41051a0b00d7946.diff

LOG: [mlir][sparse] add narrower choices for pointers/indices

Use cases with 16- or even 8-bit pointer/index structures have been identified.

Reviewed By: penpornk

Differential Revision: https://reviews.llvm.org/D95015

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
    mlir/test/Dialect/Linalg/sparse_storage.mlir
    mlir/test/lib/Transforms/TestSparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 0effa2f45c20..611ab6867372 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -853,13 +853,13 @@ enum class SparseVectorizationStrategy {
 };
 
 /// Defines a type for "pointer" and "index" storage in the sparse storage
-/// scheme, with a choice between the native platform-dependent index width,
-/// 64-bit integers, or 32-bit integers. A narrow width obviously reduces
+/// scheme, with a choice between the native platform-dependent index width
+/// or any of 64-/32-/16-/8-bit integers. A narrow width obviously reduces
 /// the memory footprint of the sparse storage scheme, but the width should
 /// suffice to define the total required range (viz. the maximum number of
 /// stored entries per indirection level for the "pointers" and the maximum
 /// value of each tensor index over all dimensions for the "indices").
-enum class SparseIntType { kNative, kI64, kI32 };
+enum class SparseIntType { kNative, kI64, kI32, kI16, kI8 };
 
 /// Sparsification options.
 struct SparsificationOptions {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index 898b15266072..cefcdcbed9ae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -512,6 +512,10 @@ static Type genIntType(PatternRewriter &rewriter, linalg::SparseIntType tp) {
     return rewriter.getIntegerType(64);
   case linalg::SparseIntType::kI32:
     return rewriter.getIntegerType(32);
+  case linalg::SparseIntType::kI16:
+    return rewriter.getIntegerType(16);
+  case linalg::SparseIntType::kI8:
+    return rewriter.getIntegerType(8);
   }
   llvm_unreachable("unexpected SparseIntType");
 }

diff  --git a/mlir/test/Dialect/Linalg/sparse_storage.mlir b/mlir/test/Dialect/Linalg/sparse_storage.mlir
index 69b8e1903d69..ef5dc0d766e3 100644
--- a/mlir/test/Dialect/Linalg/sparse_storage.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_storage.mlir
@@ -6,6 +6,10 @@
 // RUN:   FileCheck %s --check-prefix=CHECK-TYPE2
 // RUN: mlir-opt %s -test-sparsification="ptr-type=2 ind-type=2" | \
 // RUN:   FileCheck %s --check-prefix=CHECK-TYPE3
+// RUN: mlir-opt %s -test-sparsification="ptr-type=3 ind-type=3" | \
+// RUN:   FileCheck %s --check-prefix=CHECK-TYPE4
+// RUN: mlir-opt %s -test-sparsification="ptr-type=4 ind-type=4" | \
+// RUN:   FileCheck %s --check-prefix=CHECK-TYPE5
 
 #trait_mul_1d = {
   indexing_maps = [
@@ -86,6 +90,38 @@
 // CHECK-TYPE3:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
 // CHECK-TYPE3: }
 
+// CHECK-TYPE4-LABEL: func @mul_dd(
+// CHECK-TYPE4: %[[C0:.*]] = constant 0 : index
+// CHECK-TYPE4: %[[C1:.*]] = constant 1 : index
+// CHECK-TYPE4: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi16>
+// CHECK-TYPE4: %[[B0:.*]] = index_cast %[[P0]] : i16 to index
+// CHECK-TYPE4: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi16>
+// CHECK-TYPE4: %[[B1:.*]] = index_cast %[[P1]] : i16 to index
+// CHECK-TYPE4: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
+// CHECK-TYPE4:   %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi16>
+// CHECK-TYPE4:   %[[INDC:.*]] = index_cast %[[IND0]] : i16 to index
+// CHECK-TYPE4:   %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
+// CHECK-TYPE4:   %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE4:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
+// CHECK-TYPE4:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE4: }
+
+// CHECK-TYPE5-LABEL: func @mul_dd(
+// CHECK-TYPE5: %[[C0:.*]] = constant 0 : index
+// CHECK-TYPE5: %[[C1:.*]] = constant 1 : index
+// CHECK-TYPE5: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi8>
+// CHECK-TYPE5: %[[B0:.*]] = index_cast %[[P0]] : i8 to index
+// CHECK-TYPE5: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi8>
+// CHECK-TYPE5: %[[B1:.*]] = index_cast %[[P1]] : i8 to index
+// CHECK-TYPE5: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
+// CHECK-TYPE5:   %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi8>
+// CHECK-TYPE5:   %[[INDC:.*]] = index_cast %[[IND0]] : i8 to index
+// CHECK-TYPE5:   %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
+// CHECK-TYPE5:   %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE5:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
+// CHECK-TYPE5:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE5: }
+
 func @mul_dd(%arga: tensor<32xf64>, %argb: tensor<32xf64>) -> tensor<32xf64> {
   %0 = linalg.generic #trait_mul_1d
      ins(%arga, %argb: tensor<32xf64>, tensor<32xf64>)

diff  --git a/mlir/test/lib/Transforms/TestSparsification.cpp b/mlir/test/lib/Transforms/TestSparsification.cpp
index 441c101d5cde..f5c0ae67836c 100644
--- a/mlir/test/lib/Transforms/TestSparsification.cpp
+++ b/mlir/test/lib/Transforms/TestSparsification.cpp
@@ -82,6 +82,10 @@ struct TestSparsification
       return linalg::SparseIntType::kI64;
     case 2:
       return linalg::SparseIntType::kI32;
+    case 3:
+      return linalg::SparseIntType::kI16;
+    case 4:
+      return linalg::SparseIntType::kI8;
     }
   }
 


        


More information about the llvm-branch-commits mailing list