[Mlir-commits] [mlir] 1485fd2 - [mlir] [VectorOps] Improve scatter/gather CPU performance
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 22 23:47:50 PDT 2020
Author: aartbik
Date: 2020-07-22T23:47:36-07:00
New Revision: 1485fd295b2ab7fde3bf8d3ce33912c8a5fe1105
URL: https://github.com/llvm/llvm-project/commit/1485fd295b2ab7fde3bf8d3ce33912c8a5fe1105
DIFF: https://github.com/llvm/llvm-project/commit/1485fd295b2ab7fde3bf8d3ce33912c8a5fe1105.diff
LOG: [mlir] [VectorOps] Improve scatter/gather CPU performance
Replaced the linearized address with the proper LLVM way of
defining vector of base + indices in SIMD style. This yields
much better code. Some prototype results with microbencmarking
sparse matrix x vector with 50% sparsity (about 2-3x faster):
LINEARIZED IMPROVED
GFLOPS sdot saxpy sdot saxpy
16x16 1.6 1.4 4.4 2.1
32x32 1.7 1.6 5.8 5.9
64x64 1.7 1.7 6.4 6.4
128x128 1.7 1.7 5.9 5.9
256x256 1.6 1.6 6.1 6.0
512x512 1.4 1.4 4.9 4.7
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D84368
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index fd3d190990d4..b49cc4a62a50 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1244,7 +1244,7 @@ def Vector_ScatterOp :
```mlir
vector.scatter %base, %indices, %mask, %value
- : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?f32>
+ : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
```
}];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a877bd12c2e1..5dbc8394b03a 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -147,28 +147,13 @@ LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
offset != 0 || memRefType.getMemorySpace() != 0)
return failure();
- // Base pointer.
+ // Create a vector of pointers from base and indices.
MemRefDescriptor memRefDescriptor(memref);
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
-
- // Create a vector of pointers from base and indices.
- //
- // TODO: this step serializes the address computations unfortunately,
- // ideally we would like to add splat(base) + index_vector
- // in SIMD form, but this does not match well with current
- // constraints of the standard and vector dialect....
- //
int64_t size = vType.getDimSize(0);
auto pType = memRefDescriptor.getElementType();
auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size);
- auto idxType = typeConverter.convertType(iType);
- ptrs = rewriter.create<LLVM::UndefOp>(loc, ptrsType);
- for (int64_t i = 0; i < size; i++) {
- Value off =
- extractOne(rewriter, typeConverter, loc, indices, idxType, 1, i);
- Value ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base, off);
- ptrs = insertOne(rewriter, typeConverter, loc, ptrs, ptr, ptrsType, 1, i);
- }
+ ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 69d3aeca3d95..c5259a17fcef 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -976,7 +976,8 @@ func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>,
}
// CHECK-LABEL: func @gather_op
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm<"<3 x float*>">, !llvm<"<3 x i1>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
+// CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm<"float*">, !llvm<"<3 x i32>">) -> !llvm<"<3 x float*>">
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm<"<3 x float*>">, !llvm<"<3 x i1>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
// CHECK: llvm.return %[[G]] : !llvm<"<3 x float>">
func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
@@ -985,5 +986,6 @@ func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>
}
// CHECK-LABEL: func @scatter_op
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : !llvm<"<3 x float>">, !llvm<"<3 x i1>"> into !llvm<"<3 x float*>">
+// CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm<"float*">, !llvm<"<3 x i32>">) -> !llvm<"<3 x float*>">
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm<"<3 x float>">, !llvm<"<3 x i1>"> into !llvm<"<3 x float*>">
// CHECK: llvm.return
More information about the Mlir-commits
mailing list