[llvm-branch-commits] [mlir] b5c542d - [mlir][sparse] add narrower choices for pointers/indices
Aart Bik via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jan 19 20:24:56 PST 2021
Author: Aart Bik
Date: 2021-01-19T20:20:38-08:00
New Revision: b5c542d64b98b5a74d35dedad41051a0b00d7946
URL: https://github.com/llvm/llvm-project/commit/b5c542d64b98b5a74d35dedad41051a0b00d7946
DIFF: https://github.com/llvm/llvm-project/commit/b5c542d64b98b5a74d35dedad41051a0b00d7946.diff
LOG: [mlir][sparse] add narrower choices for pointers/indices
Use cases with 16- or even 8-bit pointer/index structures have been identified.
Reviewed By: penpornk
Differential Revision: https://reviews.llvm.org/D95015
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
mlir/test/Dialect/Linalg/sparse_storage.mlir
mlir/test/lib/Transforms/TestSparsification.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 0effa2f45c20..611ab6867372 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -853,13 +853,13 @@ enum class SparseVectorizationStrategy {
};
/// Defines a type for "pointer" and "index" storage in the sparse storage
-/// scheme, with a choice between the native platform-dependent index width,
-/// 64-bit integers, or 32-bit integers. A narrow width obviously reduces
+/// scheme, with a choice between the native platform-dependent index width
+/// or any of 64-/32-/16-/8-bit integers. A narrow width obviously reduces
/// the memory footprint of the sparse storage scheme, but the width should
/// suffice to define the total required range (viz. the maximum number of
/// stored entries per indirection level for the "pointers" and the maximum
/// value of each tensor index over all dimensions for the "indices").
-enum class SparseIntType { kNative, kI64, kI32 };
+enum class SparseIntType { kNative, kI64, kI32, kI16, kI8 };
/// Sparsification options.
struct SparsificationOptions {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index 898b15266072..cefcdcbed9ae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -512,6 +512,10 @@ static Type genIntType(PatternRewriter &rewriter, linalg::SparseIntType tp) {
return rewriter.getIntegerType(64);
case linalg::SparseIntType::kI32:
return rewriter.getIntegerType(32);
+ case linalg::SparseIntType::kI16:
+ return rewriter.getIntegerType(16);
+ case linalg::SparseIntType::kI8:
+ return rewriter.getIntegerType(8);
}
llvm_unreachable("unexpected SparseIntType");
}
diff --git a/mlir/test/Dialect/Linalg/sparse_storage.mlir b/mlir/test/Dialect/Linalg/sparse_storage.mlir
index 69b8e1903d69..ef5dc0d766e3 100644
--- a/mlir/test/Dialect/Linalg/sparse_storage.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_storage.mlir
@@ -6,6 +6,10 @@
// RUN: FileCheck %s --check-prefix=CHECK-TYPE2
// RUN: mlir-opt %s -test-sparsification="ptr-type=2 ind-type=2" | \
// RUN: FileCheck %s --check-prefix=CHECK-TYPE3
+// RUN: mlir-opt %s -test-sparsification="ptr-type=3 ind-type=3" | \
+// RUN: FileCheck %s --check-prefix=CHECK-TYPE4
+// RUN: mlir-opt %s -test-sparsification="ptr-type=4 ind-type=4" | \
+// RUN: FileCheck %s --check-prefix=CHECK-TYPE5
#trait_mul_1d = {
indexing_maps = [
@@ -86,6 +90,38 @@
// CHECK-TYPE3: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
// CHECK-TYPE3: }
+// CHECK-TYPE4-LABEL: func @mul_dd(
+// CHECK-TYPE4: %[[C0:.*]] = constant 0 : index
+// CHECK-TYPE4: %[[C1:.*]] = constant 1 : index
+// CHECK-TYPE4: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi16>
+// CHECK-TYPE4: %[[B0:.*]] = index_cast %[[P0]] : i16 to index
+// CHECK-TYPE4: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi16>
+// CHECK-TYPE4: %[[B1:.*]] = index_cast %[[P1]] : i16 to index
+// CHECK-TYPE4: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
+// CHECK-TYPE4: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi16>
+// CHECK-TYPE4: %[[INDC:.*]] = index_cast %[[IND0]] : i16 to index
+// CHECK-TYPE4: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
+// CHECK-TYPE4: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE4: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
+// CHECK-TYPE4: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE4: }
+
+// CHECK-TYPE5-LABEL: func @mul_dd(
+// CHECK-TYPE5: %[[C0:.*]] = constant 0 : index
+// CHECK-TYPE5: %[[C1:.*]] = constant 1 : index
+// CHECK-TYPE5: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi8>
+// CHECK-TYPE5: %[[B0:.*]] = index_cast %[[P0]] : i8 to index
+// CHECK-TYPE5: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi8>
+// CHECK-TYPE5: %[[B1:.*]] = index_cast %[[P1]] : i8 to index
+// CHECK-TYPE5: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
+// CHECK-TYPE5: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi8>
+// CHECK-TYPE5: %[[INDC:.*]] = index_cast %[[IND0]] : i8 to index
+// CHECK-TYPE5: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
+// CHECK-TYPE5: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE5: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
+// CHECK-TYPE5: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE5: }
+
func @mul_dd(%arga: tensor<32xf64>, %argb: tensor<32xf64>) -> tensor<32xf64> {
%0 = linalg.generic #trait_mul_1d
ins(%arga, %argb: tensor<32xf64>, tensor<32xf64>)
diff --git a/mlir/test/lib/Transforms/TestSparsification.cpp b/mlir/test/lib/Transforms/TestSparsification.cpp
index 441c101d5cde..f5c0ae67836c 100644
--- a/mlir/test/lib/Transforms/TestSparsification.cpp
+++ b/mlir/test/lib/Transforms/TestSparsification.cpp
@@ -82,6 +82,10 @@ struct TestSparsification
return linalg::SparseIntType::kI64;
case 2:
return linalg::SparseIntType::kI32;
+ case 3:
+ return linalg::SparseIntType::kI16;
+ case 4:
+ return linalg::SparseIntType::kI8;
}
}
More information about the llvm-branch-commits
mailing list