[Mlir-commits] [mlir] 435905e - [mlir][vector] Add extra lowering for more transfer_write maps
Thomas Raoux
llvmlistbot at llvm.org
Tue Jan 17 09:08:27 PST 2023
Author: Thomas Raoux
Date: 2023-01-17T17:06:00Z
New Revision: 435905ecf25ab9da0753931358414164352810f5
URL: https://github.com/llvm/llvm-project/commit/435905ecf25ab9da0753931358414164352810f5
DIFF: https://github.com/llvm/llvm-project/commit/435905ecf25ab9da0753931358414164352810f5.diff
LOG: [mlir][vector] Add extra lowering for more transfer_write maps
Add pattern to lower transfer_write with permutation map that are not
permutation of minor identity map.
Differential Revision: https://reviews.llvm.org/D141815
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
index df8ba7b85534..68d9a349478b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
@@ -33,6 +33,19 @@ inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
return builder.getBoolArrayAttr(newInBoundsValues);
}
+/// Extend the rank of a vector Value by `addedRanks` by adding outer unit
+/// dimensions.
+static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
+ int64_t addedRank) {
+ auto originalVecType = vec.getType().cast<VectorType>();
+ SmallVector<int64_t> newShape(addedRank, 1);
+ newShape.append(originalVecType.getShape().begin(),
+ originalVecType.getShape().end());
+ VectorType newVecType =
+ VectorType::get(newShape, originalVecType.getElementType());
+ return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
+}
+
/// Lower transfer_read op with permutation into a transfer_read with a
/// permutation map composed of leading zeros followed by a minor identiy +
/// vector.transpose op.
@@ -170,6 +183,77 @@ struct TransferWritePermutationLowering
}
};
+/// Convert a transfer.write op with a map which isn't the permutation of a
+/// minor identity into a vector.broadcast + transfer_write with permutation of
+/// minor identity map by adding unit dim on inner dimension. Ex:
+/// ```
+/// vector.transfer_write %v
+/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} :
+/// vector<8x16xf32>
+/// ```
+/// into:
+/// ```
+/// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32>
+/// vector.transfer_write %v1
+/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
+/// vector<1x8x16xf32>
+/// ```
+struct TransferWriteNonPermutationLowering
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getTransferRank() == 0)
+ return failure();
+ SmallVector<unsigned> permutation;
+ AffineMap map = op.getPermutationMap();
+ if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
+ return failure();
+
+ // Missing outer dimensions are allowed, find the most outer existing
+ // dimension then deduce the missing inner dimensions.
+ SmallVector<bool> foundDim(map.getNumDims(), false);
+ for (AffineExpr exp : map.getResults()) {
+ foundDim[exp.cast<AffineDimExpr>().getPosition()] = true;
+ }
+ SmallVector<AffineExpr> exprs;
+ bool foundFirstDim = false;
+ SmallVector<int64_t> missingInnerDim;
+ for (size_t i = 0; i < foundDim.size(); i++) {
+ if (foundDim[i]) {
+ foundFirstDim = true;
+ continue;
+ }
+ if (!foundFirstDim)
+ continue;
+ // Once we found one outer dimension existing in the map keep track of all
+ // the missing dimensions after that.
+ missingInnerDim.push_back(i);
+ exprs.push_back(rewriter.getAffineDimExpr(i));
+ }
+ // Add unit dims at the beginning of the shape.
+ Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
+ missingInnerDim.size());
+ exprs.append(map.getResults().begin(), map.getResults().end());
+ AffineMap newMap =
+ AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
+ ArrayAttr newInBoundsAttr;
+ if (op.getInBounds()) {
+ // All the new dimensions added are inbound.
+ SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
+ for (Attribute attr : op.getInBounds().value().getValue()) {
+ newInBoundsValues.push_back(attr.cast<BoolAttr>().getValue());
+ }
+ newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
+ }
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
+ op.getMask(), newInBoundsAttr);
+ return success();
+ }
+};
+
/// Lower transfer_read op with broadcast in the leading dimensions into
/// transfer_read of lower rank + vector.broadcast.
/// Ex: vector.transfer_read ...
@@ -250,7 +334,8 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<TransferReadPermutationLowering,
- TransferWritePermutationLowering, TransferOpReduceRank>(
- patterns.getContext(), benefit);
+ patterns
+ .add<TransferReadPermutationLowering, TransferWritePermutationLowering,
+ TransferOpReduceRank, TransferWriteNonPermutationLowering>(
+ patterns.getContext(), benefit);
}
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 0d5678117dfb..ca353a07ad76 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -149,34 +149,32 @@ func.func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
// CHECK-LABEL:func @materialize_write(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
func.func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
- // CHECK-DAG: %{{.*}} = arith.constant dense<1.000000e+00> : vector<5x4x3xf32>
+ // CHECK-DAG: %{{.*}} = arith.constant dense<1.000000e+00> : vector<3x4x1x5xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
- // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
// CHECK: %{{.*}} = memref.alloc(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : memref<?x?x?x?xf32>
// CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %{{.*}} step 3 {
// CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %{{.*}} step 4 {
// CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %{{.*}} {
// CHECK-NEXT: affine.for %[[I3:.*]] = 0 to %{{.*}} step 5 {
- // CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<vector<5x4x3xf32>>
- // CHECK: memref.store %{{.*}}, %[[ALLOC]][] : memref<vector<5x4x3xf32>>
- // CHECK: %[[VECTOR_VIEW1:.*]] = vector.type_cast %[[ALLOC]] : memref<vector<5x4x3xf32>> to memref<5xvector<4x3xf32>>
- // CHECK: scf.for %[[I4:.*]] = %[[C0]] to %[[C5]] step %[[C1]] {
+ // CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<vector<3x4x1x5xf32>>
+ // CHECK: memref.store %{{.*}}, %[[ALLOC]][] : memref<vector<3x4x1x5xf32>>
+ // CHECK: %[[VECTOR_VIEW1:.*]] = vector.type_cast %[[ALLOC]] : memref<vector<3x4x1x5xf32>> to memref<3xvector<4x1x5xf32>>
+ // CHECK: scf.for %[[I4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK: scf.if
- // CHECK: %[[S3:.*]] = affine.apply #[[$ADD]](%[[I3]], %[[I4]])
- // CHECK: %[[VECTOR_VIEW2:.*]] = vector.type_cast %[[VECTOR_VIEW1]] : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
+ // CHECK: %[[S3:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I4]])
+ // CHECK: %[[VECTOR_VIEW2:.*]] = vector.type_cast %[[VECTOR_VIEW1]] : memref<3xvector<4x1x5xf32>> to memref<3x4xvector<1x5xf32>>
// CHECK: scf.for %[[I5:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
// CHECK: scf.if
// CHECK: %[[S1:.*]] = affine.apply #[[$ADD]](%[[I1]], %[[I5]])
- // CHECK: %[[VEC:.*]] = memref.load %[[VECTOR_VIEW2]][%[[I4]], %[[I5]]] : memref<5x4xvector<3xf32>>
- // CHECK: scf.for %[[I6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
- // CHECK: %[[S0:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I6]])
+ // CHECK: %[[VECTOR_VIEW3:.*]] = vector.type_cast %[[VECTOR_VIEW2]] : memref<3x4xvector<1x5xf32>> to memref<3x4x1xvector<5xf32>>
+ // CHECK: scf.for %[[I6:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
// CHECK: scf.if
- // CHECK: %[[SCAL:.*]] = vector.extractelement %[[VEC]][%[[I6]] : index] : vector<3xf32>
- // CHECK: memref.store %[[SCAL]], {{.*}}[%[[S0]], %[[S1]], %[[I2]], %[[S3]]] : memref<?x?x?x?xf32>
- // CHECK: }
+ // CHECK: %[[S0:.*]] = affine.apply #[[$ADD]](%[[I2]], %[[I6]])
+ // CHECK: %[[VEC:.*]] = memref.load %[[VECTOR_VIEW3]][%[[I4]], %[[I5]], %[[I6]]] : memref<3x4x1xvector<5xf32>>
+ // CHECK: vector.transfer_write %[[VEC]], %{{.*}}[%[[S3]], %[[S1]], %[[S0]], %[[I3]]] : vector<5xf32>, memref<?x?x?x?xf32>
// CHECK: }
// CHECK: }
// CHECK: }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 779b84f96a57..6911cd599c1f 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -178,14 +178,12 @@ func.func @transfer_nondefault_layout(%mem : memref<8x8xf32, #layout>, %i : inde
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
// CHECK-NEXT: %[[CF0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {in_bounds = [true], permutation_map = #{{.*}}} : memref<8x8xf32>, vector<4xf32>
-// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {in_bounds = [true], permutation_map = #{{.*}}} : vector<4xf32>, memref<8x8xf32>
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
// CHECK-NEXT: }
func.func @transfer_perm_map(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
%cf0 = arith.constant 0.0 : f32
%res = vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<8x8xf32>, vector<4xf32>
- vector.transfer_write %res, %mem[%i, %i] {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<4xf32>, memref<8x8xf32>
return %res : vector<4xf32>
}
@@ -349,3 +347,30 @@ func.func @transfer_write_permutations(
return %0 : tensor<?x?x?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @transfer_write_broadcast_unit_dim
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:.*]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG2:.*]]: vector<14x8x16xf32>
+// CHECK-SAME: %[[ARG3:.*]]: vector<8x16xf32>
+// CHECK-SAME: %[[M:.*]]: i1
+func.func @transfer_write_broadcast_unit_dim(
+ %arg0 : memref<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
+ %v1 : vector<14x8x16xf32>, %v2 : vector<8x16xf32>, %m: i1) -> tensor<?x?x?x?xf32> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+
+ %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
+ // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32>
+ // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32>
+ // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
+
+ vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
+ // CHECK: %[[NEW_VEC2:.*]] = vector.broadcast %{{.*}} : vector<8x16xf32> to vector<1x8x16xf32>
+ // CHECK: %[[NEW_VEC3:.*]] = vector.transpose %[[NEW_VEC2]], [1, 2, 0] : vector<1x8x16xf32> to vector<8x16x1xf32>
+ // CHECK: vector.transfer_write %[[NEW_VEC3]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] : vector<8x16x1xf32>, memref<?x?x?x?xf32>
+
+ return %0 : tensor<?x?x?x?xf32>
+}
More information about the Mlir-commits
mailing list