[llvm-branch-commits] [mlir] 6728af1 - [mlir][vector] modified scatter/gather syntax, pass_thru mandatory
Aart Bik via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sat Jan 9 11:46:37 PST 2021
Author: Aart Bik
Date: 2021-01-09T11:41:37-08:00
New Revision: 6728af16cf987df3cf051f3a1f9c92ed2b8fbc2d
URL: https://github.com/llvm/llvm-project/commit/6728af16cf987df3cf051f3a1f9c92ed2b8fbc2d
DIFF: https://github.com/llvm/llvm-project/commit/6728af16cf987df3cf051f3a1f9c92ed2b8fbc2d.diff
LOG: [mlir][vector] modified scatter/gather syntax, pass_thru mandatory
This change makes the scatter/gather syntax more consistent with
the syntax of all the other memory operations in the Vector dialect
(order of types, use of [] for index, etc.). This will make the MLIR
code easier to read. In addition, the pass_thru parameter of the
gather has been made mandatory (there is very little benefit in
using the implicit "undefined" values).
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D94352
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir
mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
mlir/integration_test/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
mlir/integration_test/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Dialect/Vector/vector-mem-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 0aa4950e0a9e..7f57dcd77def 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1419,7 +1419,7 @@ def Vector_GatherOp :
Arguments<(ins AnyMemRef:$base,
VectorOfRankAndType<[1], [AnyInteger]>:$indices,
VectorOfRankAndType<[1], [I1]>:$mask,
- Variadic<VectorOfRank<[1]>>:$pass_thru)>,
+ VectorOfRank<[1]>:$pass_thru)>,
Results<(outs VectorOfRank<[1]>:$result)> {
let summary = "gathers elements from memory into a vector as defined by an index vector and mask";
@@ -1428,10 +1428,8 @@ def Vector_GatherOp :
The gather operation gathers elements from memory into a 1-D vector as
defined by a base and a 1-D index vector, but only if the corresponding
bit is set in a 1-D mask vector. Otherwise, the element is taken from a
- 1-D pass-through vector, if provided, or left undefined. Informally the
- semantics are:
+ 1-D pass-through vector. Informally the semantics are:
```
- if (!defined(pass_thru)) pass_thru = [undef, .., undef]
result[0] := mask[0] ? base[index[0]] : pass_thru[0]
result[1] := mask[1] ? base[index[1]] : pass_thru[1]
etc.
@@ -1447,8 +1445,8 @@ def Vector_GatherOp :
Example:
```mlir
- %g = vector.gather %base, %indices, %mask, %pass_thru
- : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+ %g = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
```
}];
let extraClassDeclaration = [{
@@ -1462,15 +1460,14 @@ def Vector_GatherOp :
return mask().getType().cast<VectorType>();
}
VectorType getPassThruVectorType() {
- return (llvm::size(pass_thru()) == 0)
- ? VectorType()
- : (*pass_thru().begin()).getType().cast<VectorType>();
+ return pass_thru().getType().cast<VectorType>();
}
VectorType getResultVectorType() {
return result().getType().cast<VectorType>();
}
}];
- let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
+ let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
+ "type($base) `,` type($indices) `,` type($mask) `,` type($pass_thru) `into` type($result)";
let hasCanonicalizer = 1;
}
@@ -1507,8 +1504,8 @@ def Vector_ScatterOp :
Example:
```mlir
- vector.scatter %base, %indices, %mask, %value
- : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
```
}];
let extraClassDeclaration = [{
@@ -1525,8 +1522,8 @@ def Vector_ScatterOp :
return value().getType().cast<VectorType>();
}
}];
- let assemblyFormat = "$base `,` $indices `,` $mask `,` $value attr-dict `:` "
- "type($indices) `,` type($mask) `,` type($value) `into` type($base)";
+ let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
+ "type($base) `,` type($indices) `,` type($mask) `,` type($value)";
let hasCanonicalizer = 1;
}
diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir
index 5ed8f3ee38f8..95df5aea06e4 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir
@@ -3,18 +3,10 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
-func @gather8(%base: memref<?xf32>,
- %indices: vector<8xi32>, %mask: vector<8xi1>) -> vector<8xf32> {
- %g = vector.gather %base, %indices, %mask
- : (memref<?xf32>, vector<8xi32>, vector<8xi1>) -> vector<8xf32>
- return %g : vector<8xf32>
-}
-
-func @gather_pass_thru8(%base: memref<?xf32>,
- %indices: vector<8xi32>, %mask: vector<8xi1>,
- %pass_thru: vector<8xf32>) -> vector<8xf32> {
- %g = vector.gather %base, %indices, %mask, %pass_thru
- : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> vector<8xf32>
+func @gather8(%base: memref<?xf32>, %indices: vector<8xi32>,
+ %mask: vector<8xi1>, %pass_thru: vector<8xf32>) -> vector<8xf32> {
+ %g = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
return %g : vector<8xf32>
}
@@ -63,31 +55,31 @@ func @entry() {
// Gather tests.
//
- %g1 = call @gather8(%A, %idx, %all)
- : (memref<?xf32>, vector<8xi32>, vector<8xi1>)
+ %g1 = call @gather8(%A, %idx, %all, %pass)
+ : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
-> (vector<8xf32>)
vector.print %g1 : vector<8xf32>
// CHECK: ( 0, 6, 1, 3, 5, 4, 9, 2 )
- %g2 = call @gather_pass_thru8(%A, %idx, %none, %pass)
+ %g2 = call @gather8(%A, %idx, %none, %pass)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
-> (vector<8xf32>)
vector.print %g2 : vector<8xf32>
// CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7 )
- %g3 = call @gather_pass_thru8(%A, %idx, %some, %pass)
+ %g3 = call @gather8(%A, %idx, %some, %pass)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
-> (vector<8xf32>)
vector.print %g3 : vector<8xf32>
// CHECK: ( 0, 6, 1, 3, -7, -7, -7, -7 )
- %g4 = call @gather_pass_thru8(%A, %idx, %more, %pass)
+ %g4 = call @gather8(%A, %idx, %more, %pass)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
-> (vector<8xf32>)
vector.print %g4 : vector<8xf32>
// CHECK: ( 0, 6, 1, 3, -7, -7, -7, 2 )
- %g5 = call @gather_pass_thru8(%A, %idx, %all, %pass)
+ %g5 = call @gather8(%A, %idx, %all, %pass)
: (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
-> (vector<8xf32>)
vector.print %g5 : vector<8xf32>
diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
index 54171e744605..0666cc852c2a 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
@@ -6,8 +6,8 @@
func @scatter8(%base: memref<?xf32>,
%indices: vector<8xi32>,
%mask: vector<8xi1>, %value: vector<8xf32>) {
- vector.scatter %base, %indices, %mask, %value
- : vector<8xi32>, vector<8xi1>, vector<8xf32> into memref<?xf32>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>
return
}
diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
index 088e68e0e507..7940e8c68b1a 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
@@ -60,11 +60,12 @@ func @spmv8x8(%AVAL: memref<8xvector<4xf32>>,
%cn = constant 8 : index
%f0 = constant 0.0 : f32
%mask = vector.constant_mask [4] : vector<4xi1>
+ %pass = vector.broadcast %f0 : f32 to vector<4xf32>
scf.for %i = %c0 to %cn step %c1 {
%aval = load %AVAL[%i] : memref<8xvector<4xf32>>
%aidx = load %AIDX[%i] : memref<8xvector<4xi32>>
- %0 = vector.gather %X, %aidx, %mask
- : (memref<?xf32>, vector<4xi32>, vector<4xi1>) -> vector<4xf32>
+ %0 = vector.gather %X[%aidx], %mask, %pass
+ : memref<?xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
%1 = vector.contract #dot_trait %aval, %0, %f0 : vector<4xf32>, vector<4xf32> into f32
store %1, %B[%i] : memref<?xf32>
}
diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir
index 5b02b5df4b73..31f288e0f6c5 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir
@@ -50,12 +50,15 @@ func @spmv8x8(%AVAL: memref<4xvector<8xf32>>,
%c0 = constant 0 : index
%c1 = constant 1 : index
%cn = constant 4 : index
+ %f0 = constant 0.0 : f32
%mask = vector.constant_mask [8] : vector<8xi1>
+ %pass = vector.broadcast %f0 : f32 to vector<8xf32>
%b = load %B[%c0] : memref<1xvector<8xf32>>
%b_out = scf.for %k = %c0 to %cn step %c1 iter_args(%b_iter = %b) -> (vector<8xf32>) {
%aval = load %AVAL[%k] : memref<4xvector<8xf32>>
%aidx = load %AIDX[%k] : memref<4xvector<8xi32>>
- %0 = vector.gather %X, %aidx, %mask : (memref<?xf32>, vector<8xi32>, vector<8xi1>) -> vector<8xf32>
+ %0 = vector.gather %X[%aidx], %mask, %pass
+ : memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
%b_new = vector.fma %aval, %0, %b_iter : vector<8xf32>
scf.yield %b_new : vector<8xf32>
}
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 318ca1e27c88..91eab5027962 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2446,11 +2446,8 @@ static LogicalResult verify(GatherOp op) {
return op.emitOpError("expected result dim to match indices dim");
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
return op.emitOpError("expected result dim to match mask dim");
- if (llvm::size(op.pass_thru()) != 0) {
- VectorType passVType = op.getPassThruVectorType();
- if (resVType != passVType)
- return op.emitOpError("expected pass_thru of same type as result type");
- }
+ if (resVType != op.getPassThruVectorType())
+ return op.emitOpError("expected pass_thru of same type as result type");
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 5c0c9651133d..ef4a85c1652c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1096,7 +1096,7 @@ func @masked_store_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<
// CHECK: llvm.return
func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
- %0 = vector.gather %arg0, %arg1, %arg2, %arg3 : (memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+ %0 = vector.gather %arg0[%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
return %0 : vector<3xf32>
}
@@ -1106,7 +1106,7 @@ func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>,
// CHECK: llvm.return %[[G]] : !llvm.vec<3 x f32>
func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
- vector.scatter %arg0, %arg1, %arg2, %arg3 : vector<3xi32>, vector<3xi1>, vector<3xf32> into memref<?xf32>
+ vector.scatter %arg0[%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
return
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 8cadafae1ec4..11100c4e615e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1238,65 +1238,83 @@ func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>,
// -----
-func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
+func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
// expected-error at +1 {{'vector.gather' op base and result element type should match}}
- %0 = vector.gather %base, %indices, %mask : (memref<?xf64>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
+ %0 = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
// -----
-func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
+func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
// expected-error at +1 {{'vector.gather' op result #0 must be of ranks 1, but got 'vector<2x16xf32>'}}
- %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<2x16xf32>
+ %0 = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<2x16xf32>
}
// -----
-func @gather_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>, %mask: vector<16xi1>) {
+func @gather_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
// expected-error at +1 {{'vector.gather' op expected result dim to match indices dim}}
- %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<17xi32>, vector<16xi1>) -> vector<16xf32>
+ %0 = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<?xf32>, vector<17xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
// -----
-func @gather_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<17xi1>) {
+func @gather_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
// expected-error at +1 {{'vector.gather' op expected result dim to match mask dim}}
- %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<16xi32>, vector<17xi1>) -> vector<16xf32>
+ %0 = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<?xf32>, vector<16xi32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
}
// -----
-func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf64>) {
+func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf64>) {
// expected-error at +1 {{'vector.gather' op expected pass_thru of same type as result type}}
- %0 = vector.gather %base, %indices, %mask, %pass_thru : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf64>) -> vector<16xf32>
+ %0 = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf64> into vector<16xf32>
}
// -----
-func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>) {
// expected-error at +1 {{'vector.scatter' op base and value element type should match}}
- vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf64>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32>
}
// -----
-func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<2x16xf32>) {
+func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<2x16xf32>) {
// expected-error at +1 {{'vector.scatter' op operand #3 must be of ranks 1, but got 'vector<2x16xf32>'}}
- vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<2x16xf32> into memref<?xf32>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
}
// -----
-func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>) {
// expected-error at +1 {{'vector.scatter' op expected value dim to match indices dim}}
- vector.scatter %base, %indices, %mask, %value : vector<17xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<?xf32>, vector<17xi32>, vector<16xi1>, vector<16xf32>
}
// -----
-func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
+func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<17xi1>, %value: vector<16xf32>) {
// expected-error at +1 {{'vector.scatter' op expected value dim to match mask dim}}
- vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref<?xf32>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<?xf32>, vector<16xi32>, vector<17xi1>, vector<16xf32>
}
// -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 60890e58aef5..7284cab523a7 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -461,21 +461,19 @@ func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthr
}
// CHECK-LABEL: @gather_and_scatter
-func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
- // CHECK: %[[X:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}} : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
- %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
- // CHECK: %[[Y:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}}, %[[X]] : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
- %1 = vector.gather %base, %indices, %mask, %0 : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
- // CHECK: vector.scatter %{{.*}}, %{{.*}}, %{{.*}}, %[[Y]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
- vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
+func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.gather %base[%indices], %mask, %pass_thru : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ // CHECK: vector.scatter %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+ vector.scatter %base[%indices], %mask, %0 : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
return
}
// CHECK-LABEL: @expand_and_compress
-func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
+func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = constant 0 : index
// CHECK: %[[X:.*]] = vector.expandload %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
- %0 = vector.expandload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK: vector.compressstore %{{.*}}[{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
vector.compressstore %base[%c0], %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
return
diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index f9d7903a343b..5c55cc5b6f0e 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -86,12 +86,12 @@ func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>) {
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
-// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]], %[[A1]], %[[M]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]][%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
// CHECK-NEXT: return %[[G]] : vector<16xf32>
func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%mask = vector.constant_mask [16] : vector<16xi1>
- %ld = vector.gather %base, %indices, %mask, %pass_thru
- : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+ %ld = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
@@ -102,8 +102,8 @@ func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
// CHECK-NEXT: return %[[A2]] : vector<16xf32>
func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%mask = vector.constant_mask [0] : vector<16xi1>
- %ld = vector.gather %base, %indices, %mask, %pass_thru
- : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+ %ld = vector.gather %base[%indices], %mask, %pass_thru
+ : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
@@ -112,12 +112,12 @@ func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
-// CHECK-NEXT: vector.scatter %[[A0]], %[[A1]], %[[M]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
+// CHECK-NEXT: vector.scatter %[[A0]][%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK-NEXT: return
func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
%mask = vector.constant_mask [16] : vector<16xi1>
- vector.scatter %base, %indices, %mask, %value
- : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
return
}
@@ -129,8 +129,8 @@ func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<1
func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
%0 = vector.type_cast %base : memref<16xf32> to memref<vector<16xf32>>
%mask = vector.constant_mask [0] : vector<16xi1>
- vector.scatter %base, %indices, %mask, %value
- : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
+ vector.scatter %base[%indices], %mask, %value
+ : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
return
}
More information about the llvm-branch-commits
mailing list