[Mlir-commits] [mlir] 92b0a9d - [mlir][sparse] remove restriction on vectorization of index type

Aart Bik llvmlistbot at llvm.org
Thu Apr 15 10:27:12 PDT 2021


Author: Aart Bik
Date: 2021-04-15T10:27:04-07:00
New Revision: 92b0a9d7d496bb382a24eb8340fec1c1cc7300e1

URL: https://github.com/llvm/llvm-project/commit/92b0a9d7d496bb382a24eb8340fec1c1cc7300e1
DIFF: https://github.com/llvm/llvm-project/commit/92b0a9d7d496bb382a24eb8340fec1c1cc7300e1.diff

LOG: [mlir][sparse] remove restriction on vectorization of index type

Rationale:
Now that vector<?xindex> is allowed, the restriction on vectorization
of index types in the sparse compiler can be removed. Also needs
generalization of scatter/gather index types.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D100522

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
    mlir/test/Dialect/Linalg/sparse_vector.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7b7d6accd6a07..4d45642b1d983 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1098,13 +1098,7 @@ struct SparsificationOptions {
                         SparseVectorizationStrategy v, unsigned vl,
                         SparseIntType pt, SparseIntType it, bool fo)
       : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl),
-        ptrType(pt), indType(it), fastOutput(fo) {
-    // TODO: remove restriction when vectors with index elements are supported
-    assert((v != SparseVectorizationStrategy::kAnyStorageInnerLoop ||
-            (ptrType != SparseIntType::kNative &&
-             indType != SparseIntType::kNative)) &&
-           "This combination requires support for vectors with index elements");
-  }
+        ptrType(pt), indType(it), fastOutput(fo) {}
   SparsificationOptions()
       : SparsificationOptions(SparseParallelizationStrategy::kNone,
                               SparseVectorizationStrategy::kNone, 1u,

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 0a1228599b601..b0ff956fe49de 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1684,7 +1684,7 @@ def Vector_GatherOp :
   Vector_Op<"gather">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-	       VectorOfRankAndType<[1], [AnyInteger]>:$index_vec,
+	       VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
                VectorOfRankAndType<[1], [I1]>:$mask,
                VectorOfRank<[1]>:$pass_thru)>,
     Results<(outs VectorOfRank<[1]>:$result)> {
@@ -1749,7 +1749,7 @@ def Vector_ScatterOp :
   Vector_Op<"scatter">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-	       VectorOfRankAndType<[1], [AnyInteger]>:$index_vec,
+	       VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
                VectorOfRankAndType<[1], [I1]>:$mask,
                VectorOfRank<[1]>:$valueToStore)> {
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index aa162bf83e61e..99024c32104ff 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -771,12 +771,14 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
     // extremely large offsets.
     Type etp = ptr.getType().cast<MemRefType>().getElementType();
     Value vload = genVectorLoad(codegen, rewriter, ptr, {s});
-    if (etp.getIntOrFloatBitWidth() < 32)
-      vload = rewriter.create<ZeroExtendIOp>(
-          loc, vload, vectorType(codegen, rewriter.getIntegerType(32)));
-    else if (etp.getIntOrFloatBitWidth() < 64)
-      vload = rewriter.create<ZeroExtendIOp>(
-          loc, vload, vectorType(codegen, rewriter.getIntegerType(64)));
+    if (!etp.isa<IndexType>()) {
+      if (etp.getIntOrFloatBitWidth() < 32)
+        vload = rewriter.create<ZeroExtendIOp>(
+            loc, vload, vectorType(codegen, rewriter.getIntegerType(32)));
+      else if (etp.getIntOrFloatBitWidth() < 64)
+        vload = rewriter.create<ZeroExtendIOp>(
+            loc, vload, vectorType(codegen, rewriter.getIntegerType(64)));
+    }
     return vload;
   }
   // For the scalar case, we simply zero extend narrower indices into 64-bit

diff  --git a/mlir/test/Dialect/Linalg/sparse_vector.mlir b/mlir/test/Dialect/Linalg/sparse_vector.mlir
index 87f1a406179d7..b22a4ebd7df26 100644
--- a/mlir/test/Dialect/Linalg/sparse_vector.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_vector.mlir
@@ -4,6 +4,8 @@
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC1
 // RUN: mlir-opt %s -test-sparsification="vectorization-strategy=2 ptr-type=2 ind-type=2 vl=16" | \
 // RUN:   FileCheck %s --check-prefix=CHECK-VEC2
+// RUN: mlir-opt %s -test-sparsification="vectorization-strategy=2 ptr-type=0 ind-type=0 vl=16" | \
+// RUN:   FileCheck %s --check-prefix=CHECK-VEC3
 
 #trait_scale_d = {
   indexing_maps = [
@@ -54,6 +56,18 @@
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
+// CHECK-VEC3-LABEL: func @scale_d
+// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
+// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
+// CHECK-VEC3-DAG:   %[[c1024:.*]] = constant 1024 : index
+// CHECK-VEC3:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
+// CHECK-VEC3:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC3:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
+// CHECK-VEC3:         %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32>
+// CHECK-VEC3:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC3:       }
+// CHECK-VEC3:       return
+//
 func @scale_d(%arga: tensor<1024xf32>, %scale: f32, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
   %0 = linalg.generic #trait_scale_d
     ins(%arga: tensor<1024xf32>)
@@ -143,6 +157,23 @@ func @scale_d(%arga: tensor<1024xf32>, %scale: f32, %argx: tensor<1024xf32>) ->
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
+// CHECK-VEC3-LABEL: func @mul_s
+// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
+// CHECK-VEC3-DAG:   %[[c1:.*]] = constant 1 : index
+// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
+// CHECK-VEC3:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xindex>
+// CHECK-VEC3:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xindex>
+// CHECK-VEC3:       scf.for %[[i:.*]] = %[[p]] to %[[r]] step %[[c16]] {
+// CHECK-VEC3:         %[[sub:.*]] = subi %[[r]], %[[i]] : index
+// CHECK-VEC3:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC3:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
+// CHECK-VEC3:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC3:         vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC3:       }
+// CHECK-VEC3:       return
+//
 func @mul_s(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
   %0 = linalg.generic #trait_mul_s
     ins(%arga, %argb: tensor<1024xf32>, tensor<1024xf32>)
@@ -177,6 +208,24 @@ func @mul_s(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tensor<1024
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
+// CHECK-VEC3-LABEL: func @mul_s_alt
+// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
+// CHECK-VEC3-DAG:   %[[c1:.*]] = constant 1 : index
+// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
+// CHECK-VEC3:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xindex>
+// CHECK-VEC3:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xindex>
+// CHECK-VEC3:       scf.for %[[i:.*]] = %[[p]] to %[[r]] step %[[c16]] {
+// CHECK-VEC3:         %[[sub:.*]] = subi %[[r]], %[[i]] : index
+// CHECK-VEC3:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC3:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
+// CHECK-VEC3:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC3:         vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC3:       }
+// CHECK-VEC3:       return
+//
+//
 !SparseTensor = type !llvm.ptr<i8>
 func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
   %arga = linalg.sparse_tensor %argA : !SparseTensor to tensor<1024xf32>
@@ -250,6 +299,21 @@ func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf
 // CHECK-VEC2:       %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32
 // CHECK-VEC2:       return
 //
+// CHECK-VEC3-LABEL: func @reduction_d
+// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
+// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
+// CHECK-VEC3-DAG:   %[[c1024:.*]] = constant 1024 : index
+// CHECK-VEC3-DAG:   %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
+// CHECK-VEC3:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) {
+// CHECK-VEC3:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC3:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC3:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC3:         %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
+// CHECK-VEC3:         scf.yield %[[a]] : vector<16xf32>
+// CHECK-VEC3:       }
+// CHECK-VEC3:       %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32
+// CHECK-VEC3:       return
+//
 func @reduction_d(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tensor<f32>) -> tensor<f32> {
   %0 = linalg.generic #trait_reduction_d
     ins(%arga, %argb: tensor<1024xf32>, tensor<1024xf32>)
@@ -383,6 +447,27 @@ func @reduction_17(%arga: tensor<17xf32>, %argb: tensor<17xf32>, %argx: tensor<f
 // CHECK-VEC2:       }
 // CHECK-VEC2:       return
 //
+// CHECK-VEC3-LABEL: func @mul_ds
+// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
+// CHECK-VEC3-DAG:   %[[c1:.*]] = constant 1 : index
+// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
+// CHECK-VEC3-DAG:   %[[c512:.*]] = constant 512 : index
+// CHECK-VEC3:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
+// CHECK-VEC3:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
+// CHECK-VEC3:         %[[a:.*]] = addi %[[i]], %[[c1]] : index
+// CHECK-VEC3:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xindex>
+// CHECK-VEC3:         scf.for %[[j:.*]] = %[[p]] to %[[r]] step %[[c16]] {
+// CHECK-VEC3:           %[[sub:.*]] = subi %[[r]], %[[j]] : index
+// CHECK-VEC3:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC3:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
+// CHECK-VEC3:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3:           %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC3:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC3:         }
+// CHECK-VEC3:       }
+// CHECK-VEC3:       return
+//
 func @mul_ds(%arga: tensor<512x1024xf32>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
   %0 = linalg.generic #trait_mul_ds
     ins(%arga, %argb: tensor<512x1024xf32>, tensor<512x1024xf32>)


        


More information about the Mlir-commits mailing list