[Mlir-commits] [mlir] 1244bca - [mlir][vector] Support distributing transfer op with permutation map
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 21 12:56:32 PDT 2021
Author: thomasraoux
Date: 2021-06-21T12:56:08-07:00
New Revision: 1244bca53fb2ff2e6061ae43b830a645bf93cc6d
URL: https://github.com/llvm/llvm-project/commit/1244bca53fb2ff2e6061ae43b830a645bf93cc6d
DIFF: https://github.com/llvm/llvm-project/commit/1244bca53fb2ff2e6061ae43b830a645bf93cc6d.diff
LOG: [mlir][vector] Support distributing transfer op with permutation map
Differential Revision: https://reviews.llvm.org/D104263
Added:
Modified:
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-distribution.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index baded89d9074b..6765fd4946d47 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2842,6 +2842,20 @@ Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
return ops;
}
+/// Converts TransferRead op used by ExtractMap op into a smaller dimension
+/// TransferRead.
+/// Example:
+/// ```
+/// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
+/// memref<64x64x64xf32>, vector<64x4x32xf32>
+/// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
+/// ```
+/// to:
+/// ```
+/// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
+/// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
+/// memref<64x64x64xf32>, vector<2x4x1xf32>
+/// ```
struct TransferReadExtractPattern
: public OpRewritePattern<vector::TransferReadOp> {
TransferReadExtractPattern(MLIRContext *context)
@@ -2858,18 +2872,23 @@ struct TransferReadExtractPattern
return failure();
SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
- AffineMap map = extract.map();
+ AffineMap indexMap = extract.map().compose(read.permutation_map());
unsigned idCount = 0;
ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
- for (auto expr : map.getResults()) {
+ for (auto it :
+ llvm::zip(indexMap.getResults(), extract.map().getResults())) {
AffineExpr d0, d1;
bindDims(read.getContext(), d0, d1);
- unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+ auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
+ if (!indexExpr)
+ continue;
+ unsigned indexPos = indexExpr.getPosition();
+ unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
auto scale = getAffineConstantExpr(
- extract.getResultType().getDimSize(pos), read.getContext());
- indices[pos] =
- makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
- {indices[pos], extract.ids()[idCount++]});
+ extract.getResultType().getDimSize(vectorPos), read.getContext());
+ indices[indexPos] = makeComposedAffineApply(
+ rewriter, read.getLoc(), d0 + scale * d1,
+ {indices[indexPos], extract.ids()[idCount++]});
}
Value newRead = lb.create<vector::TransferReadOp>(
extract.getType(), read.source(), indices, read.permutation_map(),
@@ -2895,18 +2914,24 @@ struct TransferWriteInsertPattern
return failure();
SmallVector<Value, 4> indices(write.indices().begin(),
write.indices().end());
- AffineMap map = insert.map();
+ AffineMap indexMap = insert.map().compose(write.permutation_map());
unsigned idCount = 0;
Location loc = write.getLoc();
- for (auto expr : map.getResults()) {
+ for (auto it :
+ llvm::zip(indexMap.getResults(), insert.map().getResults())) {
AffineExpr d0, d1;
bindDims(write.getContext(), d0, d1);
- unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+ auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
+ if (!indexExpr)
+ continue;
+ unsigned indexPos = indexExpr.getPosition();
+ unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
auto scale = getAffineConstantExpr(
- insert.getSourceVectorType().getDimSize(pos), write.getContext());
- indices[pos] =
+ insert.getSourceVectorType().getDimSize(vectorPos),
+ write.getContext());
+ indices[indexPos] =
makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
- {indices[pos], insert.ids()[idCount++]});
+ {indices[indexPos], insert.ids()[idCount++]});
}
rewriter.create<vector::TransferWriteOp>(
loc, insert.vector(), write.source(), indices, write.permutation_map(),
diff --git a/mlir/test/Dialect/Vector/vector-distribution.mlir b/mlir/test/Dialect/Vector/vector-distribution.mlir
index 950786e86caa2..0ad46d1b204e1 100644
--- a/mlir/test/Dialect/Vector/vector-distribution.mlir
+++ b/mlir/test/Dialect/Vector/vector-distribution.mlir
@@ -123,4 +123,34 @@ func @vector_add_transfer_3d(%id0 : index, %id1 : index, %A: memref<64x64x64xf32
return
}
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d3, 0, 0)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (0, d3, d0)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, 0, 0)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (0, d3, d0)>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
+// CHECK: func @vector_add_transfer_permutation
+// CHECK-SAME: (%[[ID_0:.*]]: index, %[[ID_1:.*]]: index
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
+// CHECK-NEXT: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[ID2]]], %{{.*}} {permutation_map = #[[MAP1]]} : memref<?x?x?x?xf32>, vector<2x4x1xf32>
+// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID_0]], %[[C0]], %[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP2]]} : memref<?x?x?x?xf32>, vector<2x4x1xf32>
+// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32>
+// CHECK-NEXT: %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
+// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[C0]], %[[ID_1]], %[[C0]], %[[ID3]]] {permutation_map = #[[MAP3]]} : vector<2x4x1xf32>, memref<?x?x?x?xf32>
+// CHECK-NEXT: return
+func @vector_add_transfer_permutation(%id0 : index, %id1 : index, %A: memref<?x?x?x?xf32>,
+ %B: memref<?x?x?x?xf32>, %C: memref<?x?x?x?xf32>) {
+ %c0 = constant 0 : index
+ %cf0 = constant 0.0 : f32
+ %a = vector.transfer_read %A[%c0, %c0, %c0, %c0], %cf0 {permutation_map = #map0} : memref<?x?x?x?xf32>, vector<64x4x32xf32>
+ %b = vector.transfer_read %B[%c0, %c0, %c0, %c0], %cf0 {permutation_map = #map1}: memref<?x?x?x?xf32>, vector<64x4x32xf32>
+ %acc = addf %a, %b: vector<64x4x32xf32>
+ vector.transfer_write %acc, %C[%c0, %c0, %c0, %c0] {permutation_map = #map2}: vector<64x4x32xf32>, memref<?x?x?x?xf32>
+ return
+}
More information about the Mlir-commits
mailing list