[Mlir-commits] [mlir] fd2b089 - [mlir][Vector] Lowering of transfer_read/write to vector.load/store

Sergei Grechanik llvmlistbot at llvm.org
Thu Mar 11 18:19:47 PST 2021


Author: Sergei Grechanik
Date: 2021-03-11T18:17:51-08:00
New Revision: fd2b08969b8a42945b3f79a027fb80582ff42411

URL: https://github.com/llvm/llvm-project/commit/fd2b08969b8a42945b3f79a027fb80582ff42411
DIFF: https://github.com/llvm/llvm-project/commit/fd2b08969b8a42945b3f79a027fb80582ff42411.diff

LOG: [mlir][Vector] Lowering of transfer_read/write to vector.load/store

This patch introduces progressive lowering patterns for rewriting
vector.transfer_read/write to vector.load/store and vector.broadcast
in certain supported cases.

Reviewed By: dcaballe, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D97822

Added: 
    mlir/test/Dialect/Vector/vector-transfer-lowering.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index ee7ed62dcf01..9e486d038a48 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -85,6 +85,13 @@ void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns,
 void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
                                           MLIRContext *context);
 
+/// Collect a set of transfer read/write lowering patterns.
+///
+/// These patterns lower transfer ops to simpler ops like `vector.load`,
+/// `vector.store` and `vector.broadcast`.
+void populateVectorTransferLoweringPatterns(OwningRewritePatternList &patterns,
+                                            MLIRContext *context);
+
 /// An attribute that specifies the combining function for `vector.contract`,
 /// and `vector.reduction`.
 class CombiningKindAttr

diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 86480529aa05..e837fc070ab0 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -104,6 +104,15 @@ class AffineMap {
   /// affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
   bool isMinorIdentity() const;
 
+  /// Returns true if this affine map is a minor identity up to broadcasted
+  /// dimensions which are indicated by value 0 in the result. If
+  /// `broadcastedDims` is not null, it will be populated with the indices of
+  /// the broadcasted dimensions in the result array.
+  /// Example: affine_map<(d0, d1, d2, d3, d4) -> (0, d2, 0, d4)>
+  ///          (`broadcastedDims` will contain [0, 2])
+  bool isMinorIdentityWithBroadcasting(
+      SmallVectorImpl<unsigned> *broadcastedDims = nullptr) const;
+
   /// Returns true if this affine map is an empty map, i.e., () -> ().
   bool isEmpty() const;
 

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 200eb55076f7..090afda01fe4 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -37,6 +37,7 @@
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -2729,6 +2730,116 @@ struct TransferWriteInsertPattern
   }
 };
 
+/// Progressive lowering of transfer_read. This pattern supports lowering of
+/// `vector.transfer_read` to a combination of `vector.load` and
+/// `vector.broadcast` if all of the following hold:
+/// - The op reads from a memref with the default layout.
+/// - Masking is not required.
+/// - If the memref's element type is a vector type then it coincides with the
+///   result type.
+/// - The permutation map doesn't perform permutation (broadcasting is allowed).
+struct TransferReadToVectorLoadLowering
+    : public OpRewritePattern<vector::TransferReadOp> {
+  TransferReadToVectorLoadLowering(MLIRContext *context)
+      : OpRewritePattern<vector::TransferReadOp>(context) {}
+  LogicalResult matchAndRewrite(vector::TransferReadOp read,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<unsigned, 4> broadcastedDims;
+    // TODO: Support permutations.
+    if (!read.permutation_map().isMinorIdentityWithBroadcasting(
+            &broadcastedDims))
+      return failure();
+    auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
+    if (!memRefType)
+      return failure();
+
+    // If there is broadcasting involved then we first load the unbroadcasted
+    // vector, and then broadcast it with `vector.broadcast`.
+    ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
+    SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
+                                                     vectorShape.end());
+    for (unsigned i : broadcastedDims)
+      unbroadcastedVectorShape[i] = 1;
+    VectorType unbroadcastedVectorType = VectorType::get(
+        unbroadcastedVectorShape, read.getVectorType().getElementType());
+
+    // `vector.load` supports vector types as memref's elements only when the
+    // resulting vector type is the same as the element type.
+    if (memRefType.getElementType().isa<VectorType>() &&
+        memRefType.getElementType() != unbroadcastedVectorType)
+      return failure();
+    // Only the default layout is supported by `vector.load`.
+    // TODO: Support non-default layouts.
+    if (!memRefType.getAffineMaps().empty())
+      return failure();
+    // TODO: When masking is required, we can create a MaskedLoadOp
+    if (read.hasMaskedDim())
+      return failure();
+
+    Operation *loadOp;
+    if (!broadcastedDims.empty() &&
+        unbroadcastedVectorType.getNumElements() == 1) {
+      // If broadcasting is required and the number of loaded elements is 1 then
+      // we can create `std.load` instead of `vector.load`.
+      loadOp = rewriter.create<mlir::LoadOp>(read.getLoc(), read.source(),
+                                             read.indices());
+    } else {
+      // Otherwise create `vector.load`.
+      loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
+                                               unbroadcastedVectorType,
+                                               read.source(), read.indices());
+    }
+
+    // Insert a broadcasting op if required.
+    if (!broadcastedDims.empty()) {
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+          read, read.getVectorType(), loadOp->getResult(0));
+    } else {
+      rewriter.replaceOp(read, loadOp->getResult(0));
+    }
+
+    return success();
+  }
+};
+
+/// Progressive lowering of transfer_write. This pattern supports lowering of
+/// `vector.transfer_write` to `vector.store` if all of the following hold:
+/// - The op writes to a memref with the default layout.
+/// - Masking is not required.
+/// - If the memref's element type is a vector type then it coincides with the
+///   type of the written value.
+/// - The permutation map is the minor identity map (neither permutation nor
+///   broadcasting is allowed).
+struct TransferWriteToVectorStoreLowering
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  TransferWriteToVectorStoreLowering(MLIRContext *context)
+      : OpRewritePattern<vector::TransferWriteOp>(context) {}
+  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
+                                PatternRewriter &rewriter) const override {
+    // TODO: Support non-minor-identity maps
+    if (!write.permutation_map().isMinorIdentity())
+      return failure();
+    auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
+    if (!memRefType)
+      return failure();
+    // `vector.store` supports vector types as memref's elements only when the
+    // type of the vector value being written is the same as the element type.
+    if (memRefType.getElementType().isa<VectorType>() &&
+        memRefType.getElementType() != write.getVectorType())
+      return failure();
+    // Only the default layout is supported by `vector.store`.
+    // TODO: Support non-default layouts.
+    if (!memRefType.getAffineMaps().empty())
+      return failure();
+    // TODO: When masking is required, we can create a MaskedStoreOp
+    if (write.hasMaskedDim())
+      return failure();
+    rewriter.replaceOpWithNewOp<vector::StoreOp>(
+        write, write.vector(), write.source(), write.indices());
+    return success();
+  }
+};
+
 // Trims leading one dimensions from `oldType` and returns the result type.
 // Returns `vector<1xT>` if `oldType` only has one element.
 static VectorType trimLeadingOneDims(VectorType oldType) {
@@ -3201,3 +3312,9 @@ void mlir::vector::populateVectorContractLoweringPatterns(
                   ContractionOpToOuterProductOpLowering>(parameters, context);
   // clang-format on
 }
+
+void mlir::vector::populateVectorTransferLoweringPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<TransferReadToVectorLoadLowering,
+                  TransferWriteToVectorStoreLowering>(context);
+}

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 312e940d20b4..9de80e96d451 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -110,6 +110,35 @@ bool AffineMap::isMinorIdentity() const {
          getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
 }
 
+/// Returns true if this affine map is a minor identity up to broadcasted
+/// dimensions which are indicated by value 0 in the result.
+bool AffineMap::isMinorIdentityWithBroadcasting(
+    SmallVectorImpl<unsigned> *broadcastedDims) const {
+  if (broadcastedDims)
+    broadcastedDims->clear();
+  if (getNumDims() < getNumResults())
+    return false;
+  unsigned suffixStart = getNumDims() - getNumResults();
+  for (auto idxAndExpr : llvm::enumerate(getResults())) {
+    unsigned resIdx = idxAndExpr.index();
+    AffineExpr expr = idxAndExpr.value();
+    if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+      // Each result may be either a constant 0 (broadcasted dimension).
+      if (constExpr.getValue() != 0)
+        return false;
+      if (broadcastedDims)
+        broadcastedDims->push_back(resIdx);
+    } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
+      // Or it may be the input dimension corresponding to this result position.
+      if (dimExpr.getPosition() != suffixStart + resIdx)
+        return false;
+    } else {
+      return false;
+    }
+  }
+  return true;
+}
+
 /// Returns an AffineMap representing a permutation.
 AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
                                        MLIRContext *context) {

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
new file mode 100644
index 000000000000..bc23821b856f
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
@@ -0,0 +1,208 @@
+// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -split-input-file | FileCheck %s
+
+// transfer_read/write are lowered to vector.load/store
+// CHECK-LABEL:   func @transfer_to_load(
+// CHECK-SAME:                                %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME:                                %[[IDX:.*]]: index) -> vector<4xf32> {
+// CHECK-NEXT:      %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32>
+// CHECK-NEXT:      vector.store  %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<4xf32>
+// CHECK-NEXT:    }
+
+func @transfer_to_load(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
+  %cf0 = constant 0.0 : f32
+  %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xf32>, vector<4xf32>
+  vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<4xf32>, memref<8x8xf32>
+  return %res : vector<4xf32>
+}
+
+// -----
+
+// n-D results are also supported.
+// CHECK-LABEL:   func @transfer_2D(
+// CHECK-SAME:                           %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME:                           %[[IDX:.*]]: index) -> vector<2x4xf32> {
+// CHECK-NEXT:      %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32>
+// CHECK-NEXT:      vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<2x4xf32>
+// CHECK-NEXT:    }
+
+func @transfer_2D(%mem : memref<8x8xf32>, %i : index) -> vector<2x4xf32> {
+  %cf0 = constant 0.0 : f32
+  %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, false]} : memref<8x8xf32>, vector<2x4xf32>
+  vector.transfer_write %res, %mem[%i, %i] {masked = [false, false]} : vector<2x4xf32>, memref<8x8xf32>
+  return %res : vector<2x4xf32>
+}
+
+// -----
+
+// Vector element types are supported when the result has the same type.
+// CHECK-LABEL:   func @transfer_vector_element(
+// CHECK-SAME:                           %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>,
+// CHECK-SAME:                           %[[IDX:.*]]: index) -> vector<2x4xf32> {
+// CHECK-NEXT:      %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32>
+// CHECK-NEXT:      vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<2x4xf32>
+// CHECK-NEXT:    }
+
+func @transfer_vector_element(%mem : memref<8x8xvector<2x4xf32>>, %i : index) -> vector<2x4xf32> {
+  %cf0 = constant dense<0.0> : vector<2x4xf32>
+  %res = vector.transfer_read %mem[%i, %i], %cf0 : memref<8x8xvector<2x4xf32>>, vector<2x4xf32>
+  vector.transfer_write %res, %mem[%i, %i] : vector<2x4xf32>, memref<8x8xvector<2x4xf32>>
+  return %res : vector<2x4xf32>
+}
+
+// -----
+
+// TODO: Vector element types are not supported yet when the result has a
+// 
diff erent type.
+// CHECK-LABEL:   func @transfer_vector_element_
diff erent_types(
+// CHECK-SAME:                           %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>,
+// CHECK-SAME:                           %[[IDX:.*]]: index) -> vector<1x2x4xf32> {
+// CHECK-NEXT:      %[[CF0:.*]] = constant dense<0.000000e+00> : vector<2x4xf32>
+// CHECK-NEXT:      %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false]} : memref<8x8xvector<2x4xf32>>, vector<1x2x4xf32>
+// CHECK-NEXT:      vector.transfer_write %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false]} : vector<1x2x4xf32>, memref<8x8xvector<2x4xf32>>
+// CHECK-NEXT:      return %[[RES]] : vector<1x2x4xf32>
+// CHECK-NEXT:    }
+
+func @transfer_vector_element_
diff erent_types(%mem : memref<8x8xvector<2x4xf32>>, %i : index) -> vector<1x2x4xf32> {
+  %cf0 = constant dense<0.0> : vector<2x4xf32>
+  %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xvector<2x4xf32>>, vector<1x2x4xf32>
+  vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<1x2x4xf32>, memref<8x8xvector<2x4xf32>>
+  return %res : vector<1x2x4xf32>
+}
+
+// -----
+
+// TODO: transfer_read/write cannot be lowered because there is an unmasked
+// dimension.
+// CHECK-LABEL:   func @transfer_2D_masked(
+// CHECK-SAME:                                  %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME:                                  %[[IDX:.*]]: index) -> vector<2x4xf32> {
+// CHECK-NEXT:      %[[CF0:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT:      %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false, true]} : memref<8x8xf32>, vector<2x4xf32>
+// CHECK-NEXT:      vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [true, false]} : vector<2x4xf32>, memref<8x8xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<2x4xf32>
+// CHECK-NEXT:    }
+
+func @transfer_2D_masked(%mem : memref<8x8xf32>, %i : index) -> vector<2x4xf32> {
+  %cf0 = constant 0.0 : f32
+  %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, true]} : memref<8x8xf32>, vector<2x4xf32>
+  vector.transfer_write %res, %mem[%i, %i] {masked = [true, false]} : vector<2x4xf32>, memref<8x8xf32>
+  return %res : vector<2x4xf32>
+}
+
+// -----
+
+// TODO: transfer_read/write cannot be lowered because they are masked.
+// CHECK-LABEL:   func @transfer_masked(
+// CHECK-SAME:                               %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME:                               %[[IDX:.*]]: index) -> vector<4xf32> {
+// CHECK-NEXT:      %[[CF0:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT:      %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] : memref<8x8xf32>, vector<4xf32>
+// CHECK-NEXT:      vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] : vector<4xf32>, memref<8x8xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<4xf32>
+// CHECK-NEXT:    }
+
+func @transfer_masked(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
+  %cf0 = constant 0.0 : f32
+  %res = vector.transfer_read %mem[%i, %i], %cf0 : memref<8x8xf32>, vector<4xf32>
+  vector.transfer_write %res, %mem[%i, %i] : vector<4xf32>, memref<8x8xf32>
+  return %res : vector<4xf32>
+}
+
+// -----
+
+// TODO: transfer_read/write cannot be lowered to vector.load/store because the
+// memref has a non-default layout.
+// CHECK-LABEL:   func @transfer_nondefault_layout(
+// CHECK-SAME:                                          %[[MEM:.*]]: memref<8x8xf32, #{{.*}}>,
+// CHECK-SAME:                                          %[[IDX:.*]]: index) -> vector<4xf32> {
+// CHECK-NEXT:      %[[CF0:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT:      %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false]} : memref<8x8xf32, #{{.*}}>, vector<4xf32>
+// CHECK-NEXT:      vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false]} : vector<4xf32>, memref<8x8xf32, #{{.*}}>
+// CHECK-NEXT:      return %[[RES]] : vector<4xf32>
+// CHECK-NEXT:    }
+
+#layout = affine_map<(d0, d1) -> (d0*16 + d1)>
+func @transfer_nondefault_layout(%mem : memref<8x8xf32, #layout>, %i : index) -> vector<4xf32> {
+  %cf0 = constant 0.0 : f32
+  %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xf32, #layout>, vector<4xf32>
+  vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<4xf32>, memref<8x8xf32, #layout>
+  return %res : vector<4xf32>
+}
+
+// -----
+
+// TODO: transfer_read/write cannot be lowered to vector.load/store yet when the
+// permutation map is not the minor identity map (up to broadcasting).
+// CHECK-LABEL:   func @transfer_perm_map(
+// CHECK-SAME:                                 %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME:                                 %[[IDX:.*]]: index) -> vector<4xf32> {
+// CHECK-NEXT:      %[[CF0:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT:      %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false], permutation_map = #{{.*}}} : memref<8x8xf32>, vector<4xf32>
+// CHECK-NEXT:      vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false], permutation_map = #{{.*}}} : vector<4xf32>, memref<8x8xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<4xf32>
+// CHECK-NEXT:    }
+
+func @transfer_perm_map(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
+  %cf0 = constant 0.0 : f32
+  %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<8x8xf32>, vector<4xf32>
+  vector.transfer_write %res, %mem[%i, %i] {masked = [false], permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<4xf32>, memref<8x8xf32>
+  return %res : vector<4xf32>
+}
+
+// -----
+
+// Lowering of transfer_read with broadcasting is supported (note that a `load`
+// is generated instead of a `vector.load`).
+// CHECK-LABEL:   func @transfer_broadcasting(
+// CHECK-SAME:                                %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME:                                %[[IDX:.*]]: index) -> vector<4xf32> {
+// CHECK-NEXT:      %[[LOAD:.*]] = load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
+// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<4xf32>
+// CHECK-NEXT:    }
+
+#broadcast = affine_map<(d0, d1) -> (0)>
+func @transfer_broadcasting(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
+  %cf0 = constant 0.0 : f32
+  %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false], permutation_map = #broadcast} : memref<8x8xf32>, vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// -----
+
+// An example with two broadcasted dimensions.
+// CHECK-LABEL:   func @transfer_broadcasting_2D(
+// CHECK-SAME:                                %[[MEM:.*]]: memref<8x8xf32>,
+// CHECK-SAME:                                %[[IDX:.*]]: index) -> vector<4x4xf32> {
+// CHECK-NEXT:      %[[LOAD:.*]] = load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
+// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4x4xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<4x4xf32>
+// CHECK-NEXT:    }
+
+#broadcast = affine_map<(d0, d1) -> (0, 0)>
+func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vector<4x4xf32> {
+  %cf0 = constant 0.0 : f32
+  %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, false], permutation_map = #broadcast} : memref<8x8xf32>, vector<4x4xf32>
+  return %res : vector<4x4xf32>
+}
+
+// -----
+
+// More complex broadcasting case (here a `vector.load` is generated).
+// CHECK-LABEL:   func @transfer_broadcasting_complex(
+// CHECK-SAME:                                %[[MEM:.*]]: memref<10x20x30x8x8xf32>,
+// CHECK-SAME:                                %[[IDX:.*]]: index) -> vector<3x2x4x5xf32> {
+// CHECK-NEXT:      %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]] : memref<10x20x30x8x8xf32>, vector<3x1x1x5xf32>
+// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<3x1x1x5xf32> to vector<3x2x4x5xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<3x2x4x5xf32>
+// CHECK-NEXT:    }
+
+#broadcast = affine_map<(d0, d1, d2, d3, d4) -> (d1, 0, 0, d4)>
+func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : index) -> vector<3x2x4x5xf32> {
+  %cf0 = constant 0.0 : f32
+  %res = vector.transfer_read %mem[%i, %i, %i, %i, %i], %cf0 {masked = [false, false, false, false], permutation_map = #broadcast} : memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32>
+  return %res : vector<3x2x4x5xf32>
+}

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index e673d46527d2..d45235043536 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -361,6 +361,15 @@ struct TestVectorTransferOpt
   void runOnFunction() override { transferOpflowOpt(getFunction()); }
 };
 
+struct TestVectorTransferLoweringPatterns
+    : public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> {
+  void runOnFunction() override {
+    OwningRewritePatternList patterns;
+    populateVectorTransferLoweringPatterns(patterns, &getContext());
+    (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+  }
+};
+
 } // end anonymous namespace
 
 namespace mlir {
@@ -403,6 +412,10 @@ void registerTestVectorConversions() {
   PassRegistration<TestVectorTransferOpt> transferOpOpt(
       "test-vector-transferop-opt",
       "Test optimization transformations for transfer ops");
+
+  PassRegistration<TestVectorTransferLoweringPatterns> transferOpLoweringPass(
+      "test-vector-transfer-lowering-patterns",
+      "Test conversion patterns to lower transfer ops to other vector ops");
 }
 } // namespace test
 } // namespace mlir


        


More information about the Mlir-commits mailing list