[Mlir-commits] [mlir] e8dcf5f - [mlir] [VectorOps] Add expand/compress operations to Vector dialect
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 4 12:00:56 PDT 2020
Author: aartbik
Date: 2020-08-04T12:00:42-07:00
New Revision: e8dcf5f87dc20b3f08005ac767ff934e36bf2a5b
URL: https://github.com/llvm/llvm-project/commit/e8dcf5f87dc20b3f08005ac767ff934e36bf2a5b
DIFF: https://github.com/llvm/llvm-project/commit/e8dcf5f87dc20b3f08005ac767ff934e36bf2a5b.diff
LOG: [mlir] [VectorOps] Add expand/compress operations to Vector dialect
Introduces the expand and compress operations to the Vector dialect
(important memory operations for sparse computations), together
with a first reference implementation that lowers to the LLVM IR
dialect to enable running on CPU (and other targets that support
the corresponding LLVM IR intrinsics).
Reviewed By: reidtatge
Differential Revision: https://reviews.llvm.org/D84888
Added:
mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir
mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Target/llvmir-intrinsics.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 4b1a6efe002f..768d8db121df 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1042,6 +1042,16 @@ def LLVM_masked_scatter
"type($value) `,` type($mask) `into` type($ptrs)";
}
+/// Create a call to Masked Expand Load intrinsic.
+def LLVM_masked_expandload
+ : LLVM_IntrOp<"masked.expandload", [0], [], [], 1>,
+ Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
+
+/// Create a call to Masked Compress Store intrinsic.
+def LLVM_masked_compressstore
+ : LLVM_IntrOp<"masked.compressstore", [], [0], [], 0>,
+ Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
+
//
// Atomic operations.
//
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index b49cc4a62a50..89a2b1226e1e 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1158,7 +1158,7 @@ def Vector_GatherOp :
Variadic<VectorOfRank<[1]>>:$pass_thru)>,
Results<(outs VectorOfRank<[1]>:$result)> {
- let summary = "gathers elements from memory into a vector as defined by an index vector";
+ let summary = "gathers elements from memory into a vector as defined by an index vector and mask";
let description = [{
The gather operation gathers elements from memory into a 1-D vector as
@@ -1186,7 +1186,6 @@ def Vector_GatherOp :
%g = vector.gather %base, %indices, %mask, %pass_thru
: (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
```
-
}];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
@@ -1217,7 +1216,7 @@ def Vector_ScatterOp :
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$value)> {
- let summary = "scatters elements from a vector into memory as defined by an index vector";
+ let summary = "scatters elements from a vector into memory as defined by an index vector and mask";
let description = [{
The scatter operation scatters elements from a 1-D vector into memory as
@@ -1265,6 +1264,108 @@ def Vector_ScatterOp :
"type($indices) `,` type($mask) `,` type($value) `into` type($base)";
}
+def Vector_ExpandLoadOp :
+ Vector_Op<"expandload">,
+ Arguments<(ins AnyMemRef:$base,
+ VectorOfRankAndType<[1], [I1]>:$mask,
+ VectorOfRank<[1]>:$pass_thru)>,
+ Results<(outs VectorOfRank<[1]>:$result)> {
+
+ let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
+
+ let description = [{
+ The expand load reads elements from memory into a 1-D vector as defined
+ by a base and a 1-D mask vector. When the mask is set, the next element
+ is read from memory. Otherwise, the corresponding element is taken from
+ a 1-D pass-through vector. Informally the semantics are:
+ ```
+ index = base
+ result[0] := mask[0] ? MEM[index++] : pass_thru[0]
+ result[1] := mask[1] ? MEM[index++] : pass_thru[1]
+ etc.
+ ```
+ Note that the index increment is done conditionally.
+
+ The expand load can be used directly where applicable, or can be used
+ during progressively lowering to bring other memory operations closer to
+ hardware ISA support for an expand. The semantics of the operation closely
+ correspond to those of the `llvm.masked.expandload`
+ [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
+
+ Example:
+
+ ```mlir
+ %0 = vector.expandload %base, %mask, %pass_thru
+ : memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+ ```
+ }];
+ let extraClassDeclaration = [{
+ MemRefType getMemRefType() {
+ return base().getType().cast<MemRefType>();
+ }
+ VectorType getMaskVectorType() {
+ return mask().getType().cast<VectorType>();
+ }
+ VectorType getPassThruVectorType() {
+ return pass_thru().getType().cast<VectorType>();
+ }
+ VectorType getResultVectorType() {
+ return result().getType().cast<VectorType>();
+ }
+ }];
+ let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
+ "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
+}
+
+def Vector_CompressStoreOp :
+ Vector_Op<"compressstore">,
+ Arguments<(ins AnyMemRef:$base,
+ VectorOfRankAndType<[1], [I1]>:$mask,
+ VectorOfRank<[1]>:$value)> {
+
+ let summary = "writes elements selectively from a vector as defined by a mask";
+
+ let description = [{
+ The compress store operation writes elements from a 1-D vector into memory
+ as defined by a base and a 1-D mask vector. When the mask is set, the
+ corresponding element from the vector is written next to memory. Otherwise,
+ no action is taken for the element. Informally the semantics are:
+ ```
+ index = base
+ if (mask[0]) MEM[index++] = value[0]
+ if (mask[1]) MEM[index++] = value[1]
+ etc.
+ ```
+ Note that the index increment is done conditionally.
+
+ The compress store can be used directly where applicable, or can be used
+ during progressively lowering to bring other memory operations closer to
+ hardware ISA support for a compress. The semantics of the operation closely
+ correspond to those of the `llvm.masked.compressstore`
+ [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
+
+ Example:
+
+ ```mlir
+ vector.compressstore %base, %mask, %value
+ : memref<?xf32>, vector<8xi1>, vector<8xf32>
+ ```
+ }];
+ let extraClassDeclaration = [{
+ MemRefType getMemRefType() {
+ return base().getType().cast<MemRefType>();
+ }
+ VectorType getMaskVectorType() {
+ return mask().getType().cast<VectorType>();
+ }
+ VectorType getValueVectorType() {
+ return value().getType().cast<VectorType>();
+ }
+ }];
+ let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
+ "type($base) `,` type($mask) `,` type($value)";
+}
+
def Vector_ShapeCastOp :
Vector_Op<"shape_cast", [NoSideEffect]>,
Arguments<(ins AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$source)>,
diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir
new file mode 100644
index 000000000000..6310d6ee8790
--- /dev/null
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir
@@ -0,0 +1,90 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @compress16(%base: memref<?xf32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>) {
+ vector.compressstore %base, %mask, %value
+ : memref<?xf32>, vector<16xi1>, vector<16xf32>
+ return
+}
+
+func @printmem16(%A: memref<?xf32>) {
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %c16 = constant 16: index
+ %z = constant 0.0: f32
+ %m = vector.broadcast %z : f32 to vector<16xf32>
+ %mem = scf.for %i = %c0 to %c16 step %c1
+ iter_args(%m_iter = %m) -> (vector<16xf32>) {
+ %c = load %A[%i] : memref<?xf32>
+ %i32 = index_cast %i : index to i32
+ %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32>
+ scf.yield %m_new : vector<16xf32>
+ }
+ vector.print %mem : vector<16xf32>
+ return
+}
+
+func @entry() {
+ // Set up memory.
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %c16 = constant 16: index
+ %A = alloc(%c16) : memref<?xf32>
+ %z = constant 0.0: f32
+ %v = vector.broadcast %z : f32 to vector<16xf32>
+ %value = scf.for %i = %c0 to %c16 step %c1
+ iter_args(%v_iter = %v) -> (vector<16xf32>) {
+ store %z, %A[%i] : memref<?xf32>
+ %i32 = index_cast %i : index to i32
+ %fi = sitofp %i32 : i32 to f32
+ %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32>
+ scf.yield %v_new : vector<16xf32>
+ }
+
+ // Set up masks.
+ %f = constant 0: i1
+ %t = constant 1: i1
+ %none = vector.constant_mask [0] : vector<16xi1>
+ %all = vector.constant_mask [16] : vector<16xi1>
+ %some1 = vector.constant_mask [4] : vector<16xi1>
+ %0 = vector.insert %f, %some1[0] : i1 into vector<16xi1>
+ %1 = vector.insert %t, %0[7] : i1 into vector<16xi1>
+ %2 = vector.insert %t, %1[11] : i1 into vector<16xi1>
+ %3 = vector.insert %t, %2[13] : i1 into vector<16xi1>
+ %some2 = vector.insert %t, %3[15] : i1 into vector<16xi1>
+ %some3 = vector.insert %f, %some2[2] : i1 into vector<16xi1>
+
+ //
+ // Expanding load tests.
+ //
+
+ call @compress16(%A, %none, %value)
+ : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+ call @printmem16(%A) : (memref<?xf32>) -> ()
+ // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+
+ call @compress16(%A, %all, %value)
+ : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+ call @printmem16(%A) : (memref<?xf32>) -> ()
+ // CHECK-NEXT: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
+
+ call @compress16(%A, %some3, %value)
+ : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+ call @printmem16(%A) : (memref<?xf32>) -> ()
+ // CHECK-NEXT: ( 1, 3, 7, 11, 13, 15, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
+
+ call @compress16(%A, %some2, %value)
+ : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+ call @printmem16(%A) : (memref<?xf32>) -> ()
+ // CHECK-NEXT: ( 1, 2, 3, 7, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
+
+ call @compress16(%A, %some1, %value)
+ : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+ call @printmem16(%A) : (memref<?xf32>) -> ()
+ // CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
+
+ return
+}
diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir
new file mode 100644
index 000000000000..74118fc1125b
--- /dev/null
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir
@@ -0,0 +1,82 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @expand16(%base: memref<?xf32>,
+ %mask: vector<16xi1>,
+ %pass_thru: vector<16xf32>) -> vector<16xf32> {
+ %e = vector.expandload %base, %mask, %pass_thru
+ : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ return %e : vector<16xf32>
+}
+
+func @entry() {
+ // Set up memory.
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %c16 = constant 16: index
+ %A = alloc(%c16) : memref<?xf32>
+ scf.for %i = %c0 to %c16 step %c1 {
+ %i32 = index_cast %i : index to i32
+ %fi = sitofp %i32 : i32 to f32
+ store %fi, %A[%i] : memref<?xf32>
+ }
+
+ // Set up pass thru vector.
+ %u = constant -7.0: f32
+ %v = constant 7.7: f32
+ %pass = vector.broadcast %u : f32 to vector<16xf32>
+
+ // Set up masks.
+ %f = constant 0: i1
+ %t = constant 1: i1
+ %none = vector.constant_mask [0] : vector<16xi1>
+ %all = vector.constant_mask [16] : vector<16xi1>
+ %some1 = vector.constant_mask [4] : vector<16xi1>
+ %0 = vector.insert %f, %some1[0] : i1 into vector<16xi1>
+ %1 = vector.insert %t, %0[7] : i1 into vector<16xi1>
+ %2 = vector.insert %t, %1[11] : i1 into vector<16xi1>
+ %3 = vector.insert %t, %2[13] : i1 into vector<16xi1>
+ %some2 = vector.insert %t, %3[15] : i1 into vector<16xi1>
+ %some3 = vector.insert %f, %some2[2] : i1 into vector<16xi1>
+
+ //
+ // Expanding load tests.
+ //
+
+ %e1 = call @expand16(%A, %none, %pass)
+ : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+ vector.print %e1 : vector<16xf32>
+ // CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
+
+ %e2 = call @expand16(%A, %all, %pass)
+ : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+ vector.print %e2 : vector<16xf32>
+ // CHECK-NEXT: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
+
+ %e3 = call @expand16(%A, %some1, %pass)
+ : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+ vector.print %e3 : vector<16xf32>
+ // CHECK-NEXT: ( 0, 1, 2, 3, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
+
+ %e4 = call @expand16(%A, %some2, %pass)
+ : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+ vector.print %e4 : vector<16xf32>
+ // CHECK-NEXT: ( -7, 0, 1, 2, -7, -7, -7, 3, -7, -7, -7, 4, -7, 5, -7, 6 )
+
+ %e5 = call @expand16(%A, %some3, %pass)
+ : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+ vector.print %e5 : vector<16xf32>
+ // CHECK-NEXT: ( -7, 0, -7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, -7, 5 )
+
+ %4 = vector.insert %v, %pass[1] : f32 into vector<16xf32>
+ %5 = vector.insert %v, %4[2] : f32 into vector<16xf32>
+ %alt_pass = vector.insert %v, %5[14] : f32 into vector<16xf32>
+ %e6 = call @expand16(%A, %some3, %alt_pass)
+ : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+ vector.print %e6 : vector<16xf32>
+ // CHECK-NEXT: ( -7, 0, 7.7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, 7.7, 5 )
+
+ return
+}
diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
index 6dd0cf169552..54171e744605 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
@@ -11,34 +11,20 @@ func @scatter8(%base: memref<?xf32>,
return
}
-func @printmem(%A: memref<?xf32>) {
- %f = constant 0.0: f32
- %0 = vector.broadcast %f : f32 to vector<8xf32>
- %1 = constant 0: index
- %2 = load %A[%1] : memref<?xf32>
- %3 = vector.insert %2, %0[0] : f32 into vector<8xf32>
- %4 = constant 1: index
- %5 = load %A[%4] : memref<?xf32>
- %6 = vector.insert %5, %3[1] : f32 into vector<8xf32>
- %7 = constant 2: index
- %8 = load %A[%7] : memref<?xf32>
- %9 = vector.insert %8, %6[2] : f32 into vector<8xf32>
- %10 = constant 3: index
- %11 = load %A[%10] : memref<?xf32>
- %12 = vector.insert %11, %9[3] : f32 into vector<8xf32>
- %13 = constant 4: index
- %14 = load %A[%13] : memref<?xf32>
- %15 = vector.insert %14, %12[4] : f32 into vector<8xf32>
- %16 = constant 5: index
- %17 = load %A[%16] : memref<?xf32>
- %18 = vector.insert %17, %15[5] : f32 into vector<8xf32>
- %19 = constant 6: index
- %20 = load %A[%19] : memref<?xf32>
- %21 = vector.insert %20, %18[6] : f32 into vector<8xf32>
- %22 = constant 7: index
- %23 = load %A[%22] : memref<?xf32>
- %24 = vector.insert %23, %21[7] : f32 into vector<8xf32>
- vector.print %24 : vector<8xf32>
+func @printmem8(%A: memref<?xf32>) {
+ %c0 = constant 0: index
+ %c1 = constant 1: index
+ %c8 = constant 8: index
+ %z = constant 0.0: f32
+ %m = vector.broadcast %z : f32 to vector<8xf32>
+ %mem = scf.for %i = %c0 to %c8 step %c1
+ iter_args(%m_iter = %m) -> (vector<8xf32>) {
+ %c = load %A[%i] : memref<?xf32>
+ %i32 = index_cast %i : index to i32
+ %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<8xf32>
+ scf.yield %m_new : vector<8xf32>
+ }
+ vector.print %mem : vector<8xf32>
return
}
@@ -104,31 +90,27 @@ func @entry() {
vector.print %idx : vector<8xi32>
// CHECK: ( 7, 0, 1, 6, 2, 4, 5, 3 )
- call @printmem(%A) : (memref<?xf32>) -> ()
+ call @printmem8(%A) : (memref<?xf32>) -> ()
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
call @scatter8(%A, %idx, %none, %val)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
-
- call @printmem(%A) : (memref<?xf32>) -> ()
+ call @printmem8(%A) : (memref<?xf32>) -> ()
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
call @scatter8(%A, %idx, %some, %val)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
-
- call @printmem(%A) : (memref<?xf32>) -> ()
+ call @printmem8(%A) : (memref<?xf32>) -> ()
// CHECK: ( 1, 2, 2, 3, 4, 5, 3, 0 )
call @scatter8(%A, %idx, %more, %val)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
-
- call @printmem(%A) : (memref<?xf32>) -> ()
+ call @printmem8(%A) : (memref<?xf32>) -> ()
// CHECK: ( 1, 2, 2, 7, 4, 5, 3, 0 )
call @scatter8(%A, %idx, %all, %val)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
-
- call @printmem(%A) : (memref<?xf32>) -> ()
+ call @printmem8(%A) : (memref<?xf32>) -> ()
// CHECK: ( 1, 2, 4, 7, 5, 6, 3, 0 )
return
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 3dbfaf88a443..23373f5c7edf 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -134,11 +134,9 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
return success();
}
-// Helper that returns vector of pointers given a base and an index vector.
-LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
- LLVMTypeConverter &typeConverter, Location loc,
- Value memref, Value indices, MemRefType memRefType,
- VectorType vType, Type iType, Value &ptrs) {
+// Helper that returns the base address of a memref.
+LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
+ Value memref, MemRefType memRefType, Value &base) {
// Inspect stride and offset structure.
//
// TODO: flat memory only for now, generalize
@@ -149,13 +147,31 @@ LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
offset != 0 || memRefType.getMemorySpace() != 0)
return failure();
+ base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
+ return success();
+}
+
+// Helper that returns a pointer given a memref base.
+LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
+ Value memref, MemRefType memRefType, Value &ptr) {
+ Value base;
+ if (failed(getBase(rewriter, loc, memref, memRefType, base)))
+ return failure();
+ auto pType = MemRefDescriptor(memref).getElementType();
+ ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
+ return success();
+}
- // Create a vector of pointers from base and indices.
- MemRefDescriptor memRefDescriptor(memref);
- Value base = memRefDescriptor.alignedPtr(rewriter, loc);
- int64_t size = vType.getDimSize(0);
- auto pType = memRefDescriptor.getElementType();
- auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size);
+// Helper that returns vector of pointers given a memref base and an index
+// vector.
+LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
+ Value memref, Value indices, MemRefType memRefType,
+ VectorType vType, Type iType, Value &ptrs) {
+ Value base;
+ if (failed(getBase(rewriter, loc, memref, memRefType, base)))
+ return failure();
+ auto pType = MemRefDescriptor(memref).getElementType();
+ auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
return success();
}
@@ -305,9 +321,8 @@ class VectorGatherOpConversion : public ConvertToLLVMPattern {
VectorType vType = gather.getResultVectorType();
Type iType = gather.getIndicesVectorType().getElementType();
Value ptrs;
- if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
- adaptor.indices(), gather.getMemRefType(), vType,
- iType, ptrs)))
+ if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
+ gather.getMemRefType(), vType, iType, ptrs)))
return failure();
// Replace with the gather intrinsic.
@@ -344,9 +359,8 @@ class VectorScatterOpConversion : public ConvertToLLVMPattern {
VectorType vType = scatter.getValueVectorType();
Type iType = scatter.getIndicesVectorType().getElementType();
Value ptrs;
- if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
- adaptor.indices(), scatter.getMemRefType(), vType,
- iType, ptrs)))
+ if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
+ scatter.getMemRefType(), vType, iType, ptrs)))
return failure();
// Replace with the scatter intrinsic.
@@ -357,6 +371,60 @@ class VectorScatterOpConversion : public ConvertToLLVMPattern {
}
};
+/// Conversion pattern for a vector.expandload.
+class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
+public:
+ explicit VectorExpandLoadOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
+ typeConverter) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto expand = cast<vector::ExpandLoadOp>(op);
+ auto adaptor = vector::ExpandLoadOpAdaptor(operands);
+
+ Value ptr;
+ if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
+ ptr)))
+ return failure();
+
+ auto vType = expand.getResultVectorType();
+ rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
+ op, typeConverter.convertType(vType), ptr, adaptor.mask(),
+ adaptor.pass_thru());
+ return success();
+ }
+};
+
+/// Conversion pattern for a vector.compressstore.
+class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
+public:
+ explicit VectorCompressStoreOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
+ context, typeConverter) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto compress = cast<vector::CompressStoreOp>(op);
+ auto adaptor = vector::CompressStoreOpAdaptor(operands);
+
+ Value ptr;
+ if (failed(getBasePtr(rewriter, loc, adaptor.base(),
+ compress.getMemRefType(), ptr)))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
+ op, adaptor.value(), ptr, adaptor.mask());
+ return success();
+ }
+};
+
/// Conversion pattern for all vector reductions.
class VectorReductionOpConversion : public ConvertToLLVMPattern {
public:
@@ -1274,7 +1342,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorTransferConversion<TransferWriteOp>,
VectorTypeCastOpConversion,
VectorGatherOpConversion,
- VectorScatterOpConversion>(ctx, converter);
+ VectorScatterOpConversion,
+ VectorExpandLoadOpConversion,
+ VectorCompressStoreOpConversion>(ctx, converter);
// clang-format on
}
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index c788d4ccb4a0..9e64ff9af80a 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1898,6 +1898,41 @@ static LogicalResult verify(ScatterOp op) {
return success();
}
+//===----------------------------------------------------------------------===//
+// ExpandLoadOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ExpandLoadOp op) {
+ VectorType maskVType = op.getMaskVectorType();
+ VectorType passVType = op.getPassThruVectorType();
+ VectorType resVType = op.getResultVectorType();
+
+ if (resVType.getElementType() != op.getMemRefType().getElementType())
+ return op.emitOpError("base and result element type should match");
+
+ if (resVType.getDimSize(0) != maskVType.getDimSize(0))
+ return op.emitOpError("expected result dim to match mask dim");
+ if (resVType != passVType)
+ return op.emitOpError("expected pass_thru of same type as result type");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// CompressStoreOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(CompressStoreOp op) {
+ VectorType maskVType = op.getMaskVectorType();
+ VectorType valueVType = op.getValueVectorType();
+
+ if (valueVType.getElementType() != op.getMemRefType().getElementType())
+ return op.emitOpError("base and value element type should match");
+
+ if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
+ return op.emitOpError("expected value dim to match mask dim");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ShapeCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 2e5aae886c38..be70c08bc948 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -989,3 +989,23 @@ func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>
// CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm.ptr<float>, !llvm.vec<3 x i32>) -> !llvm.vec<3 x ptr<float>>
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<3 x float>, !llvm.vec<3 x i1> into !llvm.vec<3 x ptr<float>>
// CHECK: llvm.return
+
+func @expand_load_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) -> vector<11xf32> {
+ %0 = vector.expandload %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32> into vector<11xf32>
+ return %0 : vector<11xf32>
+}
+
+// CHECK-LABEL: func @expand_load_op
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<float>) -> !llvm.ptr<float>
+// CHECK: %[[E:.*]] = "llvm.intr.masked.expandload"(%[[P]], %{{.*}}, %{{.*}}) : (!llvm.ptr<float>, !llvm.vec<11 x i1>, !llvm.vec<11 x float>) -> !llvm.vec<11 x float>
+// CHECK: llvm.return %[[E]] : !llvm.vec<11 x float>
+
+func @compress_store_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) {
+ vector.compressstore %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32>
+ return
+}
+
+// CHECK-LABEL: func @compress_store_op
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<float>) -> !llvm.ptr<float>
+// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %[[P]], %{{.*}}) : (!llvm.vec<11 x float>, !llvm.ptr<float>, !llvm.vec<11 x i1>) -> ()
+// CHECK: llvm.return
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ea354f51645a..651fe27cd36c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1240,3 +1240,38 @@ func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
// expected-error at +1 {{'vector.scatter' op expected value dim to match mask dim}}
vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref<?xf32>
}
+
+// -----
+
+func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ // expected-error at +1 {{'vector.expandload' op base and result element type should match}}
+ %0 = vector.expandload %base, %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
+ // expected-error at +1 {{'vector.expandload' op expected result dim to match mask dim}}
+ %0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+func @expand_pass_thru_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<17xf32>) {
+ // expected-error at +1 {{'vector.expandload' op expected pass_thru of same type as result type}}
+ %0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<17xf32> into vector<16xf32>
+}
+
+// -----
+
+func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+ // expected-error at +1 {{'vector.compressstore' op base and value element type should match}}
+ vector.compressstore %base, %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
+func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
+ // expected-error at +1 {{'vector.compressstore' op expected value dim to match mask dim}}
+ vector.compressstore %base, %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
+}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 0bf4ed8f84c7..d4d1abe8e646 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -379,3 +379,12 @@ func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask:
vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
return
}
+
+// CHECK-LABEL: @expand_and_compress
+func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
+ // CHECK: %[[X:.*]] = vector.expandload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.expandload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ // CHECK: vector.compressstore %{{.*}}, %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
+ vector.compressstore %base, %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
+ return
+}
diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir
index fc286599ee95..6bf9b9768dd3 100644
--- a/mlir/test/Target/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/llvmir-intrinsics.mlir
@@ -237,8 +237,8 @@ llvm.func @matrix_intrinsics(%A: !llvm.vec<64 x float>, %B: !llvm.vec<48 x float
llvm.return
}
-// CHECK-LABEL: @masked_intrinsics
-llvm.func @masked_intrinsics(%A: !llvm.ptr<vec<7 x float>>, %mask: !llvm.vec<7 x i1>) {
+// CHECK-LABEL: @masked_load_store_intrinsics
+llvm.func @masked_load_store_intrinsics(%A: !llvm.ptr<vec<7 x float>>, %mask: !llvm.vec<7 x i1>) {
// CHECK: call <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>* %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> undef)
%a = llvm.intr.masked.load %A, %mask { alignment = 1: i32} :
(!llvm.ptr<vec<7 x float>>, !llvm.vec<7 x i1>) -> !llvm.vec<7 x float>
@@ -265,6 +265,17 @@ llvm.func @masked_gather_scatter_intrinsics(%M: !llvm.vec<7 x ptr<float>>, %mask
llvm.return
}
+// CHECK-LABEL: @masked_expand_compress_intrinsics
+llvm.func @masked_expand_compress_intrinsics(%ptr: !llvm.ptr<float>, %mask: !llvm.vec<7 x i1>, %passthru: !llvm.vec<7 x float>) {
+ // CHECK: call <7 x float> @llvm.masked.expandload.v7f32(float* %{{.*}}, <7 x i1> %{{.*}}, <7 x float> %{{.*}})
+ %0 = "llvm.intr.masked.expandload"(%ptr, %mask, %passthru)
+ : (!llvm.ptr<float>, !llvm.vec<7 x i1>, !llvm.vec<7 x float>) -> (!llvm.vec<7 x float>)
+ // CHECK: call void @llvm.masked.compressstore.v7f32(<7 x float> %{{.*}}, float* %{{.*}}, <7 x i1> %{{.*}})
+ "llvm.intr.masked.compressstore"(%0, %ptr, %mask)
+ : (!llvm.vec<7 x float>, !llvm.ptr<float>, !llvm.vec<7 x i1>) -> ()
+ llvm.return
+}
+
// CHECK-LABEL: @memcpy_test
llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm.ptr<i8>, %arg3: !llvm.ptr<i8>) {
// CHECK: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %{{.*}}, i8* %{{.*}}, i32 %{{.*}}, i1 %{{.*}})
More information about the Mlir-commits
mailing list