[Mlir-commits] [mlir] d2613d5 - [mlir][tensor] Add gather/scatter op definitions to the tensor dialect.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Sep 5 02:02:30 PDT 2022
Author: Nicolas Vasilache
Date: 2022-09-05T02:02:22-07:00
New Revision: d2613d5bb5dca0624833e4747f67db6fe3236ce8
URL: https://github.com/llvm/llvm-project/commit/d2613d5bb5dca0624833e4747f67db6fe3236ce8
DIFF: https://github.com/llvm/llvm-project/commit/d2613d5bb5dca0624833e4747f67db6fe3236ce8.diff
LOG: [mlir][tensor] Add gather/scatter op definitions to the tensor dialect.
Gather/Scatter are examined from first principles in light of our recent progress on tensor-based codegen
and in-place bufferization.
In the future, lowering of these abstractions to operate **inplace** on buffers
will likely require a more powerful buffer representation than strided memref.
General context: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707
Relevant TL;DR parts of the proposal:
- gather: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707#proposal-gatherop-and-friends-10
- need for more expressive types: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707#proposal-bufferization-copy-view-and-the-need-for-more-expressive-types-12
- jagged buffer discussion: https://discourse.llvm.org/t/rfc-structured-codegen-beyond-rectangular-arrays/64707#proposal-first-class-jagged-buffer-13
Differential Revision: https://reviews.llvm.org/D130348
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/invalid.mlir
mlir/test/Dialect/Tensor/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 9a0bbf690cf02..51ab419b58025 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -232,7 +232,7 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
Example:
- ```
+ ```mlir
// Rank-reducing extract_slice.
%1 = tensor.extract_slice %0[0, 0, 0][1, 16, 4][1, 1, 1] :
tensor<8x16x4xf32> to tensor<16x4xf32>
@@ -372,8 +372,8 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
"$_self.cast<ShapedType>().getNumElements(), "
"$_self.cast<ShapedType>().getElementType())">
]> {
- string summary = "tensor from elements operation.";
- string description = [{
+ let summary = "tensor from elements operation.";
+ let description = [{
Create a N-D tensor from a range of same-type arguments. The number of
provided `elements` should equal to the number of the elements in the
result type. The `elements` correspond to a flattened tensor.
@@ -406,6 +406,144 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// GatherOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_GatherOp : Tensor_Op<"gather", [
+ NoSideEffect
+ ]> {
+ let summary = "gather a subset of a tensor at specified indices";
+ let description = [{
+ The `gather` operation extracts a subset of the elements from a `source`
+ tensor at the given indices.
+
+ In its most general form, the tensor of indices specifies all the coordinates
+ of every element to extract (i.e. COO format, without the payload).
+ The indices are expected to be confined to coordinate values that fit the
+ range of the `source` tensor, otherwise the behavior is undefined.
+
+ The leading dimensions of the index tensor give the result tensor its leading
+ dimensions. The trailing dimensions of the result tensor are obtained from
+ the source tensor by omitting the dimensions specified in `gather_dims`
+ (rank-reducing semantics) or setting them to `1` (rank-preserving semantics)
+ (see examples).
+ The trailing dimension of the index tensor contains the coordinates and is
+ expected to have its size equal to the number of dimensions being gathered.
+ This convention allows an idiomatic specification and lowering of "gathering
+ multiple N-D slices from the source tensor".
+
+ Note: in the examples below, we separate out the indexing part of the tensor
+ type by a whitespace for readability purposes.
+
+ Example:
+
+ ```mlir
+ // For each 1x2 triple of coordinates in %indices, extract the
+ // element (i.e. 0-D subset) at the coordinates triple in %source.
+ //
+ %out = tensor.gather %source[%indices] gather_dims([0, 1, 2]) :
+ (tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>
+
+ // Note: result type may be further rank-reduced to tensor<1x2x f32>.
+ ```
+
+ A slice variant is provided to allow specifying whole slices of the source
+ tensor.
+
+ Example:
+
+ ```mlir
+ // For each 5x6 singleton of coordinates in %indices, extract the 2-D
+ // slice %source[*, %indices[...]:%indices[...] + 1, *] with the indices
+ // corresponding to the `gather_dims` attribute specified by %indices.
+ //
+ %out = tensor.gather %source[%indices] gather_dims([1]) :
+ (tensor<3x4x5xf32>, tensor<6x7x 1xindex>) -> tensor<6x7x 3x1x5xf32>
+
+ // Note: result type may be further rank-reduced to tensor<6x7x 3x5xf32>.
+ ```
+
+ The dimensions specified in the gather_dims attribute are ones for which the
+ result tensor has size `1`.
+ I.e. if the source type is `axbxcxd` and the coordinates are [1, 3], then
+ the shape suffix is `ax1xcx1`.
+ Gather also allows rank-reducing semantics where the shape `ax1xcx1` can be
+ further simplified to `axc`.
+
+ The elemental type of the indices tensor can be any integer type.
+ In the absence of target-specific or problem specific information the default
+ type one should use is `index`.
+
+ This operation does not support unranked tensors.
+
+ An optional `unique` unit attribute may be specified to indicate that the
+ coordinates in `indices` are statically guaranteed to be unique at runtime.
+ Incorrectly setting the `unique` attribute when the coordinates are not truly
+ unique is undefined behavior.
+
+ Only full slices are meant to be supported by this op, if one desires
+ partial slices (e.g. strided windows) one should compose this op with other
+ tensor ops (e.g. tensor.extract_slice). This is to avoid a slippery slope of
+ complexity that would make the op unusable in practice.
+
+ At the tensor-level, the index tensor is specified in an AoS form (i.e.
+ coordinate tuple is the most minor). It is the responsibility of further
+ lowerings and bufferiation to implement various concrete layouts.
+
+ Note: As currently specified, the operation must lower to an abstraction that
+ performs copies to the output tensor. This is because the buffer type system
+ is currently not rich enough to allow multiple non-contiguous views in the
+ same type. This is visible more clearly in a notional buffer version of the
+ op:
+
+ ```mlir
+ // memref<?x4x1xf32> is a contiguous buffer of ?x4x1 elements.
+ // gather from random source slices must copy to the contiguous output.
+ %out = memref.gather %source[%indices] gather_dims([1]) :
+ (memref<4x4xf32>, memref<?x 1xindex>) -> memref<?x 4x1xf32>
+
+ // Nested buffer support would allow gather to directly index into the
+ // source buffer (i.e. represent a jagged view into the source).
+ %out = memref.gather %source[%indices] gather_dims([1]) :
+ (memref<4x4xf32>, memref<?x 1xindex>) -> memref<? x memref<4x1xf32>>
+ ```
+ }];
+
+ let arguments = (ins AnyRankedTensor:$source,
+ RankedTensorOf<[AnySignlessIntegerOrIndex]>:$indices,
+ DenseI64ArrayAttr:$gather_dims,
+ UnitAttr:$unique);
+ let results = (outs AnyRankedTensor:$result);
+
+ let assemblyFormat = [{
+ $source `[` $indices `]`
+ `gather_dims` `(` $gather_dims `)`
+ (`unique` $unique^)?
+ attr-dict
+ `:` functional-type(operands, results)
+ }];
+
+ let extraClassDeclaration = [{
+ // TODO: InferTypeOpInterface once enough confidence is built with
+ // tensor<tensor> and its lwoering to memref<memref>.
+ static RankedTensorType inferResultType(RankedTensorType sourceType,
+ RankedTensorType indicesType,
+ ArrayRef<int64_t> gatherDims,
+ bool rankReduced);
+ RankedTensorType getIndicesType() {
+ return getIndices().getType().cast<RankedTensorType>();
+ }
+ RankedTensorType getSourceType() {
+ return getSource().getType().cast<RankedTensorType>();
+ }
+ RankedTensorType getResultType() {
+ return getResult().getType().cast<RankedTensorType>();
+ }
+ }];
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//
@@ -414,8 +552,8 @@ def Tensor_GenerateOp : Tensor_Op<"generate",
[RecursiveSideEffects,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
- string summary = "Creates a dynamically sized tensor from elements";
- string description = [{
+ let summary = "Creates a dynamically sized tensor from elements";
+ let description = [{
This operation creates a dynamically sized tensor with elements of any type.
It expects one index operand per dynamic extent of the result tensor.
@@ -560,7 +698,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
Example:
- ```
+ ```mlir
// Rank-altering insert_slice.
%1 = tensor.insert_slice %t into %0[0, 0, 0][1, 16, 4][1, 1, 1] :
tensor<16x4xf32> into tensor<8x16x4xf32>
@@ -1210,6 +1348,147 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
let hasVerifier = 1;
}
+
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_ScatterOp : Tensor_Op<"scatter", [
+ NoSideEffect
+ ]> {
+ let summary =
+ "scatter a tensor into a destination tensor at specified indices";
+ let description = [{
+ The `scatter` operation inserts a `source` tensor into a `dest` tensor at
+ the given indices.
+
+ In its most general form, the tensor of indices specifies all the coordinates
+ of every element to insert (i.e. COO format, without the payload).
+ The indices are expected to be confined to coordinate values that fit the
+ range of the `dest` tensor, otherwise the behavior is undefined.
+
+ The leading dimensions of the index tensor must match that of the dest
+ tensor. The trailing dimensions of the dest tensor must match those of the
+ source tensor by omitting the dimensions specified in scatter_dims
+ (rank-reducing semantics) or setting them to `1` (rank-preserving semantics)
+ (see examples).
+ This convention allows an idiomatic specification and lowering of
+ "scattering multiple N-D slices into the dest tensor".
+ The result type must match the type of the dest tensor.
+
+ Note: in the examples below, we separate out the indexing part of the tensor
+ type by a whitespace for readability purposes.
+
+ Example:
+
+ ```mlir
+ // For each 1x2 triple of coordinates in %indices, insert the
+ // element (i.e. 0-D subset) at the coordinates triple in %dest.
+ //
+ %out = tensor.scatter %source into %dest[%indices]
+ scatter_dims([0, 1, 2]) unique :
+ (tensor<1x2x 1x1x1xf32>, tensor<4x4x4xf32>, tensor<1x2x 3xindex>)
+ -> tensor<4x4x4xf32>
+
+ // Note: source type may be further rank-reduced to tensor<1x2x f32>.
+ ```
+
+ A slice variant is provided to allow specifying insertion of whole tensor
+ slices into the `dest` tensor.
+
+ Example:
+
+ ```mlir
+ // For each 3 singleton of coordinates in %indices, insert the 2-D
+ // slice into %dest[*, %indices[...]:%indices[...] + 1, *] with the
+ // indices corresponding to the scatter_dims attribute specified by
+ // %indices.
+ //
+ %out = tensor.scatter %source into %dest[%indices] scatter_dims([1]) unique :
+ (tensor<3x 4x1x6xf32>, tensor<4x5x6xf32>, tensor<3x 1xindex>)
+ -> tensor<4x5x6xf32>
+ ```
+
+ The dimensions specified in the scatter_dims attribute are ones for which the
+ source tensor has size `1`.
+ I.e. if the dest type is `axbxcxd` and the coordinates are [1, 3], then
+ the source type suffix is `ax1xcx1`.
+ Sactter also allows rank-reducing semantics where the shape `ax1xcx1` can be
+ further simplified to `axc`.
+
+ The elemental type of the indices tensor can be any integer type.
+ In the absence of target-specific or problem specific information the default
+ type one should use is `index`.
+
+ This operation does not support unranked tensors.
+
+ A `unique` unit attribute must be be specified to indicate that the
+ coordinates are statically guaranteed to be unique at runtime. If coordinates
+ are not truly unique at runtime, the behavior is undefined.
+
+ Only full slices are meant to be supported by this op, if one desires
+ partial slices (e.g. strided windows) one should compose this op with other
+ tensor ops (e.g. tensor.insert_slice). This is to avoid a slippery slope of
+ complexity that would make the op unusable in practice.
+
+ At the tensor-level, the index tensor is specified in an AoS form (i.e.
+ coordinate tuple is the most minor). It is the responsibility of further
+ lowerings and bufferiation to implement various concrete layouts.
+
+ Note: As currently specified, the operation must lower to an abstraction that
+ performs copies to the output tensor. This is because the buffer type system
+ is currently not rich enough to allow multiple non-contiguous views in the
+ same type. This is visible more clearly in a notional buffer version of the
+ op:
+
+ ```mlir
+ // memref<?x 4xf32> is a contiguous buffer of ?x4 elements, scatter into
+ // random dest slices must copy to the contiguous dest.
+ //
+ some_side_effecting_op_writing_into %source, ...: memref<3x 4xf32>
+ memref.scatter %source into %dest[%indices] scatter_dims([1]) unique :
+ (memref<3x 4xf32>, memref<?x 4xf32>, memref<?x 1xindex>)
+
+ // Nested buffer support in the producing op would allow writing directly
+ // into the dest buffer.
+ %v = some_nested_buffer_view_op %dest[%indices] scatter_dims([1]) unique :
+ memref<? x memref<4xf32>>
+ some_side_effecting_op_writing_into %v, ...: memref<? x memref<4xf32>>
+ ```
+ }];
+
+ let arguments = (ins AnyRankedTensor:$source,
+ AnyRankedTensor:$dest,
+ RankedTensorOf<[AnySignlessIntegerOrIndex]>:$indices,
+ DenseI64ArrayAttr:$scatter_dims,
+ UnitAttr:$unique);
+ let results = (outs AnyRankedTensor:$result);
+
+ let assemblyFormat = [{
+ $source `into` $dest `[` $indices `]`
+ `scatter_dims` `(` $scatter_dims `)`
+ (`unique` $unique^)?
+ attr-dict
+ `:` functional-type(operands, results)
+ }];
+
+ let extraClassDeclaration = [{
+ RankedTensorType getDestType() {
+ return getDest().getType().cast<RankedTensorType>();
+ }
+ RankedTensorType getIndicesType() {
+ return getIndices().getType().cast<RankedTensorType>();
+ }
+ RankedTensorType getSourceType() {
+ return getSource().getType().cast<RankedTensorType>();
+ }
+ RankedTensorType getResultType() {
+ return getResult().getType().cast<RankedTensorType>();
+ }
+ }];
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index cd4c4b9f03a59..7b3bc79451a4f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -17,8 +17,11 @@
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/StringRef.h"
+#include <algorithm>
using namespace mlir;
using namespace mlir::tensor;
@@ -543,6 +546,89 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ExtractElementFromIndexCast>(context);
}
+//===----------------------------------------------------------------------===//
+// GatherOp
+//===----------------------------------------------------------------------===//
+
+/// Return the inferred result type for a gatherOp where:
+/// - sourceType is the type of the source tensor gathered from
+/// - indicesType is the type of the indices used to gather
+/// - gatherDims are the dims along which the gather occurs.
+/// Return a full rank or ranked-reduced variant of the type depending on
+/// the value of rankReduced.
+///
+/// The leading dimensions of the index tensor give the result tensor its
+/// leading dimensions.
+/// The trailing dimensions of the result tensor are obtained from the source
+/// tensor by setting the dimensions specified in gather_dims to `1` (if
+/// rankedReduced is false), or skipping them (otherwise).
+RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
+ RankedTensorType indicesType,
+ ArrayRef<int64_t> gatherDims,
+ bool rankReduced) {
+ SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
+ resultShape.reserve(resultShape.size() + sourceType.getRank());
+ for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
+ if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
+ if (!rankReduced)
+ resultShape.push_back(1);
+ continue;
+ }
+ resultShape.push_back(sourceType.getDimSize(idx));
+ }
+ return RankedTensorType::Builder(sourceType).setShape(resultShape);
+}
+
+static LogicalResult
+verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims, int64_t rank,
+ StringRef gatherOrScatter, StringRef sourceOrDest) {
+ if (dims.empty())
+ return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
+
+ int64_t numGatherDims = dims.size();
+ if (numGatherDims > rank)
+ return op->emitOpError(gatherOrScatter)
+ << "_dims overflow " << sourceOrDest << " rank";
+ for (int64_t val : dims) {
+ if (val < 0)
+ return op->emitOpError(gatherOrScatter)
+ << "_dims value must be non-negative";
+ if (val >= rank)
+ return op->emitOpError(gatherOrScatter)
+ << "_dims value must be smaller than " << sourceOrDest << " rank";
+ }
+ for (int64_t i = 1; i < numGatherDims; ++i) {
+ if (dims[i - 1] >= dims[i])
+ return op->emitOpError(gatherOrScatter)
+ << "_dims values must be strictly increasing";
+ }
+ return success();
+}
+
+LogicalResult GatherOp::verify() {
+ int64_t sourceRank = getSourceType().getRank();
+ ArrayRef<int64_t> gatherDims = getGatherDims();
+ if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank,
+ "gather", "source")))
+ return failure();
+
+ RankedTensorType expectedResultType = GatherOp::inferResultType(
+ getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
+ RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
+ getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
+ if (getResultType() != expectedResultType &&
+ getResultType() != expectedRankReducedResultType) {
+ return emitOpError("result type "
+ "mismatch: "
+ "expected ")
+ << expectedResultType << " or its rank-reduced variant "
+ << expectedRankReducedResultType << " (got: " << getResultType()
+ << ")";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//
@@ -2306,6 +2392,42 @@ void ParallelInsertSliceOp::getCanonicalizationPatterns(
InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
}
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ScatterOp::verify() {
+ int64_t destRank = getDestType().getRank();
+ ArrayRef<int64_t> scatterDims = getScatterDims();
+ if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank,
+ "scatter", "dest")))
+ return failure();
+
+ if (!getUnique())
+ return emitOpError("requires 'unique' attribute to be set");
+ // TODO: we could also check statically that there are fewer leading index
+ // tensor dims than the dest dims. If this is not the case, the unique
+ // attribute cannot be true.
+
+ // Use the GatherOp::inferResultType on the `dest` type and verify the
+ // expected type matches the source type.
+ RankedTensorType expectedSourceType = GatherOp::inferResultType(
+ getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
+ RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
+ getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
+ if (getSourceType() != expectedSourceType &&
+ getSourceType() != expectedRankReducedSourceType) {
+ return emitOpError("source type "
+ "mismatch: "
+ "expected ")
+ << expectedSourceType << " or its rank-reduced variant "
+ << expectedRankReducedSourceType << " (got: " << getSourceType()
+ << ")";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 961d1def31590..6b0dfb89952c2 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -377,3 +377,140 @@ func.func @invalid_splat(%v : vector<8xf32>) {
%w = tensor.splat %v : tensor<8xvector<8xf32>>
return
}
+
+// -----
+
+func.func @gather_empty_dims(
+ %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ // expected-error at +1 {{gather_dims must be non-empty}}
+ %out = tensor.gather %source[%indices] gather_dims([]):
+ (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
+ return
+}
+
+// -----
+
+func.func @gather_coordinate_rank_overflow(
+ %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ // expected-error at +1 {{gather_dims overflow source rank}}
+ %out = tensor.gather %source[%indices] gather_dims([0, 1, 2, 3]):
+ (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
+ return
+}
+
+// -----
+
+func.func @gather_coordinate_negative(
+ %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ // expected-error at +1 {{gather_dims value must be non-negative}}
+ %out = tensor.gather %source[%indices] gather_dims([-1]):
+ (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+ return
+}
+
+// -----
+
+func.func @gather_coordinate_overflow(
+ %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ // expected-error at +1 {{gather_dims value must be smaller than source rank}}
+ %out = tensor.gather %source[%indices] gather_dims([42]):
+ (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+ return
+}
+
+// -----
+
+func.func @gather_coordinate_overflow(
+ %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ // expected-error at +1 {{gather_dims values must be strictly increasing}}
+ %out = tensor.gather %source[%indices] gather_dims([1, 0]):
+ (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+ return
+}
+
+// -----
+
+func.func @gather_wrong_result_type(
+ %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ // expected-error at +1 {{result type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<1x2x1xf32>')}}
+ %out = tensor.gather %source[%indices] gather_dims([0, 2]):
+ (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
+ return
+}
+
+// -----
+
+func.func @scatter_empty_dims(
+ %source : tensor<f32>,
+ %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ // expected-error at +1 {{scatter_dims must be non-empty}}
+ %out = tensor.scatter %source into %dest[%indices] scatter_dims([]) unique:
+ (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
+ return
+}
+
+// -----
+
+func.func @scatter_coordinate_rank_overflow(
+ %source : tensor<f32>,
+ %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ // expected-error at +1 {{scatter_dims overflow dest rank}}
+ %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 1, 2, 3]) unique:
+ (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32>
+ return
+}
+
+// -----
+
+func.func @scatter_coordinate_negative(
+ %source : tensor<f32>,
+ %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ // expected-error at +1 {{scatter_dims value must be non-negative}}
+ %out = tensor.scatter %source into %dest[%indices] scatter_dims([-1]) unique:
+ (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+ return
+}
+
+// -----
+
+func.func @scatter_coordinate_overflow(
+ %source : tensor<f32>,
+ %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ // expected-error at +1 {{scatter_dims value must be smaller than dest rank}}
+ %out = tensor.scatter %source into %dest[%indices] scatter_dims([42]) unique:
+ (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+ return
+}
+
+// -----
+
+func.func @scatter_coordinate_overflow(
+ %source : tensor<f32>,
+ %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ // expected-error at +1 {{scatter_dims values must be strictly increasing}}
+ %out = tensor.scatter %source into %dest[%indices] scatter_dims([1, 0]) unique:
+ (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
+ return
+}
+
+// -----
+
+func.func @scatter_missing_unique(
+ %source : tensor<f32>,
+ %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ // expected-error at +1 {{requires 'unique' attribute to be set}}
+ %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]):
+ (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
+ return
+}
+
+// -----
+
+func.func @scatter_wrong_result_type(
+ %source : tensor<f32>,
+ %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
+ // expected-error at +1 {{source type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<f32>')}}
+ %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]) unique:
+ (tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
+ return
+}
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 3f5fa6bc60d2b..436a0397c09ab 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -260,3 +260,22 @@ func.func @test_splat_op(%s : f32) {
%u = "tensor.splat"(%s) : (f32) -> tensor<4xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: func @gather_scatter
+func.func @gather_scatter(
+ %dest : tensor<4x5x6xf32>, %indices: tensor<1x3x2xindex>, %indices_i32: tensor<1x3x2xi32>) {
+ %gathered = tensor.gather %dest[%indices_i32] gather_dims([1, 2]) unique:
+ (tensor<4x5x6xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4x1x1xf32>
+ %rank_reduced_gathered = tensor.gather %dest[%indices] gather_dims([1, 2]) unique:
+ (tensor<4x5x6xf32>, tensor<1x3x2xindex>) -> tensor<1x3x4xf32>
+
+ %scattered = tensor.scatter %gathered into %dest[%indices]
+ scatter_dims([1, 2]) unique:
+ (tensor<1x3x4x1x1xf32>, tensor<4x5x6xf32>, tensor<1x3x2xindex>) -> tensor<4x5x6xf32>
+ %rank_reduced_scattered = tensor.scatter %rank_reduced_gathered into %dest[%indices_i32]
+ scatter_dims([1, 2]) unique:
+ (tensor<1x3x4xf32>, tensor<4x5x6xf32>, tensor<1x3x2xi32>) -> tensor<4x5x6xf32>
+ return
+}
More information about the Mlir-commits
mailing list