[llvm-branch-commits] [mlir] 6728af1 - [mlir][vector] modified scatter/gather syntax, pass_thru mandatory

Aart Bik via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Jan 9 11:46:37 PST 2021


Author: Aart Bik
Date: 2021-01-09T11:41:37-08:00
New Revision: 6728af16cf987df3cf051f3a1f9c92ed2b8fbc2d

URL: https://github.com/llvm/llvm-project/commit/6728af16cf987df3cf051f3a1f9c92ed2b8fbc2d
DIFF: https://github.com/llvm/llvm-project/commit/6728af16cf987df3cf051f3a1f9c92ed2b8fbc2d.diff

LOG: [mlir][vector] modified scatter/gather syntax, pass_thru mandatory

This change makes the scatter/gather syntax more consistent with
the syntax of all the other memory operations in the Vector dialect
(order of types, use of [] for index, etc.). This will make the MLIR
code easier to read. In addition, the pass_thru parameter of the
gather has been made mandatory (there is very little benefit in
using the implicit "undefined" values).

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D94352

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir
    mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
    mlir/integration_test/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
    mlir/integration_test/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Dialect/Vector/vector-mem-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 0aa4950e0a9e..7f57dcd77def 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1419,7 +1419,7 @@ def Vector_GatherOp :
     Arguments<(ins AnyMemRef:$base,
                VectorOfRankAndType<[1], [AnyInteger]>:$indices,
                VectorOfRankAndType<[1], [I1]>:$mask,
-               Variadic<VectorOfRank<[1]>>:$pass_thru)>,
+               VectorOfRank<[1]>:$pass_thru)>,
     Results<(outs VectorOfRank<[1]>:$result)> {
 
   let summary = "gathers elements from memory into a vector as defined by an index vector and mask";
@@ -1428,10 +1428,8 @@ def Vector_GatherOp :
     The gather operation gathers elements from memory into a 1-D vector as
     defined by a base and a 1-D index vector, but only if the corresponding
     bit is set in a 1-D mask vector. Otherwise, the element is taken from a
-    1-D pass-through vector, if provided, or left undefined. Informally the
-    semantics are:
+    1-D pass-through vector. Informally the semantics are:
     ```
-    if (!defined(pass_thru)) pass_thru = [undef, .., undef]
     result[0] := mask[0] ? base[index[0]] : pass_thru[0]
     result[1] := mask[1] ? base[index[1]] : pass_thru[1]
     etc.
@@ -1447,8 +1445,8 @@ def Vector_GatherOp :
     Example:
 
     ```mlir
-    %g = vector.gather %base, %indices, %mask, %pass_thru
-        : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+    %g = vector.gather %base[%indices], %mask, %pass_thru
+       : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
     ```
   }];
   let extraClassDeclaration = [{
@@ -1462,15 +1460,14 @@ def Vector_GatherOp :
       return mask().getType().cast<VectorType>();
     }
     VectorType getPassThruVectorType() {
-      return (llvm::size(pass_thru()) == 0)
-        ? VectorType()
-        : (*pass_thru().begin()).getType().cast<VectorType>();
+      return pass_thru().getType().cast<VectorType>();
     }
     VectorType getResultVectorType() {
       return result().getType().cast<VectorType>();
     }
   }];
-  let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
+  let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
+    "type($base) `,` type($indices) `,` type($mask) `,` type($pass_thru) `into` type($result)";
   let hasCanonicalizer = 1;
 }
 
@@ -1507,8 +1504,8 @@ def Vector_ScatterOp :
     Example:
 
     ```mlir
-    vector.scatter %base, %indices, %mask, %value
-        : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
+    vector.scatter %base[%indices], %mask, %value
+        : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
     ```
   }];
   let extraClassDeclaration = [{
@@ -1525,8 +1522,8 @@ def Vector_ScatterOp :
       return value().getType().cast<VectorType>();
     }
   }];
-  let assemblyFormat = "$base `,` $indices `,` $mask `,` $value attr-dict `:` "
-    "type($indices) `,` type($mask) `,` type($value) `into` type($base)";
+  let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
+    "type($base) `,` type($indices) `,` type($mask) `,` type($value)";
   let hasCanonicalizer = 1;
 }
 

diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir
index 5ed8f3ee38f8..95df5aea06e4 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir
@@ -3,18 +3,10 @@
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
-func @gather8(%base: memref<?xf32>,
-              %indices: vector<8xi32>, %mask: vector<8xi1>) -> vector<8xf32> {
-  %g = vector.gather %base, %indices, %mask
-    : (memref<?xf32>, vector<8xi32>, vector<8xi1>) -> vector<8xf32>
-  return %g : vector<8xf32>
-}
-
-func @gather_pass_thru8(%base: memref<?xf32>,
-                        %indices: vector<8xi32>, %mask: vector<8xi1>,
-                        %pass_thru: vector<8xf32>) -> vector<8xf32> {
-  %g = vector.gather %base, %indices, %mask, %pass_thru
-    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> vector<8xf32>
+func @gather8(%base: memref<?xf32>, %indices: vector<8xi32>,
+              %mask: vector<8xi1>, %pass_thru: vector<8xf32>) -> vector<8xf32> {
+  %g = vector.gather %base[%indices], %mask, %pass_thru
+    : memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
   return %g : vector<8xf32>
 }
 
@@ -63,31 +55,31 @@ func @entry() {
   // Gather tests.
   //
 
-  %g1 = call @gather8(%A, %idx, %all)
-    : (memref<?xf32>, vector<8xi32>, vector<8xi1>)
+  %g1 = call @gather8(%A, %idx, %all, %pass)
+    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
     -> (vector<8xf32>)
   vector.print %g1 : vector<8xf32>
   // CHECK: ( 0, 6, 1, 3, 5, 4, 9, 2 )
 
-  %g2 = call @gather_pass_thru8(%A, %idx, %none, %pass)
+  %g2 = call @gather8(%A, %idx, %none, %pass)
     : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
     -> (vector<8xf32>)
   vector.print %g2 : vector<8xf32>
   // CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7 )
 
-  %g3 = call @gather_pass_thru8(%A, %idx, %some, %pass)
+  %g3 = call @gather8(%A, %idx, %some, %pass)
     : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
     -> (vector<8xf32>)
   vector.print %g3 : vector<8xf32>
   // CHECK: ( 0, 6, 1, 3, -7, -7, -7, -7 )
 
-  %g4 = call @gather_pass_thru8(%A, %idx, %more, %pass)
+  %g4 = call @gather8(%A, %idx, %more, %pass)
     : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
     -> (vector<8xf32>)
   vector.print %g4 : vector<8xf32>
   // CHECK: ( 0, 6, 1, 3, -7, -7, -7, 2 )
 
-  %g5 = call @gather_pass_thru8(%A, %idx, %all, %pass)
+  %g5 = call @gather8(%A, %idx, %all, %pass)
     : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
     -> (vector<8xf32>)
   vector.print %g5 : vector<8xf32>

diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
index 54171e744605..0666cc852c2a 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
@@ -6,8 +6,8 @@
 func @scatter8(%base: memref<?xf32>,
                %indices: vector<8xi32>,
                %mask: vector<8xi1>, %value: vector<8xf32>) {
-  vector.scatter %base, %indices, %mask, %value
-    : vector<8xi32>, vector<8xi1>, vector<8xf32> into memref<?xf32>
+  vector.scatter %base[%indices], %mask, %value
+    : memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>
   return
 }
 

diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
index 088e68e0e507..7940e8c68b1a 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir
@@ -60,11 +60,12 @@ func @spmv8x8(%AVAL: memref<8xvector<4xf32>>,
   %cn = constant 8 : index
   %f0 = constant 0.0 : f32
   %mask = vector.constant_mask [4] : vector<4xi1>
+  %pass = vector.broadcast %f0 : f32 to vector<4xf32>
   scf.for %i = %c0 to %cn step %c1 {
     %aval = load %AVAL[%i] : memref<8xvector<4xf32>>
     %aidx = load %AIDX[%i] : memref<8xvector<4xi32>>
-    %0 = vector.gather %X, %aidx, %mask
-       : (memref<?xf32>, vector<4xi32>, vector<4xi1>) -> vector<4xf32>
+    %0 = vector.gather %X[%aidx], %mask, %pass
+       : memref<?xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
     %1 = vector.contract #dot_trait %aval, %0, %f0 : vector<4xf32>, vector<4xf32> into f32
     store %1, %B[%i] : memref<?xf32>
   }

diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir
index 5b02b5df4b73..31f288e0f6c5 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir
@@ -50,12 +50,15 @@ func @spmv8x8(%AVAL: memref<4xvector<8xf32>>,
   %c0 = constant 0 : index
   %c1 = constant 1 : index
   %cn = constant 4 : index
+  %f0 = constant 0.0 : f32
   %mask = vector.constant_mask [8] : vector<8xi1>
+  %pass = vector.broadcast %f0 : f32 to vector<8xf32>
   %b = load %B[%c0] : memref<1xvector<8xf32>>
   %b_out = scf.for %k = %c0 to %cn step %c1 iter_args(%b_iter = %b) -> (vector<8xf32>) {
     %aval = load %AVAL[%k] : memref<4xvector<8xf32>>
     %aidx = load %AIDX[%k] : memref<4xvector<8xi32>>
-    %0 = vector.gather %X, %aidx, %mask : (memref<?xf32>, vector<8xi32>, vector<8xi1>) -> vector<8xf32>
+    %0 = vector.gather %X[%aidx], %mask, %pass
+       : memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
     %b_new = vector.fma %aval, %0, %b_iter : vector<8xf32>
     scf.yield %b_new : vector<8xf32>
   }

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 318ca1e27c88..91eab5027962 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2446,11 +2446,8 @@ static LogicalResult verify(GatherOp op) {
     return op.emitOpError("expected result dim to match indices dim");
   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
     return op.emitOpError("expected result dim to match mask dim");
-  if (llvm::size(op.pass_thru()) != 0) {
-    VectorType passVType = op.getPassThruVectorType();
-    if (resVType != passVType)
-      return op.emitOpError("expected pass_thru of same type as result type");
-  }
+  if (resVType != op.getPassThruVectorType())
+    return op.emitOpError("expected pass_thru of same type as result type");
   return success();
 }
 

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 5c0c9651133d..ef4a85c1652c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1096,7 +1096,7 @@ func @masked_store_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<
 // CHECK: llvm.return
 
 func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
-  %0 = vector.gather %arg0, %arg1, %arg2, %arg3 : (memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+  %0 = vector.gather %arg0[%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
   return %0 : vector<3xf32>
 }
 
@@ -1106,7 +1106,7 @@ func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>,
 // CHECK: llvm.return %[[G]] : !llvm.vec<3 x f32>
 
 func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
-  vector.scatter %arg0, %arg1, %arg2, %arg3 : vector<3xi32>, vector<3xi1>, vector<3xf32> into memref<?xf32>
+  vector.scatter %arg0[%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
   return
 }
 

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 8cadafae1ec4..11100c4e615e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1238,65 +1238,83 @@ func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>,
 
 // -----
 
-func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
+func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
+                                %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   // expected-error at +1 {{'vector.gather' op base and result element type should match}}
-  %0 = vector.gather %base, %indices, %mask : (memref<?xf64>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
+  %0 = vector.gather %base[%indices], %mask, %pass_thru
+    : memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
 
 // -----
 
-func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
+func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
+                           %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   // expected-error at +1 {{'vector.gather' op result #0 must be  of ranks 1, but got 'vector<2x16xf32>'}}
-  %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<2x16xf32>
+  %0 = vector.gather %base[%indices], %mask, %pass_thru
+    : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<2x16xf32>
 }
 
 // -----
 
-func @gather_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>, %mask: vector<16xi1>) {
+func @gather_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>,
+                                  %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   // expected-error at +1 {{'vector.gather' op expected result dim to match indices dim}}
-  %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<17xi32>, vector<16xi1>) -> vector<16xf32>
+  %0 = vector.gather %base[%indices], %mask, %pass_thru
+    : memref<?xf32>, vector<17xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
 
 // -----
 
-func @gather_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<17xi1>) {
+func @gather_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
+                               %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
   // expected-error at +1 {{'vector.gather' op expected result dim to match mask dim}}
-  %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<16xi32>, vector<17xi1>) -> vector<16xf32>
+  %0 = vector.gather %base[%indices], %mask, %pass_thru
+    : memref<?xf32>, vector<16xi32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
 }
 
 // -----
 
-func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf64>) {
+func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
+                                     %mask: vector<16xi1>, %pass_thru: vector<16xf64>) {
   // expected-error at +1 {{'vector.gather' op expected pass_thru of same type as result type}}
-  %0 = vector.gather %base, %indices, %mask, %pass_thru : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf64>) -> vector<16xf32>
+  %0 = vector.gather %base[%indices], %mask, %pass_thru
+    : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf64> into vector<16xf32>
 }
 
 // -----
 
-func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
+                                 %mask: vector<16xi1>, %value: vector<16xf32>) {
   // expected-error at +1 {{'vector.scatter' op base and value element type should match}}
-  vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf64>
+  vector.scatter %base[%indices], %mask, %value
+    : memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 }
 
 // -----
 
-func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<2x16xf32>) {
+func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
+                            %mask: vector<16xi1>, %value: vector<2x16xf32>) {
   // expected-error at +1 {{'vector.scatter' op operand #3 must be  of ranks 1, but got 'vector<2x16xf32>'}}
-  vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<2x16xf32> into memref<?xf32>
+  vector.scatter %base[%indices], %mask, %value
+    : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
 }
 
 // -----
 
-func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>,
+                                   %mask: vector<16xi1>, %value: vector<16xf32>) {
   // expected-error at +1 {{'vector.scatter' op expected value dim to match indices dim}}
-  vector.scatter %base, %indices, %mask, %value : vector<17xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
+  vector.scatter %base[%indices], %mask, %value
+    : memref<?xf32>, vector<17xi32>, vector<16xi1>, vector<16xf32>
 }
 
 // -----
 
-func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
+func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
+                                %mask: vector<17xi1>, %value: vector<16xf32>) {
   // expected-error at +1 {{'vector.scatter' op expected value dim to match mask dim}}
-  vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref<?xf32>
+  vector.scatter %base[%indices], %mask, %value
+    : memref<?xf32>, vector<16xi32>, vector<17xi1>, vector<16xf32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 60890e58aef5..7284cab523a7 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -461,21 +461,19 @@ func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthr
 }
 
 // CHECK-LABEL: @gather_and_scatter
-func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
-  // CHECK: %[[X:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}} : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
-  %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
-  // CHECK: %[[Y:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}}, %[[X]] : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
-  %1 = vector.gather %base, %indices, %mask, %0 : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
-  // CHECK: vector.scatter %{{.*}}, %{{.*}}, %{{.*}}, %[[Y]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
-  vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
+func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.gather %base[%indices], %mask, %pass_thru : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  // CHECK: vector.scatter %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+  vector.scatter %base[%indices], %mask, %0 : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
   return
 }
 
 // CHECK-LABEL: @expand_and_compress
-func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
+func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = constant 0 : index
   // CHECK: %[[X:.*]] = vector.expandload %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-  %0 = vector.expandload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   // CHECK: vector.compressstore %{{.*}}[{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
   vector.compressstore %base[%c0], %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
   return

diff  --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index f9d7903a343b..5c55cc5b6f0e 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -86,12 +86,12 @@ func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>)  {
 // CHECK-SAME:                  %[[A1:.*]]: vector<16xi32>,
 // CHECK-SAME:                  %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
 // CHECK-NEXT:      %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
-// CHECK-NEXT:      %[[G:.*]] = vector.gather %[[A0]], %[[A1]], %[[M]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+// CHECK-NEXT:      %[[G:.*]] = vector.gather %[[A0]][%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 // CHECK-NEXT:      return %[[G]] : vector<16xf32>
 func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
   %mask = vector.constant_mask [16] : vector<16xi1>
-  %ld = vector.gather %base, %indices, %mask, %pass_thru
-    : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+  %ld = vector.gather %base[%indices], %mask, %pass_thru
+    : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return %ld : vector<16xf32>
 }
 
@@ -102,8 +102,8 @@ func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
 // CHECK-NEXT:      return %[[A2]] : vector<16xf32>
 func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
   %mask = vector.constant_mask [0] : vector<16xi1>
-  %ld = vector.gather %base, %indices, %mask, %pass_thru
-    : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+  %ld = vector.gather %base[%indices], %mask, %pass_thru
+    : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return %ld : vector<16xf32>
 }
 
@@ -112,12 +112,12 @@ func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
 // CHECK-SAME:                   %[[A1:.*]]: vector<16xi32>,
 // CHECK-SAME:                   %[[A2:.*]]: vector<16xf32>) {
 // CHECK-NEXT:      %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
-// CHECK-NEXT:      vector.scatter %[[A0]], %[[A1]], %[[M]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
+// CHECK-NEXT:      vector.scatter %[[A0]][%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 // CHECK-NEXT:      return
 func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
   %mask = vector.constant_mask [16] : vector<16xi1>
-  vector.scatter %base, %indices, %mask, %value
-    : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
+  vector.scatter %base[%indices], %mask, %value
+    : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
   return
 }
 
@@ -129,8 +129,8 @@ func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<1
 func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
   %0 = vector.type_cast %base : memref<16xf32> to memref<vector<16xf32>>
   %mask = vector.constant_mask [0] : vector<16xi1>
-  vector.scatter %base, %indices, %mask, %value
-    : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
+  vector.scatter %base[%indices], %mask, %value
+    : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
   return
 }
 


        


More information about the llvm-branch-commits mailing list