[Mlir-commits] [mlir] 86e9bc1 - [mlir][sparse] add option for 32-bit indices in scatter/gather
Aart Bik
llvmlistbot at llvm.org
Fri Jun 4 16:57:31 PDT 2021
Author: Aart Bik
Date: 2021-06-04T16:57:12-07:00
New Revision: 86e9bc1a34a0eafcce52c0dfda0817b1465a0dc2
URL: https://github.com/llvm/llvm-project/commit/86e9bc1a34a0eafcce52c0dfda0817b1465a0dc2
DIFF: https://github.com/llvm/llvm-project/commit/86e9bc1a34a0eafcce52c0dfda0817b1465a0dc2.diff
LOG: [mlir][sparse] add option for 32-bit indices in scatter/gather
Controlled by a compiler option, if 32-bit indices can be handled
with zero/sign-extention alike (viz. no worries on non-negative
indices), scatter/gather operations can use the more efficient
32-bit SIMD version.
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D103632
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/SparseTensor/sparse_vector.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index ec5dcfea8afe3..fdafe1b274ffd 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -48,15 +48,16 @@ enum class SparseVectorizationStrategy {
/// Sparsification options.
struct SparsificationOptions {
SparsificationOptions(SparseParallelizationStrategy p,
- SparseVectorizationStrategy v, unsigned vl)
- : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl) {
- }
+ SparseVectorizationStrategy v, unsigned vl, bool e)
+ : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl),
+ enableSIMDIndex32(e) {}
SparsificationOptions()
: SparsificationOptions(SparseParallelizationStrategy::kNone,
- SparseVectorizationStrategy::kNone, 1u) {}
+ SparseVectorizationStrategy::kNone, 1u, false) {}
SparseParallelizationStrategy parallelizationStrategy;
SparseVectorizationStrategy vectorizationStrategy;
unsigned vectorLength;
+ bool enableSIMDIndex32;
};
/// Sets up sparsification rewriting rules with the given options.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 67960d286fc9d..290e83df674bc 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -21,6 +21,16 @@ def Sparsification : Pass<"sparsification", "ModuleOp"> {
"sparse_tensor::SparseTensorDialect",
"vector::VectorDialect",
];
+ let options = [
+ Option<"parallelization", "parallelization-strategy", "int32_t", "0",
+ "Set the parallelization strategy">,
+ Option<"vectorization", "vectorization-strategy", "int32_t", "0",
+ "Set the vectorization strategy">,
+ Option<"vectorLength", "vl", "int32_t", "1",
+ "Set the vector length">,
+ Option<"enableSIMDIndex32", "enable-simd-index32", "bool", "false",
+ "Enable i32 indexing into vectors (for efficiency)">
+ ];
}
def SparseTensorConversion : Pass<"sparse-tensor-conversion", "ModuleOp"> {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index dee741e851d28..a779c6ef2aae5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -35,17 +35,6 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
SparsificationPass(const SparsificationPass &pass)
: SparsificationBase<SparsificationPass>() {}
- Option<int32_t> parallelization{
- *this, "parallelization-strategy",
- llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)};
-
- Option<int32_t> vectorization{
- *this, "vectorization-strategy",
- llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)};
-
- Option<int32_t> vectorLength{
- *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)};
-
/// Returns parallelization strategy given on command line.
SparseParallelizationStrategy parallelOption() {
switch (parallelization) {
@@ -79,7 +68,7 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
RewritePatternSet patterns(ctx);
// Translate strategy flags to strategy options.
SparsificationOptions options(parallelOption(), vectorOption(),
- vectorLength);
+ vectorLength, enableSIMDIndex32);
// Apply rewriting.
populateSparsificationPatterns(patterns, options);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index b4bc58f5ced36..c99aafb29c3ee 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -768,9 +768,9 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
// zero extend the vector to an index width. For 8-bit and 16-bit values,
// an 32-bit index width suffices. For 32-bit values, zero extending the
// elements into 64-bit loses some performance since the 32-bit indexed
- // gather/scatter is more efficient than the 64-bit index variant (in
- // the future, we could introduce a flag that states the negative space
- // of 32-bit indices is unused). For 64-bit values, there is no good way
+ // gather/scatter is more efficient than the 64-bit index variant (if the
+ // negative 32-bit index space is unused, the enableSIMDIndex32 flag can
+ // preserve this performance)). For 64-bit values, there is no good way
// to state that the indices are unsigned, with creates the potential of
// incorrect address calculations in the unlikely case we need such
// extremely large offsets.
@@ -780,7 +780,8 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
if (etp.getIntOrFloatBitWidth() < 32)
vload = rewriter.create<ZeroExtendIOp>(
loc, vload, vectorType(codegen, rewriter.getIntegerType(32)));
- else if (etp.getIntOrFloatBitWidth() < 64)
+ else if (etp.getIntOrFloatBitWidth() < 64 &&
+ !codegen.options.enableSIMDIndex32)
vload = rewriter.create<ZeroExtendIOp>(
loc, vload, vectorType(codegen, rewriter.getIntegerType(64)));
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 29dd1f9bb9eae..9674123d13bd2 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -4,6 +4,8 @@
// RUN: FileCheck %s --check-prefix=CHECK-VEC1
// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16" | \
// RUN: FileCheck %s --check-prefix=CHECK-VEC2
+// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16 enable-simd-index32=true" | \
+// RUN: FileCheck %s --check-prefix=CHECK-VEC3
#DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
@@ -148,6 +150,27 @@ func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024
// CHECK-VEC2: }
// CHECK-VEC2: return
//
+// CHECK-VEC3-LABEL: func @mul_s
+// CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index
+// CHECK-VEC3-DAG: %[[c1:.*]] = constant 1 : index
+// CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index
+// CHECK-VEC3: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
+// CHECK-VEC3: %[[a:.*]] = zexti %[[p]] : i32 to i64
+// CHECK-VEC3: %[[q:.*]] = index_cast %[[a]] : i64 to index
+// CHECK-VEC3: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
+// CHECK-VEC3: %[[b:.*]] = zexti %[[r]] : i32 to i64
+// CHECK-VEC3: %[[s:.*]] = index_cast %[[b]] : i64 to index
+// CHECK-VEC3: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
+// CHECK-VEC3: %[[sub:.*]] = subi %{{.*}}, %[[i]] : index
+// CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC3: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
+// CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC3: vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC3: }
+// CHECK-VEC3: return
+//
func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
%0 = linalg.generic #trait_mul_s
ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>)
@@ -310,6 +333,31 @@ func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>
// CHECK-VEC2: }
// CHECK-VEC2: return
//
+// CHECK-VEC3-LABEL: func @mul_ds
+// CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index
+// CHECK-VEC3-DAG: %[[c1:.*]] = constant 1 : index
+// CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index
+// CHECK-VEC3-DAG: %[[c512:.*]] = constant 512 : index
+// CHECK-VEC3: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
+// CHECK-VEC3: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
+// CHECK-VEC3: %[[a:.*]] = zexti %[[p]] : i32 to i64
+// CHECK-VEC3: %[[q:.*]] = index_cast %[[a]] : i64 to index
+// CHECK-VEC3: %[[a:.*]] = addi %[[i]], %[[c1]] : index
+// CHECK-VEC3: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
+// CHECK-VEC3: %[[b:.*]] = zexti %[[r]] : i32 to i64
+// CHECK-VEC3: %[[s:.*]] = index_cast %[[b]] : i64 to index
+// CHECK-VEC3: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
+// CHECK-VEC3: %[[sub:.*]] = subi %[[s]], %[[j]] : index
+// CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC3: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
+// CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC3: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC3: }
+// CHECK-VEC3: }
+// CHECK-VEC3: return
+//
func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
%0 = linalg.generic #trait_mul_ds
ins(%arga, %argb: tensor<512x1024xf32, #SparseMatrix>, tensor<512x1024xf32>)
More information about the Mlir-commits
mailing list