[Mlir-commits] [mlir] 19dbb23 - [mlir] [VectorOps] Add scatter/gather operations to Vector dialect

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 21 10:58:07 PDT 2020


Author: aartbik
Date: 2020-07-21T10:57:40-07:00
New Revision: 19dbb230a245d3404a485d8684587c3d37c198d3

URL: https://github.com/llvm/llvm-project/commit/19dbb230a245d3404a485d8684587c3d37c198d3
DIFF: https://github.com/llvm/llvm-project/commit/19dbb230a245d3404a485d8684587c3d37c198d3.diff

LOG: [mlir] [VectorOps] Add scatter/gather operations to Vector dialect

Introduces the scatter/gather operations to the Vector dialect
(important memory operations for sparse computations), together
with a first reference implementation that lowers to the LLVM IR
dialect to enable running on CPU (and other targets that support
the corresponding LLVM IR intrinsics).

The operations can be used directly where applicable, or can be used
during progressively lowering to bring other memory operations closer to
hardware ISA support for a gather/scatter. The semantics of the operation
closely correspond to those of the corresponding llvm intrinsics.

Note that the operation allows for a dynamic index vector (which is
important for sparse computations). However, this first reference
lowering implementation "serializes" the address computation when
base + index_vector is converted to a vector of pointers. Exploring
how to use SIMD properly during these step is TBD. More general
memrefs and idiomatic versions of striding are also TBD.

Reviewed By: arpith-jacob

Differential Revision: https://reviews.llvm.org/D84039

Added: 
    mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir
    mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir
    mlir/test/Target/llvmir-intrinsics.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index ce0b3de82d2c..f421d2e46463 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -991,6 +991,35 @@ def LLVM_MaskedStoreOp
   let assemblyFormat = "$value `,` $data `,` $mask attr-dict `:` "
     "type($value) `,` type($mask) `into` type($data)";
 }
+
+/// Create a call to Masked Gather intrinsic.
+def LLVM_masked_gather
+    : LLVM_OneResultOp<"intr.masked.gather">,
+      Arguments<(ins LLVM_Type:$ptrs, LLVM_Type:$mask,
+                 Variadic<LLVM_Type>:$pass_thru, I32Attr:$alignment)> {
+  string llvmBuilder = [{
+    $res = $pass_thru.empty() ? builder.CreateMaskedGather(
+      $ptrs, llvm::Align($alignment.getZExtValue()), $mask) :
+      builder.CreateMaskedGather(
+        $ptrs, llvm::Align($alignment.getZExtValue()), $mask, $pass_thru[0]);
+  }];
+  let assemblyFormat =
+    "operands attr-dict `:` functional-type(operands, results)";
+}
+
+/// Create a call to Masked Scatter intrinsic.
+def LLVM_masked_scatter
+    : LLVM_ZeroResultOp<"intr.masked.scatter">,
+      Arguments<(ins LLVM_Type:$value, LLVM_Type:$ptrs, LLVM_Type:$mask,
+                 I32Attr:$alignment)> {
+  string llvmBuilder = [{
+    builder.CreateMaskedScatter(
+      $value, $ptrs, llvm::Align($alignment.getZExtValue()), $mask);
+  }];
+  let assemblyFormat = "$value `,` $ptrs `,` $mask attr-dict `:` "
+    "type($value) `,` type($mask) `into` type($ptrs)";
+}
+
 //
 // Atomic operations.
 //

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 10a4498b0bbd..fd3d190990d4 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1150,6 +1150,121 @@ def Vector_TransferWriteOp :
   let hasFolder = 1;
 }
 
+def Vector_GatherOp :
+  Vector_Op<"gather">,
+    Arguments<(ins AnyMemRef:$base,
+               VectorOfRankAndType<[1], [AnyInteger]>:$indices,
+               VectorOfRankAndType<[1], [I1]>:$mask,
+               Variadic<VectorOfRank<[1]>>:$pass_thru)>,
+    Results<(outs VectorOfRank<[1]>:$result)> {
+
+  let summary = "gathers elements from memory into a vector as defined by an index vector";
+
+  let description = [{
+    The gather operation gathers elements from memory into a 1-D vector as
+    defined by a base and a 1-D index vector, but only if the corresponding
+    bit is set in a 1-D mask vector. Otherwise, the element is taken from a
+    1-D pass-through vector, if provided, or left undefined. Informally the
+    semantics are:
+    ```
+    if (!defined(pass_thru)) pass_thru = [undef, .., undef]
+    result[0] := mask[0] ? MEM[base + index[0]] : pass_thru[0]
+    result[1] := mask[1] ? MEM[base + index[1]] : pass_thru[1]
+    etc.
+    ```
+    The vector dialect leaves out-of-bounds behavior undefined.
+
+    The gather operation can be used directly where applicable, or can be used
+    during progressively lowering to bring other memory operations closer to
+    hardware ISA support for a gather. The semantics of the operation closely
+    correspond to those of the `llvm.masked.gather`
+    [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-gather-intrinsics).
+
+    Example:
+
+    ```mlir
+    %g = vector.gather %base, %indices, %mask, %pass_thru
+        : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+    ```
+
+  }];
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return base().getType().cast<MemRefType>();
+    }
+    VectorType getIndicesVectorType() {
+      return indices().getType().cast<VectorType>();
+    }
+    VectorType getMaskVectorType() {
+      return mask().getType().cast<VectorType>();
+    }
+    VectorType getPassThruVectorType() {
+      return (llvm::size(pass_thru()) == 0)
+        ? VectorType()
+        : (*pass_thru().begin()).getType().cast<VectorType>();
+    }
+    VectorType getResultVectorType() {
+      return result().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
+}
+
+def Vector_ScatterOp :
+  Vector_Op<"scatter">,
+    Arguments<(ins AnyMemRef:$base,
+               VectorOfRankAndType<[1], [AnyInteger]>:$indices,
+               VectorOfRankAndType<[1], [I1]>:$mask,
+               VectorOfRank<[1]>:$value)> {
+
+  let summary = "scatters elements from a vector into memory as defined by an index vector";
+
+  let description = [{
+    The scatter operation scatters elements from a 1-D vector into memory as
+    defined by a base and a 1-D index vector, but only if the corresponding
+    bit in a 1-D mask vector is set. Otherwise, no action is taken for that
+    element. Informally the semantics are:
+    ```
+    if (mask[0]) MEM[base + index[0]] = value[0]
+    if (mask[1]) MEM[base + index[1]] = value[1]
+    etc.
+    ```
+    The vector dialect leaves out-of-bounds and repeated index behavior
+    undefined. Underlying implementations may enforce strict sequential
+    semantics for the latter, though.
+    TODO: enforce the latter always?
+
+    The scatter operation can be used directly where applicable, or can be used
+    during progressively lowering to bring other memory operations closer to
+    hardware ISA support for a scatter. The semantics of the operation closely
+    correspond to those of the `llvm.masked.scatter`
+    [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-scatter-intrinsics).
+
+    Example:
+
+    ```mlir
+    vector.scatter %base, %indices, %mask, %value
+        : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?f32>
+    ```
+  }];
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return base().getType().cast<MemRefType>();
+    }
+    VectorType getIndicesVectorType() {
+      return indices().getType().cast<VectorType>();
+    }
+    VectorType getMaskVectorType() {
+      return mask().getType().cast<VectorType>();
+    }
+    VectorType getValueVectorType() {
+      return value().getType().cast<VectorType>();
+    }
+  }];
+  let assemblyFormat = "$base `,` $indices `,` $mask `,` $value attr-dict `:` "
+    "type($indices) `,` type($mask) `,` type($value) `into` type($base)";
+}
+
 def Vector_ShapeCastOp :
   Vector_Op<"shape_cast", [NoSideEffect]>,
     Arguments<(ins AnyTypeOf<[AnyVector, TupleOf<[AnyVector]>]>:$source)>,

diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir
new file mode 100644
index 000000000000..5ed8f3ee38f8
--- /dev/null
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-gather.mlir
@@ -0,0 +1,97 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @gather8(%base: memref<?xf32>,
+              %indices: vector<8xi32>, %mask: vector<8xi1>) -> vector<8xf32> {
+  %g = vector.gather %base, %indices, %mask
+    : (memref<?xf32>, vector<8xi32>, vector<8xi1>) -> vector<8xf32>
+  return %g : vector<8xf32>
+}
+
+func @gather_pass_thru8(%base: memref<?xf32>,
+                        %indices: vector<8xi32>, %mask: vector<8xi1>,
+                        %pass_thru: vector<8xf32>) -> vector<8xf32> {
+  %g = vector.gather %base, %indices, %mask, %pass_thru
+    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> vector<8xf32>
+  return %g : vector<8xf32>
+}
+
+func @entry() {
+  // Set up memory.
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %c10 = constant 10: index
+  %A = alloc(%c10) : memref<?xf32>
+  scf.for %i = %c0 to %c10 step %c1 {
+    %i32 = index_cast %i : index to i32
+    %fi = sitofp %i32 : i32 to f32
+    store %fi, %A[%i] : memref<?xf32>
+  }
+
+  // Set up idx vector.
+  %i0 = constant 0: i32
+  %i1 = constant 1: i32
+  %i2 = constant 2: i32
+  %i3 = constant 3: i32
+  %i4 = constant 4: i32
+  %i5 = constant 5: i32
+  %i6 = constant 6: i32
+  %i9 = constant 9: i32
+  %0 = vector.broadcast %i0 : i32 to vector<8xi32>
+  %1 = vector.insert %i6, %0[1] : i32 into vector<8xi32>
+  %2 = vector.insert %i1, %1[2] : i32 into vector<8xi32>
+  %3 = vector.insert %i3, %2[3] : i32 into vector<8xi32>
+  %4 = vector.insert %i5, %3[4] : i32 into vector<8xi32>
+  %5 = vector.insert %i4, %4[5] : i32 into vector<8xi32>
+  %6 = vector.insert %i9, %5[6] : i32 into vector<8xi32>
+  %idx = vector.insert %i2, %6[7] : i32 into vector<8xi32>
+
+  // Set up pass thru vector.
+  %u = constant -7.0: f32
+  %pass = vector.broadcast %u : f32 to vector<8xf32>
+
+  // Set up masks.
+  %t = constant 1: i1
+  %none = vector.constant_mask [0] : vector<8xi1>
+  %all = vector.constant_mask [8] : vector<8xi1>
+  %some = vector.constant_mask [4] : vector<8xi1>
+  %more = vector.insert %t, %some[7] : i1 into vector<8xi1>
+
+  //
+  // Gather tests.
+  //
+
+  %g1 = call @gather8(%A, %idx, %all)
+    : (memref<?xf32>, vector<8xi32>, vector<8xi1>)
+    -> (vector<8xf32>)
+  vector.print %g1 : vector<8xf32>
+  // CHECK: ( 0, 6, 1, 3, 5, 4, 9, 2 )
+
+  %g2 = call @gather_pass_thru8(%A, %idx, %none, %pass)
+    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
+    -> (vector<8xf32>)
+  vector.print %g2 : vector<8xf32>
+  // CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7 )
+
+  %g3 = call @gather_pass_thru8(%A, %idx, %some, %pass)
+    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
+    -> (vector<8xf32>)
+  vector.print %g3 : vector<8xf32>
+  // CHECK: ( 0, 6, 1, 3, -7, -7, -7, -7 )
+
+  %g4 = call @gather_pass_thru8(%A, %idx, %more, %pass)
+    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
+    -> (vector<8xf32>)
+  vector.print %g4 : vector<8xf32>
+  // CHECK: ( 0, 6, 1, 3, -7, -7, -7, 2 )
+
+  %g5 = call @gather_pass_thru8(%A, %idx, %all, %pass)
+    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
+    -> (vector<8xf32>)
+  vector.print %g5 : vector<8xf32>
+  // CHECK: ( 0, 6, 1, 3, 5, 4, 9, 2 )
+
+  return
+}

diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
new file mode 100644
index 000000000000..6dd0cf169552
--- /dev/null
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-scatter.mlir
@@ -0,0 +1,135 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @scatter8(%base: memref<?xf32>,
+               %indices: vector<8xi32>,
+               %mask: vector<8xi1>, %value: vector<8xf32>) {
+  vector.scatter %base, %indices, %mask, %value
+    : vector<8xi32>, vector<8xi1>, vector<8xf32> into memref<?xf32>
+  return
+}
+
+func @printmem(%A: memref<?xf32>) {
+  %f = constant 0.0: f32
+  %0 = vector.broadcast %f : f32 to vector<8xf32>
+  %1 = constant 0: index
+  %2 = load %A[%1] : memref<?xf32>
+  %3 = vector.insert %2, %0[0] : f32 into vector<8xf32>
+  %4 = constant 1: index
+  %5 = load %A[%4] : memref<?xf32>
+  %6 = vector.insert %5, %3[1] : f32 into vector<8xf32>
+  %7 = constant 2: index
+  %8 = load %A[%7] : memref<?xf32>
+  %9 = vector.insert %8, %6[2] : f32 into vector<8xf32>
+  %10 = constant 3: index
+  %11 = load %A[%10] : memref<?xf32>
+  %12 = vector.insert %11, %9[3] : f32 into vector<8xf32>
+  %13 = constant 4: index
+  %14 = load %A[%13] : memref<?xf32>
+  %15 = vector.insert %14, %12[4] : f32 into vector<8xf32>
+  %16 = constant 5: index
+  %17 = load %A[%16] : memref<?xf32>
+  %18 = vector.insert %17, %15[5] : f32 into vector<8xf32>
+  %19 = constant 6: index
+  %20 = load %A[%19] : memref<?xf32>
+  %21 = vector.insert %20, %18[6] : f32 into vector<8xf32>
+  %22 = constant 7: index
+  %23 = load %A[%22] : memref<?xf32>
+  %24 = vector.insert %23, %21[7] : f32 into vector<8xf32>
+  vector.print %24 : vector<8xf32>
+  return
+}
+
+func @entry() {
+  // Set up memory.
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %c8 = constant 8: index
+  %A = alloc(%c8) : memref<?xf32>
+  scf.for %i = %c0 to %c8 step %c1 {
+    %i32 = index_cast %i : index to i32
+    %fi = sitofp %i32 : i32 to f32
+    store %fi, %A[%i] : memref<?xf32>
+  }
+
+  // Set up idx vector.
+  %i0 = constant 0: i32
+  %i1 = constant 1: i32
+  %i2 = constant 2: i32
+  %i3 = constant 3: i32
+  %i4 = constant 4: i32
+  %i5 = constant 5: i32
+  %i6 = constant 6: i32
+  %i7 = constant 7: i32
+  %0 = vector.broadcast %i7 : i32 to vector<8xi32>
+  %1 = vector.insert %i0, %0[1] : i32 into vector<8xi32>
+  %2 = vector.insert %i1, %1[2] : i32 into vector<8xi32>
+  %3 = vector.insert %i6, %2[3] : i32 into vector<8xi32>
+  %4 = vector.insert %i2, %3[4] : i32 into vector<8xi32>
+  %5 = vector.insert %i4, %4[5] : i32 into vector<8xi32>
+  %6 = vector.insert %i5, %5[6] : i32 into vector<8xi32>
+  %idx = vector.insert %i3, %6[7] : i32 into vector<8xi32>
+
+  // Set up value vector.
+  %f0 = constant 0.0: f32
+  %f1 = constant 1.0: f32
+  %f2 = constant 2.0: f32
+  %f3 = constant 3.0: f32
+  %f4 = constant 4.0: f32
+  %f5 = constant 5.0: f32
+  %f6 = constant 6.0: f32
+  %f7 = constant 7.0: f32
+  %7 = vector.broadcast %f0 : f32 to vector<8xf32>
+  %8 = vector.insert %f1, %7[1] : f32 into vector<8xf32>
+  %9 = vector.insert %f2, %8[2] : f32 into vector<8xf32>
+  %10 = vector.insert %f3, %9[3] : f32 into vector<8xf32>
+  %11 = vector.insert %f4, %10[4] : f32 into vector<8xf32>
+  %12 = vector.insert %f5, %11[5] : f32 into vector<8xf32>
+  %13 = vector.insert %f6, %12[6] : f32 into vector<8xf32>
+  %val = vector.insert %f7, %13[7] : f32 into vector<8xf32>
+
+  // Set up masks.
+  %t = constant 1: i1
+  %none = vector.constant_mask [0] : vector<8xi1>
+  %some = vector.constant_mask [4] : vector<8xi1>
+  %more = vector.insert %t, %some[7] : i1 into vector<8xi1>
+  %all = vector.constant_mask [8] : vector<8xi1>
+
+  //
+  // Scatter tests.
+  //
+
+  vector.print %idx : vector<8xi32>
+  // CHECK: ( 7, 0, 1, 6, 2, 4, 5, 3 )
+
+  call @printmem(%A) : (memref<?xf32>) -> ()
+  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
+
+  call @scatter8(%A, %idx, %none, %val)
+    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
+
+  call @printmem(%A) : (memref<?xf32>) -> ()
+  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
+
+  call @scatter8(%A, %idx, %some, %val)
+    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
+
+  call @printmem(%A) : (memref<?xf32>) -> ()
+  // CHECK: ( 1, 2, 2, 3, 4, 5, 3, 0 )
+
+  call @scatter8(%A, %idx, %more, %val)
+    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
+
+  call @printmem(%A) : (memref<?xf32>) -> ()
+  // CHECK: ( 1, 2, 2, 7, 4, 5, 3, 0 )
+
+  call @scatter8(%A, %idx, %all, %val)
+    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>) -> ()
+
+  call @printmem(%A) : (memref<?xf32>) -> ()
+  // CHECK: ( 1, 2, 4, 7, 5, 6, 3, 0 )
+
+  return
+}

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a59f02681c54..a877bd12c2e1 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -34,14 +34,6 @@
 using namespace mlir;
 using namespace mlir::vector;
 
-template <typename T>
-static LLVM::LLVMType getPtrToElementType(T containerType,
-                                          LLVMTypeConverter &typeConverter) {
-  return typeConverter.convertType(containerType.getElementType())
-      .template cast<LLVM::LLVMType>()
-      .getPointerTo();
-}
-
 // Helper to reduce vector type by one rank at front.
 static VectorType reducedVectorTypeFront(VectorType tp) {
   assert((tp.getRank() > 1) && "unlowerable vector type");
@@ -124,11 +116,12 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
   return res;
 }
 
-template <typename TransferOp>
-LogicalResult getVectorTransferAlignment(LLVMTypeConverter &typeConverter,
-                                         TransferOp xferOp, unsigned &align) {
+// Helper that returns data layout alignment of an operation with memref.
+template <typename T>
+LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
+                                 unsigned &align) {
   Type elementTy =
-      typeConverter.convertType(xferOp.getMemRefType().getElementType());
+      typeConverter.convertType(op.getMemRefType().getElementType());
   if (!elementTy)
     return failure();
 
@@ -138,13 +131,54 @@ LogicalResult getVectorTransferAlignment(LLVMTypeConverter &typeConverter,
   return success();
 }
 
+// Helper that returns vector of pointers given a base and an index vector.
+LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
+                             LLVMTypeConverter &typeConverter, Location loc,
+                             Value memref, Value indices, MemRefType memRefType,
+                             VectorType vType, Type iType, Value &ptrs) {
+  // Inspect stride and offset structure.
+  //
+  // TODO: flat memory only for now, generalize
+  //
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  auto successStrides = getStridesAndOffset(memRefType, strides, offset);
+  if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
+      offset != 0 || memRefType.getMemorySpace() != 0)
+    return failure();
+
+  // Base pointer.
+  MemRefDescriptor memRefDescriptor(memref);
+  Value base = memRefDescriptor.alignedPtr(rewriter, loc);
+
+  // Create a vector of pointers from base and indices.
+  //
+  // TODO: this step serializes the address computations unfortunately,
+  //       ideally we would like to add splat(base) + index_vector
+  //       in SIMD form, but this does not match well with current
+  //       constraints of the standard and vector dialect....
+  //
+  int64_t size = vType.getDimSize(0);
+  auto pType = memRefDescriptor.getElementType();
+  auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size);
+  auto idxType = typeConverter.convertType(iType);
+  ptrs = rewriter.create<LLVM::UndefOp>(loc, ptrsType);
+  for (int64_t i = 0; i < size; i++) {
+    Value off =
+        extractOne(rewriter, typeConverter, loc, indices, idxType, 1, i);
+    Value ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base, off);
+    ptrs = insertOne(rewriter, typeConverter, loc, ptrs, ptr, ptrsType, 1, i);
+  }
+  return success();
+}
+
 static LogicalResult
 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
                                  LLVMTypeConverter &typeConverter, Location loc,
                                  TransferReadOp xferOp,
                                  ArrayRef<Value> operands, Value dataPtr) {
   unsigned align;
-  if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
+  if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
     return failure();
   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
   return success();
@@ -165,7 +199,7 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
     return failure();
 
   unsigned align;
-  if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
+  if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
     return failure();
 
   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
@@ -180,7 +214,7 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
                                  TransferWriteOp xferOp,
                                  ArrayRef<Value> operands, Value dataPtr) {
   unsigned align;
-  if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
+  if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
     return failure();
   auto adaptor = TransferWriteOpAdaptor(operands);
   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
@@ -194,7 +228,7 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
                             TransferWriteOp xferOp, ArrayRef<Value> operands,
                             Value dataPtr, Value mask) {
   unsigned align;
-  if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
+  if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
     return failure();
 
   auto adaptor = TransferWriteOpAdaptor(operands);
@@ -259,6 +293,83 @@ class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
   }
 };
 
+/// Conversion pattern for a vector.gather.
+class VectorGatherOpConversion : public ConvertToLLVMPattern {
+public:
+  explicit VectorGatherOpConversion(MLIRContext *context,
+                                    LLVMTypeConverter &typeConverter)
+      : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
+                             typeConverter) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+    auto gather = cast<vector::GatherOp>(op);
+    auto adaptor = vector::GatherOpAdaptor(operands);
+
+    // Resolve alignment.
+    unsigned align;
+    if (failed(getMemRefAlignment(typeConverter, gather, align)))
+      return failure();
+
+    // Get index ptrs.
+    VectorType vType = gather.getResultVectorType();
+    Type iType = gather.getIndicesVectorType().getElementType();
+    Value ptrs;
+    if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
+                              adaptor.indices(), gather.getMemRefType(), vType,
+                              iType, ptrs)))
+      return failure();
+
+    // Replace with the gather intrinsic.
+    ValueRange v = (llvm::size(adaptor.pass_thru()) == 0) ? ValueRange({})
+                                                          : adaptor.pass_thru();
+    rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+        gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), v,
+        rewriter.getI32IntegerAttr(align));
+    return success();
+  }
+};
+
+/// Conversion pattern for a vector.scatter.
+class VectorScatterOpConversion : public ConvertToLLVMPattern {
+public:
+  explicit VectorScatterOpConversion(MLIRContext *context,
+                                     LLVMTypeConverter &typeConverter)
+      : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
+                             typeConverter) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+    auto scatter = cast<vector::ScatterOp>(op);
+    auto adaptor = vector::ScatterOpAdaptor(operands);
+
+    // Resolve alignment.
+    unsigned align;
+    if (failed(getMemRefAlignment(typeConverter, scatter, align)))
+      return failure();
+
+    // Get index ptrs.
+    VectorType vType = scatter.getValueVectorType();
+    Type iType = scatter.getIndicesVectorType().getElementType();
+    Value ptrs;
+    if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
+                              adaptor.indices(), scatter.getMemRefType(), vType,
+                              iType, ptrs)))
+      return failure();
+
+    // Replace with the scatter intrinsic.
+    rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
+        scatter, adaptor.value(), ptrs, adaptor.mask(),
+        rewriter.getI32IntegerAttr(align));
+    return success();
+  }
+};
+
+/// Conversion pattern for all vector reductions.
 class VectorReductionOpConversion : public ConvertToLLVMPattern {
 public:
   explicit VectorReductionOpConversion(MLIRContext *context,
@@ -1173,7 +1284,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
               VectorPrintOpConversion,
               VectorTransferConversion<TransferReadOp>,
               VectorTransferConversion<TransferWriteOp>,
-              VectorTypeCastOpConversion>(ctx, converter);
+              VectorTypeCastOpConversion,
+              VectorGatherOpConversion,
+              VectorScatterOpConversion>(ctx, converter);
   // clang-format on
 }
 

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 03c4079ef171..d16c7c3d6fdb 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1858,6 +1858,49 @@ Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
   return SmallVector<int64_t, 4>{s.begin(), s.end()};
 }
 
+//===----------------------------------------------------------------------===//
+// GatherOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(GatherOp op) {
+  VectorType indicesVType = op.getIndicesVectorType();
+  VectorType maskVType = op.getMaskVectorType();
+  VectorType resVType = op.getResultVectorType();
+
+  if (resVType.getElementType() != op.getMemRefType().getElementType())
+    return op.emitOpError("base and result element type should match");
+
+  if (resVType.getDimSize(0) != indicesVType.getDimSize(0))
+    return op.emitOpError("expected result dim to match indices dim");
+  if (resVType.getDimSize(0) != maskVType.getDimSize(0))
+    return op.emitOpError("expected result dim to match mask dim");
+  if (llvm::size(op.pass_thru()) != 0) {
+    VectorType passVType = op.getPassThruVectorType();
+    if (resVType != passVType)
+      return op.emitOpError("expected pass_thru of same type as result type");
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ScatterOp op) {
+  VectorType indicesVType = op.getIndicesVectorType();
+  VectorType maskVType = op.getMaskVectorType();
+  VectorType valueVType = op.getValueVectorType();
+
+  if (valueVType.getElementType() != op.getMemRefType().getElementType())
+    return op.emitOpError("base and value element type should match");
+
+  if (valueVType.getDimSize(0) != indicesVType.getDimSize(0))
+    return op.emitOpError("expected value dim to match indices dim");
+  if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
+    return op.emitOpError("expected value dim to match mask dim");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ShapeCastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 874cb5cca141..69d3aeca3d95 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -969,3 +969,21 @@ func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
 // CHECK-SAME:      {columns = 4 : i32, rows = 4 : i32} :
 // CHECK-SAME:      !llvm<"<16 x float>"> into !llvm<"<16 x float>">
 // CHECK:       llvm.return %[[T]] : !llvm<"<16 x float>">
+
+func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+  %0 = vector.gather %arg0, %arg1, %arg2, %arg3 : (memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+  return %0 : vector<3xf32>
+}
+
+// CHECK-LABEL: func @gather_op
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm<"<3 x float*>">, !llvm<"<3 x i1>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
+// CHECK: llvm.return %[[G]] : !llvm<"<3 x float>">
+
+func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
+  vector.scatter %arg0, %arg1, %arg2, %arg3 : vector<3xi32>, vector<3xi1>, vector<3xf32> into memref<?xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_op
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : !llvm<"<3 x float>">, !llvm<"<3 x i1>"> into !llvm<"<3 x float*>">
+// CHECK: llvm.return

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 916403800fe1..ea354f51645a 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1177,3 +1177,66 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] ->
   // expected-error at +1 {{expects operand to be a memref with no layout}}
   %0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref<vector<4x3xf32>>
 }
+
+// -----
+
+func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
+  // expected-error at +1 {{'vector.gather' op base and result element type should match}}
+  %0 = vector.gather %base, %indices, %mask : (memref<?xf64>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
+}
+
+// -----
+
+func @gather_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
+  // expected-error at +1 {{'vector.gather' op result #0 must be  of ranks 1, but got 'vector<2x16xf32>'}}
+  %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<2x16xf32>
+}
+
+// -----
+
+func @gather_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>, %mask: vector<16xi1>) {
+  // expected-error at +1 {{'vector.gather' op expected result dim to match indices dim}}
+  %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<17xi32>, vector<16xi1>) -> vector<16xf32>
+}
+
+// -----
+
+func @gather_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<17xi1>) {
+  // expected-error at +1 {{'vector.gather' op expected result dim to match mask dim}}
+  %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<16xi32>, vector<17xi1>) -> vector<16xf32>
+}
+
+// -----
+
+func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf64>) {
+  // expected-error at +1 {{'vector.gather' op expected pass_thru of same type as result type}}
+  %0 = vector.gather %base, %indices, %mask, %pass_thru : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf64>) -> vector<16xf32>
+}
+
+// -----
+
+func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+  // expected-error at +1 {{'vector.scatter' op base and value element type should match}}
+  vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf64>
+}
+
+// -----
+
+func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<2x16xf32>) {
+  // expected-error at +1 {{'vector.scatter' op operand #3 must be  of ranks 1, but got 'vector<2x16xf32>'}}
+  vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<16xi1>, vector<2x16xf32> into memref<?xf32>
+}
+
+// -----
+
+func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+  // expected-error at +1 {{'vector.scatter' op expected value dim to match indices dim}}
+  vector.scatter %base, %indices, %mask, %value : vector<17xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
+}
+
+// -----
+
+func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
+  // expected-error at +1 {{'vector.scatter' op expected value dim to match mask dim}}
+  vector.scatter %base, %indices, %mask, %value : vector<16xi32>, vector<17xi1>, vector<16xf32> into memref<?xf32>
+}

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4ea72864ea94..0bf4ed8f84c7 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -368,3 +368,14 @@ func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
   // CHECK: return %[[X]] : vector<16xi32>
   return %0 : vector<16xi32>
 }
+
+// CHECK-LABEL: @gather_and_scatter
+func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>) {
+  // CHECK: %[[X:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}} : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
+  %0 = vector.gather %base, %indices, %mask : (memref<?xf32>, vector<16xi32>, vector<16xi1>) -> vector<16xf32>
+  // CHECK: %[[Y:.*]] = vector.gather %{{.*}}, %{{.*}}, %{{.*}}, %[[X]] : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+  %1 = vector.gather %base, %indices, %mask, %0 : (memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+  // CHECK: vector.scatter %{{.*}}, %{{.*}}, %{{.*}}, %[[Y]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
+  vector.scatter %base, %indices, %mask, %1 : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<?xf32>
+  return
+}

diff  --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir
index 1595edffb45b..79b7edb3c701 100644
--- a/mlir/test/Target/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/llvmir-intrinsics.mlir
@@ -206,6 +206,20 @@ llvm.func @masked_intrinsics(%A: !llvm<"<7 x float>*">, %mask: !llvm<"<7 x i1>">
   llvm.return
 }
 
+// CHECK-LABEL: @masked_gather_scatter_intrinsics
+llvm.func @masked_gather_scatter_intrinsics(%M: !llvm<"<7 x float*>">, %mask: !llvm<"<7 x i1>">) {
+  // CHECK: call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> undef)
+  %a = llvm.intr.masked.gather %M, %mask { alignment = 1: i32} :
+      (!llvm<"<7 x float*>">, !llvm<"<7 x i1>">) -> !llvm<"<7 x float>">
+  // CHECK: call <7 x float> @llvm.masked.gather.v7f32.v7p0f32(<7 x float*> %{{.*}}, i32 1, <7 x i1> %{{.*}}, <7 x float> %{{.*}})
+  %b = llvm.intr.masked.gather %M, %mask, %a { alignment = 1: i32} :
+      (!llvm<"<7 x float*>">, !llvm<"<7 x i1>">, !llvm<"<7 x float>">) -> !llvm<"<7 x float>">
+  // CHECK: call void @llvm.masked.scatter.v7f32.v7p0f32(<7 x float> %{{.*}}, <7 x float*> %{{.*}}, i32 1, <7 x i1> %{{.*}})
+  llvm.intr.masked.scatter %b, %M, %mask { alignment = 1: i32} :
+      !llvm<"<7 x float>">, !llvm<"<7 x i1>"> into !llvm<"<7 x float*>">
+  llvm.return
+}
+
 // CHECK-LABEL: @memcpy_test
 llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm<"i8*">, %arg3: !llvm<"i8*">) {
   // CHECK: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %{{.*}}, i8* %{{.*}}, i32 %{{.*}}, i1 %{{.*}})


        


More information about the Mlir-commits mailing list