[Mlir-commits] [mlir] dc82547 - [mlir][vector] Make write permutation lowering work with tensors.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 2 01:22:18 PST 2022
Author: gysit
Date: 2022-02-02T09:21:10Z
New Revision: dc82547b173ffed79c87e57339540baed096410f
URL: https://github.com/llvm/llvm-project/commit/dc82547b173ffed79c87e57339540baed096410f
DIFF: https://github.com/llvm/llvm-project/commit/dc82547b173ffed79c87e57339540baed096410f.diff
LOG: [mlir][vector] Make write permutation lowering work with tensors.
Use type inference when building the TransferWriteOp in the TransferWritePermutationLowering. Previously, the result type has been set to Type() which triggers an assertion if the pattern is used with tensors instead of memrefs.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D118758
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
index 533dc20da15b1..baf6973be12e2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
@@ -185,8 +185,8 @@ struct TransferWritePermutationLowering
auto newMap = AffineMap::getMinorIdentityMap(
map.getNumDims(), map.getNumResults(), rewriter.getContext());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- op, Type(), newVec, op.source(), op.indices(),
- AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
+ op, newVec, op.source(), op.indices(), AffineMapAttr::get(newMap),
+ newMask, newInBoundsAttr);
return success();
}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 562870c4d9fe6..7983a81f8d8c6 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -327,21 +327,24 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
// -----
// CHECK-LABEL: func @transfer_write_permutations
-func @transfer_write_permutations(%arg0 : memref<?x?x?x?xf32>,
- %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>) -> () {
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:.*]]: tensor<?x?x?x?xf32>
+func @transfer_write_permutations(
+ %arg0 : memref<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
+ %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>) -> tensor<?x?x?x?xf32> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
%m = arith.constant 1 : i1
%mask0 = splat %m : vector<7x14x8x16xi1>
- vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, memref<?x?x?x?xf32>
+ %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32>
// CHECK: %[[NEW_MASK0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xi1> to vector<8x14x16x7xi1>
// CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32>
- // CHECK: vector.transfer_write %[[NEW_VEC0]], %arg0[%c0, %c0, %c0, %c0], %[[NEW_MASK0]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, memref<?x?x?x?xf32>
+ // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[NEW_MASK0]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, tensor<?x?x?x?xf32>
vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
// CHECK: %[[NEW_VEC1:.*]] = vector.transpose %{{.*}} [1, 0] : vector<8x16xf32> to vector<16x8xf32>
- // CHECK: vector.transfer_write %[[NEW_VEC1]], %arg0[%c0, %c0, %c0, %c0] : vector<16x8xf32>, memref<?x?x?x?xf32>
+ // CHECK: vector.transfer_write %[[NEW_VEC1]], %[[ARG0]][%c0, %c0, %c0, %c0] : vector<16x8xf32>, memref<?x?x?x?xf32>
- return
+ return %0 : tensor<?x?x?x?xf32>
}
More information about the Mlir-commits
mailing list