[Mlir-commits] [mlir] f9190c8 - [mlir][vector] Support unrolling for transfer ops using tensors
Thomas Raoux
llvmlistbot at llvm.org
Wed Jan 6 13:53:49 PST 2021
Author: Thomas Raoux
Date: 2021-01-06T13:28:04-08:00
New Revision: f9190c868137dcf43833db2c8e1e00c7acca67bc
URL: https://github.com/llvm/llvm-project/commit/f9190c868137dcf43833db2c8e1e00c7acca67bc
DIFF: https://github.com/llvm/llvm-project/commit/f9190c868137dcf43833db2c8e1e00c7acca67bc.diff
LOG: [mlir][vector] Support unrolling for transfer ops using tensors
Differential Revision: https://reviews.llvm.org/D93904
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
mlir/test/Dialect/Vector/vector-transforms.mlir
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index c88aa7f5bc65..a258903d5a3a 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -71,7 +71,8 @@ SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
/// Unroll a transfer_write op. Break up the vector source into a tuple of
/// vectors matching the given shape. Then store each element with its own
-/// transfer_write.
+/// transfer_write. If the transfer_write takes a tensor source, return the
+/// unrolled Value in result.
///
/// Example:
/// vector.transfer_write %A, %M[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
@@ -83,7 +84,8 @@ SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
/// %2 = vector.tuple_get %0, 1 : tuple<vector<2x4xf32>, vector<2x4xf32>>
/// vector.transfer_write %2, %M[%c2, %c0] : vector<2x4xf32>, memref<4x4xf32>
LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op,
- ArrayRef<int64_t> targetShape);
+ ArrayRef<int64_t> targetShape,
+ SmallVector<Value, 1> &result);
/// Options that control the vector unrolling.
struct UnrollVectorOptions {
@@ -143,9 +145,10 @@ struct UnrollVectorPattern : public RewritePattern {
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
return failure();
if (isa<TransferWriteOp>(op)) {
- if (failed(unrollTransferWriteOp(rewriter, op, *targetShape)))
+ SmallVector<Value, 1> result;
+ if (failed(unrollTransferWriteOp(rewriter, op, *targetShape, result)))
return failure();
- rewriter.eraseOp(op);
+ rewriter.replaceOp(op, result);
return success();
}
if (op->getNumResults() != 1)
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index f1708db113d4..ca6e92d95ed0 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -515,7 +515,7 @@ static void getVectorElementwiseOpUnrollState(Operation *op,
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
/// calls 'fn' with linear index and indices for each slice.
static void generateTransferOpSlices(
- Type memrefElementType, VectorType vectorType, TupleType tupleType,
+ Type shapedElementType, VectorType vectorType, TupleType tupleType,
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<Value> indices,
OpBuilder &builder, function_ref<void(unsigned, ArrayRef<Value>)> fn) {
// Compute strides w.r.t. to slice counts in each dimension.
@@ -539,9 +539,9 @@ static void generateTransferOpSlices(
// vector rank is 4 - 2 = 2, and so 'indexOffset' = 3 - 2 = 1.
//
unsigned vectorRank = vectorType.getRank();
- if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
- assert(vectorRank >= memrefVectorElementType.getRank());
- vectorRank -= memrefVectorElementType.getRank();
+ if (auto sourceVectorElementType = shapedElementType.dyn_cast<VectorType>()) {
+ assert(vectorRank >= sourceVectorElementType.getRank());
+ vectorRank -= sourceVectorElementType.getRank();
}
unsigned indexOffset = numSliceIndices - vectorRank;
@@ -598,8 +598,8 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
SmallVector<int64_t, 4> strides(targetShape.size(), 1);
Location loc = readOp.getLoc();
- auto memrefElementType =
- readOp.source().getType().cast<MemRefType>().getElementType();
+ auto shapedElementType =
+ readOp.source().getType().cast<ShapedType>().getElementType();
auto tupleType = generateExtractSlicesOpResultType(
sourceVectorType, targetShape, strides, builder);
int64_t numSlices = tupleType.size();
@@ -618,7 +618,7 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
readOp.permutation_map(), readOp.padding(),
readOp.masked() ? *readOp.masked() : ArrayAttr());
};
- generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
+ generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
targetShape, strides, indices, builder, createSlice);
// Create tuple of splice transfer read operations.
@@ -634,7 +634,8 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
// Entry point for unrolling declarative pattern rewrite for transfer_write op.
LogicalResult
mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
- ArrayRef<int64_t> targetShape) {
+ ArrayRef<int64_t> targetShape,
+ SmallVector<Value, 1> &result) {
auto writeOp = cast<vector::TransferWriteOp>(op);
if (!isIdentitySuffix(writeOp.permutation_map()))
return failure();
@@ -645,20 +646,28 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
Location loc = writeOp.getLoc();
Value tuple = builder.create<vector::ExtractSlicesOp>(
loc, tupleType, writeOp.vector(), targetShape, strides);
- auto memrefElementType =
- writeOp.source().getType().cast<MemRefType>().getElementType();
+ auto shapedElementType =
+ writeOp.source().getType().cast<ShapedType>().getElementType();
SmallVector<Value, 4> indices(writeOp.indices().begin(),
writeOp.indices().end());
+ // If the TransferWrite returns a tensor, keep track of the last tensor
+ // created.
+ Value resultTensor;
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
auto element = builder.create<vector::TupleGetOp>(
loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
- builder.create<vector::TransferWriteOp>(
- loc, element.getResult(), writeOp.source(), sliceIndices,
+ Operation *write = builder.create<vector::TransferWriteOp>(
+ loc, element.getResult(),
+ resultTensor ? resultTensor : writeOp.source(), sliceIndices,
writeOp.permutation_map(),
writeOp.masked() ? *writeOp.masked() : ArrayAttr());
+ if (!write->getResults().empty())
+ resultTensor = write->getResult(0);
};
- generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
+ generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
targetShape, strides, indices, builder, createSlice);
+ if (resultTensor)
+ result.push_back(resultTensor);
return success();
}
@@ -761,25 +770,32 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
insertSlicesOp.getStrides(strides);
Location loc = xferWriteOp.getLoc();
- auto memrefElementType =
- xferWriteOp.source().getType().cast<MemRefType>().getElementType();
+ auto shapedElementType =
+ xferWriteOp.source().getType().cast<ShapedType>().getElementType();
SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
xferWriteOp.indices().end());
+ Value resultTensor;
auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
// Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
// `masked` attribute propagates conservatively: if the coarse op didn't
// need masking, the fine op doesn't either.
- rewriter.create<vector::TransferWriteOp>(
- loc, tupleOp.getOperand(index), xferWriteOp.source(), sliceIndices,
+ Operation *write = rewriter.create<vector::TransferWriteOp>(
+ loc, tupleOp.getOperand(index),
+ resultTensor ? resultTensor : xferWriteOp.source(), sliceIndices,
xferWriteOp.permutation_map(),
xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
+ if (!write->getResults().empty())
+ resultTensor = write->getResult(0);
};
- generateTransferOpSlices(memrefElementType, resultVectorType,
+ generateTransferOpSlices(shapedElementType, resultVectorType,
sourceTupleType, sizes, strides, indices, rewriter,
createSlice);
// Erase old 'xferWriteOp'.
- rewriter.eraseOp(xferWriteOp);
+ if (resultTensor)
+ rewriter.replaceOp(xferWriteOp, ArrayRef<Value>(resultTensor));
+ else
+ rewriter.eraseOp(xferWriteOp);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index b676700dae06..d5e9535acb8e 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -58,3 +58,65 @@ func @transfer_readwrite_unroll(%arg0 : memref<4x4xf32>) {
vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
return
}
+
+// CHECK-LABEL: func @transfer_read_unroll_tensor
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C2:.*]] = constant 2 : index
+// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
+// CHECK-NEXT: return %[[VEC]] : vector<4x4xf32>
+
+func @transfer_read_unroll_tensor(%arg0 : tensor<4x4xf32>) -> vector<4x4xf32> {
+ %c0 = constant 0 : index
+ %cf0 = constant 0.0 : f32
+ %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : tensor<4x4xf32>, vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @transfer_write_unroll_tensor
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C2:.*]] = constant 2 : index
+// CHECK: %[[TUPL:.*]] = vector.extract_slices {{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+// CHECK-NEXT: %[[T0:.*]] = vector.tuple_get %[[TUPL]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[T0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: %[[T1:.*]] = vector.tuple_get %[[TUPL]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[T1]], %[[VTW0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: %[[T2:.*]] = vector.tuple_get %[[TUPL]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[T2]], %[[VTW1]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: %[[T3:.*]] = vector.tuple_get %[[TUPL]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[T3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32>
+
+func @transfer_write_unroll_tensor(%arg0 : tensor<4x4xf32>,
+ %arg1 : vector<4x4xf32>) -> tensor<4x4xf32> {
+ %c0 = constant 0 : index
+ %r = vector.transfer_write %arg1, %arg0[%c0, %c0] :
+ vector<4x4xf32>, tensor<4x4xf32>
+ return %r: tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func @transfer_readwrite_unroll_tensor
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C2:.*]] = constant 2 : index
+// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[VTR0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[VTR1]], %[[VTW0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[VTR2]], %[[VTW1]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[VTR3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32>
+
+func @transfer_readwrite_unroll_tensor(%arg0 : tensor<4x4xf32>) ->
+ tensor<4x4xf32> {
+ %c0 = constant 0 : index
+ %cf0 = constant 0.0 : f32
+ %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : tensor<4x4xf32>, vector<4x4xf32>
+ %r = vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32>
+ return %r: tensor<4x4xf32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 43a83f04dd30..4a58261c9672 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -530,6 +530,14 @@ func @shape_cast_fold(%arg0 : vector<5x4x2xf32>, %arg1 : vector<3x4x2xf32>)
// CHECK: %[[CMP1:.*]] = cmpf "ult", %[[VT1]], %[[VT5]] : vector<2x2xf32>
// CHECK: %[[CMP2:.*]] = cmpf "ult", %[[VT2]], %[[VT6]] : vector<2x2xf32>
// CHECK: %[[CMP3:.*]] = cmpf "ult", %[[VT3]], %[[VT7]] : vector<2x2xf32>
+// CHECK: %[[VT0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK: %[[VT1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK: %[[VT2:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK: %[[VT3:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK: %[[VT4:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK: %[[VT5:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK: %[[VT6:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK: %[[VT7:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK: %[[SEL0:.*]] = select %[[CMP0]], %[[VT0]], %[[VT4]] : vector<2x2xi1>, vector<2x2xf32>
// CHECK: %[[SEL1:.*]] = select %[[CMP1]], %[[VT1]], %[[VT5]] : vector<2x2xi1>, vector<2x2xf32>
// CHECK: %[[SEL2:.*]] = select %[[CMP2]], %[[VT2]], %[[VT6]] : vector<2x2xi1>, vector<2x2xf32>
@@ -544,7 +552,52 @@ func @elementwise_unroll(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) {
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
%1 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
%cond = cmpf "ult", %0, %1 : vector<4x4xf32>
- %2 = select %cond, %0, %1 : vector<4x4xi1>, vector<4x4xf32>
- vector.transfer_write %2, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
+ // Vector transfer split pattern only support single user right now.
+ %2 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
+ %3 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
+ %4 = select %cond, %2, %3 : vector<4x4xi1>, vector<4x4xf32>
+ vector.transfer_write %4, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
return
}
+
+// Check that vector.transfer read/write are split based on contract unrolling.
+// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32>
+
+// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32>
+
+// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+
+// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[R1]], %[[VTW0]][%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[R2]], %[[VTW1]][%[[C2]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[R3]], %[[VTW2]][%[[C2]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32>
+
+func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
+ %arg1 : tensor<2x4xf32>,
+ %arg2 : tensor<4x4xf32>) ->
+ tensor<4x4xf32> {
+ %c0 = constant 0 : index
+ %cf0 = constant 0.0 : f32
+ %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 :
+ tensor<4x2xf32>, vector<4x2xf32>
+ %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 :
+ tensor<2x4xf32>, vector<2x4xf32>
+ %2 = vector.transfer_read %arg2[%c0, %c0], %cf0 :
+ tensor<4x4xf32>, vector<4x4xf32>
+ %3 = vector.contract #contraction_trait1 %0, %1, %2
+ : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
+ %r = vector.transfer_write %3, %arg2[%c0, %c0]
+ : vector<4x4xf32>, tensor<4x4xf32>
+ return %r : tensor<4x4xf32>
+}
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index f219ef04fce5..572cd1cd68f1 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -28,7 +28,9 @@ struct TestVectorToVectorConversion
OwningRewritePatternList patterns;
auto *ctx = &getContext();
patterns.insert<UnrollVectorPattern>(
- ctx, UnrollVectorOptions().setNativeShapeFn(getShape));
+ ctx,
+ UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
+ filter));
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
@@ -39,13 +41,14 @@ struct TestVectorToVectorConversion
static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
if (isa<AddFOp, SelectOp, CmpFOp>(op))
return SmallVector<int64_t, 4>(2, 2);
- if (auto transferOp = dyn_cast<VectorTransferOpInterface>(op)) {
- return SmallVector<int64_t, 4>(transferOp.getVectorType().getRank(), 2);
- }
if (isa<vector::ContractionOp>(op))
return SmallVector<int64_t, 4>(3, 2);
return llvm::None;
}
+
+ static LogicalResult filter(Operation *op) {
+ return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp>(op));
+ }
};
struct TestVectorSlicesConversion
More information about the Mlir-commits
mailing list