[Mlir-commits] [mlir] 92b0a9d - [mlir][sparse] remove restriction on vectorization of index type
Aart Bik
llvmlistbot at llvm.org
Thu Apr 15 10:27:12 PDT 2021
Author: Aart Bik
Date: 2021-04-15T10:27:04-07:00
New Revision: 92b0a9d7d496bb382a24eb8340fec1c1cc7300e1
URL: https://github.com/llvm/llvm-project/commit/92b0a9d7d496bb382a24eb8340fec1c1cc7300e1
DIFF: https://github.com/llvm/llvm-project/commit/92b0a9d7d496bb382a24eb8340fec1c1cc7300e1.diff
LOG: [mlir][sparse] remove restriction on vectorization of index type
Rationale:
Now that vector<?xindex> is allowed, the restriction on vectorization
of index types in the sparse compiler can be removed. Also needs
generalization of scatter/gather index types.
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D100522
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
mlir/test/Dialect/Linalg/sparse_vector.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7b7d6accd6a07..4d45642b1d983 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1098,13 +1098,7 @@ struct SparsificationOptions {
SparseVectorizationStrategy v, unsigned vl,
SparseIntType pt, SparseIntType it, bool fo)
: parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl),
- ptrType(pt), indType(it), fastOutput(fo) {
- // TODO: remove restriction when vectors with index elements are supported
- assert((v != SparseVectorizationStrategy::kAnyStorageInnerLoop ||
- (ptrType != SparseIntType::kNative &&
- indType != SparseIntType::kNative)) &&
- "This combination requires support for vectors with index elements");
- }
+ ptrType(pt), indType(it), fastOutput(fo) {}
SparsificationOptions()
: SparsificationOptions(SparseParallelizationStrategy::kNone,
SparseVectorizationStrategy::kNone, 1u,
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 0a1228599b601..b0ff956fe49de 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1684,7 +1684,7 @@ def Vector_GatherOp :
Vector_Op<"gather">,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[1], [AnyInteger]>:$index_vec,
+ VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$pass_thru)>,
Results<(outs VectorOfRank<[1]>:$result)> {
@@ -1749,7 +1749,7 @@ def Vector_ScatterOp :
Vector_Op<"scatter">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[1], [AnyInteger]>:$index_vec,
+ VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$valueToStore)> {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index aa162bf83e61e..99024c32104ff 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -771,12 +771,14 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
// extremely large offsets.
Type etp = ptr.getType().cast<MemRefType>().getElementType();
Value vload = genVectorLoad(codegen, rewriter, ptr, {s});
- if (etp.getIntOrFloatBitWidth() < 32)
- vload = rewriter.create<ZeroExtendIOp>(
- loc, vload, vectorType(codegen, rewriter.getIntegerType(32)));
- else if (etp.getIntOrFloatBitWidth() < 64)
- vload = rewriter.create<ZeroExtendIOp>(
- loc, vload, vectorType(codegen, rewriter.getIntegerType(64)));
+ if (!etp.isa<IndexType>()) {
+ if (etp.getIntOrFloatBitWidth() < 32)
+ vload = rewriter.create<ZeroExtendIOp>(
+ loc, vload, vectorType(codegen, rewriter.getIntegerType(32)));
+ else if (etp.getIntOrFloatBitWidth() < 64)
+ vload = rewriter.create<ZeroExtendIOp>(
+ loc, vload, vectorType(codegen, rewriter.getIntegerType(64)));
+ }
return vload;
}
// For the scalar case, we simply zero extend narrower indices into 64-bit
diff --git a/mlir/test/Dialect/Linalg/sparse_vector.mlir b/mlir/test/Dialect/Linalg/sparse_vector.mlir
index 87f1a406179d7..b22a4ebd7df26 100644
--- a/mlir/test/Dialect/Linalg/sparse_vector.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_vector.mlir
@@ -4,6 +4,8 @@
// RUN: FileCheck %s --check-prefix=CHECK-VEC1
// RUN: mlir-opt %s -test-sparsification="vectorization-strategy=2 ptr-type=2 ind-type=2 vl=16" | \
// RUN: FileCheck %s --check-prefix=CHECK-VEC2
+// RUN: mlir-opt %s -test-sparsification="vectorization-strategy=2 ptr-type=0 ind-type=0 vl=16" | \
+// RUN: FileCheck %s --check-prefix=CHECK-VEC3
#trait_scale_d = {
indexing_maps = [
@@ -54,6 +56,18 @@
// CHECK-VEC2: }
// CHECK-VEC2: return
//
+// CHECK-VEC3-LABEL: func @scale_d
+// CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index
+// CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index
+// CHECK-VEC3-DAG: %[[c1024:.*]] = constant 1024 : index
+// CHECK-VEC3: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
+// CHECK-VEC3: %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC3: %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
+// CHECK-VEC3: %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32>
+// CHECK-VEC3: vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC3: }
+// CHECK-VEC3: return
+//
func @scale_d(%arga: tensor<1024xf32>, %scale: f32, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
%0 = linalg.generic #trait_scale_d
ins(%arga: tensor<1024xf32>)
@@ -143,6 +157,23 @@ func @scale_d(%arga: tensor<1024xf32>, %scale: f32, %argx: tensor<1024xf32>) ->
// CHECK-VEC2: }
// CHECK-VEC2: return
//
+// CHECK-VEC3-LABEL: func @mul_s
+// CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index
+// CHECK-VEC3-DAG: %[[c1:.*]] = constant 1 : index
+// CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index
+// CHECK-VEC3: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xindex>
+// CHECK-VEC3: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xindex>
+// CHECK-VEC3: scf.for %[[i:.*]] = %[[p]] to %[[r]] step %[[c16]] {
+// CHECK-VEC3: %[[sub:.*]] = subi %[[r]], %[[i]] : index
+// CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC3: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
+// CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC3: vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC3: }
+// CHECK-VEC3: return
+//
func @mul_s(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
%0 = linalg.generic #trait_mul_s
ins(%arga, %argb: tensor<1024xf32>, tensor<1024xf32>)
@@ -177,6 +208,24 @@ func @mul_s(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tensor<1024
// CHECK-VEC2: }
// CHECK-VEC2: return
//
+// CHECK-VEC3-LABEL: func @mul_s_alt
+// CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index
+// CHECK-VEC3-DAG: %[[c1:.*]] = constant 1 : index
+// CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index
+// CHECK-VEC3: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xindex>
+// CHECK-VEC3: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xindex>
+// CHECK-VEC3: scf.for %[[i:.*]] = %[[p]] to %[[r]] step %[[c16]] {
+// CHECK-VEC3: %[[sub:.*]] = subi %[[r]], %[[i]] : index
+// CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC3: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
+// CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC3: vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC3: }
+// CHECK-VEC3: return
+//
+//
!SparseTensor = type !llvm.ptr<i8>
func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
%arga = linalg.sparse_tensor %argA : !SparseTensor to tensor<1024xf32>
@@ -250,6 +299,21 @@ func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf
// CHECK-VEC2: %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32
// CHECK-VEC2: return
//
+// CHECK-VEC3-LABEL: func @reduction_d
+// CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index
+// CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index
+// CHECK-VEC3-DAG: %[[c1024:.*]] = constant 1024 : index
+// CHECK-VEC3-DAG: %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
+// CHECK-VEC3: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) {
+// CHECK-VEC3: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC3: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC3: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC3: %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
+// CHECK-VEC3: scf.yield %[[a]] : vector<16xf32>
+// CHECK-VEC3: }
+// CHECK-VEC3: %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32
+// CHECK-VEC3: return
+//
func @reduction_d(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tensor<f32>) -> tensor<f32> {
%0 = linalg.generic #trait_reduction_d
ins(%arga, %argb: tensor<1024xf32>, tensor<1024xf32>)
@@ -383,6 +447,27 @@ func @reduction_17(%arga: tensor<17xf32>, %argb: tensor<17xf32>, %argx: tensor<f
// CHECK-VEC2: }
// CHECK-VEC2: return
//
+// CHECK-VEC3-LABEL: func @mul_ds
+// CHECK-VEC3-DAG: %[[c0:.*]] = constant 0 : index
+// CHECK-VEC3-DAG: %[[c1:.*]] = constant 1 : index
+// CHECK-VEC3-DAG: %[[c16:.*]] = constant 16 : index
+// CHECK-VEC3-DAG: %[[c512:.*]] = constant 512 : index
+// CHECK-VEC3: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
+// CHECK-VEC3: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
+// CHECK-VEC3: %[[a:.*]] = addi %[[i]], %[[c1]] : index
+// CHECK-VEC3: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xindex>
+// CHECK-VEC3: scf.for %[[j:.*]] = %[[p]] to %[[r]] step %[[c16]] {
+// CHECK-VEC3: %[[sub:.*]] = subi %[[r]], %[[j]] : index
+// CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
+// CHECK-VEC3: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
+// CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK-VEC3: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
+// CHECK-VEC3: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
+// CHECK-VEC3: }
+// CHECK-VEC3: }
+// CHECK-VEC3: return
+//
func @mul_ds(%arga: tensor<512x1024xf32>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
%0 = linalg.generic #trait_mul_ds
ins(%arga, %argb: tensor<512x1024xf32>, tensor<512x1024xf32>)
More information about the Mlir-commits
mailing list