[Mlir-commits] [mlir] 86e9bc1 - [mlir][sparse] add option for 32-bit indices in scatter/gather

Aart Bik llvmlistbot at llvm.org
Fri Jun 4 16:57:31 PDT 2021


Author: Aart Bik
Date: 2021-06-04T16:57:12-07:00
New Revision: 86e9bc1a34a0eafcce52c0dfda0817b1465a0dc2

URL: https://github.com/llvm/llvm-project/commit/86e9bc1a34a0eafcce52c0dfda0817b1465a0dc2
DIFF: https://github.com/llvm/llvm-project/commit/86e9bc1a34a0eafcce52c0dfda0817b1465a0dc2.diff

LOG: [mlir][sparse] add option for 32-bit indices in scatter/gather

Controlled by a compiler option, if 32-bit indices can be handled
with zero/sign-extention alike (viz. no worries on non-negative
indices), scatter/gather operations can use the more efficient
32-bit SIMD version.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D103632

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
    mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Dialect/SparseTensor/sparse_vector.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index ec5dcfea8afe3..fdafe1b274ffd 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -48,15 +48,16 @@ enum class SparseVectorizationStrategy {
 /// Sparsification options.
 struct SparsificationOptions {
   SparsificationOptions(SparseParallelizationStrategy p,
-                        SparseVectorizationStrategy v, unsigned vl)
-      : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl) {
-  }
+                        SparseVectorizationStrategy v, unsigned vl, bool e)
+      : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl),
+        enableSIMDIndex32(e) {}
   SparsificationOptions()
       : SparsificationOptions(SparseParallelizationStrategy::kNone,
-                              SparseVectorizationStrategy::kNone, 1u) {}
+                              SparseVectorizationStrategy::kNone, 1u, false) {}
   SparseParallelizationStrategy parallelizationStrategy;
   SparseVectorizationStrategy vectorizationStrategy;
   unsigned vectorLength;
+  bool enableSIMDIndex32;
 };
 
 /// Sets up sparsification rewriting rules with the given options.

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 67960d286fc9d..290e83df674bc 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -21,6 +21,16 @@ def Sparsification : Pass<"sparsification", "ModuleOp"> {
     "sparse_tensor::SparseTensorDialect",
     "vector::VectorDialect",
   ];
+  let options = [
+    Option<"parallelization", "parallelization-strategy", "int32_t", "0",
+           "Set the parallelization strategy">,
+    Option<"vectorization", "vectorization-strategy", "int32_t", "0",
+           "Set the vectorization strategy">,
+    Option<"vectorLength", "vl", "int32_t", "1",
+           "Set the vector length">,
+    Option<"enableSIMDIndex32", "enable-simd-index32", "bool", "false",
+           "Enable i32 indexing into vectors (for efficiency)">
+  ];
 }
 
 def SparseTensorConversion : Pass<"sparse-tensor-conversion", "ModuleOp"> {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index dee741e851d28..a779c6ef2aae5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -35,17 +35,6 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
   SparsificationPass(const SparsificationPass &pass)
       : SparsificationBase<SparsificationPass>() {}
 
-  Option<int32_t> parallelization{
-      *this, "parallelization-strategy",
-      llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)};
-
-  Option<int32_t> vectorization{
-      *this, "vectorization-strategy",
-      llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)};
-
-  Option<int32_t> vectorLength{
-      *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)};
-
   /// Returns parallelization strategy given on command line.
   SparseParallelizationStrategy parallelOption() {
     switch (parallelization) {
@@ -79,7 +68,7 @@ struct SparsificationPass : public SparsificationBase<SparsificationPass> {
     RewritePatternSet patterns(ctx);
     // Translate strategy flags to strategy options.
     SparsificationOptions options(parallelOption(), vectorOption(),
-                                  vectorLength);
+                                  vectorLength, enableSIMDIndex32);
     // Apply rewriting.
     populateSparsificationPatterns(patterns, options);
     vector::populateVectorToVectorCanonicalizationPatterns(patterns);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index b4bc58f5ced36..c99aafb29c3ee 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -768,9 +768,9 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
     // zero extend the vector to an index width. For 8-bit and 16-bit values,
     // an 32-bit index width suffices. For 32-bit values, zero extending the
     // elements into 64-bit loses some performance since the 32-bit indexed
-    // gather/scatter is more efficient than the 64-bit index variant (in
-    // the future, we could introduce a flag that states the negative space
-    // of 32-bit indices is unused). For 64-bit values, there is no good way
+    // gather/scatter is more efficient than the 64-bit index variant (if the
+    // negative 32-bit index space is unused, the enableSIMDIndex32 flag can
+    // preserve this performance)). For 64-bit values, there is no good way
     // to state that the indices are unsigned, with creates the potential of
     // incorrect address calculations in the unlikely case we need such
     // extremely large offsets.
@@ -780,7 +780,8 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
       if (etp.getIntOrFloatBitWidth() < 32)
         vload = rewriter.create<ZeroExtendIOp>(
             loc, vload, vectorType(codegen, rewriter.getIntegerType(32)));
-      else if (etp.getIntOrFloatBitWidth() < 64)
+      else if (etp.getIntOrFloatBitWidth() < 64 &&
+               !codegen.options.enableSIMDIndex32)
         vload = rewriter.create<ZeroExtendIOp>(
             loc, vload, vectorType(codegen, rewriter.getIntegerType(64)));
     }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 29dd1f9bb9eae..9674123d13bd2 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -4,6 +4,8 @@
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC1
 // RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16" | \
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC2
+// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16 enable-simd-index32=true" | \
+// RUN:   FileCheck %s --check-prefix=CHECK-VEC3
 
 #DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
 
@@ -148,6 +150,27 @@ func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
+// CHECK-VEC3-LABEL: func @mul_s
+// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
+// CHECK-VEC3-DAG:   %[[c1:.*]] = constant 1 : index
+// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
+// CHECK-VEC3:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
+// CHECK-VEC3:       %[[a:.*]] = zexti %[[p]] : i32 to i64
+// CHECK-VEC3:       %[[q:.*]] = index_cast %[[a]] : i64 to index
+// CHECK-VEC3:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
+// CHECK-VEC3:       %[[b:.*]] = zexti %[[r]] : i32 to i64
+// CHECK-VEC3:       %[[s:.*]] = index_cast %[[b]] : i64 to index
+// CHECK-VEC3:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
+// CHECK-VEC3:         %[[sub:.*]] = subi %{{.*}}, %[[i]] : index
+// CHECK-VEC3:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC3:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
+// CHECK-VEC3:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC3:         vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC3:       }
+// CHECK-VEC3:       return
+//
 func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
   %0 = linalg.generic #trait_mul_s
     ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>)
@@ -310,6 +333,31 @@ func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
+// CHECK-VEC3-LABEL: func @mul_ds
+// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
+// CHECK-VEC3-DAG:   %[[c1:.*]] = constant 1 : index
+// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
+// CHECK-VEC3-DAG:   %[[c512:.*]] = constant 512 : index
+// CHECK-VEC3:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
+// CHECK-VEC3:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
+// CHECK-VEC3:         %[[a:.*]] = zexti %[[p]] : i32 to i64
+// CHECK-VEC3:         %[[q:.*]] = index_cast %[[a]] : i64 to index
+// CHECK-VEC3:         %[[a:.*]] = addi %[[i]], %[[c1]] : index
+// CHECK-VEC3:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
+// CHECK-VEC3:         %[[b:.*]] = zexti %[[r]] : i32 to i64
+// CHECK-VEC3:         %[[s:.*]] = index_cast %[[b]] : i64 to index
+// CHECK-VEC3:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
+// CHECK-VEC3:           %[[sub:.*]] = subi %[[s]], %[[j]] : index
+// CHECK-VEC3:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC3:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
+// CHECK-VEC3:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3:           %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC3:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC3:         }
+// CHECK-VEC3:       }
+// CHECK-VEC3:       return
+//
 func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
   %0 = linalg.generic #trait_mul_ds
     ins(%arga, %argb: tensor<512x1024xf32, #SparseMatrix>, tensor<512x1024xf32>)


        


More information about the Mlir-commits mailing list