[Mlir-commits] [mlir] 26c8f90 - [mlir[[vector] Extend Transfer read/write ops to support tensor types.
Thomas Raoux
llvmlistbot at llvm.org
Mon Dec 21 08:55:41 PST 2020
Author: Thomas Raoux
Date: 2020-12-21T08:55:04-08:00
New Revision: 26c8f9081b6b1ca9358ac2ca38e8e603fb6f7d64
URL: https://github.com/llvm/llvm-project/commit/26c8f9081b6b1ca9358ac2ca38e8e603fb6f7d64
DIFF: https://github.com/llvm/llvm-project/commit/26c8f9081b6b1ca9358ac2ca38e8e603fb6f7d64.diff
LOG: [mlir[[vector] Extend Transfer read/write ops to support tensor types.
Transfer_ops can now work on both buffers and tensor. Right now, lowering of
the tensor case is not supported yet.
Differential Revision: https://reviews.llvm.org/D93500
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/include/mlir/Dialect/Vector/VectorUtils.h
mlir/include/mlir/Interfaces/VectorInterfaces.td
mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/Dialect/Vector/VectorUtils.cpp
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 95964665ced6..5540a56a4043 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -126,7 +126,7 @@ namespace impl {
/// Build the default minor identity map suitable for a vector transfer. This
/// also handles the case memref<... x vector<...>> -> vector<...> in which the
/// rank of the identity map must take the vector element type into account.
-AffineMap getTransferMinorIdentityMap(MemRefType memRefType,
+AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
VectorType vectorType);
} // namespace impl
} // end namespace vector
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index de77e3b03483..13aba2076ee9 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1056,7 +1056,7 @@ def Vector_TransferReadOp :
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
- Arguments<(ins AnyMemRef:$memref, Variadic<Index>:$indices,
+ Arguments<(ins AnyShaped:$source, Variadic<Index>:$indices,
AffineMapAttr:$permutation_map, AnyType:$padding,
OptionalAttr<BoolArrayAttr>:$masked)>,
Results<(outs AnyVector:$vector)> {
@@ -1065,15 +1065,16 @@ def Vector_TransferReadOp :
let description = [{
The `vector.transfer_read` op performs a read from a slice within a
- [MemRef](../LangRef.md#memref-type) supplied as its first operand
- into a [vector](../LangRef.md#vector-type) of the same base elemental type.
+ [MemRef](../LangRef.md#memref-type) or a Ranked
+ [Tensor](../LangRef.md#tensor-type) supplied as its first operand into a
+ [vector](../LangRef.md#vector-type) of the same base elemental type.
- A memref operand with vector element type, must have its vector element
- type match a suffix (shape and element type) of the vector (e.g.
+ A memref/tensor operand with vector element type, must have its vector
+ element type match a suffix (shape and element type) of the vector (e.g.
memref<3x2x6x4x3xf32>, vector<1x1x4x3xf32>).
- The slice is further defined by a full-rank index within the MemRef,
- supplied as the operands `2 .. 1 + rank(memref)`.
+ The slice is further defined by a full-rank index within the MemRef/Tensor,
+ supplied as the operands `2 .. 1 + rank(memref/tensor)`.
The permutation_map [attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
@@ -1084,8 +1085,9 @@ def Vector_TransferReadOp :
The size of the slice is specified by the size of the vector, given as the
return type.
- An `ssa-value` of the same elemental type as the MemRef is provided as the
- last operand to specify padding in the case of out-of-bounds accesses.
+ An `ssa-value` of the same elemental type as the MemRef/Tensor is provided
+ as the last operand to specify padding in the case of out-of-bounds
+ accesses.
An optional boolean array attribute is provided to specify which dimensions
of the transfer need masking. When a dimension is specified as not requiring
@@ -1196,17 +1198,22 @@ def Vector_TransferReadOp :
%4 = vector.transfer_read %arg1[%c3, %c3], %vf0
{permutation_map = (d0, d1)->(d0, d1)}
: memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+
+ // Read from a tensor with vector element type.
+ %4 = vector.transfer_read %arg1[%c3, %c3], %vf0
+ {permutation_map = (d0, d1)->(d0, d1)}
+ : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
```
}];
let builders = [
// Builder that sets padding to zero.
- OpBuilderDAG<(ins "VectorType":$vector, "Value":$memref,
+ OpBuilderDAG<(ins "VectorType":$vector, "Value":$source,
"ValueRange":$indices, "AffineMap":$permutationMap,
CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>,
// Builder that sets permutation map (resp. padding) to
// 'getMinorIdentityMap' (resp. zero).
- OpBuilderDAG<(ins "VectorType":$vector, "Value":$memref,
+ OpBuilderDAG<(ins "VectorType":$vector, "Value":$source,
"ValueRange":$indices, CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>
];
@@ -1217,26 +1224,29 @@ def Vector_TransferWriteOp :
Vector_Op<"transfer_write", [
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
- ]>,
- Arguments<(ins AnyVector:$vector, AnyMemRef:$memref,
+ ]>,
+ Arguments<(ins AnyVector:$vector, AnyShaped:$source,
Variadic<Index>:$indices,
AffineMapAttr:$permutation_map,
- OptionalAttr<BoolArrayAttr>:$masked)> {
+ OptionalAttr<BoolArrayAttr>:$masked)>,
+ Results<(outs Optional<AnyRankedTensor>:$result)> {
let summary = "The vector.transfer_write op writes a supervector to memory.";
let description = [{
The `vector.transfer_write` op performs a write from a
[vector](../LangRef.md#vector-type), supplied as its first operand, into a
- slice within a [MemRef](../LangRef.md#memref-type) of the same base
- elemental type, supplied as its second operand.
+ slice within a [MemRef](../LangRef.md#memref-type) or a Ranked
+ [Tensor](../LangRef.md#tensor-type) of the same base elemental type,
+ supplied as its second operand.
- A vector memref operand must have its vector element type match a suffix
- (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>,
- vector<1x1x4x3xf32>).
+ A vector memref/tensor operand must have its vector element type match a
+ suffix (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>,
+ vector<1x1x4x3xf32>). If the operand is a tensor, the operation returns a
+ new tensor of the same type.
- The slice is further defined by a full-rank index within the MemRef,
- supplied as the operands `3 .. 2 + rank(memref)`.
+ The slice is further defined by a full-rank index within the MemRef/Tensor,
+ supplied as the operands `3 .. 2 + rank(memref/tensor)`.
The permutation_map [attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
@@ -1280,15 +1290,24 @@ def Vector_TransferWriteOp :
vector.transfer_write %4, %arg1[%c3, %c3]
{permutation_map = (d0, d1)->(d0, d1)}
: vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
+
+ // return a tensor where the vector is inserted into the source tensor.
+ %5 = vector.transfer_write %4, %arg1[%c3, %c3]
+ {permutation_map = (d0, d1)->(d0, d1)}
+ : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
```
}];
let builders = [
// Builder that sets permutation map to 'getMinorIdentityMap'.
- OpBuilderDAG<(ins "Value":$vector, "Value":$memref, "ValueRange":$indices,
+ OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
CArg<"ArrayRef<bool>", "{}">:$maybeMasked)>,
- OpBuilderDAG<(ins "Value":$vector, "Value":$memref, "ValueRange":$indices,
+ OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
"AffineMap":$permutationMap)>,
+ OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
+ "AffineMapAttr":$permutationMap, "ArrayAttr":$masked)>,
+ OpBuilderDAG<(ins "Value":$vector, "Value":$source, "ValueRange":$indices,
+ "AffineMap":$permutationMap, "ArrayAttr":$masked)>,
];
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
index f70fba819b66..a06bc8cf6562 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
@@ -20,9 +20,9 @@ class AffineApplyOp;
class AffineForOp;
class AffineMap;
class Location;
-class MemRefType;
class OpBuilder;
class Operation;
+class ShapedType;
class Value;
class VectorType;
class VectorTransferOpInterface;
@@ -157,7 +157,7 @@ makePermutationMap(Operation *op, ArrayRef<Value> indices,
/// Build the default minor identity map suitable for a vector transfer. This
/// also handles the case memref<... x vector<...>> -> vector<...> in which the
/// rank of the identity map must take the vector element type into account.
-AffineMap getTransferMinorIdentityMap(MemRefType memRefType,
+AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
VectorType vectorType);
/// Return true if we can prove that the transfer operations access disjoint
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 73332afd8825..3f60de5831c9 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -47,7 +47,7 @@ def VectorUnrollOpInterface : OpInterface<"VectorUnrollOpInterface"> {
def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
let description = [{
- Encodes properties of an operation on vectors that can be unrolled.
+ Encodes properties of a transfer read or write operation.
}];
let cppNamespace = "::mlir";
@@ -83,11 +83,11 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
}]
>,
InterfaceMethod<
- /*desc=*/"Return the memref operand.",
+ /*desc=*/"Return the memref or ranked tensor operand.",
/*retTy=*/"Value",
- /*methodName=*/"memref",
+ /*methodName=*/"source",
/*args=*/(ins),
- /*methodBody=*/"return $_op.memref();"
+ /*methodBody=*/"return $_op.source();"
/*defaultImplementation=*/
>,
InterfaceMethod<
@@ -123,13 +123,13 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*defaultImplementation=*/
>,
InterfaceMethod<
- /*desc=*/"Return the MemRefType.",
- /*retTy=*/"MemRefType",
- /*methodName=*/"getMemRefType",
+ /*desc=*/"Return the ShapedType.",
+ /*retTy=*/"ShapedType",
+ /*methodName=*/"getShapedType",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/
- "return $_op.memref().getType().template cast<MemRefType>();"
+ "return $_op.source().getType().template cast<ShapedType>();"
>,
InterfaceMethod<
/*desc=*/"Return the VectorType.",
@@ -152,14 +152,14 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
"return $_op.permutation_map().getNumResults();"
>,
InterfaceMethod<
- /*desc=*/[{ Return the number of leading memref dimensions that do not
+ /*desc=*/[{ Return the number of leading shaped dimensions that do not
participate in the permutation map.}],
/*retTy=*/"unsigned",
- /*methodName=*/"getLeadingMemRefRank",
+ /*methodName=*/"getLeadingShapedRank",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/
- "return $_op.getMemRefType().getRank() - $_op.getTransferRank();"
+ "return $_op.getShapedType().getRank() - $_op.getTransferRank();"
>,
InterfaceMethod<
/*desc=*/[{ Returns true if at least one of the dimensions is masked.}],
@@ -178,8 +178,8 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*desc=*/[{
Helper function to account for the fact that `permutationMap` results and
`op.indices` sizes may not match and may not be aligned. The first
- `getLeadingMemRefRank()` indices may just be indexed and not transferred
- from/into the vector.
+ `getLeadingShapedRank()` indices may just be indexed and not
+ transferred from/into the vector.
For example:
```
vector.transfer %0[%i, %j, %k, %c0] :
@@ -195,7 +195,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
for (int64_t resultIdx = 0,
- indicesIdx = $_op.getLeadingMemRefRank(),
+ indicesIdx = $_op.getLeadingShapedRank(),
eResult = $_op.getTransferRank();
resultIdx < eResult;
++resultIdx, ++indicesIdx)
diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
index ea483aa6abae..10d727df701b 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -22,6 +22,17 @@
using namespace mlir;
+/// Helpers to access the memref operand for each op.
+static Value getMemRefOperand(LoadOp op) { return op.memref(); }
+
+static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
+
+static Value getMemRefOperand(StoreOp op) { return op.memref(); }
+
+static Value getMemRefOperand(vector::TransferWriteOp op) {
+ return op.source();
+}
+
namespace {
/// Merges subview operation with load/transferRead operation.
template <typename OpTy>
@@ -141,7 +152,7 @@ template <typename OpTy>
LogicalResult
LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
PatternRewriter &rewriter) const {
- auto subViewOp = loadOp.memref().template getDefiningOp<SubViewOp>();
+ auto subViewOp = getMemRefOperand(loadOp).template getDefiningOp<SubViewOp>();
if (!subViewOp) {
return failure();
}
@@ -162,7 +173,8 @@ template <typename OpTy>
LogicalResult
StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
PatternRewriter &rewriter) const {
- auto subViewOp = storeOp.memref().template getDefiningOp<SubViewOp>();
+ auto subViewOp =
+ getMemRefOperand(storeOp).template getDefiningOp<SubViewOp>();
if (!subViewOp) {
return failure();
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index ebe07366f6ec..a982b90e0e93 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -141,12 +141,10 @@ static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
}
-// Helper that returns data layout alignment of an operation with memref.
-template <typename T>
-LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
- unsigned &align) {
- Type elementTy =
- typeConverter.convertType(op.getMemRefType().getElementType());
+// Helper that returns data layout alignment of a memref.
+LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
+ MemRefType memrefType, unsigned &align) {
+ Type elementTy = typeConverter.convertType(memrefType.getElementType());
if (!elementTy)
return failure();
@@ -222,7 +220,8 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
TransferReadOp xferOp,
ArrayRef<Value> operands, Value dataPtr) {
unsigned align;
- if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
+ if (failed(getMemRefAlignment(
+ typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
return success();
@@ -243,7 +242,8 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
return failure();
unsigned align;
- if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
+ if (failed(getMemRefAlignment(
+ typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
@@ -258,7 +258,8 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
TransferWriteOp xferOp,
ArrayRef<Value> operands, Value dataPtr) {
unsigned align;
- if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
+ if (failed(getMemRefAlignment(
+ typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
auto adaptor = TransferWriteOpAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
@@ -272,7 +273,8 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
TransferWriteOp xferOp, ArrayRef<Value> operands,
Value dataPtr, Value mask) {
unsigned align;
- if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
+ if (failed(getMemRefAlignment(
+ typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
return failure();
auto adaptor = TransferWriteOpAdaptor(operands);
@@ -345,7 +347,8 @@ class VectorMaskedLoadOpConversion
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), load, align)))
+ if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(),
+ align)))
return failure();
auto vtype = typeConverter->convertType(load.getResultVectorType());
@@ -375,7 +378,8 @@ class VectorMaskedStoreOpConversion
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), store, align)))
+ if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(),
+ align)))
return failure();
auto vtype = typeConverter->convertType(store.getValueVectorType());
@@ -405,7 +409,8 @@ class VectorGatherOpConversion
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), gather, align)))
+ if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
+ align)))
return failure();
// Get index ptrs.
@@ -438,7 +443,8 @@ class VectorScatterOpConversion
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align)))
+ if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
+ align)))
return failure();
// Get index ptrs.
@@ -1182,8 +1188,11 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
xferOp.getVectorType().getRank(),
xferOp->getContext()))
return failure();
+ auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
+ if (!memRefType)
+ return failure();
// Only contiguous source tensors supported atm.
- auto strides = computeContiguousStrides(xferOp.getMemRefType());
+ auto strides = computeContiguousStrides(memRefType);
if (!strides)
return failure();
@@ -1192,10 +1201,9 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
};
Location loc = xferOp->getLoc();
- MemRefType memRefType = xferOp.getMemRefType();
if (auto memrefVectorElementType =
- memRefType.getElementType().dyn_cast<VectorType>()) {
+ memRefType.getElementType().template dyn_cast<VectorType>()) {
// Memref has vector element type.
if (memrefVectorElementType.getElementType() !=
xferOp.getVectorType().getElementType())
@@ -1222,7 +1230,7 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
// address space 0.
// TODO: support alignment when possible.
Value dataPtr = this->getStridedElementPtr(
- loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
+ loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
auto vecTy =
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
Value vectorDataPtr;
@@ -1248,7 +1256,7 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
unsigned vecWidth = vecTy.getVectorNumElements();
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = xferOp.indices()[lastIndex];
- Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
+ Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
Value mask = buildVectorComparison(
rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index e5474abfd3e3..973b116ef498 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -89,7 +89,9 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
return failure();
// Obtain dataPtr and elementType from the memref.
- MemRefType memRefType = xferOp.getMemRefType();
+ auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
+ if (!memRefType)
+ return failure();
// MUBUF instruction operate only on addresspace 0(unified) or 1(global)
// In case of 3(LDS): fall back to vector->llvm pass
// In case of 5(VGPR): wrong
@@ -101,7 +103,7 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
// indices, so no need to calculate offset size in bytes again in
// the MUBUF instruction.
Value dataPtr = this->getStridedElementPtr(
- loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter);
+ loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
// 1. Create and fill a <4 x i32> dwordConfig with:
// 1st two elements holding the address of dataPtr.
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index c05e9da2c949..b0f1b46b2459 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -107,7 +107,7 @@ class NDTransferOpHelper {
// TODO: when we go to k > 1-D vectors adapt minorRank.
minorRank = 1;
majorRank = vectorType.getRank() - minorRank;
- leadingRank = xferOp.getLeadingMemRefRank();
+ leadingRank = xferOp.getLeadingShapedRank();
majorVectorType =
VectorType::get(vectorType.getShape().take_front(majorRank),
vectorType.getElementType());
@@ -115,9 +115,9 @@ class NDTransferOpHelper {
VectorType::get(vectorType.getShape().take_back(minorRank),
vectorType.getElementType());
/// Memref of minor vector type is used for individual transfers.
- memRefMinorVectorType =
- MemRefType::get(majorVectorType.getShape(), minorVectorType, {},
- xferOp.getMemRefType().getMemorySpace());
+ memRefMinorVectorType = MemRefType::get(
+ majorVectorType.getShape(), minorVectorType, {},
+ xferOp.getShapedType().template cast<MemRefType>().getMemorySpace());
}
LogicalResult doReplace();
@@ -155,7 +155,7 @@ void NDTransferOpHelper<ConcreteOp>::emitLoops(
const MemRefBoundsCapture &)>
loopBodyBuilder) {
/// Loop nest operates on the major dimensions
- MemRefBoundsCapture memrefBoundsCapture(xferOp.memref());
+ MemRefBoundsCapture memrefBoundsCapture(xferOp.source());
if (options.unroll) {
auto shape = majorVectorType.getShape();
@@ -272,9 +272,9 @@ LogicalResult NDTransferOpHelper<TransferReadOp>::doReplace() {
indexing.append(leadingOffsets.begin(), leadingOffsets.end());
indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end());
indexing.append(minorOffsets.begin(), minorOffsets.end());
- Value memref = xferOp.memref();
+ Value memref = xferOp.source();
auto map =
- getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType);
+ getTransferMinorIdentityMap(xferOp.getShapedType(), minorVectorType);
ArrayAttr masked;
if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
OpBuilder &b = ScopedContext::getBuilderRef();
@@ -379,13 +379,13 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
else
result = std_load(alloc, majorIvs);
auto map =
- getTransferMinorIdentityMap(xferOp.getMemRefType(), minorVectorType);
+ getTransferMinorIdentityMap(xferOp.getShapedType(), minorVectorType);
ArrayAttr masked;
if (!xferOp.isMaskedDim(xferOp.getVectorType().getRank() - 1)) {
OpBuilder &b = ScopedContext::getBuilderRef();
masked = b.getBoolArrayAttr({false});
}
- vector_transfer_write(result, xferOp.memref(), indexing,
+ vector_transfer_write(result, xferOp.source(), indexing,
AffineMapAttr::get(map), masked);
};
@@ -422,7 +422,7 @@ template <typename TransferOpTy>
static int computeCoalescedIndex(TransferOpTy transfer) {
// rank of the remote memory access, coalescing behavior occurs on the
// innermost memory dimension.
- auto remoteRank = transfer.getMemRefType().getRank();
+ auto remoteRank = transfer.getShapedType().getRank();
// Iterate over the results expressions of the permutation map to determine
// the loop order for creating pointwise copies between remote and local
// memories.
@@ -536,13 +536,14 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
using namespace mlir::edsc::op;
TransferReadOp transfer = cast<TransferReadOp>(op);
-
+ auto memRefType = transfer.getShapedType().dyn_cast<MemRefType>();
+ if (!memRefType)
+ return failure();
// Fall back to a loop if the fastest varying stride is not 1 or it is
// permuted.
int64_t offset;
SmallVector<int64_t, 4> strides;
- auto successStrides =
- getStridesAndOffset(transfer.getMemRefType(), strides, offset);
+ auto successStrides = getStridesAndOffset(memRefType, strides, offset);
if (succeeded(successStrides) && strides.back() == 1 &&
transfer.permutation_map().isMinorIdentity()) {
// If > 1D, emit a bunch of loops around 1-D vector transfers.
@@ -557,8 +558,8 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
// Conservative lowering to scalar load / stores.
// 1. Setup all the captures.
ScopedContext scope(rewriter, transfer.getLoc());
- StdIndexedValue remote(transfer.memref());
- MemRefBoundsCapture memRefBoundsCapture(transfer.memref());
+ StdIndexedValue remote(transfer.source());
+ MemRefBoundsCapture memRefBoundsCapture(transfer.source());
VectorBoundsCapture vectorBoundsCapture(transfer.vector());
int coalescedIdx = computeCoalescedIndex(transfer);
// Swap the vectorBoundsCapture which will reorder loop bounds.
@@ -621,13 +622,15 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
using namespace edsc::op;
TransferWriteOp transfer = cast<TransferWriteOp>(op);
+ auto memRefType = transfer.getShapedType().template dyn_cast<MemRefType>();
+ if (!memRefType)
+ return failure();
// Fall back to a loop if the fastest varying stride is not 1 or it is
// permuted.
int64_t offset;
SmallVector<int64_t, 4> strides;
- auto successStrides =
- getStridesAndOffset(transfer.getMemRefType(), strides, offset);
+ auto successStrides = getStridesAndOffset(memRefType, strides, offset);
if (succeeded(successStrides) && strides.back() == 1 &&
transfer.permutation_map().isMinorIdentity()) {
// If > 1D, emit a bunch of loops around 1-D vector transfers.
@@ -641,8 +644,8 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
// 1. Setup all the captures.
ScopedContext scope(rewriter, transfer.getLoc());
- StdIndexedValue remote(transfer.memref());
- MemRefBoundsCapture memRefBoundsCapture(transfer.memref());
+ StdIndexedValue remote(transfer.source());
+ MemRefBoundsCapture memRefBoundsCapture(transfer.source());
Value vectorValue(transfer.vector());
VectorBoundsCapture vectorBoundsCapture(transfer.vector());
int coalescedIdx = computeCoalescedIndex(transfer);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 9e7e7efdd136..a1797fde7da6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -111,7 +111,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
vector::TransferWriteOp transferWrite;
for (auto *sliceOp : llvm::reverse(forwardSlice)) {
auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
- if (!candidateWrite || candidateWrite.memref() != transferRead.memref())
+ if (!candidateWrite || candidateWrite.source() != transferRead.source())
continue;
transferWrite = candidateWrite;
}
@@ -142,7 +142,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
DominanceInfo dom(loop);
if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
return WalkResult::advance();
- for (auto &use : transferRead.memref().getUses()) {
+ for (auto &use : transferRead.source().getUses()) {
if (!dom.properlyDominates(loop, use.getOwner()))
continue;
if (use.getOwner() == transferRead.getOperation() ||
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2df1a9469eab..7165ee775e9c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -411,7 +411,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
// Transfer into `view`.
- Value viewOrAlloc = xferOp.memref();
+ Value viewOrAlloc = xferOp.source();
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
!viewOrAlloc.getDefiningOp<AllocOp>())
return failure();
@@ -487,7 +487,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
// Transfer into `viewOrAlloc`.
- Value viewOrAlloc = xferOp.memref();
+ Value viewOrAlloc = xferOp.source();
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
!viewOrAlloc.getDefiningOp<AllocOp>())
return failure();
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 5c1f377e589e..a3ad355d30b2 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1890,41 +1890,43 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
return success();
}
-static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
+static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
VectorType vectorType,
AffineMap permutationMap,
ArrayAttr optionalMasked) {
- auto memrefElementType = memrefType.getElementType();
- if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
- // Memref has vector element type.
-
- unsigned memrefVecSize = memrefVectorElementType.getElementTypeBitWidth() *
- memrefVectorElementType.getShape().back();
+ if (!shapedType.isa<MemRefType, RankedTensorType>())
+ return op->emitOpError(
+ "requires source to be a memref or ranked tensor type");
+ auto elementType = shapedType.getElementType();
+ if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
+ // Memref or tensor has vector element type.
+ unsigned sourceVecSize = vectorElementType.getElementTypeBitWidth() *
+ vectorElementType.getShape().back();
unsigned resultVecSize =
vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
- if (resultVecSize % memrefVecSize != 0)
+ if (resultVecSize % sourceVecSize != 0)
return op->emitOpError(
"requires the bitwidth of the minor 1-D vector to be an integral "
- "multiple of the bitwidth of the minor 1-D vector of the memref");
+ "multiple of the bitwidth of the minor 1-D vector of the source");
- unsigned memrefVecEltRank = memrefVectorElementType.getRank();
+ unsigned sourceVecEltRank = vectorElementType.getRank();
unsigned resultVecRank = vectorType.getRank();
- if (memrefVecEltRank > resultVecRank)
+ if (sourceVecEltRank > resultVecRank)
return op->emitOpError(
- "requires memref vector element and vector result ranks to match.");
- unsigned rankOffset = resultVecRank - memrefVecEltRank;
+ "requires source vector element and vector result ranks to match.");
+ unsigned rankOffset = resultVecRank - sourceVecEltRank;
// Check that permutation map results match 'rankOffset' of vector type.
if (permutationMap.getNumResults() != rankOffset)
return op->emitOpError("requires a permutation_map with result dims of "
"the same rank as the vector type");
} else {
- // Memref has scalar element type.
+ // Memref or tensor has scalar element type.
unsigned resultVecSize =
vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
- if (resultVecSize % memrefElementType.getIntOrFloatBitWidth() != 0)
+ if (resultVecSize % elementType.getIntOrFloatBitWidth() != 0)
return op->emitOpError(
"requires the bitwidth of the minor 1-D vector to be an integral "
- "multiple of the bitwidth of the memref element type");
+ "multiple of the bitwidth of the source element type");
// Check that permutation map results match rank of vector type.
if (permutationMap.getNumResults() != vectorType.getRank())
@@ -1934,9 +1936,9 @@ static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
if (permutationMap.getNumSymbols() != 0)
return op->emitOpError("requires permutation_map without symbols");
- if (permutationMap.getNumInputs() != memrefType.getRank())
+ if (permutationMap.getNumInputs() != shapedType.getRank())
return op->emitOpError("requires a permutation_map with input dims of the "
- "same rank as the memref type");
+ "same rank as the source type");
if (optionalMasked) {
if (permutationMap.getNumResults() !=
@@ -1978,7 +1980,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
SmallVector<StringRef, 2> elidedAttrs;
if (op.permutation_map() ==
- getTransferMinorIdentityMap(op.getMemRefType(), op.getVectorType()))
+ getTransferMinorIdentityMap(op.getShapedType(), op.getVectorType()))
elidedAttrs.push_back(op.getPermutationMapAttrName());
bool elideMasked = true;
if (auto maybeMasked = op.masked()) {
@@ -1995,21 +1997,21 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
}
static void print(OpAsmPrinter &p, TransferReadOp op) {
- p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
+ p << op.getOperationName() << " " << op.source() << "[" << op.indices()
<< "], " << op.padding();
printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
- p << " : " << op.getMemRefType() << ", " << op.getVectorType();
+ p << " : " << op.getShapedType() << ", " << op.getVectorType();
}
static ParseResult parseTransferReadOp(OpAsmParser &parser,
OperationState &result) {
llvm::SMLoc typesLoc;
- OpAsmParser::OperandType memrefInfo;
+ OpAsmParser::OperandType sourceInfo;
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
OpAsmParser::OperandType paddingInfo;
SmallVector<Type, 2> types;
// Parsing with support for paddingValue.
- if (parser.parseOperand(memrefInfo) ||
+ if (parser.parseOperand(sourceInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser.parseComma() || parser.parseOperand(paddingInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
@@ -2018,48 +2020,48 @@ static ParseResult parseTransferReadOp(OpAsmParser &parser,
if (types.size() != 2)
return parser.emitError(typesLoc, "requires two types");
auto indexType = parser.getBuilder().getIndexType();
- MemRefType memRefType = types[0].dyn_cast<MemRefType>();
- if (!memRefType)
- return parser.emitError(typesLoc, "requires memref type");
+ auto shapedType = types[0].dyn_cast<ShapedType>();
+ if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
+ return parser.emitError(typesLoc, "requires memref or ranked tensor type");
VectorType vectorType = types[1].dyn_cast<VectorType>();
if (!vectorType)
return parser.emitError(typesLoc, "requires vector type");
auto permutationAttrName = TransferReadOp::getPermutationMapAttrName();
auto attr = result.attributes.get(permutationAttrName);
if (!attr) {
- auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
+ auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
}
return failure(
- parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
+ parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
parser.resolveOperands(indexInfo, indexType, result.operands) ||
- parser.resolveOperand(paddingInfo, memRefType.getElementType(),
+ parser.resolveOperand(paddingInfo, shapedType.getElementType(),
result.operands) ||
parser.addTypeToList(vectorType, result.types));
}
static LogicalResult verify(TransferReadOp op) {
- // Consistency of elemental types in memref and vector.
- MemRefType memrefType = op.getMemRefType();
+ // Consistency of elemental types in source and vector.
+ ShapedType shapedType = op.getShapedType();
VectorType vectorType = op.getVectorType();
auto paddingType = op.padding().getType();
auto permutationMap = op.permutation_map();
- auto memrefElementType = memrefType.getElementType();
+ auto sourceElementType = shapedType.getElementType();
- if (static_cast<int64_t>(op.indices().size()) != memrefType.getRank())
- return op.emitOpError("requires ") << memrefType.getRank() << " indices";
+ if (static_cast<int64_t>(op.indices().size()) != shapedType.getRank())
+ return op.emitOpError("requires ") << shapedType.getRank() << " indices";
- if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
+ if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
permutationMap,
op.masked() ? *op.masked() : ArrayAttr())))
return failure();
- if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
- // Memref has vector element type.
- // Check that 'memrefVectorElementType' and 'paddingType' types match.
- if (memrefVectorElementType != paddingType)
+ if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
+ // Source has vector element type.
+ // Check that 'sourceVectorElementType' and 'paddingType' types match.
+ if (sourceVectorElementType != paddingType)
return op.emitOpError(
- "requires memref element type and padding type to match.");
+ "requires source element type and padding type to match.");
} else {
// Check that 'paddingType' is valid to store in a vector type.
@@ -2067,9 +2069,9 @@ static LogicalResult verify(TransferReadOp op) {
return op.emitOpError("requires valid padding vector elemental type");
// Check that padding type and vector element types match.
- if (paddingType != memrefElementType)
+ if (paddingType != sourceElementType)
return op.emitOpError(
- "requires formal padding and memref of the same elemental type");
+ "requires formal padding and source of the same elemental type");
}
return verifyPermutationMap(permutationMap,
@@ -2096,18 +2098,18 @@ static LogicalResult foldMemRefCast(Operation *op) {
template <typename TransferOp>
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
// TODO: support more aggressive createOrFold on:
- // `op.indices()[indicesIdx] + vectorType < dim(op.memref(), indicesIdx)`
- if (op.getMemRefType().isDynamicDim(indicesIdx))
+ // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
+ if (op.getShapedType().isDynamicDim(indicesIdx))
return false;
Value index = op.indices()[indicesIdx];
auto cstOp = index.getDefiningOp<ConstantIndexOp>();
if (!cstOp)
return false;
- int64_t memrefSize = op.getMemRefType().getDimSize(indicesIdx);
+ int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
- return cstOp.getValue() + vectorSize <= memrefSize;
+ return cstOp.getValue() + vectorSize <= sourceSize;
}
template <typename TransferOp>
@@ -2159,33 +2161,51 @@ Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
/// Builder that sets permutation map to 'getMinorIdentityMap'.
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
- Value vector, Value memref, ValueRange indices,
+ Value vector, Value source, ValueRange indices,
ArrayRef<bool> maybeMasked) {
auto vectorType = vector.getType().cast<VectorType>();
auto permMap = getTransferMinorIdentityMap(
- memref.getType().cast<MemRefType>(), vectorType);
+ source.getType().cast<MemRefType>(), vectorType);
if (maybeMasked.empty())
- return build(builder, result, vector, memref, indices, permMap,
+ return build(builder, result, vector, source, indices, permMap,
ArrayAttr());
ArrayAttr maskedArrayAttr = builder.getBoolArrayAttr(maybeMasked);
- build(builder, result, vector, memref, indices, permMap, maskedArrayAttr);
+ build(builder, result, vector, source, indices, permMap, maskedArrayAttr);
}
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
- Value vector, Value memref, ValueRange indices,
+ Value vector, Value source, ValueRange indices,
AffineMap permutationMap) {
- build(builder, result, vector, memref, indices, permutationMap,
+ build(builder, result, vector, source, indices, permutationMap,
/*maybeMasked=*/ArrayAttr());
}
+void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
+ Value vector, Value source, ValueRange indices,
+ AffineMapAttr permutationMap,
+ /*optional*/ ArrayAttr masked) {
+ Type resultType = source.getType().dyn_cast<RankedTensorType>();
+ build(builder, result, resultType, vector, source, indices, permutationMap,
+ masked);
+}
+
+void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
+ Value vector, Value source, ValueRange indices,
+ AffineMap permutationMap,
+ /*optional*/ ArrayAttr masked) {
+ Type resultType = source.getType().dyn_cast<RankedTensorType>();
+ build(builder, result, resultType, vector, source, indices, permutationMap,
+ masked);
+}
+
static ParseResult parseTransferWriteOp(OpAsmParser &parser,
OperationState &result) {
llvm::SMLoc typesLoc;
- OpAsmParser::OperandType vectorInfo, memrefInfo;
+ OpAsmParser::OperandType vectorInfo, sourceInfo;
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
SmallVector<Type, 2> types;
if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
- parser.parseOperand(memrefInfo) ||
+ parser.parseOperand(sourceInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
@@ -2196,38 +2216,40 @@ static ParseResult parseTransferWriteOp(OpAsmParser &parser,
VectorType vectorType = types[0].dyn_cast<VectorType>();
if (!vectorType)
return parser.emitError(typesLoc, "requires vector type");
- MemRefType memRefType = types[1].dyn_cast<MemRefType>();
- if (!memRefType)
- return parser.emitError(typesLoc, "requires memref type");
+ ShapedType shapedType = types[1].dyn_cast<ShapedType>();
+ if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
+ return parser.emitError(typesLoc, "requires memref or ranked tensor type");
auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName();
auto attr = result.attributes.get(permutationAttrName);
if (!attr) {
- auto permMap = getTransferMinorIdentityMap(memRefType, vectorType);
+ auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
}
return failure(
parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
- parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
- parser.resolveOperands(indexInfo, indexType, result.operands));
+ parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
+ parser.resolveOperands(indexInfo, indexType, result.operands) ||
+ (shapedType.isa<RankedTensorType>() &&
+ parser.addTypeToList(shapedType, result.types)));
}
static void print(OpAsmPrinter &p, TransferWriteOp op) {
- p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "["
+ p << op.getOperationName() << " " << op.vector() << ", " << op.source() << "["
<< op.indices() << "]";
printTransferAttrs(p, cast<VectorTransferOpInterface>(op.getOperation()));
- p << " : " << op.getVectorType() << ", " << op.getMemRefType();
+ p << " : " << op.getVectorType() << ", " << op.getShapedType();
}
static LogicalResult verify(TransferWriteOp op) {
// Consistency of elemental types in memref and vector.
- MemRefType memrefType = op.getMemRefType();
+ ShapedType shapedType = op.getShapedType();
VectorType vectorType = op.getVectorType();
auto permutationMap = op.permutation_map();
- if (llvm::size(op.indices()) != memrefType.getRank())
- return op.emitOpError("requires ") << memrefType.getRank() << " indices";
+ if (llvm::size(op.indices()) != shapedType.getRank())
+ return op.emitOpError("requires ") << shapedType.getRank() << " indices";
- if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
+ if (failed(verifyTransferOp(op.getOperation(), shapedType, vectorType,
permutationMap,
op.masked() ? *op.masked() : ArrayAttr())))
return failure();
diff --git a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
index b7de983dd3b1..ea1189d53b31 100644
--- a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
@@ -94,7 +94,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
<< "\n");
llvm::SmallVector<Operation *, 8> reads;
Operation *firstOverwriteCandidate = nullptr;
- for (auto *user : write.memref().getUsers()) {
+ for (auto *user : write.source().getUsers()) {
if (user == write.getOperation())
continue;
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
@@ -163,7 +163,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
<< "\n");
SmallVector<Operation *, 8> blockingWrites;
vector::TransferWriteOp lastwrite = nullptr;
- for (Operation *user : read.memref().getUsers()) {
+ for (Operation *user : read.source().getUsers()) {
if (isa<vector::TransferReadOp>(user))
continue;
if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 664426960beb..1e58a759d305 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -597,7 +597,7 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
Location loc = readOp.getLoc();
auto memrefElementType =
- readOp.memref().getType().cast<MemRefType>().getElementType();
+ readOp.source().getType().cast<MemRefType>().getElementType();
auto tupleType = generateExtractSlicesOpResultType(
sourceVectorType, targetShape, strides, builder);
int64_t numSlices = tupleType.size();
@@ -612,7 +612,7 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
// `masked` attribute propagates conservatively: if the coarse op didn't
// need masking, the fine op doesn't either.
vectorTupleValues[index] = builder.create<vector::TransferReadOp>(
- loc, sliceVectorType, readOp.memref(), sliceIndices,
+ loc, sliceVectorType, readOp.source(), sliceIndices,
readOp.permutation_map(), readOp.padding(),
readOp.masked() ? *readOp.masked() : ArrayAttr());
};
@@ -644,14 +644,14 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
Value tuple = builder.create<vector::ExtractSlicesOp>(
loc, tupleType, writeOp.vector(), targetShape, strides);
auto memrefElementType =
- writeOp.memref().getType().cast<MemRefType>().getElementType();
+ writeOp.source().getType().cast<MemRefType>().getElementType();
SmallVector<Value, 4> indices(writeOp.indices().begin(),
writeOp.indices().end());
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
auto element = builder.create<vector::TupleGetOp>(
loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
builder.create<vector::TransferWriteOp>(
- loc, element.getResult(), writeOp.memref(), sliceIndices,
+ loc, element.getResult(), writeOp.source(), sliceIndices,
writeOp.permutation_map(),
writeOp.masked() ? *writeOp.masked() : ArrayAttr());
};
@@ -760,7 +760,7 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
Location loc = xferWriteOp.getLoc();
auto memrefElementType =
- xferWriteOp.memref().getType().cast<MemRefType>().getElementType();
+ xferWriteOp.source().getType().cast<MemRefType>().getElementType();
SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
xferWriteOp.indices().end());
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
@@ -768,7 +768,7 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
// `masked` attribute propagates conservatively: if the coarse op didn't
// need masking, the fine op doesn't either.
rewriter.create<vector::TransferWriteOp>(
- loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices,
+ loc, tupleOp.getOperand(index), xferWriteOp.source(), sliceIndices,
xferWriteOp.permutation_map(),
xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
};
@@ -2142,7 +2142,7 @@ static Value createScopedInBoundsCond(VectorTransferOpInterface xferOp) {
// Fold or create the check that `index + vector_size` <= `memref_size`.
Value sum = xferOp.indices()[indicesIdx] + std_constant_index(vectorSize);
Value cond =
- createScopedFoldedSLE(sum, std_dim(xferOp.memref(), indicesIdx));
+ createScopedFoldedSLE(sum, std_dim(xferOp.source(), indicesIdx));
if (!cond)
return;
// Conjunction over all dims for which we are in-bounds.
@@ -2207,23 +2207,23 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
}
/// Operates under a scoped context to build the intersection between the
-/// view `xferOp.memref()` @ `xferOp.indices()` and the view `alloc`.
+/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
// TODO: view intersection/union/
diff erences should be a proper std op.
static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
Value alloc) {
using namespace edsc::intrinsics;
- int64_t memrefRank = xferOp.getMemRefType().getRank();
+ int64_t memrefRank = xferOp.getShapedType().getRank();
// TODO: relax this precondition, will require rank-reducing subviews.
assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
"Expected memref rank to match the alloc rank");
Value one = std_constant_index(1);
ValueRange leadingIndices =
- xferOp.indices().take_front(xferOp.getLeadingMemRefRank());
+ xferOp.indices().take_front(xferOp.getLeadingShapedRank());
SmallVector<Value, 4> sizes;
sizes.append(leadingIndices.begin(), leadingIndices.end());
xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
- Value dimMemRef = std_dim(xferOp.memref(), indicesIdx);
+ Value dimMemRef = std_dim(xferOp.source(), indicesIdx);
Value dimAlloc = std_dim(alloc, resultIdx);
Value index = xferOp.indices()[indicesIdx];
AffineExpr i, j, k;
@@ -2235,7 +2235,7 @@ static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp,
ValueRange{dimMemRef, index, dimAlloc});
sizes.push_back(affineMin);
});
- return std_sub_view(xferOp.memref(), xferOp.indices(), sizes,
+ return std_sub_view(xferOp.source(), xferOp.indices(), sizes,
SmallVector<Value, 4>(memrefRank, one));
}
@@ -2263,12 +2263,12 @@ static scf::IfOp createScopedFullPartialLinalgCopy(
using namespace edsc::intrinsics;
scf::IfOp fullPartialIfOp;
Value zero = std_constant_index(0);
- Value memref = xferOp.memref();
+ Value memref = xferOp.source();
conditionBuilder(
returnTypes, inBoundsCond,
[&]() -> scf::ValueVector {
Value res = memref;
- if (compatibleMemRefType != xferOp.getMemRefType())
+ if (compatibleMemRefType != xferOp.getShapedType())
res = std_memref_cast(memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
@@ -2317,12 +2317,12 @@ static scf::IfOp createScopedFullPartialVectorTransferRead(
using namespace edsc::intrinsics;
scf::IfOp fullPartialIfOp;
Value zero = std_constant_index(0);
- Value memref = xferOp.memref();
+ Value memref = xferOp.source();
conditionBuilder(
returnTypes, inBoundsCond,
[&]() -> scf::ValueVector {
Value res = memref;
- if (compatibleMemRefType != xferOp.getMemRefType())
+ if (compatibleMemRefType != xferOp.getShapedType())
res = std_memref_cast(memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
@@ -2376,7 +2376,7 @@ static scf::IfOp createScopedFullPartialVectorTransferRead(
///
/// Preconditions:
/// 1. `xferOp.permutation_map()` must be a minor identity map
-/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
+/// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()`
/// must be equal. This will be relaxed in the future but requires
/// rank-reducing subviews.
LogicalResult mlir::vector::splitFullAndPartialTransfer(
@@ -2404,8 +2404,8 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
return failure();
OpBuilder::InsertionGuard guard(b);
- if (xferOp.memref().getDefiningOp())
- b.setInsertionPointAfter(xferOp.memref().getDefiningOp());
+ if (Operation *sourceOp = xferOp.source().getDefiningOp())
+ b.setInsertionPointAfter(sourceOp);
else
b.setInsertionPoint(xferOp);
ScopedContext scope(b, xferOp.getLoc());
@@ -2426,8 +2426,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
b.getI64IntegerAttr(32));
}
- MemRefType compatibleMemRefType = getCastCompatibleMemRefType(
- xferOp.getMemRefType(), alloc.getType().cast<MemRefType>());
+ MemRefType compatibleMemRefType =
+ getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
+ alloc.getType().cast<MemRefType>());
// Read case: full fill + partial copy -> unmasked vector.xfer_read.
SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
@@ -2543,7 +2544,7 @@ struct TransferReadExtractPattern
extract.ids()[idCount++] *
std_constant_index(extract.getResultType().getDimSize(pos));
}
- Value newRead = vector_transfer_read(extract.getType(), read.memref(),
+ Value newRead = vector_transfer_read(extract.getType(), read.source(),
indices, read.permutation_map(),
read.padding(), read.maskedAttr());
Value dest = rewriter.create<ConstantOp>(
@@ -2579,7 +2580,7 @@ struct TransferWriteInsertPattern
insert.ids()[idCount++] *
std_constant_index(insert.getSourceVectorType().getDimSize(pos));
}
- vector_transfer_write(insert.vector(), write.memref(), indices,
+ vector_transfer_write(insert.vector(), write.source(), indices,
write.permutation_map(), write.maskedAttr());
rewriter.eraseOp(write);
return success();
diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index 3ab1f500f5d1..fc08d21b27a5 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -243,16 +243,16 @@ AffineMap mlir::makePermutationMap(
return ::makePermutationMap(indices, enclosingLoopToVectorDim);
}
-AffineMap mlir::getTransferMinorIdentityMap(MemRefType memRefType,
+AffineMap mlir::getTransferMinorIdentityMap(ShapedType shapedType,
VectorType vectorType) {
int64_t elementVectorRank = 0;
VectorType elementVectorType =
- memRefType.getElementType().dyn_cast<VectorType>();
+ shapedType.getElementType().dyn_cast<VectorType>();
if (elementVectorType)
elementVectorRank += elementVectorType.getRank();
return AffineMap::getMinorIdentityMap(
- memRefType.getRank(), vectorType.getRank() - elementVectorRank,
- memRefType.getContext());
+ shapedType.getRank(), vectorType.getRank() - elementVectorRank,
+ shapedType.getContext());
}
bool matcher::operatesOnSuperVectorsOf(Operation &op,
@@ -314,12 +314,12 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA,
VectorTransferOpInterface transferB) {
- if (transferA.memref() != transferB.memref())
+ if (transferA.source() != transferB.source())
return false;
// For simplicity only look at transfer of same type.
if (transferA.getVectorType() != transferB.getVectorType())
return false;
- unsigned rankOffset = transferA.getLeadingMemRefRank();
+ unsigned rankOffset = transferA.getLeadingShapedRank();
for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
auto indexA = transferA.indices()[i].getDefiningOp<ConstantOp>();
auto indexB = transferB.indices()[i].getDefiningOp<ConstantOp>();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 73b1f9e1e06e..62eaa4e3a14e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -269,7 +269,7 @@ func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>
- // expected-error at +1 {{ requires memref type}}
+ // expected-error at +1 {{ requires memref or ranked tensor type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : vector<4x3xf32>, vector<1x1x2x3xf32>
}
@@ -297,7 +297,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
%c3 = constant 3 : index
%cst = constant 3.0 : f32
- // expected-error at +1 {{requires a permutation_map with input dims of the same rank as the memref type}}
+ // expected-error at +1 {{requires a permutation_map with input dims of the same rank as the source type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0)->(d0)>} : memref<?x?xf32>, vector<128xf32>
}
@@ -343,7 +343,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>
- // expected-error at +1 {{requires memref vector element and vector result ranks to match}}
+ // expected-error at +1 {{requires source vector element and vector result ranks to match}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<3xf32>
}
@@ -353,7 +353,7 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<6xf32>
- // expected-error at +1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the memref}}
+ // expected-error at +1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the source}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<6xf32>>, vector<3xf32>
}
@@ -392,7 +392,7 @@ func @test_vector.transfer_write(%arg0: vector<4x3xf32>) {
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>
- // expected-error at +1 {{ requires memref type}}
+ // expected-error at +1 {{ requires memref or ranked tensor type}}
vector.transfer_write %arg0, %arg0[%c3, %c3] : vector<4x3xf32>, f32
}
@@ -419,7 +419,7 @@ func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
%c3 = constant 3 : index
%cst = constant dense<3.0> : vector<128 x f32>
- // expected-error at +1 {{requires a permutation_map with input dims of the same rank as the memref type}}
+ // expected-error at +1 {{requires a permutation_map with input dims of the same rank as the source type}}
vector.transfer_write %cst, %arg0[%c3, %c3] {permutation_map = affine_map<(d0)->(d0)>} : vector<128xf32>, memref<?x?xf32>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index aab6cabf759d..07e9d8de3f49 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -43,6 +43,54 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
return
}
+
+// CHECK-LABEL: func @vector_transfer_ops_tensor(
+func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
+ %arg1 : tensor<?x?xvector<4x3xf32>>,
+ %arg2 : tensor<?x?xvector<4x3xi32>>) ->
+ (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xvector<4x3xf32>>,
+ tensor<?x?xvector<4x3xf32>>, tensor<?x?xvector<4x3xi32>>){
+ // CHECK: %[[C3:.*]] = constant 3 : index
+ %c3 = constant 3 : index
+ %cst = constant 3.0 : f32
+ %f0 = constant 0.0 : f32
+ %c0 = constant 0 : i32
+ %vf0 = splat %f0 : vector<4x3xf32>
+ %v0 = splat %c0 : vector<4x3xi32>
+
+ //
+ // CHECK: vector.transfer_read
+ %0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : tensor<?x?xf32>, vector<128xf32>
+ // CHECK: vector.transfer_read
+ %1 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : tensor<?x?xf32>, vector<3x7xf32>
+ // CHECK: vector.transfer_read
+ %2 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d0)>} : tensor<?x?xf32>, vector<128xf32>
+ // CHECK: vector.transfer_read
+ %3 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d1)>} : tensor<?x?xf32>, vector<128xf32>
+ // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+ %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+ // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+ %5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : tensor<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+ // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : tensor<?x?xvector<4x3xi32>>, vector<5x24xi8>
+ %6 = vector.transfer_read %arg2[%c3, %c3], %v0 : tensor<?x?xvector<4x3xi32>>, vector<5x24xi8>
+
+
+ // CHECK: vector.transfer_write
+ %7 = vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, tensor<?x?xf32>
+ // CHECK: vector.transfer_write
+ %8 = vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, tensor<?x?xf32>
+ // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
+ %9 = vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
+ // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
+ %10 = vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, tensor<?x?xvector<4x3xf32>>
+ // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, tensor<?x?xvector<4x3xi32>>
+ %11 = vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, tensor<?x?xvector<4x3xi32>>
+
+ return %7, %8, %9, %10, %11 :
+ tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xvector<4x3xf32>>,
+ tensor<?x?xvector<4x3xf32>>, tensor<?x?xvector<4x3xi32>>
+}
+
// CHECK-LABEL: @vector_broadcast
func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> {
// CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
More information about the Mlir-commits
mailing list