[Mlir-commits] [mlir] d2613d5 - [mlir][tensor] Add gather/scatter op definitions to the tensor dialect.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Sep 5 02:02:30 PDT 2022


Author: Nicolas Vasilache
Date: 2022-09-05T02:02:22-07:00
New Revision: d2613d5bb5dca0624833e4747f67db6fe3236ce8

URL: https://github.com/llvm/llvm-project/commit/d2613d5bb5dca0624833e4747f67db6fe3236ce8
DIFF: https://github.com/llvm/llvm-project/commit/d2613d5bb5dca0624833e4747f67db6fe3236ce8.diff

LOG: [mlir][tensor] Add gather/scatter op definitions to the tensor dialect.

Gather/Scatter are examined from first principles in light of our recent progress on tensor-based codegen
and in-place bufferization.

In the future, lowering of these abstractions to operate **inplace** on buffers
will likely require a more powerful buffer representation than strided memref.

General context: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707
Relevant TL;DR parts of the proposal:
- gather: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707#proposal-gatherop-and-friends-10
- need for more expressive types: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707#proposal-bufferization-copy-view-and-the-need-for-more-expressive-types-12
- jagged buffer discussion: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707#proposal-first-class-jagged-buffer-13

Differential Revision: https://reviews.llvm.org/D130348

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/invalid.mlir
    mlir/test/Dialect/Tensor/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 9a0bbf690cf02..51ab419b58025 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -232,7 +232,7 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
 
     Example:
 
-    ```
+    ```mlir
     // Rank-reducing extract_slice.
     %1 = tensor.extract_slice %0[0, 0, 0][1, 16, 4][1, 1, 1] :
       tensor<8x16x4xf32> to tensor<16x4xf32>
@@ -372,8 +372,8 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
                    "$_self.cast<ShapedType>().getNumElements(), "
                    "$_self.cast<ShapedType>().getElementType())">
   ]> {
-  string summary = "tensor from elements operation.";
-  string description = [{
+  let summary = "tensor from elements operation.";
+  let description = [{
     Create a N-D tensor from a range of same-type arguments. The number of
     provided `elements` should equal to the number of the elements in the
     result type. The `elements` correspond to a flattened tensor.
@@ -406,6 +406,144 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// GatherOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_GatherOp : Tensor_Op<"gather", [
+    NoSideEffect
+  ]> {
+  let summary = "gather a subset of a tensor at specified indices";
+  let description = [{
+    The `gather` operation extracts a subset of the elements from a `source`
+    tensor at the given indices.
+
+    In its most general form, the tensor of indices specifies all the coordinates
+    of every element to extract (i.e. COO format, without the payload). 
+    The indices are expected to be confined to coordinate values that fit the
+    range of the `source` tensor, otherwise the behavior is undefined.
+
+    The leading dimensions of the index tensor give the result tensor its leading
+    dimensions. The trailing dimensions of the result tensor are obtained from 
+    the source tensor by omitting the dimensions specified in `gather_dims` 
+    (rank-reducing semantics) or setting them to `1` (rank-preserving semantics)
+    (see examples).
+    The trailing dimension of the index tensor contains the coordinates and is
+    expected to have its size equal to the number of dimensions being gathered.
+    This convention allows an idiomatic specification and lowering of "gathering
+    multiple N-D slices from the source tensor". 
+
+    Note: in the examples below, we separate out the indexing part of the tensor
+    type by a whitespace for readability purposes.
+
+    Example:
+
+    ```mlir
+        // For each 1x2 triple of coordinates in %indices, extract the 
+        // element (i.e. 0-D subset) at the coordinates triple in %source.
+        //
+        %out = tensor.gather %source[%indices] gather_dims([0, 1, 2]) :
+          (tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>
+
+        // Note: result type may be further rank-reduced to tensor<1x2x f32>.
+    ```
+
+    A slice variant is provided to allow specifying whole slices of the source
+    tensor.
+
+    Example:
+
+    ```mlir
+        // For each 5x6 singleton of coordinates in %indices, extract the 2-D
+        // slice %source[*, %indices[...]:%indices[...] + 1, *] with the indices
+        // corresponding to the `gather_dims` attribute specified by %indices.
+        //
+        %out = tensor.gather %source[%indices] gather_dims([1]) : 
+          (tensor<3x4x5xf32>, tensor<6x7x 1xindex>) -> tensor<6x7x 3x1x5xf32>
+
+        // Note: result type may be further rank-reduced to tensor<6x7x 3x5xf32>.
+    ```
+
+    The dimensions specified in the gather_dims attribute are ones for which the
+    result tensor has size `1`. 
+    I.e. if the source type is `axbxcxd` and the coordinates are [1, 3], then
+    the shape suffix is `ax1xcx1`.
+    Gather also allows rank-reducing semantics where the shape `ax1xcx1` can be
+    further simplified to `axc`.
+
+    The elemental type of the indices tensor can be any integer type. 
+    In the absence of target-specific or problem specific information the default
+    type one should use is `index`.
+
+    This operation does not support unranked tensors.
+
+    An optional `unique` unit attribute may be specified to indicate that the
+    coordinates in `indices` are statically guaranteed to be unique at runtime.
+    Incorrectly setting the `unique` attribute when the coordinates are not truly
+    unique is undefined behavior.
+
+    Only full slices are meant to be supported by this op, if one desires 
+    partial slices (e.g. strided windows) one should compose this op with other
+    tensor ops (e.g. tensor.extract_slice). This is to avoid a slippery slope of
+    complexity that would make the op unusable in practice.
+
+    At the tensor-level, the index tensor is specified in an AoS form (i.e. 
+    coordinate tuple is the most minor). It is the responsibility of further 
+    lowerings and bufferiation to implement various concrete layouts.
+
+    Note: As currently specified, the operation must lower to an abstraction that
+    performs copies to the output tensor. This is because the buffer type system
+    is currently not rich enough to allow multiple non-contiguous views in the 
+    same type. This is visible more clearly in a notional buffer version of the
+    op:
+
+    ```mlir
+        // memref<?x4x1xf32> is a contiguous buffer of ?x4x1 elements.
+        // gather from random source slices must copy to the contiguous output.
+        %out = memref.gather %source[%indices] gather_dims([1]) : 
+          (memref<4x4xf32>, memref<?x 1xindex>) -> memref<?x 4x1xf32>
+
+        // Nested buffer support would allow gather to directly index into the 
+        // source buffer (i.e. represent a jagged view into the source).
+        %out = memref.gather %source[%indices] gather_dims([1]) : 
+          (memref<4x4xf32>, memref<?x 1xindex>) -> memref<? x memref<4x1xf32>>
+    ```
+  }];
+
+  let arguments = (ins AnyRankedTensor:$source, 
+                       RankedTensorOf<[AnySignlessIntegerOrIndex]>:$indices,
+                       DenseI64ArrayAttr:$gather_dims,
+                       UnitAttr:$unique);
+  let results = (outs AnyRankedTensor:$result);
+
+  let assemblyFormat = [{
+    $source `[` $indices `]` 
+      `gather_dims` `(` $gather_dims `)`
+      (`unique` $unique^)?  
+      attr-dict
+    `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    // TODO: InferTypeOpInterface once enough confidence is built with 
+    // tensor<tensor> and its lwoering to memref<memref>.
+    static RankedTensorType inferResultType(RankedTensorType sourceType,
+                                            RankedTensorType indicesType,
+                                            ArrayRef<int64_t> gatherDims,
+                                            bool rankReduced);
+    RankedTensorType getIndicesType() {
+      return getIndices().getType().cast<RankedTensorType>();
+    }
+    RankedTensorType getSourceType() {
+      return getSource().getType().cast<RankedTensorType>();
+    }
+    RankedTensorType getResultType() {
+      return getResult().getType().cast<RankedTensorType>();
+    }
+  }];
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // GenerateOp
 //===----------------------------------------------------------------------===//
@@ -414,8 +552,8 @@ def Tensor_GenerateOp : Tensor_Op<"generate",
     [RecursiveSideEffects,
      DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
      SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
-  string summary = "Creates a dynamically sized tensor from elements";
-  string description = [{
+  let summary = "Creates a dynamically sized tensor from elements";
+  let description = [{
     This operation creates a dynamically sized tensor with elements of any type.
     It expects one index operand per dynamic extent of the result tensor.
 
@@ -560,7 +698,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
 
     Example:
 
-    ```
+    ```mlir
     // Rank-altering insert_slice.
     %1 = tensor.insert_slice %t into %0[0, 0, 0][1, 16, 4][1, 1, 1] :
       tensor<16x4xf32> into tensor<8x16x4xf32>
@@ -1210,6 +1348,147 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
   let hasVerifier = 1;
 }
 
+
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_ScatterOp : Tensor_Op<"scatter", [
+    NoSideEffect
+  ]> {
+  let summary = 
+    "scatter a tensor into a destination tensor at specified indices";
+  let description = [{
+    The `scatter` operation inserts a `source` tensor into a `dest` tensor at
+    the given indices.
+
+    In its most general form, the tensor of indices specifies all the coordinates
+    of every element to insert (i.e. COO format, without the payload).
+    The indices are expected to be confined to coordinate values that fit the
+    range of the `dest` tensor, otherwise the behavior is undefined.
+
+    The leading dimensions of the index tensor must match that of the dest 
+    tensor. The trailing dimensions of the dest tensor must match those of the
+    source tensor by omitting the dimensions specified in scatter_dims 
+    (rank-reducing semantics) or setting them to `1` (rank-preserving semantics)
+    (see examples). 
+    This convention allows an idiomatic specification and lowering of 
+    "scattering multiple N-D slices into the dest tensor". 
+    The result type must match the type of the dest tensor.
+
+    Note: in the examples below, we separate out the indexing part of the tensor
+    type by a whitespace for readability purposes.
+
+    Example:
+
+    ```mlir
+        // For each 1x2 triple of coordinates in %indices, insert the 
+        // element (i.e. 0-D subset) at the coordinates triple in %dest.
+        //
+        %out = tensor.scatter %source into %dest[%indices]
+            scatter_dims([0, 1, 2]) unique :
+          (tensor<1x2x 1x1x1xf32>, tensor<4x4x4xf32>, tensor<1x2x 3xindex>)
+            -> tensor<4x4x4xf32>
+
+        // Note: source type may be further rank-reduced to tensor<1x2x f32>.
+    ```
+
+    A slice variant is provided to allow specifying insertion of whole tensor
+    slices into the `dest` tensor.
+
+    Example:
+
+    ```mlir
+        // For each 3 singleton of coordinates in %indices, insert the 2-D
+        // slice into %dest[*, %indices[...]:%indices[...] + 1, *] with the
+        // indices corresponding to the scatter_dims attribute specified by
+        // %indices.
+        //
+        %out = tensor.scatter %source into %dest[%indices] scatter_dims([1]) unique : 
+          (tensor<3x 4x1x6xf32>, tensor<4x5x6xf32>, tensor<3x 1xindex>)
+            -> tensor<4x5x6xf32>
+    ```
+
+    The dimensions specified in the scatter_dims attribute are ones for which the
+    source tensor has size `1`. 
+    I.e. if the dest type is `axbxcxd` and the coordinates are [1, 3], then
+    the source type suffix is `ax1xcx1`.
+    Sactter also allows rank-reducing semantics where the shape `ax1xcx1` can be
+    further simplified to `axc`.
+
+    The elemental type of the indices tensor can be any integer type. 
+    In the absence of target-specific or problem specific information the default
+    type one should use is `index`.
+
+    This operation does not support unranked tensors.
+
+    A `unique` unit attribute must be be specified to indicate that the
+    coordinates are statically guaranteed to be unique at runtime. If coordinates
+    are not truly unique at runtime, the behavior is undefined.
+
+    Only full slices are meant to be supported by this op, if one desires 
+    partial slices (e.g. strided windows) one should compose this op with other
+    tensor ops (e.g. tensor.insert_slice). This is to avoid a slippery slope of
+    complexity that would make the op unusable in practice.
+
+    At the tensor-level, the index tensor is specified in an AoS form (i.e. 
+    coordinate tuple is the most minor). It is the responsibility of further 
+    lowerings and bufferiation to implement various concrete layouts.
+
+    Note: As currently specified, the operation must lower to an abstraction that
+    performs copies to the output tensor. This is because the buffer type system
+    is currently not rich enough to allow multiple non-contiguous views in the 
+    same type. This is visible more clearly in a notional buffer version of the
+    op:
+
+    ```mlir
+        // memref<?x 4xf32> is a contiguous buffer of ?x4 elements, scatter into
+        // random dest slices must copy to the contiguous dest.
+        //
+        some_side_effecting_op_writing_into %source, ...: memref<3x 4xf32>
+        memref.scatter %source into %dest[%indices] scatter_dims([1]) unique : 
+          (memref<3x 4xf32>, memref<?x 4xf32>, memref<?x 1xindex>)
+
+        // Nested buffer support in the producing op would allow writing directly
+        // into the dest buffer.
+        %v = some_nested_buffer_view_op %dest[%indices] scatter_dims([1]) unique : 
+          memref<? x memref<4xf32>>
+        some_side_effecting_op_writing_into %v, ...: memref<? x memref<4xf32>>
+    ```
+  }];
+
+  let arguments = (ins AnyRankedTensor:$source, 
+                       AnyRankedTensor:$dest, 
+                       RankedTensorOf<[AnySignlessIntegerOrIndex]>:$indices,
+                       DenseI64ArrayAttr:$scatter_dims,
+                       UnitAttr:$unique);
+  let results = (outs AnyRankedTensor:$result);
+
+  let assemblyFormat = [{
+    $source `into` $dest `[` $indices `]` 
+      `scatter_dims` `(` $scatter_dims `)`
+      (`unique` $unique^)?
+      attr-dict
+    `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    RankedTensorType getDestType() {
+      return getDest().getType().cast<RankedTensorType>();
+    }
+    RankedTensorType getIndicesType() {
+      return getIndices().getType().cast<RankedTensorType>();
+    }
+    RankedTensorType getSourceType() {
+      return getSource().getType().cast<RankedTensorType>();
+    }
+    RankedTensorType getResultType() {
+      return getResult().getType().cast<RankedTensorType>();
+    }
+  }];
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // SplatOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index cd4c4b9f03a59..7b3bc79451a4f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -17,8 +17,11 @@
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/StringRef.h"
+#include <algorithm>
 
 using namespace mlir;
 using namespace mlir::tensor;
@@ -543,6 +546,89 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ExtractElementFromIndexCast>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// GatherOp
+//===----------------------------------------------------------------------===//
+
+/// Return the inferred result type for a gatherOp where:
+///   - sourceType is the type of the source tensor gathered from
+///   - indicesType is the type of the indices used to gather
+///   - gatherDims are the dims along which the gather occurs.
+/// Return a full rank or ranked-reduced variant of the type depending on
+/// the value of rankReduced.
+///
+/// The leading dimensions of the index tensor give the result tensor its
+/// leading dimensions.
+/// The trailing dimensions of the result tensor are obtained from the source
+/// tensor by setting the dimensions specified in gather_dims to `1` (if
+/// rankedReduced is false), or skipping them (otherwise).
+RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
+                                           RankedTensorType indicesType,
+                                           ArrayRef<int64_t> gatherDims,
+                                           bool rankReduced) {
+  SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
+  resultShape.reserve(resultShape.size() + sourceType.getRank());
+  for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
+    if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
+      if (!rankReduced)
+        resultShape.push_back(1);
+      continue;
+    }
+    resultShape.push_back(sourceType.getDimSize(idx));
+  }
+  return RankedTensorType::Builder(sourceType).setShape(resultShape);
+}
+
+static LogicalResult
+verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims, int64_t rank,
+                          StringRef gatherOrScatter, StringRef sourceOrDest) {
+  if (dims.empty())
+    return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
+
+  int64_t numGatherDims = dims.size();
+  if (numGatherDims > rank)
+    return op->emitOpError(gatherOrScatter)
+           << "_dims overflow " << sourceOrDest << " rank";
+  for (int64_t val : dims) {
+    if (val < 0)
+      return op->emitOpError(gatherOrScatter)
+             << "_dims value must be non-negative";
+    if (val >= rank)
+      return op->emitOpError(gatherOrScatter)
+             << "_dims value must be smaller than " << sourceOrDest << " rank";
+  }
+  for (int64_t i = 1; i < numGatherDims; ++i) {
+    if (dims[i - 1] >= dims[i])
+      return op->emitOpError(gatherOrScatter)
+             << "_dims values must be strictly increasing";
+  }
+  return success();
+}
+
+LogicalResult GatherOp::verify() {
+  int64_t sourceRank = getSourceType().getRank();
+  ArrayRef<int64_t> gatherDims = getGatherDims();
+  if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank,
+                                       "gather", "source")))
+    return failure();
+
+  RankedTensorType expectedResultType = GatherOp::inferResultType(
+      getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
+  RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
+      getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
+  if (getResultType() != expectedResultType &&
+      getResultType() != expectedRankReducedResultType) {
+    return emitOpError("result type "
+                       "mismatch: "
+                       "expected ")
+           << expectedResultType << " or its rank-reduced variant "
+           << expectedRankReducedResultType << " (got: " << getResultType()
+           << ")";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // InsertOp
 //===----------------------------------------------------------------------===//
@@ -2306,6 +2392,42 @@ void ParallelInsertSliceOp::getCanonicalizationPatterns(
               InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ScatterOp::verify() {
+  int64_t destRank = getDestType().getRank();
+  ArrayRef<int64_t> scatterDims = getScatterDims();
+  if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank,
+                                       "scatter", "dest")))
+    return failure();
+
+  if (!getUnique())
+    return emitOpError("requires 'unique' attribute to be set");
+  // TODO: we could also check statically that there are fewer leading index
+  // tensor dims than the dest dims. If this is not the case, the unique
+  // attribute cannot be true.
+
+  // Use the GatherOp::inferResultType on the `dest` type and verify the
+  // expected type matches the source type.
+  RankedTensorType expectedSourceType = GatherOp::inferResultType(
+      getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
+  RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
+      getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
+  if (getSourceType() != expectedSourceType &&
+      getSourceType() != expectedRankReducedSourceType) {
+    return emitOpError("source type "
+                       "mismatch: "
+                       "expected ")
+           << expectedSourceType << " or its rank-reduced variant "
+           << expectedRankReducedSourceType << " (got: " << getSourceType()
+           << ")";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // SplatOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 961d1def31590..6b0dfb89952c2 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -377,3 +377,140 @@ func.func @invalid_splat(%v : vector<8xf32>) {
   %w = tensor.splat %v : tensor<8xvector<8xf32>>
   return
 }
+
+// -----
+
+func.func @gather_empty_dims(
+    %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+  // expected-error at +1 {{gather_dims must be non-empty}}
+  %out = tensor.gather %source[%indices] gather_dims([]):
+    (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
+  return
+}
+
+// -----
+
+func.func @gather_coordinate_rank_overflow(
+    %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+  // expected-error at +1 {{gather_dims overflow source rank}}
+  %out = tensor.gather %source[%indices] gather_dims([0, 1, 2, 3]):
+    (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
+  return
+}
+
+// -----
+
+func.func @gather_coordinate_negative(
+    %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+  // expected-error at +1 {{gather_dims value must be non-negative}}
+  %out = tensor.gather %source[%indices] gather_dims([-1]):
+    (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+  return
+}
+
+// -----
+
+func.func @gather_coordinate_overflow(
+    %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+  // expected-error at +1 {{gather_dims value must be smaller than source rank}}
+  %out = tensor.gather %source[%indices] gather_dims([42]):
+    (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+  return
+}
+
+// -----
+
+func.func @gather_coordinate_overflow(
+    %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+  // expected-error at +1 {{gather_dims values must be strictly increasing}}
+  %out = tensor.gather %source[%indices] gather_dims([1, 0]):
+    (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+  return
+}
+
+// -----
+
+func.func @gather_wrong_result_type(
+    %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+  // expected-error at +1 {{result type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<1x2x1xf32>')}}
+  %out = tensor.gather %source[%indices] gather_dims([0, 2]):
+    (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
+  return
+}
+
+// -----
+
+func.func @scatter_empty_dims(
+    %source : tensor<f32>, 
+    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+  // expected-error at +1 {{scatter_dims must be non-empty}}
+  %out = tensor.scatter %source into %dest[%indices] scatter_dims([]) unique:
+    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
+  return
+}
+
+// -----
+
+func.func @scatter_coordinate_rank_overflow(
+    %source : tensor<f32>, 
+    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+  // expected-error at +1 {{scatter_dims overflow dest rank}}
+  %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 1, 2, 3]) unique:
+    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
+  return
+}
+
+// -----
+
+func.func @scatter_coordinate_negative(
+    %source : tensor<f32>, 
+    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+  // expected-error at +1 {{scatter_dims value must be non-negative}}
+  %out = tensor.scatter %source into %dest[%indices] scatter_dims([-1]) unique:
+    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+  return
+}
+
+// -----
+
+func.func @scatter_coordinate_overflow(
+    %source : tensor<f32>, 
+    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+  // expected-error at +1 {{scatter_dims value must be smaller than dest rank}}
+  %out = tensor.scatter %source into %dest[%indices] scatter_dims([42]) unique:
+    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+  return
+}
+
+// -----
+
+func.func @scatter_coordinate_overflow(
+    %source : tensor<f32>, 
+    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+  // expected-error at +1 {{scatter_dims values must be strictly increasing}}
+  %out = tensor.scatter %source into %dest[%indices] scatter_dims([1, 0]) unique:
+    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+  return
+}
+
+// -----
+
+func.func @scatter_missing_unique(
+    %source : tensor<f32>, 
+    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+  // expected-error at +1 {{requires 'unique' attribute to be set}}
+  %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]):
+    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
+  return
+}
+
+// -----
+
+func.func @scatter_wrong_result_type(
+    %source : tensor<f32>, 
+    %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+  // expected-error at +1 {{source type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<f32>')}}
+  %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]) unique:
+    (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
+  return
+}

diff  --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 3f5fa6bc60d2b..436a0397c09ab 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -260,3 +260,22 @@ func.func @test_splat_op(%s : f32) {
   %u = "tensor.splat"(%s) : (f32) -> tensor<4xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @gather_scatter
+func.func @gather_scatter(
+    %dest : tensor<4x5x6xf32>, %indices: tensor<1x3x2xindex>, %indices_i32: tensor<1x3x2xi32>) {
+  %gathered = tensor.gather %dest[%indices_i32] gather_dims([1, 2]) unique:
+    (tensor<4x5x6xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4x1x1xf32>
+  %rank_reduced_gathered = tensor.gather %dest[%indices] gather_dims([1, 2]) unique:
+    (tensor<4x5x6xf32>, tensor<1x3x2xindex>) -> tensor<1x3x4xf32>
+
+  %scattered = tensor.scatter %gathered into %dest[%indices] 
+      scatter_dims([1, 2]) unique:
+    (tensor<1x3x4x1x1xf32>, tensor<4x5x6xf32>, tensor<1x3x2xindex>) -> tensor<4x5x6xf32>
+  %rank_reduced_scattered = tensor.scatter %rank_reduced_gathered into %dest[%indices_i32] 
+      scatter_dims([1, 2]) unique:
+    (tensor<1x3x4xf32>, tensor<4x5x6xf32>, tensor<1x3x2xi32>) -> tensor<4x5x6xf32>
+  return
+}


        


More information about the Mlir-commits mailing list