[Mlir-commits] [mlir] edbdea7 - [mlir][vector] Add unrolling patterns for Transfer read/write
Thomas Raoux
llvmlistbot at llvm.org
Thu Oct 15 15:18:37 PDT 2020
Author: Thomas Raoux
Date: 2020-10-15T15:17:36-07:00
New Revision: edbdea7466d25c5e4d9f73e3043ac87efe433193
URL: https://github.com/llvm/llvm-project/commit/edbdea7466d25c5e4d9f73e3043ac87efe433193
DIFF: https://github.com/llvm/llvm-project/commit/edbdea7466d25c5e4d9f73e3043ac87efe433193.diff
LOG: [mlir][vector] Add unrolling patterns for Transfer read/write
Adding unroll support for transfer read and transfer write operation. This
allows to pick the ideal size for the memory access for a given target.
Differential Revision: https://reviews.llvm.org/D89289
Added:
mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
Modified:
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index da9650c67efb..157084a2bff1 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -69,6 +69,22 @@ SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
Operation *op,
ArrayRef<int64_t> targetShape);
+/// Unroll a transfer_write op. Break up the vector source into a tuple of
+/// vectors matching the given shape. Then store each element with its own
+/// transfer_write.
+///
+/// Example:
+/// vector.transfer_write %A, %M[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
+/// ->
+/// %0 = vector.extract_slices %A, [2, 4], [1, 1] :
+/// vector<4x4xf32> into tuple<vector<2x4xf32>, vector<2x4xf32>>
+/// %1 = vector.tuple_get %0, 0 : tuple<vector<2x4xf32>, vector<2x4xf32>>
+/// vector.transfer_write %1, %M[%c0, %c0] : vector<2x4xf32>, memref<4x4xf32>
+/// %2 = vector.tuple_get %0, 1 : tuple<vector<2x4xf32>, vector<2x4xf32>>
+/// vector.transfer_write %2, %M[%c2, %c0] : vector<2x4xf32>, memref<4x4xf32>
+LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op,
+ ArrayRef<int64_t> targetShape);
+
/// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
/// declaratively.
template <typename OpTy>
@@ -95,6 +111,12 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
if (!maybeShapeRatio ||
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
return failure();
+ if (std::is_same<OpTy, TransferWriteOp>::value) {
+ if (failed(unrollTransferWriteOp(rewriter, op, targetShape)))
+ return failure();
+ rewriter.eraseOp(op);
+ return success();
+ }
if (op.getOperation()->getNumResults() != 1)
return failure();
auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape);
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 08ee5c64af09..5f34b6caf65a 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -511,35 +511,6 @@ static void getVectorElementwiseOpUnrollState(Operation *op,
resultIndex = numVectors - 1;
}
-// Entry point for unrolling declarative pattern rewrites.
-SmallVector<Value, 1>
-mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
- ArrayRef<int64_t> targetShape) {
- assert(op->getNumResults() == 1 && "Expected single result operation");
-
- // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
- SmallVector<int64_t, 6> iterationBounds;
- auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
- auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
- assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
-
- std::vector<VectorState> vectors;
- unsigned resultIndex;
-
- if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
- // Populate state for vector ContractionOp.
- getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
- resultIndex);
- } else {
- // Populate state for vector elementwise op.
- getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
- }
-
- // Unroll 'op' with 'iterationBounds' to 'targetShape'.
- return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
- op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
-}
-
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
/// calls 'fn' with linear index and indices for each slice.
static void generateTransferOpSlices(
@@ -615,6 +586,114 @@ static bool isIdentitySuffix(AffineMap map) {
return true;
}
+/// Unroll transfer_read ops to the given shape and create an aggregate with all
+/// the chunks.
+static Value unrollTransferReadOp(vector::TransferReadOp readOp,
+ ArrayRef<int64_t> targetShape,
+ OpBuilder &builder) {
+ if (!isIdentitySuffix(readOp.permutation_map()))
+ return nullptr;
+ auto sourceVectorType = readOp.getVectorType();
+ SmallVector<int64_t, 4> strides(targetShape.size(), 1);
+
+ Location loc = readOp.getLoc();
+ auto memrefElementType =
+ readOp.memref().getType().cast<MemRefType>().getElementType();
+ auto tupleType = generateExtractSlicesOpResultType(
+ sourceVectorType, targetShape, strides, builder);
+ int64_t numSlices = tupleType.size();
+
+ SmallVector<Value, 4> vectorTupleValues(numSlices);
+ SmallVector<Value, 4> indices(readOp.indices().begin(),
+ readOp.indices().end());
+ auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
+ // Get VectorType for slice 'i'.
+ auto sliceVectorType = tupleType.getType(index);
+ // Create split TransferReadOp for 'sliceUser'.
+ // `masked` attribute propagates conservatively: if the coarse op didn't
+ // need masking, the fine op doesn't either.
+ vectorTupleValues[index] = builder.create<vector::TransferReadOp>(
+ loc, sliceVectorType, readOp.memref(), sliceIndices,
+ readOp.permutation_map(), readOp.padding(),
+ readOp.masked() ? *readOp.masked() : ArrayAttr());
+ };
+ generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
+ targetShape, strides, indices, builder, createSlice);
+
+ // Create tuple of splice transfer read operations.
+ Value tupleOp =
+ builder.create<vector::TupleOp>(loc, tupleType, vectorTupleValues);
+ // Replace 'readOp' with result 'insertSlicesResult'.
+ Value newVec = builder.create<vector::InsertSlicesOp>(
+ loc, sourceVectorType, tupleOp, builder.getI64ArrayAttr(targetShape),
+ builder.getI64ArrayAttr(strides));
+ return newVec;
+}
+
+// Entry point for unrolling declarative pattern rewrite for transfer_write op.
+LogicalResult
+mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
+ ArrayRef<int64_t> targetShape) {
+ auto writeOp = cast<vector::TransferWriteOp>(op);
+ if (!isIdentitySuffix(writeOp.permutation_map()))
+ return failure();
+ VectorType sourceVectorType = writeOp.getVectorType();
+ SmallVector<int64_t, 4> strides(targetShape.size(), 1);
+ TupleType tupleType = generateExtractSlicesOpResultType(
+ sourceVectorType, targetShape, strides, builder);
+ Location loc = writeOp.getLoc();
+ Value tuple = builder.create<vector::ExtractSlicesOp>(
+ loc, tupleType, writeOp.vector(), targetShape, strides);
+ auto memrefElementType =
+ writeOp.memref().getType().cast<MemRefType>().getElementType();
+ SmallVector<Value, 4> indices(writeOp.indices().begin(),
+ writeOp.indices().end());
+ auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
+ auto element = builder.create<vector::TupleGetOp>(
+ loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
+ builder.create<vector::TransferWriteOp>(
+ loc, element.getResult(), writeOp.memref(), sliceIndices,
+ writeOp.permutation_map(),
+ writeOp.masked() ? *writeOp.masked() : ArrayAttr());
+ };
+ generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
+ targetShape, strides, indices, builder, createSlice);
+ return success();
+}
+
+// Entry point for unrolling declarative pattern rewrites.
+SmallVector<Value, 1>
+mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
+ ArrayRef<int64_t> targetShape) {
+ assert(op->getNumResults() == 1 && "Expected single result operation");
+
+ // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
+ SmallVector<int64_t, 6> iterationBounds;
+ auto unrollableVectorOp = cast<VectorUnrollOpInterface>(op);
+ auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
+ assert(maybeUnrollShape && "Trying to unroll an incorrect vector op");
+
+ std::vector<VectorState> vectors;
+ unsigned resultIndex;
+
+ if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
+ return SmallVector<Value, 1>{
+ unrollTransferReadOp(readOp, targetShape, builder)};
+
+ if (auto contractionOp = dyn_cast<vector::ContractionOp>(op)) {
+ // Populate state for vector ContractionOp.
+ getVectorContractionOpUnrollState(contractionOp, targetShape, vectors,
+ resultIndex);
+ } else {
+ // Populate state for vector elementwise op.
+ getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex);
+ }
+
+ // Unroll 'op' with 'iterationBounds' to 'targetShape'.
+ return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
+ op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)};
+}
+
namespace {
// Splits vector TransferReadOp into smaller TransferReadOps based on slicing
@@ -636,43 +715,16 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
return failure();
// Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
- auto sourceVectorType = extractSlicesOp.getSourceVectorType();
- auto resultTupleType = extractSlicesOp.getResultTupleType();
SmallVector<int64_t, 4> sizes;
extractSlicesOp.getSizes(sizes);
SmallVector<int64_t, 4> strides;
extractSlicesOp.getStrides(strides);
assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
- Location loc = xferReadOp.getLoc();
- auto memrefElementType =
- xferReadOp.memref().getType().cast<MemRefType>().getElementType();
- int64_t numSlices = resultTupleType.size();
- SmallVector<Value, 4> vectorTupleValues(numSlices);
- SmallVector<Value, 4> indices(xferReadOp.indices().begin(),
- xferReadOp.indices().end());
- auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
- // Get VectorType for slice 'i'.
- auto sliceVectorType = resultTupleType.getType(index);
- // Create split TransferReadOp for 'sliceUser'.
- // `masked` attribute propagates conservatively: if the coarse op didn't
- // need masking, the fine op doesn't either.
- vectorTupleValues[index] = rewriter.create<vector::TransferReadOp>(
- loc, sliceVectorType, xferReadOp.memref(), sliceIndices,
- xferReadOp.permutation_map(), xferReadOp.padding(),
- xferReadOp.masked() ? *xferReadOp.masked() : ArrayAttr());
- };
- generateTransferOpSlices(memrefElementType, sourceVectorType,
- resultTupleType, sizes, strides, indices, rewriter,
- createSlice);
-
- // Create tuple of splice xfer read operations.
- Value tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType,
- vectorTupleValues);
- // Replace 'xferReadOp' with result 'insertSlicesResult'.
- rewriter.replaceOpWithNewOp<vector::InsertSlicesOp>(
- xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(),
- extractSlicesOp.strides());
+ Value newVec = unrollTransferReadOp(xferReadOp, sizes, rewriter);
+ if (!newVec)
+ return failure();
+ rewriter.replaceOp(xferReadOp, newVec);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
new file mode 100644
index 000000000000..b676700dae06
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns | FileCheck %s
+
+// CHECK-LABEL: func @transfer_read_unroll
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C2:.*]] = constant 2 : index
+// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
+// CHECK-NEXT: return %[[VEC]] : vector<4x4xf32>
+
+func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> {
+ %c0 = constant 0 : index
+ %cf0 = constant 0.0 : f32
+ %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @transfer_write_unroll
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C2:.*]] = constant 2 : index
+// CHECK: %[[TUPL:.*]] = vector.extract_slices {{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+// CHECK-NEXT: %[[T0:.*]] = vector.tuple_get %[[TUPL]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+// CHECK-NEXT: vector.transfer_write %[[T0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+// CHECK-NEXT: %[[T1:.*]] = vector.tuple_get %[[TUPL]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+// CHECK-NEXT: vector.transfer_write %[[T1]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+// CHECK-NEXT: %[[T2:.*]] = vector.tuple_get %[[TUPL]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+// CHECK-NEXT: vector.transfer_write %[[T2]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+// CHECK-NEXT: %[[T3:.*]] = vector.tuple_get %[[TUPL]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+// CHECK-NEXT: vector.transfer_write %[[T3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+// CHECK-NEXT: return
+
+func @transfer_write_unroll(%arg0 : memref<4x4xf32>, %arg1 : vector<4x4xf32>) {
+ %c0 = constant 0 : index
+ vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
+ return
+}
+
+// CHECK-LABEL: func @transfer_readwrite_unroll
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C2:.*]] = constant 2 : index
+// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: vector.transfer_write %[[VTR0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+// CHECK-NEXT: vector.transfer_write %[[VTR1]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+// CHECK-NEXT: vector.transfer_write %[[VTR2]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+// CHECK-NEXT: vector.transfer_write %[[VTR3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32>
+// CHECK-NEXT: return
+
+func @transfer_readwrite_unroll(%arg0 : memref<4x4xf32>) {
+ %c0 = constant 0 : index
+ %cf0 = constant 0.0 : f32
+ %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
+ vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
+ return
+}
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index fe0947d0ac30..52d0f7b2bb5e 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -156,6 +156,24 @@ struct TestVectorDistributePatterns
}
};
+struct TestVectorTransferUnrollingPatterns
+ : public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<AffineDialect>();
+ }
+ void runOnFunction() override {
+ MLIRContext *ctx = &getContext();
+ OwningRewritePatternList patterns;
+ patterns.insert<UnrollVectorPattern<vector::TransferReadOp>>(
+ ArrayRef<int64_t>{2, 2}, ctx);
+ patterns.insert<UnrollVectorPattern<vector::TransferWriteOp>>(
+ ArrayRef<int64_t>{2, 2}, ctx);
+ populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
+ populateVectorToVectorTransformationPatterns(patterns, ctx);
+ applyPatternsAndFoldGreedily(getFunction(), patterns);
+ }
+};
+
struct TestVectorTransferFullPartialSplitPatterns
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
FunctionPass> {
@@ -205,6 +223,10 @@ void registerTestVectorConversions() {
"test-vector-unrolling-patterns",
"Test conversion patterns to unroll contract ops in the vector dialect");
+ PassRegistration<TestVectorTransferUnrollingPatterns> transferOpUnrollingPass(
+ "test-vector-transfer-unrolling-patterns",
+ "Test conversion patterns to unroll transfer ops in the vector dialect");
+
PassRegistration<TestVectorTransferFullPartialSplitPatterns>
vectorTransformFullPartialPass("test-vector-transfer-full-partial-split",
"Test conversion patterns to split "
More information about the Mlir-commits
mailing list