[Mlir-commits] [mlir] dc82547 - [mlir][vector] Make write permutation lowering work with tensors.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 2 01:22:18 PST 2022


Author: gysit
Date: 2022-02-02T09:21:10Z
New Revision: dc82547b173ffed79c87e57339540baed096410f

URL: https://github.com/llvm/llvm-project/commit/dc82547b173ffed79c87e57339540baed096410f
DIFF: https://github.com/llvm/llvm-project/commit/dc82547b173ffed79c87e57339540baed096410f.diff

LOG: [mlir][vector] Make write permutation lowering work with tensors.

Use type inference when building the TransferWriteOp in the TransferWritePermutationLowering. Previously, the result type has been set to Type() which triggers an assertion if the pattern is used with tensors instead of memrefs.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D118758

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
    mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
index 533dc20da15b1..baf6973be12e2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp
@@ -185,8 +185,8 @@ struct TransferWritePermutationLowering
     auto newMap = AffineMap::getMinorIdentityMap(
         map.getNumDims(), map.getNumResults(), rewriter.getContext());
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        op, Type(), newVec, op.source(), op.indices(),
-        AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
+        op, newVec, op.source(), op.indices(), AffineMapAttr::get(newMap),
+        newMask, newInBoundsAttr);
 
     return success();
   }

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 562870c4d9fe6..7983a81f8d8c6 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -327,21 +327,24 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
 // -----
 
 // CHECK-LABEL: func @transfer_write_permutations
-func @transfer_write_permutations(%arg0 : memref<?x?x?x?xf32>,
-    %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>) -> () {
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?x?x?x?xf32>
+// CHECK-SAME:      %[[ARG1:.*]]: tensor<?x?x?x?xf32>
+func @transfer_write_permutations(
+    %arg0 : memref<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
+    %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>) -> tensor<?x?x?x?xf32> {
   // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
   %m = arith.constant 1 : i1
 
   %mask0 = splat %m : vector<7x14x8x16xi1>
-  vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, memref<?x?x?x?xf32>
+  %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32>
   // CHECK: %[[NEW_MASK0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xi1> to vector<8x14x16x7xi1>
   // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32>
-  // CHECK: vector.transfer_write %[[NEW_VEC0]], %arg0[%c0, %c0, %c0, %c0], %[[NEW_MASK0]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, memref<?x?x?x?xf32>
+  // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[NEW_MASK0]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, tensor<?x?x?x?xf32>
 
   vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
   // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %{{.*}} [1, 0] : vector<8x16xf32> to vector<16x8xf32>
-  // CHECK: vector.transfer_write %[[NEW_VEC1]], %arg0[%c0, %c0, %c0, %c0] : vector<16x8xf32>, memref<?x?x?x?xf32>
+  // CHECK: vector.transfer_write %[[NEW_VEC1]], %[[ARG0]][%c0, %c0, %c0, %c0] : vector<16x8xf32>, memref<?x?x?x?xf32>
 
-  return
+  return %0 : tensor<?x?x?x?xf32>
 }


        


More information about the Mlir-commits mailing list