[Mlir-commits] [mlir] b614ada - [mlir] add support for index type in vectors.
Tobias Gysi
llvmlistbot at llvm.org
Thu Apr 8 01:32:27 PDT 2021
Author: Tobias Gysi
Date: 2021-04-08T08:17:13Z
New Revision: b614ada0e80fe1af00294e8460f987dc6a7e4d5b
URL: https://github.com/llvm/llvm-project/commit/b614ada0e80fe1af00294e8460f987dc6a7e4d5b
DIFF: https://github.com/llvm/llvm-project/commit/b614ada0e80fe1af00294e8460f987dc6a7e4d5b.diff
LOG: [mlir] add support for index type in vectors.
The patch enables the use of index type in vectors. It is a prerequisite to support vectorization for indexed Linalg operations. This refactoring became possible due to the newly introduced data layout infrastructure. The data layout of a module defines the bitwidth of the index type needed to verify bitcasts and similar vector operations.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D99948
Added:
mlir/test/Integration/Dialect/Vector/CPU/test-index-vectors.mlir
Modified:
mlir/docs/Rationale/Rationale.md
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/include/mlir/IR/BuiltinTypes.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/CMakeLists.txt
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/Interfaces/DataLayoutInterfaces.cpp
mlir/lib/Parser/TypeParser.cpp
mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/IR/invalid-ops.mlir
mlir/test/IR/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/docs/Rationale/Rationale.md b/mlir/docs/Rationale/Rationale.md
index f09f03b946a80..c159d3eda59f2 100644
--- a/mlir/docs/Rationale/Rationale.md
+++ b/mlir/docs/Rationale/Rationale.md
@@ -202,39 +202,39 @@ and described in
interest
[starts here](https://www.google.com/url?q=https://youtu.be/Ntj8ab-5cvE?t%3D596&sa=D&ust=1529450150971000&usg=AFQjCNFQHEWL7m8q3eO-1DiKw9zqC2v24Q).
-### Index type disallowed in vector types
-
-Index types are not allowed as elements of `vector` types. Index
-types are intended to be used for platform-specific "size" values and may appear
-in subscripts, sizes of aggregate types and affine expressions. They are also
-tightly coupled with `affine.apply` and affine.load/store operations; having
-`index` type is a necessary precondition of a value to be acceptable by these
-operations.
-
-We allow `index` types in tensors and memrefs as a code generation strategy has
-to map `index` to an implementation type and hence needs to be able to
-materialize corresponding values. However, the target might lack support for
+### Index type usage and limitations
+
+Index types are intended to be used for platform-specific "size" values and may
+appear in subscripts, sizes of aggregate types and affine expressions. They are
+also tightly coupled with `affine.apply` and affine.load/store operations;
+having `index` type is a necessary precondition of a value to be acceptable by
+these operations.
+
+We allow `index` types in tensors, vectors, and memrefs as a code generation
+strategy has to map `index` to an implementation type and hence needs to be able
+to materialize corresponding values. However, the target might lack support for
`vector` values with the target specific equivalent of the `index` type.
-### Bit width of a non-primitive type and `index` is undefined
-
-The bit width of a compound type is not defined by MLIR, it may be defined by a
-specific lowering pass. In MLIR, bit width is a property of certain primitive
-_type_, in particular integers and floats. It is equal to the number that
-appears in the type definition, e.g. the bit width of `i32` is `32`, so is the
-bit width of `f32`. The bit width is not _necessarily_ related to the amount of
-memory (in bytes) or the size of register (in bits) that is necessary to store
-the value of the given type. These quantities are target and ABI-specific and
-should be defined during the lowering process rather than imposed from above.
-For example, `vector<3xi57>` is likely to be lowered to a vector of four 64-bit
-integers, so that its storage requirement is `4 x 64 / 8 = 32` bytes, rather
-than `(3 x 57) ceildiv 8 = 22` bytes as can be naively computed from the
-bitwidth. Individual components of MLIR that allocate space for storing values
-may use the bit size as the baseline and query the target description when it is
-introduced.
-
-The bit width is not defined for dialect-specific types at MLIR level. Dialects
-are free to define their own quantities for type sizes.
+### Data layout of non-primitive types
+
+Data layout information such as the bit width or the alignment of types may be
+target and ABI-specific and thus should be configurable rather than imposed by
+the compiler. Especially, the layout of compound or `index` types may vary. MLIR
+specifies default bit widths for certain primitive _types_, in particular for
+integers and floats. It is equal to the number that appears in the type
+definition, e.g. the bit width of `i32` is `32`, so is the bit width of `f32`.
+The bit width is not _necessarily_ related to the amount of memory (in bytes) or
+the register size (in bits) that is necessary to store the value of the given
+type. For example, `vector<3xi57>` is likely to be lowered to a vector of four
+64-bit integers, so that its storage requirement is `4 x 64 / 8 = 32` bytes,
+rather than `(3 x 57) ceildiv 8 = 22` bytes as can be naively computed from the
+bit width. MLIR makes such [data layout information](../DataLayout.md)
+configurable using attributes that can be queried during lowering, for example,
+when allocating a compound type.
+
+The data layout of dialect-specific types is undefined at MLIR level. Yet
+dialects are free to define their own quantities and make them available via the
+data layout infrastructure.
### Integer signedness semantics
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index fcfe8f1850e95..6d058f4261416 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1738,8 +1738,8 @@ def SplatOp : Std_Op<"splat", [NoSideEffect,
let summary = "splat or broadcast operation";
let description = [{
Broadcast the operand to all elements of the result vector or tensor. The
- operand has to be of either integer or float type. When the result is a
- tensor, it has to be statically shaped.
+ operand has to be of integer/index/float type. When the result is a tensor,
+ it has to be statically shaped.
Example:
@@ -1761,8 +1761,8 @@ def SplatOp : Std_Op<"splat", [NoSideEffect,
```
}];
- let arguments = (ins AnyTypeOf<[AnySignlessInteger, AnyFloat],
- "integer or float type">:$input);
+ let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
+ "integer/index/float type">:$input);
let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate);
let builders = [
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 14afe95048067..0a1228599b601 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -2307,13 +2307,13 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect,
Arguments<(
// TODO: tighten vector element types that make sense.
ins VectorOfRankAndType<[1],
- [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$lhs,
+ [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs,
VectorOfRankAndType<[1],
- [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$rhs,
+ [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs,
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
Results<(
outs VectorOfRankAndType<[1],
- [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)>
+ [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)>
{
let summary = "Vector matrix multiplication op that operates on flattened 1-D"
" MLIR vectors";
@@ -2370,11 +2370,11 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [NoSideEffect,
Arguments<(
// TODO: tighten vector element types that make sense.
ins VectorOfRankAndType<[1],
- [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$matrix,
+ [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$matrix,
I32Attr:$rows, I32Attr:$columns)>,
Results<(
outs VectorOfRankAndType<[1],
- [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)> {
+ [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> {
let summary = "Vector matrix transposition on flattened 1-D MLIR vectors";
let description = [{
This is the counterpart of llvm.matrix.transpose in MLIR. It serves
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 22d194db3b685..f271c56f41627 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -874,7 +874,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "ShapedType"> {
```
vector-type ::= `vector` `<` static-dimension-list vector-element-type `>`
- vector-element-type ::= float-type | integer-type
+ vector-element-type ::= float-type | integer-type | index-type
static-dimension-list ::= (decimal-literal `x`)+
```
@@ -911,9 +911,10 @@ def Builtin_Vector : Builtin_Type<"Vector", "ShapedType"> {
];
let extraClassDeclaration = [{
/// Returns true of the given type can be used as an element of a vector
- /// type. In particular, vectors can consist of integer or float primitives.
+ /// type. In particular, vectors can consist of integer, index, or float
+ /// primitives.
static bool isValidElementType(Type t) {
- return t.isa<IntegerType, FloatType>();
+ return t.isa<IntegerType, IndexType, FloatType>();
}
/// Get or create a new VectorType with the same shape as `this` and an
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 3ea9bb41518ee..a2469dc5bee32 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -758,11 +758,11 @@ def BoolLike : TypeConstraint<Or<[I1.predicate, VectorOf<[I1]>.predicate,
"bool-like">;
// Type constraint for signless-integer-like types: signless integers, indices,
-// vectors of signless integers, tensors of signless integers.
+// vectors of signless integers or indices, tensors of signless integers.
def SignlessIntegerLike : TypeConstraint<Or<[
AnySignlessInteger.predicate, Index.predicate,
- VectorOf<[AnySignlessInteger]>.predicate,
- TensorOf<[AnySignlessInteger]>.predicate]>,
+ VectorOf<[AnySignlessInteger, Index]>.predicate,
+ TensorOf<[AnySignlessInteger, Index]>.predicate]>,
"signless-integer-like">;
// Type constraint for float-like types: floats, vectors or tensors thereof.
diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
index 87cac054c55ec..0633eb341f11e 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
@@ -144,6 +144,9 @@ class DataLayout {
explicit DataLayout(DataLayoutOpInterface op);
explicit DataLayout(ModuleOp op);
+ /// Returns the layout of the closest parent operation carrying layout info.
+ static DataLayout closest(Operation *op);
+
/// Returns the size of the given type in the current scope.
unsigned getTypeSize(Type t) const;
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 0c752c33ff164..9ecee857e2e5c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -163,13 +163,13 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
TransferReadOp xferOp, ArrayRef<Value> operands,
Value dataPtr, Value mask) {
- VectorType fillType = xferOp.getVectorType();
- Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
-
Type vecTy = typeConverter.convertType(xferOp.getVectorType());
if (!vecTy)
return failure();
+ auto adaptor = TransferReadOpAdaptor(operands, xferOp->getAttrDictionary());
+ Value fill = rewriter.create<SplatOp>(loc, vecTy, adaptor.padding());
+
unsigned align;
if (failed(getMemRefAlignment(
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 1c895f950c289..b74c9a5a823a0 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -23,6 +23,7 @@ add_mlir_dialect_library(MLIRVector
MLIRMemRef
MLIRSCF
MLIRLoopAnalysis
+ MLIRDataLayoutInterfaces
MLIRSideEffectInterfaces
MLIRVectorInterfaces
)
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index cff5fcb5649eb..0ad89109b3ada 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2202,12 +2202,15 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
return op->emitOpError(
"requires source to be a memref or ranked tensor type");
auto elementType = shapedType.getElementType();
+ DataLayout dataLayout = DataLayout::closest(op);
if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
// Memref or tensor has vector element type.
- unsigned sourceVecSize = vectorElementType.getElementTypeBitWidth() *
- vectorElementType.getShape().back();
+ unsigned sourceVecSize =
+ dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
+ vectorElementType.getShape().back();
unsigned resultVecSize =
- vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
+ dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
+ vectorType.getShape().back();
if (resultVecSize % sourceVecSize != 0)
return op->emitOpError(
"requires the bitwidth of the minor 1-D vector to be an integral "
@@ -2226,8 +2229,9 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
} else {
// Memref or tensor has scalar element type.
unsigned resultVecSize =
- vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
- if (resultVecSize % elementType.getIntOrFloatBitWidth() != 0)
+ dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
+ vectorType.getShape().back();
+ if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
return op->emitOpError(
"requires the bitwidth of the minor 1-D vector to be an integral "
"multiple of the bitwidth of the source element type");
@@ -3233,9 +3237,10 @@ static LogicalResult verify(BitCastOp op) {
return op.emitOpError("dimension size mismatch at: ") << i;
}
- if (sourceVectorType.getElementTypeBitWidth() *
+ DataLayout dataLayout = DataLayout::closest(op);
+ if (dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()) *
sourceVectorType.getShape().back() !=
- resultVectorType.getElementTypeBitWidth() *
+ dataLayout.getTypeSizeInBits(resultVectorType.getElementType()) *
resultVectorType.getShape().back())
return op.emitOpError(
"source/result bitwidth of the minor 1-D vectors must be equal");
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index ba8ca26b336e4..3bb333cf786d4 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1388,7 +1388,7 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
VectorType resType = op.getVectorType();
Type eltType = resType.getElementType();
- bool isInt = eltType.isa<IntegerType>();
+ bool isInt = eltType.isa<IntegerType, IndexType>();
Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
vector::CombiningKind kind = op.kind();
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 8ef7c26741846..ce6d4b3a603ba 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -693,7 +693,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
"expected attribute value to have element type");
if (eltType.isa<FloatType>())
intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
- else if (eltType.isa<IntegerType>())
+ else if (eltType.isa<IntegerType, IndexType>())
intVal = values[i].cast<IntegerAttr>().getValue();
else
llvm_unreachable("unexpected element type");
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index f792dfeacbf12..da1453367c7ac 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -392,7 +392,7 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "vector types must have at least one dimension";
if (!isValidElementType(elementType))
- return emitError() << "vector elements must be int or float type";
+ return emitError() << "vector elements must be int/index/float type";
if (any_of(shape, [](int64_t i) { return i <= 0; }))
return emitError() << "vector types must have positive constant sizes";
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 9f5c75a425fb7..3369d61a834bb 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -264,6 +264,19 @@ mlir::DataLayout::DataLayout(ModuleOp op)
#endif
}
+mlir::DataLayout mlir::DataLayout::closest(Operation *op) {
+ // Search the closest parent either being a module operation or implementing
+ // the data layout interface.
+ while (op) {
+ if (auto module = dyn_cast<ModuleOp>(op))
+ return DataLayout(module);
+ if (auto iface = dyn_cast<DataLayoutOpInterface>(op))
+ return DataLayout(iface);
+ op = op->getParentOp();
+ }
+ return DataLayout();
+}
+
void mlir::DataLayout::checkValid() const {
#ifndef NDEBUG
SmallVector<DataLayoutSpecInterface> specs;
diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 378b82f3bb1f6..d81cb53060b18 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -472,7 +472,7 @@ VectorType Parser::parseVectorType() {
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
if (!VectorType::isValidElementType(elementType))
- return emitError(typeLoc, "vector elements must be int or float type"),
+ return emitError(typeLoc, "vector elements must be int/index/float type"),
nullptr;
return VectorType::get(dimensions, elementType);
diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index 5eca81dcad003..1d12eb937378f 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -248,3 +248,16 @@ func @fmaf(%arg0: f32, %arg1: vector<4xf32>) {
%1 = fmaf %arg1, %arg1, %arg1 : vector<4xf32>
std.return
}
+
+// -----
+
+// CHECK-LABEL: func @index_vector(
+// CHECK-SAME: %[[ARG0:.*]]: vector<4xi64>
+func @index_vector(%arg0: vector<4xindex>) {
+ // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xindex>) : vector<4xi64>
+ %0 = constant dense<[0, 1, 2, 3]> : vector<4xindex>
+ // CHECK: %[[V:.*]] = llvm.add %[[ARG0]], %[[CST]] : vector<4xi64>
+ %1 = addi %arg0, %0 : vector<4xindex>
+ std.return
+}
+
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 9faf7caa3439d..c3ca8ef095e5f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -23,18 +23,40 @@ func @bitcast_i8_to_f32_vector(%input: vector<64xi8>) -> vector<16xf32> {
// -----
+func @bitcast_index_to_i8_vector(%input: vector<16xindex>) -> vector<128xi8> {
+ %0 = vector.bitcast %input : vector<16xindex> to vector<128xi8>
+ return %0 : vector<128xi8>
+}
+
+// CHECK-LABEL: @bitcast_index_to_i8_vector
+// CHECK-SAME: %[[input:.*]]: vector<16xindex>
+// CHECK: %[[T0:.*]] = llvm.mlir.cast %[[input]] : vector<16xindex> to vector<16xi64>
+// CHECK: llvm.bitcast %[[T0]] : vector<16xi64> to vector<128xi8>
+
+// -----
-func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
+func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<2xf32>
return %0 : vector<2xf32>
}
-// CHECK-LABEL: @broadcast_vec1d_from_scalar
+// CHECK-LABEL: @broadcast_vec1d_from_f32
// CHECK-SAME: %[[A:.*]]: f32)
// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2xf32>
// CHECK: return %[[T0]] : vector<2xf32>
// -----
+func @broadcast_vec1d_from_index(%arg0: index) -> vector<2xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<2xindex>
+ return %0 : vector<2xindex>
+}
+// CHECK-LABEL: @broadcast_vec1d_from_index
+// CHECK-SAME: %[[A:.*]]: index)
+// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2xindex>
+// CHECK: return %[[T0]] : vector<2xindex>
+
+// -----
+
func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
%0 = vector.broadcast %arg0 : f32 to vector<2x3xf32>
return %0 : vector<2x3xf32>
@@ -83,6 +105,22 @@ func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
// -----
+func @broadcast_vec2d_from_index_vec1d(%arg0: vector<2xindex>) -> vector<3x2xindex> {
+ %0 = vector.broadcast %arg0 : vector<2xindex> to vector<3x2xindex>
+ return %0 : vector<3x2xindex>
+}
+// CHECK-LABEL: @broadcast_vec2d_from_index_vec1d(
+// CHECK-SAME: %[[A:.*]]: vector<2xindex>)
+// CHECK: %[[T0:.*]] = constant dense<0> : vector<3x2xindex>
+// CHECK: %[[T1:.*]] = llvm.mlir.cast %[[A]] : vector<2xindex> to vector<2xi64>
+// CHECK: %[[T2:.*]] = llvm.mlir.cast %[[T0]] : vector<3x2xindex> to !llvm.array<3 x vector<2xi64>>
+// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<3 x vector<2xi64>>
+
+// CHECK: %[[T4:.*]] = llvm.mlir.cast %{{.*}} : !llvm.array<3 x vector<2xi64>> to vector<3x2xindex>
+// CHECK: return %[[T4]] : vector<3x2xindex>
+
+// -----
+
func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> {
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32>
return %0 : vector<4x3x2xf32>
@@ -264,6 +302,26 @@ func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32
// -----
+func @outerproduct_index(%arg0: vector<2xindex>, %arg1: vector<3xindex>) -> vector<2x3xindex> {
+ %2 = vector.outerproduct %arg0, %arg1 : vector<2xindex>, vector<3xindex>
+ return %2 : vector<2x3xindex>
+}
+// CHECK-LABEL: @outerproduct_index(
+// CHECK-SAME: %[[A:.*]]: vector<2xindex>,
+// CHECK-SAME: %[[B:.*]]: vector<3xindex>)
+// CHECK: %[[T0:.*]] = constant dense<0> : vector<2x3xindex>
+// CHECK: %[[T1:.*]] = llvm.mlir.cast %[[A]] : vector<2xindex> to vector<2xi64>
+// CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[T3:.*]] = llvm.extractelement %[[T1]]{{\[}}%[[T2]] : i64] : vector<2xi64>
+// CHECK: %[[T4:.*]] = llvm.mlir.cast %[[T3]] : i64 to index
+// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xindex>
+// CHECK: %[[T6:.*]] = muli %[[T5]], %[[B]] : vector<3xindex>
+// CHECK: %[[T7:.*]] = llvm.mlir.cast %[[T6]] : vector<3xindex> to vector<3xi64>
+// CHECK: %[[T8:.*]] = llvm.mlir.cast %[[T0]] : vector<2x3xindex> to !llvm.array<2 x vector<3xi64>>
+// CHECK: %{{.*}} = llvm.insertvalue %[[T7]], %[[T8]][0] : !llvm.array<2 x vector<3xi64>>
+
+// -----
+
func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
%2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32>
return %2 : vector<2x3xf32>
@@ -305,6 +363,21 @@ func @shuffle_1D_direct(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<2x
// -----
+func @shuffle_1D_index_direct(%arg0: vector<2xindex>, %arg1: vector<2xindex>) -> vector<2xindex> {
+ %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xindex>, vector<2xindex>
+ return %1 : vector<2xindex>
+}
+// CHECK-LABEL: @shuffle_1D_index_direct(
+// CHECK-SAME: %[[A:.*]]: vector<2xindex>,
+// CHECK-SAME: %[[B:.*]]: vector<2xindex>)
+// CHECK: %[[T0:.*]] = llvm.mlir.cast %[[A]] : vector<2xindex> to vector<2xi64>
+// CHECK: %[[T1:.*]] = llvm.mlir.cast %[[B]] : vector<2xindex> to vector<2xi64>
+// CHECK: %[[T2:.*]] = llvm.shufflevector %[[T0]], %[[T1]] [0, 1] : vector<2xi64>, vector<2xi64>
+// CHECK: %[[T3:.*]] = llvm.mlir.cast %[[T2]] : vector<2xi64> to vector<2xindex>
+// CHECK: return %[[T3]] : vector<2xindex>
+
+// -----
+
func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> {
%1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32>
return %1 : vector<5xf32>
@@ -382,6 +455,20 @@ func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
// -----
+func @extract_index_element_from_vec_1d(%arg0: vector<16xindex>) -> index {
+ %0 = vector.extract %arg0[15]: vector<16xindex>
+ return %0 : index
+}
+// CHECK-LABEL: @extract_index_element_from_vec_1d(
+// CHECK-SAME: %[[A:.*]]: vector<16xindex>)
+// CHECK: %[[T0:.*]] = llvm.mlir.cast %[[A]] : vector<16xindex> to vector<16xi64>
+// CHECK: %[[T1:.*]] = llvm.mlir.constant(15 : i64) : i64
+// CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<16xi64>
+// CHECK: %[[T3:.*]] = llvm.mlir.cast %[[T2]] : i64 to index
+// CHECK: return %[[T3]] : index
+
+// -----
+
func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> {
%0 = vector.extract %arg0[0]: vector<4x3x16xf32>
return %0 : vector<3x16xf32>
@@ -439,6 +526,22 @@ func @insert_element_into_vec_1d(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf
// -----
+func @insert_index_element_into_vec_1d(%arg0: index, %arg1: vector<4xindex>) -> vector<4xindex> {
+ %0 = vector.insert %arg0, %arg1[3] : index into vector<4xindex>
+ return %0 : vector<4xindex>
+}
+// CHECK-LABEL: @insert_index_element_into_vec_1d(
+// CHECK-SAME: %[[A:.*]]: index,
+// CHECK-SAME: %[[B:.*]]: vector<4xindex>)
+// CHECK: %[[T0:.*]] = llvm.mlir.cast %[[A]] : index to i64
+// CHECK: %[[T1:.*]] = llvm.mlir.cast %[[B]] : vector<4xindex> to vector<4xi64>
+// CHECK: %[[T3:.*]] = llvm.mlir.constant(3 : i64) : i64
+// CHECK: %[[T4:.*]] = llvm.insertelement %[[T0]], %[[T1]][%[[T3]] : i64] : vector<4xi64>
+// CHECK: %[[T5:.*]] = llvm.mlir.cast %[[T4]] : vector<4xi64> to vector<4xindex>
+// CHECK: return %[[T5]] : vector<4xindex>
+
+// -----
+
func @insert_vec_2d_into_vec_3d(%arg0: vector<8x16xf32>, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> {
%0 = vector.insert %arg0, %arg1[3] : vector<8x16xf32> into vector<4x8x16xf32>
return %0 : vector<4x8x16xf32>
@@ -489,6 +592,18 @@ func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
// -----
+func @vector_index_type_cast(%arg0: memref<8x8x8xindex>) -> memref<vector<8x8x8xindex>> {
+ %0 = vector.type_cast %arg0: memref<8x8x8xindex> to memref<vector<8x8x8xindex>>
+ return %0 : memref<vector<8x8x8xindex>>
+}
+// CHECK-LABEL: @vector_index_type_cast(
+// CHECK-SAME: %[[A:.*]]: memref<8x8x8xindex>)
+// CHECK: %{{.*}} = llvm.mlir.cast %[[A]] : memref<8x8x8xindex> to !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<3 x i64>, array<3 x i64>)>
+
+// CHECK: %{{.*}} = llvm.mlir.cast %{{.*}} : !llvm.struct<(ptr<array<8 x array<8 x vector<8xi64>>>>, ptr<array<8 x array<8 x vector<8xi64>>>>, i64)> to memref<vector<8x8x8xindex>>
+
+// -----
+
func @vector_type_cast_non_zero_addrspace(%arg0: memref<8x8x8xf32, 3>) -> memref<vector<8x8x8xf32>, 3> {
%0 = vector.type_cast %arg0: memref<8x8x8xf32, 3> to memref<vector<8x8x8xf32>, 3>
return %0 : memref<vector<8x8x8xf32>, 3>
@@ -723,6 +838,20 @@ func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
// -----
+func @extract_strided_index_slice1(%arg0: vector<4xindex>) -> vector<2xindex> {
+ %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xindex> to vector<2xindex>
+ return %0 : vector<2xindex>
+}
+// CHECK-LABEL: @extract_strided_index_slice1(
+// CHECK-SAME: %[[A:.*]]: vector<4xindex>)
+// CHECK: %[[T0:.*]] = llvm.mlir.cast %[[A]] : vector<4xindex> to vector<4xi64>
+// CHECK: %[[T1:.*]] = llvm.mlir.cast %[[A]] : vector<4xindex> to vector<4xi64>
+// CHECK: %[[T2:.*]] = llvm.shufflevector %[[T0]], %[[T1]] [2, 3] : vector<4xi64>, vector<4xi64>
+// CHECK: %[[T3:.*]] = llvm.mlir.cast %[[T2]] : vector<2xi64> to vector<2xindex>
+// CHECK: return %[[T3]] : vector<2xindex>
+
+// -----
+
func @extract_strided_slice2(%arg0: vector<4x8xf32>) -> vector<2x8xf32> {
%0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32>
return %0 : vector<2x8xf32>
@@ -772,6 +901,16 @@ func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vecto
// -----
+func @insert_strided_index_slice1(%b: vector<4x4xindex>, %c: vector<4x4x4xindex>) -> vector<4x4x4xindex> {
+ %0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xindex> into vector<4x4x4xindex>
+ return %0 : vector<4x4x4xindex>
+}
+// CHECK-LABEL: @insert_strided_index_slice1(
+// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
+// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
+
+// -----
+
func @insert_strided_slice2(%a: vector<2x2xf32>, %b: vector<4x4xf32>) -> vector<4x4xf32> {
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
return %0 : vector<4x4xf32>
@@ -1019,6 +1158,18 @@ func @reduce_i64(%arg0: vector<16xi64>) -> i64 {
// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.add"(%[[A]])
// CHECK: return %[[V]] : i64
+// -----
+
+func @reduce_index(%arg0: vector<16xindex>) -> index {
+ %0 = vector.reduction "add", %arg0 : vector<16xindex> into index
+ return %0 : index
+}
+// CHECK-LABEL: @reduce_index(
+// CHECK-SAME: %[[A:.*]]: vector<16xindex>)
+// CHECK: %[[T0:.*]] = llvm.mlir.cast %[[A]] : vector<16xindex> to vector<16xi64>
+// CHECK: %[[T1:.*]] = "llvm.intr.vector.reduce.add"(%[[T0]])
+// CHECK: %[[T2:.*]] = llvm.mlir.cast %[[T1]] : i64 to index
+// CHECK: return %[[T2]] : index
// 4x16 16x3 4x3
// -----
@@ -1036,6 +1187,19 @@ func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
// -----
+func @matrix_ops_index(%A: vector<64xindex>, %B: vector<48xindex>) -> vector<12xindex> {
+ %C = vector.matrix_multiply %A, %B
+ { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
+ (vector<64xindex>, vector<48xindex>) -> vector<12xindex>
+ return %C: vector<12xindex>
+}
+// CHECK-LABEL: @matrix_ops_index
+// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
+// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
+// CHECK-SAME: } : (vector<64xi64>, vector<48xi64>) -> vector<12xi64>
+
+// -----
+
func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
%f7 = constant 7.0: f32
%f = vector.transfer_read %A[%base], %f7
@@ -1108,6 +1272,29 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
// -----
+func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<17xindex> {
+ %f7 = constant 7: index
+ %f = vector.transfer_read %A[%base], %f7
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<?xindex>, vector<17xindex>
+ vector.transfer_write %f, %A[%base]
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ vector<17xindex>, memref<?xindex>
+ return %f: vector<17xindex>
+}
+// CHECK-LABEL: func @transfer_read_index_1d
+// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xindex>
+// CHECK: %[[C7:.*]] = constant 7
+// CHECK: %{{.*}} = llvm.mlir.cast %[[C7]] : index to i64
+
+// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} :
+// CHECK-SAME: (!llvm.ptr<vector<17xi64>>, vector<17xi1>, vector<17xi64>) -> vector<17xi64>
+
+// CHECK: llvm.intr.masked.store %[[loaded]], %{{.*}}, %{{.*}} {alignment = 8 : i32} :
+// CHECK-SAME: vector<17xi64>, vector<17xi1> into !llvm.ptr<vector<17xi64>>
+
+// -----
+
func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<17xf32> {
%f7 = constant 7.0: f32
%f = vector.transfer_read %A[%base0, %base1], %f7
@@ -1258,6 +1445,22 @@ func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
// -----
+func @flat_transpose_index(%arg0: vector<16xindex>) -> vector<16xindex> {
+ %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
+ : vector<16xindex> -> vector<16xindex>
+ return %0 : vector<16xindex>
+}
+// CHECK-LABEL: func @flat_transpose_index
+// CHECK-SAME: %[[A:.*]]: vector<16xindex>
+// CHECK: %[[T0:.*]] = llvm.mlir.cast %[[A]] : vector<16xindex> to vector<16xi64>
+// CHECK: %[[T1:.*]] = llvm.intr.matrix.transpose %[[T0]]
+// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} :
+// CHECK-SAME: vector<16xi64> into vector<16xi64>
+// CHECK: %[[T2:.*]] = llvm.mlir.cast %[[T1]] : vector<16xi64> to vector<16xindex>
+// CHECK: return %[[T2]] : vector<16xindex>
+
+// -----
+
func @vector_load_op(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
%0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
return %0 : vector<8xf32>
@@ -1271,6 +1474,19 @@ func @vector_load_op(%memref : memref<200x100xf32>, %i : index, %j : index) -> v
// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<8xf32>>
// CHECK: llvm.load %[[bcast]] {alignment = 4 : i64} : !llvm.ptr<vector<8xf32>>
+// -----
+
+func @vector_load_op_index(%memref : memref<200x100xindex>, %i : index, %j : index) -> vector<8xindex> {
+ %0 = vector.load %memref[%i, %j] : memref<200x100xindex>, vector<8xindex>
+ return %0 : vector<8xindex>
+}
+// CHECK-LABEL: func @vector_load_op_index
+// CHECK: %[[T0:.*]] = llvm.load %{{.*}} {alignment = 8 : i64} : !llvm.ptr<vector<8xi64>>
+// CHECK: %[[T1:.*]] = llvm.mlir.cast %[[T0]] : vector<8xi64> to vector<8xindex>
+// CHECK: return %[[T1]] : vector<8xindex>
+
+// -----
+
func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index) {
%val = constant dense<11.0> : vector<4xf32>
vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
@@ -1285,6 +1501,18 @@ func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index) {
// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
// CHECK: llvm.store %{{.*}}, %[[bcast]] {alignment = 4 : i64} : !llvm.ptr<vector<4xf32>>
+// -----
+
+func @vector_store_op_index(%memref : memref<200x100xindex>, %i : index, %j : index) {
+ %val = constant dense<11> : vector<4xindex>
+ vector.store %val, %memref[%i, %j] : memref<200x100xindex>, vector<4xindex>
+ return
+}
+// CHECK-LABEL: func @vector_store_op_index
+// CHECK: llvm.store %{{.*}}, %{{.*}} {alignment = 8 : i64} : !llvm.ptr<vector<4xi64>>
+
+// -----
+
func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
%c0 = constant 0: index
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
@@ -1301,6 +1529,16 @@ func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<1
// -----
+func @masked_load_op_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) -> vector<16xindex> {
+ %c0 = constant 0: index
+ %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
+ return %0 : vector<16xindex>
+}
+// CHECK-LABEL: func @masked_load_op_index
+// CHECK: %{{.*}} = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.ptr<vector<16xi64>>, vector<16xi1>, vector<16xi64>) -> vector<16xi64>
+
+// -----
+
func @masked_store_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
%c0 = constant 0: index
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32>
@@ -1316,6 +1554,16 @@ func @masked_store_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<
// -----
+func @masked_store_op_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) {
+ %c0 = constant 0: index
+ vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xi1>, vector<16xindex>
+ return
+}
+// CHECK-LABEL: func @masked_store_op_index
+// CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<16xi64>, vector<16xi1> into !llvm.ptr<vector<16xi64>>
+
+// -----
+
func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
%0 = constant 0: index
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
@@ -1329,6 +1577,16 @@ func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>,
// -----
+func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) -> vector<3xindex> {
+ %0 = constant 0: index
+ %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xindex>, vector<3xi32>, vector<3xi1>, vector<3xindex> into vector<3xindex>
+ return %1 : vector<3xindex>
+}
+// CHECK-LABEL: func @gather_op_index
+// CHECK: %{{.*}} = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<3 x ptr<i64>>, vector<3xi1>, vector<3xi64>) -> vector<3xi64>
+
+// -----
+
func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
%0 = constant 3 : index
%1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
@@ -1355,6 +1613,16 @@ func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>
// -----
+func @scatter_op_index(%arg0: memref<?xindex>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) {
+ %0 = constant 0: index
+ vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xindex>, vector<3xi32>, vector<3xi1>, vector<3xindex>
+ return
+}
+// CHECK-LABEL: func @scatter_op_index
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xi64>, vector<3xi1> into !llvm.vec<3 x ptr<i64>>
+
+// -----
+
func @scatter_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) {
%0 = constant 3 : index
vector.scatter %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32>
@@ -1383,6 +1651,16 @@ func @expand_load_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<1
// -----
+func @expand_load_op_index(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %arg2: vector<11xindex>) -> vector<11xindex> {
+ %c0 = constant 0: index
+ %0 = vector.expandload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<11xi1>, vector<11xindex> into vector<11xindex>
+ return %0 : vector<11xindex>
+}
+// CHECK-LABEL: func @expand_load_op_index
+// CHECK: %{{.*}} = "llvm.intr.masked.expandload"(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i64>, vector<11xi1>, vector<11xi64>) -> vector<11xi64>
+
+// -----
+
func @compress_store_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) {
%c0 = constant 0: index
vector.compressstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32>
@@ -1394,3 +1672,13 @@ func @compress_store_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vecto
// CHECK: %[[C:.*]] = llvm.mlir.cast %[[CO]] : index to i64
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %[[P]], %{{.*}}) : (vector<11xf32>, !llvm.ptr<f32>, vector<11xi1>) -> ()
+
+// -----
+
+func @compress_store_op_index(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %arg2: vector<11xindex>) {
+ %c0 = constant 0: index
+ vector.compressstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<11xi1>, vector<11xindex>
+ return
+}
+// CHECK-LABEL: func @compress_store_op_index
+// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %{{.*}}, %{{.*}}) : (vector<11xi64>, !llvm.ptr<i64>, vector<11xi1>) -> ()
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 43bef97f799e3..fd5c0c8ac67e3 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -3,14 +3,18 @@
// CHECK-LABEL: func @vector_transfer_ops(
func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%arg1 : memref<?x?xvector<4x3xf32>>,
- %arg2 : memref<?x?xvector<4x3xi32>>) {
+ %arg2 : memref<?x?xvector<4x3xi32>>,
+ %arg3 : memref<?x?xvector<4x3xindex>>) {
// CHECK: %[[C3:.*]] = constant 3 : index
%c3 = constant 3 : index
%cst = constant 3.0 : f32
%f0 = constant 0.0 : f32
%c0 = constant 0 : i32
+ %i0 = constant 0 : index
+
%vf0 = splat %f0 : vector<4x3xf32>
%v0 = splat %c0 : vector<4x3xi32>
+ %vi0 = splat %i0 : vector<4x3xindex>
%m = constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
//
@@ -28,8 +32,10 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {in_bounds = [false, true]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
%6 = vector.transfer_read %arg2[%c3, %c3], %v0 : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
+ // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref<?x?xvector<4x3xindex>>, vector<5x48xi8>
+ %7 = vector.transfer_read %arg3[%c3, %c3], %vi0 : memref<?x?xvector<4x3xindex>>, vector<5x48xi8>
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<5xf32>
- %7 = vector.transfer_read %arg0[%c3, %c3], %f0, %m : memref<?x?xf32>, vector<5xf32>
+ %8 = vector.transfer_read %arg0[%c3, %c3], %f0, %m : memref<?x?xf32>, vector<5xf32>
// CHECK: vector.transfer_write
vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref<?x?xf32>
@@ -41,8 +47,11 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
vector.transfer_write %5, %arg1[%c3, %c3] {in_bounds = [false, false]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
+ // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x48xi8>, memref<?x?xvector<4x3xindex>>
+ vector.transfer_write %7, %arg3[%c3, %c3] : vector<5x48xi8>, memref<?x?xvector<4x3xindex>>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : vector<5xf32>, memref<?x?xf32>
- vector.transfer_write %7, %arg0[%c3, %c3], %m : vector<5xf32>, memref<?x?xf32>
+ vector.transfer_write %8, %arg0[%c3, %c3], %m : vector<5xf32>, memref<?x?xf32>
+
return
}
@@ -50,16 +59,21 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
// CHECK-LABEL: func @vector_transfer_ops_tensor(
func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
%arg1 : tensor<?x?xvector<4x3xf32>>,
- %arg2 : tensor<?x?xvector<4x3xi32>>) ->
+ %arg2 : tensor<?x?xvector<4x3xi32>>,
+ %arg3 : tensor<?x?xvector<4x3xindex>>) ->
(tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xvector<4x3xf32>>,
- tensor<?x?xvector<4x3xf32>>, tensor<?x?xvector<4x3xi32>>){
+ tensor<?x?xvector<4x3xf32>>, tensor<?x?xvector<4x3xi32>>,
+ tensor<?x?xvector<4x3xindex>>){
// CHECK: %[[C3:.*]] = constant 3 : index
%c3 = constant 3 : index
%cst = constant 3.0 : f32
%f0 = constant 0.0 : f32
%c0 = constant 0 : i32
+ %i0 = constant 0 : index
+
%vf0 = splat %f0 : vector<4x3xf32>
%v0 = splat %c0 : vector<4x3xi32>
+ %vi0 = splat %i0 : vector<4x3xindex>
//
// CHECK: vector.transfer_read
@@ -76,22 +90,27 @@ func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
%5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {in_bounds = [false, true]} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : tensor<?x?xvector<4x3xi32>>, vector<5x24xi8>
%6 = vector.transfer_read %arg2[%c3, %c3], %v0 : tensor<?x?xvector<4x3xi32>>, vector<5x24xi8>
+ // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : tensor<?x?xvector<4x3xindex>>, vector<5x48xi8>
+ %7 = vector.transfer_read %arg3[%c3, %c3], %vi0 : tensor<?x?xvector<4x3xindex>>, vector<5x48xi8>
// CHECK: vector.transfer_write
- %7 = vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, tensor<?x?xf32>
+ %8 = vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, tensor<?x?xf32>
// CHECK: vector.transfer_write
- %8 = vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, tensor<?x?xf32>
+ %9 = vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, tensor<?x?xf32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
- %9 = vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
+ %10 = vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
- %10 = vector.transfer_write %5, %arg1[%c3, %c3] {in_bounds = [false, false]} : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
+ %11 = vector.transfer_write %5, %arg1[%c3, %c3] {in_bounds = [false, false]} : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, tensor<?x?xvector<4x3xi32>>
- %11 = vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, tensor<?x?xvector<4x3xi32>>
+ %12 = vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, tensor<?x?xvector<4x3xi32>>
+ // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x48xi8>, tensor<?x?xvector<4x3xindex>>
+ %13 = vector.transfer_write %7, %arg3[%c3, %c3] : vector<5x48xi8>, tensor<?x?xvector<4x3xindex>>
- return %7, %8, %9, %10, %11 :
+ return %8, %9, %10, %11, %12, %13 :
tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xvector<4x3xf32>>,
- tensor<?x?xvector<4x3xf32>>, tensor<?x?xvector<4x3xi32>>
+ tensor<?x?xvector<4x3xf32>>, tensor<?x?xvector<4x3xi32>>,
+ tensor<?x?xvector<4x3xindex>>
}
// CHECK-LABEL: @vector_broadcast
@@ -381,8 +400,9 @@ func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
// CHECK-LABEL: @bitcast
func @bitcast(%arg0 : vector<5x1x3x2xf32>,
%arg1 : vector<8x1xi32>,
- %arg2 : vector<16x1x8xi8>)
- -> (vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>) {
+ %arg2 : vector<16x1x8xi8>,
+ %arg3 : vector<8x2x1xindex>)
+ -> (vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>) {
// CHECK: vector.bitcast %{{.*}} : vector<5x1x3x2xf32> to vector<5x1x3x4xf16>
%0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x1x3x4xf16>
@@ -402,7 +422,13 @@ func @bitcast(%arg0 : vector<5x1x3x2xf32>,
// CHECK-NEXT: vector.bitcast %{{.*}} : vector<16x1x8xi8> to vector<16x1x4xi16>
%5 = vector.bitcast %arg2 : vector<16x1x8xi8> to vector<16x1x4xi16>
- return %0, %1, %2, %3, %4, %5 : vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>
+ // CHECK-NEXT: vector.bitcast %{{.*}} : vector<16x1x8xi8> to vector<16x1x1xindex>
+ %6 = vector.bitcast %arg2 : vector<16x1x8xi8> to vector<16x1x1xindex>
+
+ // CHECK-NEXT: vector.bitcast %{{.*}} : vector<8x2x1xindex> to vector<8x2x2xf32>
+ %7 = vector.bitcast %arg3 : vector<8x2x1xindex> to vector<8x2x2xf32>
+
+ return %0, %1, %2, %3, %4, %5, %6, %7 : vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>
}
// CHECK-LABEL: @vector_fma
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 3747039944b53..a3240704f4310 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -842,7 +842,7 @@ func @invalid_splat(%v : f32) {
func @invalid_splat(%v : vector<8xf32>) {
%w = splat %v : tensor<8xvector<8xf32>>
- // expected-error at -1 {{must be integer or float type}}
+ // expected-error at -1 {{must be integer/index/float type}}
return
}
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 39e72d2cb8daa..220f46d5b344b 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -23,10 +23,6 @@ func @illegalmemrefelementtype(memref<?xtensor<i8>>) -> () // expected-error {{i
func @illegalunrankedmemrefelementtype(memref<*xtensor<i8>>) -> () // expected-error {{invalid memref element type}}
-// -----
-
-func @indexvector(vector<4 x index>) -> () // expected-error {{vector elements must be int or float type}}
-
// -----
// Test no map in memref type.
func @memrefs(memref<2x4xi8, >) // expected-error {{expected list element}}
@@ -387,7 +383,7 @@ func @succ_arg_type_mismatch() {
// Test no nested vector.
func @vectors(vector<1 x vector<1xi32>>, vector<2x4xf32>)
-// expected-error at -1 {{vector elements must be int or float type}}
+// expected-error at -1 {{vector elements must be int/index/float type}}
// -----
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-index-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-index-vectors.mlir
new file mode 100644
index 0000000000000..2ae8a92c61f9f
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-index-vectors.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @entry() {
+ %c0 = constant dense<[0, 1, 2, 3]>: vector<4xindex>
+ %c1 = constant dense<[0, 1]>: vector<2xindex>
+ %c2 = constant 2 : index
+
+ %v1 = vector.broadcast %c0 : vector<4xindex> to vector<2x4xindex>
+ %v2 = vector.broadcast %c1 : vector<2xindex> to vector<4x2xindex>
+ %v3 = vector.transpose %v2, [1, 0] : vector<4x2xindex> to vector<2x4xindex>
+ %v4 = vector.broadcast %c2 : index to vector<2x4xindex>
+
+ %v5 = addi %v1, %v3 : vector<2x4xindex>
+
+ vector.print %v1 : vector<2x4xindex>
+ vector.print %v3 : vector<2x4xindex>
+ vector.print %v4 : vector<2x4xindex>
+ vector.print %v5 : vector<2x4xindex>
+
+ //
+ // created index vectors:
+ //
+ // CHECK: ( ( 0, 1, 2, 3 ), ( 0, 1, 2, 3 ) )
+ // CHECK: ( ( 0, 0, 0, 0 ), ( 1, 1, 1, 1 ) )
+ // CHECK: ( ( 2, 2, 2, 2 ), ( 2, 2, 2, 2 ) )
+ // CHECK: ( ( 0, 1, 2, 3 ), ( 1, 2, 3, 4 ) )
+
+ return
+}
More information about the Mlir-commits
mailing list