[Mlir-commits] [mlir] d5f0d0c - [mlir][sparse] add ability to select pointer/index storage type
Aart Bik
llvmlistbot at llvm.org
Wed Nov 25 17:33:05 PST 2020
Author: Aart Bik
Date: 2020-11-25T17:32:44-08:00
New Revision: d5f0d0c0c4117295d9e76bbafaf0597e01ef3c99
URL: https://github.com/llvm/llvm-project/commit/d5f0d0c0c4117295d9e76bbafaf0597e01ef3c99
DIFF: https://github.com/llvm/llvm-project/commit/d5f0d0c0c4117295d9e76bbafaf0597e01ef3c99.diff
LOG: [mlir][sparse] add ability to select pointer/index storage type
This change gives sparse compiler clients more control over selecting
individual types for the pointers and indices in the sparse storage schemes.
Narrower width obviously results in smaller memory footprints, but the
range should always suffice for the maximum number of entries or index value.
Reviewed By: penpornk
Differential Revision: https://reviews.llvm.org/D92126
Added:
mlir/test/Dialect/Linalg/sparse_storage.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
mlir/test/lib/Transforms/TestSparsification.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7a96a5ed1390..b37a14f0eb7a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -821,18 +821,31 @@ enum class SparseVectorizationStrategy {
kAnyStorageInnerLoop
};
+/// Defines a type for "pointer" and "index" storage in the sparse storage
+/// scheme, with a choice between the native platform-dependent index width,
+/// 64-bit integers, or 32-bit integers. A narrow width obviously reduces
+/// the memory footprint of the sparse storage scheme, but the width should
+/// suffice to define the total required range (viz. the maximum number of
+/// stored entries per indirection level for the "pointers" and the maximum
+/// value of each tensor index over all dimensions for the "indices").
+enum class SparseIntType { kNative, kI64, kI32 };
+
/// Sparsification options.
struct SparsificationOptions {
SparsificationOptions(SparseParallelizationStrategy p,
- SparseVectorizationStrategy v, unsigned vl)
- : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl) {
- }
+ SparseVectorizationStrategy v, unsigned vl,
+ SparseIntType pt, SparseIntType it)
+ : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl),
+ ptrType(pt), indType(it) {}
SparsificationOptions()
: SparsificationOptions(SparseParallelizationStrategy::kNone,
- SparseVectorizationStrategy::kNone, 1u) {}
+ SparseVectorizationStrategy::kNone, 1u,
+ SparseIntType::kNative, SparseIntType::kNative) {}
SparseParallelizationStrategy parallelizationStrategy;
SparseVectorizationStrategy vectorizationStrategy;
unsigned vectorLength;
+ SparseIntType ptrType;
+ SparseIntType indType;
};
/// Set up sparsification rewriting rules with the given options.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index 729268393ed9..07a3e1569622 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -420,16 +420,27 @@ static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
}
}
+/// Maps sparse integer option to actual integral storage type.
+static Type genIntType(PatternRewriter &rewriter, linalg::SparseIntType tp) {
+ switch (tp) {
+ case linalg::SparseIntType::kNative:
+ return rewriter.getIndexType();
+ case linalg::SparseIntType::kI64:
+ return rewriter.getIntegerType(64);
+ case linalg::SparseIntType::kI32:
+ return rewriter.getIntegerType(32);
+ }
+}
+
/// Local bufferization of all dense and sparse data structures.
/// This code enables testing the first prototype sparse compiler.
// TODO: replace this with a proliferated bufferization strategy
-void genBuffers(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
- linalg::GenericOp op) {
+static void genBuffers(Merger &merger, CodeGen &codegen,
+ PatternRewriter &rewriter, linalg::GenericOp op) {
Location loc = op.getLoc();
unsigned numTensors = op.getNumInputsAndOutputs();
unsigned numInputs = op.getNumInputs();
assert(numTensors == numInputs + 1);
- Type indexType = rewriter.getIndexType();
// For now, set all unknown dimensions to 999.
// TODO: compute these values (using sparsity or by reading tensor)
@@ -450,9 +461,13 @@ void genBuffers(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
// Handle sparse storage schemes.
if (merger.isSparseAccess(t, i)) {
allDense = false;
- auto dynTp = MemRefType::get({ShapedType::kDynamicSize}, indexType);
- codegen.pointers[t][i] = rewriter.create<AllocaOp>(loc, dynTp, unknown);
- codegen.indices[t][i] = rewriter.create<AllocaOp>(loc, dynTp, unknown);
+ auto dynShape = {ShapedType::kDynamicSize};
+ auto ptrTp = MemRefType::get(
+ dynShape, genIntType(rewriter, codegen.options.ptrType));
+ auto indTp = MemRefType::get(
+ dynShape, genIntType(rewriter, codegen.options.indType));
+ codegen.pointers[t][i] = rewriter.create<AllocaOp>(loc, ptrTp, unknown);
+ codegen.indices[t][i] = rewriter.create<AllocaOp>(loc, indTp, unknown);
}
// Find lower and upper bound in current dimension.
Value up;
@@ -516,6 +531,15 @@ static void genTensorStore(Merger &merger, CodeGen &codegen,
rewriter.create<StoreOp>(op.getLoc(), rhs, codegen.buffers[tensor], args);
}
+/// Generates a pointer/index load from the sparse storage scheme.
+static Value genIntLoad(PatternRewriter &rewriter, Location loc, Value ptr,
+ Value s) {
+ Value load = rewriter.create<LoadOp>(loc, ptr, s);
+ return load.getType().isa<IndexType>()
+ ? load
+ : rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
+}
+
/// Recursively generates tensor expression.
static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
linalg::GenericOp op, unsigned exp) {
@@ -551,7 +575,6 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
unsigned idx = topSort[at];
// Initialize sparse positions.
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
for (unsigned b = 0, be = inits.size(); b < be; b++) {
if (inits[b]) {
unsigned tensor = merger.tensor(b);
@@ -564,11 +587,12 @@ static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
break;
}
Value ptr = codegen.pointers[tensor][idx];
- Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
- : codegen.pidxs[tensor][topSort[pat - 1]];
- codegen.pidxs[tensor][idx] = rewriter.create<LoadOp>(loc, ptr, p);
- p = rewriter.create<AddIOp>(loc, p, one);
- codegen.highs[tensor][idx] = rewriter.create<LoadOp>(loc, ptr, p);
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+ Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
+ : codegen.pidxs[tensor][topSort[pat - 1]];
+ codegen.pidxs[tensor][idx] = genIntLoad(rewriter, loc, ptr, p0);
+ Value p1 = rewriter.create<AddIOp>(loc, p0, one);
+ codegen.highs[tensor][idx] = genIntLoad(rewriter, loc, ptr, p1);
} else {
// Dense index still in play.
needsUniv = true;
@@ -723,15 +747,17 @@ static void genLocals(Merger &merger, CodeGen &codegen,
if (locals[b] && merger.isSparseBit(b)) {
unsigned tensor = merger.tensor(b);
assert(idx == merger.index(b));
- Value ld = rewriter.create<LoadOp>(loc, codegen.indices[tensor][idx],
- codegen.pidxs[tensor][idx]);
- codegen.idxs[tensor][idx] = ld;
+ Value ptr = codegen.indices[tensor][idx];
+ Value s = codegen.pidxs[tensor][idx];
+ Value load = genIntLoad(rewriter, loc, ptr, s);
+ codegen.idxs[tensor][idx] = load;
if (!needsUniv) {
if (min) {
- Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, ld, min);
- min = rewriter.create<SelectOp>(loc, cmp, ld, min);
+ Value cmp =
+ rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min);
+ min = rewriter.create<SelectOp>(loc, cmp, load, min);
} else {
- min = ld;
+ min = load;
}
}
}
diff --git a/mlir/test/Dialect/Linalg/sparse_storage.mlir b/mlir/test/Dialect/Linalg/sparse_storage.mlir
new file mode 100644
index 000000000000..c63bdb1e413d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/sparse_storage.mlir
@@ -0,0 +1,98 @@
+// RUN: mlir-opt %s -test-sparsification="ptr-type=1 ind-type=1" | \
+// RUN: FileCheck %s --check-prefix=CHECK-TYPE0
+// RUN: mlir-opt %s -test-sparsification="ptr-type=1 ind-type=2" | \
+// RUN: FileCheck %s --check-prefix=CHECK-TYPE1
+// RUN: mlir-opt %s -test-sparsification="ptr-type=2 ind-type=1" | \
+// RUN: FileCheck %s --check-prefix=CHECK-TYPE2
+// RUN: mlir-opt %s -test-sparsification="ptr-type=2 ind-type=2" | \
+// RUN: FileCheck %s --check-prefix=CHECK-TYPE3
+
+#trait_mul_1d = {
+ indexing_maps = [
+ affine_map<(i) -> (i)>, // a
+ affine_map<(i) -> (i)>, // b
+ affine_map<(i) -> (i)> // x (out)
+ ],
+ sparse = [
+ [ "S" ], // a
+ [ "D" ], // b
+ [ "D" ] // x
+ ],
+ iterator_types = ["parallel"],
+ doc = "x(i) = a(i) * b(i)"
+}
+
+// CHECK-TYPE0-LABEL: func @mul_dd(
+// CHECK-TYPE0: %[[C0:.*]] = constant 0 : index
+// CHECK-TYPE0: %[[C1:.*]] = constant 1 : index
+// CHECK-TYPE0: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi64>
+// CHECK-TYPE0: %[[B0:.*]] = index_cast %[[P0]] : i64 to index
+// CHECK-TYPE0: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi64>
+// CHECK-TYPE0: %[[B1:.*]] = index_cast %[[P1]] : i64 to index
+// CHECK-TYPE0: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
+// CHECK-TYPE0: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi64>
+// CHECK-TYPE0: %[[INDC:.*]] = index_cast %[[IND0]] : i64 to index
+// CHECK-TYPE0: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
+// CHECK-TYPE0: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE0: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
+// CHECK-TYPE0: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE0: }
+
+// CHECK-TYPE1-LABEL: func @mul_dd(
+// CHECK-TYPE1: %[[C0:.*]] = constant 0 : index
+// CHECK-TYPE1: %[[C1:.*]] = constant 1 : index
+// CHECK-TYPE1: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi64>
+// CHECK-TYPE1: %[[B0:.*]] = index_cast %[[P0]] : i64 to index
+// CHECK-TYPE1: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi64>
+// CHECK-TYPE1: %[[B1:.*]] = index_cast %[[P1]] : i64 to index
+// CHECK-TYPE1: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
+// CHECK-TYPE1: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi32>
+// CHECK-TYPE1: %[[INDC:.*]] = index_cast %[[IND0]] : i32 to index
+// CHECK-TYPE1: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
+// CHECK-TYPE1: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE1: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
+// CHECK-TYPE1: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE1: }
+
+// CHECK-TYPE2-LABEL: func @mul_dd(
+// CHECK-TYPE2: %[[C0:.*]] = constant 0 : index
+// CHECK-TYPE2: %[[C1:.*]] = constant 1 : index
+// CHECK-TYPE2: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi32>
+// CHECK-TYPE2: %[[B0:.*]] = index_cast %[[P0]] : i32 to index
+// CHECK-TYPE2: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi32>
+// CHECK-TYPE2: %[[B1:.*]] = index_cast %[[P1]] : i32 to index
+// CHECK-TYPE2: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
+// CHECK-TYPE2: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi64>
+// CHECK-TYPE2: %[[INDC:.*]] = index_cast %[[IND0]] : i64 to index
+// CHECK-TYPE2: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
+// CHECK-TYPE2: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE2: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
+// CHECK-TYPE2: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE2: }
+
+// CHECK-TYPE3-LABEL: func @mul_dd(
+// CHECK-TYPE3: %[[C0:.*]] = constant 0 : index
+// CHECK-TYPE3: %[[C1:.*]] = constant 1 : index
+// CHECK-TYPE3: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi32>
+// CHECK-TYPE3: %[[B0:.*]] = index_cast %[[P0]] : i32 to index
+// CHECK-TYPE3: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi32>
+// CHECK-TYPE3: %[[B1:.*]] = index_cast %[[P1]] : i32 to index
+// CHECK-TYPE3: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
+// CHECK-TYPE3: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi32>
+// CHECK-TYPE3: %[[INDC:.*]] = index_cast %[[IND0]] : i32 to index
+// CHECK-TYPE3: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
+// CHECK-TYPE3: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE3: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
+// CHECK-TYPE3: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
+// CHECK-TYPE3: }
+
+func @mul_dd(%arga: tensor<32xf64>, %argb: tensor<32xf64>) -> tensor<32xf64> {
+ %0 = linalg.generic #trait_mul_1d
+ ins(%arga, %argb: tensor<32xf64>, tensor<32xf64>) {
+ ^bb(%a: f64, %b: f64):
+ %0 = mulf %a, %b : f64
+ linalg.yield %0 : f64
+ } -> tensor<32xf64>
+ return %0 : tensor<32xf64>
+}
+
diff --git a/mlir/test/lib/Transforms/TestSparsification.cpp b/mlir/test/lib/Transforms/TestSparsification.cpp
index 7544e48174b8..a960ca75b9f3 100644
--- a/mlir/test/lib/Transforms/TestSparsification.cpp
+++ b/mlir/test/lib/Transforms/TestSparsification.cpp
@@ -31,6 +31,14 @@ struct TestSparsification
Option<int32_t> vectorLength{
*this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)};
+ Option<int32_t> ptrType{*this, "ptr-type",
+ llvm::cl::desc("Set the pointer type"),
+ llvm::cl::init(0)};
+
+ Option<int32_t> indType{*this, "ind-type",
+ llvm::cl::desc("Set the index type"),
+ llvm::cl::init(0)};
+
/// Registers all dialects required by testing.
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<scf::SCFDialect, vector::VectorDialect>();
@@ -64,13 +72,26 @@ struct TestSparsification
}
}
+ /// Returns the requested integer type.
+ linalg::SparseIntType typeOption(int32_t option) {
+ switch (option) {
+ default:
+ return linalg::SparseIntType::kNative;
+ case 1:
+ return linalg::SparseIntType::kI64;
+ case 2:
+ return linalg::SparseIntType::kI32;
+ }
+ }
+
/// Runs the test on a function.
void runOnFunction() override {
auto *ctx = &getContext();
OwningRewritePatternList patterns;
// Translate strategy flags to strategy options.
linalg::SparsificationOptions options(parallelOption(), vectorOption(),
- vectorLength);
+ vectorLength, typeOption(ptrType),
+ typeOption(indType));
// Apply rewriting.
linalg::populateSparsificationPatterns(ctx, patterns, options);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
More information about the Mlir-commits
mailing list