[Mlir-commits] [mlir] fd2b089 - [mlir][Vector] Lowering of transfer_read/write to vector.load/store
Sergei Grechanik
llvmlistbot at llvm.org
Thu Mar 11 18:19:47 PST 2021
Author: Sergei Grechanik
Date: 2021-03-11T18:17:51-08:00
New Revision: fd2b08969b8a42945b3f79a027fb80582ff42411
URL: https://github.com/llvm/llvm-project/commit/fd2b08969b8a42945b3f79a027fb80582ff42411
DIFF: https://github.com/llvm/llvm-project/commit/fd2b08969b8a42945b3f79a027fb80582ff42411.diff
LOG: [mlir][Vector] Lowering of transfer_read/write to vector.load/store
This patch introduces progressive lowering patterns for rewriting
vector.transfer_read/write to vector.load/store and vector.broadcast
in certain supported cases.
Reviewed By: dcaballe, nicolasvasilache
Differential Revision: https://reviews.llvm.org/D97822
Added:
mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/IR/AffineMap.h
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/IR/AffineMap.cpp
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index ee7ed62dcf01..9e486d038a48 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -85,6 +85,13 @@ void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns,
void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
MLIRContext *context);
+/// Collect a set of transfer read/write lowering patterns.
+///
+/// These patterns lower transfer ops to simpler ops like `vector.load`,
+/// `vector.store` and `vector.broadcast`.
+void populateVectorTransferLoweringPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *context);
+
/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
class CombiningKindAttr
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 86480529aa05..e837fc070ab0 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -104,6 +104,15 @@ class AffineMap {
/// affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
bool isMinorIdentity() const;
+ /// Returns true if this affine map is a minor identity up to broadcasted
+ /// dimensions which are indicated by value 0 in the result. If
+ /// `broadcastedDims` is not null, it will be populated with the indices of
+ /// the broadcasted dimensions in the result array.
+ /// Example: affine_map<(d0, d1, d2, d3, d4) -> (0, d2, 0, d4)>
+ /// (`broadcastedDims` will contain [0, 2])
+ bool isMinorIdentityWithBroadcasting(
+ SmallVectorImpl<unsigned> *broadcastedDims = nullptr) const;
+
/// Returns true if this affine map is an empty map, i.e., () -> ().
bool isEmpty() const;
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 200eb55076f7..090afda01fe4 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -37,6 +37,7 @@
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/VectorInterfaces.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
@@ -2729,6 +2730,116 @@ struct TransferWriteInsertPattern
}
};
+/// Progressive lowering of transfer_read. This pattern supports lowering of
+/// `vector.transfer_read` to a combination of `vector.load` and
+/// `vector.broadcast` if all of the following hold:
+/// - The op reads from a memref with the default layout.
+/// - Masking is not required.
+/// - If the memref's element type is a vector type then it coincides with the
+/// result type.
+/// - The permutation map doesn't perform permutation (broadcasting is allowed).
+struct TransferReadToVectorLoadLowering
+ : public OpRewritePattern<vector::TransferReadOp> {
+ TransferReadToVectorLoadLowering(MLIRContext *context)
+ : OpRewritePattern<vector::TransferReadOp>(context) {}
+ LogicalResult matchAndRewrite(vector::TransferReadOp read,
+ PatternRewriter &rewriter) const override {
+ SmallVector<unsigned, 4> broadcastedDims;
+ // TODO: Support permutations.
+ if (!read.permutation_map().isMinorIdentityWithBroadcasting(
+ &broadcastedDims))
+ return failure();
+ auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
+ if (!memRefType)
+ return failure();
+
+ // If there is broadcasting involved then we first load the unbroadcasted
+ // vector, and then broadcast it with `vector.broadcast`.
+ ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
+ SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
+ vectorShape.end());
+ for (unsigned i : broadcastedDims)
+ unbroadcastedVectorShape[i] = 1;
+ VectorType unbroadcastedVectorType = VectorType::get(
+ unbroadcastedVectorShape, read.getVectorType().getElementType());
+
+ // `vector.load` supports vector types as memref's elements only when the
+ // resulting vector type is the same as the element type.
+ if (memRefType.getElementType().isa<VectorType>() &&
+ memRefType.getElementType() != unbroadcastedVectorType)
+ return failure();
+ // Only the default layout is supported by `vector.load`.
+ // TODO: Support non-default layouts.
+ if (!memRefType.getAffineMaps().empty())
+ return failure();
+ // TODO: When masking is required, we can create a MaskedLoadOp
+ if (read.hasMaskedDim())
+ return failure();
+
+ Operation *loadOp;
+ if (!broadcastedDims.empty() &&
+ unbroadcastedVectorType.getNumElements() == 1) {
+ // If broadcasting is required and the number of loaded elements is 1 then
+ // we can create `std.load` instead of `vector.load`.
+ loadOp = rewriter.create<mlir::LoadOp>(read.getLoc(), read.source(),
+ read.indices());
+ } else {
+ // Otherwise create `vector.load`.
+ loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
+ unbroadcastedVectorType,
+ read.source(), read.indices());
+ }
+
+ // Insert a broadcasting op if required.
+ if (!broadcastedDims.empty()) {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ read, read.getVectorType(), loadOp->getResult(0));
+ } else {
+ rewriter.replaceOp(read, loadOp->getResult(0));
+ }
+
+ return success();
+ }
+};
+
+/// Progressive lowering of transfer_write. This pattern supports lowering of
+/// `vector.transfer_write` to `vector.store` if all of the following hold:
+/// - The op writes to a memref with the default layout.
+/// - Masking is not required.
+/// - If the memref's element type is a vector type then it coincides with the
+/// type of the written value.
+/// - The permutation map is the minor identity map (neither permutation nor
+/// broadcasting is allowed).
+struct TransferWriteToVectorStoreLowering
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ TransferWriteToVectorStoreLowering(MLIRContext *context)
+ : OpRewritePattern<vector::TransferWriteOp>(context) {}
+ LogicalResult matchAndRewrite(vector::TransferWriteOp write,
+ PatternRewriter &rewriter) const override {
+ // TODO: Support non-minor-identity maps
+ if (!write.permutation_map().isMinorIdentity())
+ return failure();
+ auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
+ if (!memRefType)
+ return failure();
+ // `vector.store` supports vector types as memref's elements only when the
+ // type of the vector value being written is the same as the element type.
+ if (memRefType.getElementType().isa<VectorType>() &&
+ memRefType.getElementType() != write.getVectorType())
+ return failure();
+ // Only the default layout is supported by `vector.store`.
+ // TODO: Support non-default layouts.
+ if (!memRefType.getAffineMaps().empty())
+ return failure();
+ // TODO: When masking is required, we can create a MaskedStoreOp
+ if (write.hasMaskedDim())
+ return failure();
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ write, write.vector(), write.source(), write.indices());
+ return success();
+ }
+};
+
// Trims leading one dimensions from `oldType` and returns the result type.
// Returns `vector<1xT>` if `oldType` only has one element.
static VectorType trimLeadingOneDims(VectorType oldType) {
@@ -3201,3 +3312,9 @@ void mlir::vector::populateVectorContractLoweringPatterns(
ContractionOpToOuterProductOpLowering>(parameters, context);
// clang-format on
}
+
+void mlir::vector::populateVectorTransferLoweringPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ patterns.insert<TransferReadToVectorLoadLowering,
+ TransferWriteToVectorStoreLowering>(context);
+}
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 312e940d20b4..9de80e96d451 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -110,6 +110,35 @@ bool AffineMap::isMinorIdentity() const {
getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
}
+/// Returns true if this affine map is a minor identity up to broadcasted
+/// dimensions which are indicated by value 0 in the result.
+bool AffineMap::isMinorIdentityWithBroadcasting(
+ SmallVectorImpl<unsigned> *broadcastedDims) const {
+ if (broadcastedDims)
+ broadcastedDims->clear();
+ if (getNumDims() < getNumResults())
+ return false;
+ unsigned suffixStart = getNumDims() - getNumResults();
+ for (auto idxAndExpr : llvm::enumerate(getResults())) {
+ unsigned resIdx = idxAndExpr.index();
+ AffineExpr expr = idxAndExpr.value();
+ if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+ // Each result may be either a constant 0 (broadcasted dimension).
+ if (constExpr.getValue() != 0)
+ return false;
+ if (broadcastedDims)
+ broadcastedDims->push_back(resIdx);
+ } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
+ // Or it may be the input dimension corresponding to this result position.
+ if (dimExpr.getPosition() != suffixStart + resIdx)
+ return false;
+ } else {
+ return false;
+ }
+ }
+ return true;
+}
+
/// Returns an AffineMap representing a permutation.
AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
MLIRContext *context) {
diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
new file mode 100644
index 000000000000..bc23821b856f
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
@@ -0,0 +1,208 @@
+// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -split-input-file | FileCheck %s
+
+// transfer_read/write are lowered to vector.load/store
+// CHECK-LABEL: func @transfer_to_load(
+// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
+// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32>
+// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32>
+// CHECK-NEXT: return %[[RES]] : vector<4xf32>
+// CHECK-NEXT: }
+
+func @transfer_to_load(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
+ %cf0 = constant 0.0 : f32
+ %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xf32>, vector<4xf32>
+ vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<4xf32>, memref<8x8xf32>
+ return %res : vector<4xf32>
+}
+
+// -----
+
+// n-D results are also supported.
+// CHECK-LABEL: func @transfer_2D(
+// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> {
+// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32>
+// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32>
+// CHECK-NEXT: return %[[RES]] : vector<2x4xf32>
+// CHECK-NEXT: }
+
+func @transfer_2D(%mem : memref<8x8xf32>, %i : index) -> vector<2x4xf32> {
+ %cf0 = constant 0.0 : f32
+ %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, false]} : memref<8x8xf32>, vector<2x4xf32>
+ vector.transfer_write %res, %mem[%i, %i] {masked = [false, false]} : vector<2x4xf32>, memref<8x8xf32>
+ return %res : vector<2x4xf32>
+}
+
+// -----
+
+// Vector element types are supported when the result has the same type.
+// CHECK-LABEL: func @transfer_vector_element(
+// CHECK-SAME: %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> {
+// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32>
+// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32>
+// CHECK-NEXT: return %[[RES]] : vector<2x4xf32>
+// CHECK-NEXT: }
+
+func @transfer_vector_element(%mem : memref<8x8xvector<2x4xf32>>, %i : index) -> vector<2x4xf32> {
+ %cf0 = constant dense<0.0> : vector<2x4xf32>
+ %res = vector.transfer_read %mem[%i, %i], %cf0 : memref<8x8xvector<2x4xf32>>, vector<2x4xf32>
+ vector.transfer_write %res, %mem[%i, %i] : vector<2x4xf32>, memref<8x8xvector<2x4xf32>>
+ return %res : vector<2x4xf32>
+}
+
+// -----
+
+// TODO: Vector element types are not supported yet when the result has a
+//
diff erent type.
+// CHECK-LABEL: func @transfer_vector_element_
diff erent_types(
+// CHECK-SAME: %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> vector<1x2x4xf32> {
+// CHECK-NEXT: %[[CF0:.*]] = constant dense<0.000000e+00> : vector<2x4xf32>
+// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false]} : memref<8x8xvector<2x4xf32>>, vector<1x2x4xf32>
+// CHECK-NEXT: vector.transfer_write %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false]} : vector<1x2x4xf32>, memref<8x8xvector<2x4xf32>>
+// CHECK-NEXT: return %[[RES]] : vector<1x2x4xf32>
+// CHECK-NEXT: }
+
+func @transfer_vector_element_
diff erent_types(%mem : memref<8x8xvector<2x4xf32>>, %i : index) -> vector<1x2x4xf32> {
+ %cf0 = constant dense<0.0> : vector<2x4xf32>
+ %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xvector<2x4xf32>>, vector<1x2x4xf32>
+ vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<1x2x4xf32>, memref<8x8xvector<2x4xf32>>
+ return %res : vector<1x2x4xf32>
+}
+
+// -----
+
+// TODO: transfer_read/write cannot be lowered because there is an unmasked
+// dimension.
+// CHECK-LABEL: func @transfer_2D_masked(
+// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> {
+// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false, true]} : memref<8x8xf32>, vector<2x4xf32>
+// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [true, false]} : vector<2x4xf32>, memref<8x8xf32>
+// CHECK-NEXT: return %[[RES]] : vector<2x4xf32>
+// CHECK-NEXT: }
+
+func @transfer_2D_masked(%mem : memref<8x8xf32>, %i : index) -> vector<2x4xf32> {
+ %cf0 = constant 0.0 : f32
+ %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, true]} : memref<8x8xf32>, vector<2x4xf32>
+ vector.transfer_write %res, %mem[%i, %i] {masked = [true, false]} : vector<2x4xf32>, memref<8x8xf32>
+ return %res : vector<2x4xf32>
+}
+
+// -----
+
+// TODO: transfer_read/write cannot be lowered because they are masked.
+// CHECK-LABEL: func @transfer_masked(
+// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
+// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] : memref<8x8xf32>, vector<4xf32>
+// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] : vector<4xf32>, memref<8x8xf32>
+// CHECK-NEXT: return %[[RES]] : vector<4xf32>
+// CHECK-NEXT: }
+
+func @transfer_masked(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
+ %cf0 = constant 0.0 : f32
+ %res = vector.transfer_read %mem[%i, %i], %cf0 : memref<8x8xf32>, vector<4xf32>
+ vector.transfer_write %res, %mem[%i, %i] : vector<4xf32>, memref<8x8xf32>
+ return %res : vector<4xf32>
+}
+
+// -----
+
+// TODO: transfer_read/write cannot be lowered to vector.load/store because the
+// memref has a non-default layout.
+// CHECK-LABEL: func @transfer_nondefault_layout(
+// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32, #{{.*}}>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
+// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false]} : memref<8x8xf32, #{{.*}}>, vector<4xf32>
+// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false]} : vector<4xf32>, memref<8x8xf32, #{{.*}}>
+// CHECK-NEXT: return %[[RES]] : vector<4xf32>
+// CHECK-NEXT: }
+
+#layout = affine_map<(d0, d1) -> (d0*16 + d1)>
+func @transfer_nondefault_layout(%mem : memref<8x8xf32, #layout>, %i : index) -> vector<4xf32> {
+ %cf0 = constant 0.0 : f32
+ %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xf32, #layout>, vector<4xf32>
+ vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<4xf32>, memref<8x8xf32, #layout>
+ return %res : vector<4xf32>
+}
+
+// -----
+
+// TODO: transfer_read/write cannot be lowered to vector.load/store yet when the
+// permutation map is not the minor identity map (up to broadcasting).
+// CHECK-LABEL: func @transfer_perm_map(
+// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
+// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false], permutation_map = #{{.*}}} : memref<8x8xf32>, vector<4xf32>
+// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false], permutation_map = #{{.*}}} : vector<4xf32>, memref<8x8xf32>
+// CHECK-NEXT: return %[[RES]] : vector<4xf32>
+// CHECK-NEXT: }
+
+func @transfer_perm_map(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
+ %cf0 = constant 0.0 : f32
+ %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<8x8xf32>, vector<4xf32>
+ vector.transfer_write %res, %mem[%i, %i] {masked = [false], permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<4xf32>, memref<8x8xf32>
+ return %res : vector<4xf32>
+}
+
+// -----
+
+// Lowering of transfer_read with broadcasting is supported (note that a `load`
+// is generated instead of a `vector.load`).
+// CHECK-LABEL: func @transfer_broadcasting(
+// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
+// CHECK-NEXT: %[[LOAD:.*]] = load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
+// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4xf32>
+// CHECK-NEXT: return %[[RES]] : vector<4xf32>
+// CHECK-NEXT: }
+
+#broadcast = affine_map<(d0, d1) -> (0)>
+func @transfer_broadcasting(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
+ %cf0 = constant 0.0 : f32
+ %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false], permutation_map = #broadcast} : memref<8x8xf32>, vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// -----
+
+// An example with two broadcasted dimensions.
+// CHECK-LABEL: func @transfer_broadcasting_2D(
+// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4x4xf32> {
+// CHECK-NEXT: %[[LOAD:.*]] = load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
+// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4x4xf32>
+// CHECK-NEXT: return %[[RES]] : vector<4x4xf32>
+// CHECK-NEXT: }
+
+#broadcast = affine_map<(d0, d1) -> (0, 0)>
+func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vector<4x4xf32> {
+ %cf0 = constant 0.0 : f32
+ %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, false], permutation_map = #broadcast} : memref<8x8xf32>, vector<4x4xf32>
+ return %res : vector<4x4xf32>
+}
+
+// -----
+
+// More complex broadcasting case (here a `vector.load` is generated).
+// CHECK-LABEL: func @transfer_broadcasting_complex(
+// CHECK-SAME: %[[MEM:.*]]: memref<10x20x30x8x8xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index) -> vector<3x2x4x5xf32> {
+// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]] : memref<10x20x30x8x8xf32>, vector<3x1x1x5xf32>
+// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<3x1x1x5xf32> to vector<3x2x4x5xf32>
+// CHECK-NEXT: return %[[RES]] : vector<3x2x4x5xf32>
+// CHECK-NEXT: }
+
+#broadcast = affine_map<(d0, d1, d2, d3, d4) -> (d1, 0, 0, d4)>
+func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : index) -> vector<3x2x4x5xf32> {
+ %cf0 = constant 0.0 : f32
+ %res = vector.transfer_read %mem[%i, %i, %i, %i, %i], %cf0 {masked = [false, false, false, false], permutation_map = #broadcast} : memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32>
+ return %res : vector<3x2x4x5xf32>
+}
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index e673d46527d2..d45235043536 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -361,6 +361,15 @@ struct TestVectorTransferOpt
void runOnFunction() override { transferOpflowOpt(getFunction()); }
};
+struct TestVectorTransferLoweringPatterns
+ : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> {
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ populateVectorTransferLoweringPatterns(patterns, &getContext());
+ (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+ }
+};
+
} // end anonymous namespace
namespace mlir {
@@ -403,6 +412,10 @@ void registerTestVectorConversions() {
PassRegistration<TestVectorTransferOpt> transferOpOpt(
"test-vector-transferop-opt",
"Test optimization transformations for transfer ops");
+
+ PassRegistration<TestVectorTransferLoweringPatterns> transferOpLoweringPass(
+ "test-vector-transfer-lowering-patterns",
+ "Test conversion patterns to lower transfer ops to other vector ops");
}
} // namespace test
} // namespace mlir
More information about the Mlir-commits
mailing list