[Mlir-commits] [mlir] f9190c8 - [mlir][vector] Support unrolling for transfer ops using tensors

Thomas Raoux llvmlistbot at llvm.org
Wed Jan 6 13:53:49 PST 2021


Author: Thomas Raoux
Date: 2021-01-06T13:28:04-08:00
New Revision: f9190c868137dcf43833db2c8e1e00c7acca67bc

URL: https://github.com/llvm/llvm-project/commit/f9190c868137dcf43833db2c8e1e00c7acca67bc
DIFF: https://github.com/llvm/llvm-project/commit/f9190c868137dcf43833db2c8e1e00c7acca67bc.diff

LOG: [mlir][vector] Support unrolling for transfer ops using tensors

Differential Revision: https://reviews.llvm.org/D93904

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
    mlir/test/Dialect/Vector/vector-transforms.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index c88aa7f5bc65..a258903d5a3a 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -71,7 +71,8 @@ SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
 
 /// Unroll a transfer_write op. Break up the vector source into a tuple of
 /// vectors matching the given shape. Then store each element with its own
-/// transfer_write.
+/// transfer_write. If the transfer_write takes a tensor source, return the
+/// unrolled Value in result.
 ///
 /// Example:
 /// vector.transfer_write %A, %M[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
@@ -83,7 +84,8 @@ SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
 /// %2 = vector.tuple_get %0, 1 : tuple<vector<2x4xf32>, vector<2x4xf32>>
 /// vector.transfer_write %2, %M[%c2, %c0] : vector<2x4xf32>, memref<4x4xf32>
 LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op,
-                                    ArrayRef<int64_t> targetShape);
+                                    ArrayRef<int64_t> targetShape,
+                                    SmallVector<Value, 1> &result);
 
 /// Options that control the vector unrolling.
 struct UnrollVectorOptions {
@@ -143,9 +145,10 @@ struct UnrollVectorPattern : public RewritePattern {
         llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
       return failure();
     if (isa<TransferWriteOp>(op)) {
-      if (failed(unrollTransferWriteOp(rewriter, op, *targetShape)))
+      SmallVector<Value, 1> result;
+      if (failed(unrollTransferWriteOp(rewriter, op, *targetShape, result)))
         return failure();
-      rewriter.eraseOp(op);
+      rewriter.replaceOp(op, result);
       return success();
     }
     if (op->getNumResults() != 1)

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index f1708db113d4..ca6e92d95ed0 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -515,7 +515,7 @@ static void getVectorElementwiseOpUnrollState(Operation *op,
 /// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
 /// calls 'fn' with linear index and indices for each slice.
 static void generateTransferOpSlices(
-    Type memrefElementType, VectorType vectorType, TupleType tupleType,
+    Type shapedElementType, VectorType vectorType, TupleType tupleType,
     ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<Value> indices,
     OpBuilder &builder, function_ref<void(unsigned, ArrayRef<Value>)> fn) {
   // Compute strides w.r.t. to slice counts in each dimension.
@@ -539,9 +539,9 @@ static void generateTransferOpSlices(
   //   vector rank is 4 - 2 = 2, and so 'indexOffset' = 3 - 2 = 1.
   //
   unsigned vectorRank = vectorType.getRank();
-  if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
-    assert(vectorRank >= memrefVectorElementType.getRank());
-    vectorRank -= memrefVectorElementType.getRank();
+  if (auto sourceVectorElementType = shapedElementType.dyn_cast<VectorType>()) {
+    assert(vectorRank >= sourceVectorElementType.getRank());
+    vectorRank -= sourceVectorElementType.getRank();
   }
   unsigned indexOffset = numSliceIndices - vectorRank;
 
@@ -598,8 +598,8 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
   SmallVector<int64_t, 4> strides(targetShape.size(), 1);
 
   Location loc = readOp.getLoc();
-  auto memrefElementType =
-      readOp.source().getType().cast<MemRefType>().getElementType();
+  auto shapedElementType =
+      readOp.source().getType().cast<ShapedType>().getElementType();
   auto tupleType = generateExtractSlicesOpResultType(
       sourceVectorType, targetShape, strides, builder);
   int64_t numSlices = tupleType.size();
@@ -618,7 +618,7 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
         readOp.permutation_map(), readOp.padding(),
         readOp.masked() ? *readOp.masked() : ArrayAttr());
   };
-  generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
+  generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
                            targetShape, strides, indices, builder, createSlice);
 
   // Create tuple of splice transfer read operations.
@@ -634,7 +634,8 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
 // Entry point for unrolling declarative pattern rewrite for transfer_write op.
 LogicalResult
 mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
-                                    ArrayRef<int64_t> targetShape) {
+                                    ArrayRef<int64_t> targetShape,
+                                    SmallVector<Value, 1> &result) {
   auto writeOp = cast<vector::TransferWriteOp>(op);
   if (!isIdentitySuffix(writeOp.permutation_map()))
     return failure();
@@ -645,20 +646,28 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
   Location loc = writeOp.getLoc();
   Value tuple = builder.create<vector::ExtractSlicesOp>(
       loc, tupleType, writeOp.vector(), targetShape, strides);
-  auto memrefElementType =
-      writeOp.source().getType().cast<MemRefType>().getElementType();
+  auto shapedElementType =
+      writeOp.source().getType().cast<ShapedType>().getElementType();
   SmallVector<Value, 4> indices(writeOp.indices().begin(),
                                 writeOp.indices().end());
+  // If the TransferWrite returns a tensor, keep track of the last tensor
+  // created.
+  Value resultTensor;
   auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
     auto element = builder.create<vector::TupleGetOp>(
         loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index));
-    builder.create<vector::TransferWriteOp>(
-        loc, element.getResult(), writeOp.source(), sliceIndices,
+    Operation *write = builder.create<vector::TransferWriteOp>(
+        loc, element.getResult(),
+        resultTensor ? resultTensor : writeOp.source(), sliceIndices,
         writeOp.permutation_map(),
         writeOp.masked() ? *writeOp.masked() : ArrayAttr());
+    if (!write->getResults().empty())
+      resultTensor = write->getResult(0);
   };
-  generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType,
+  generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
                            targetShape, strides, indices, builder, createSlice);
+  if (resultTensor)
+    result.push_back(resultTensor);
   return success();
 }
 
@@ -761,25 +770,32 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
     insertSlicesOp.getStrides(strides);
 
     Location loc = xferWriteOp.getLoc();
-    auto memrefElementType =
-        xferWriteOp.source().getType().cast<MemRefType>().getElementType();
+    auto shapedElementType =
+        xferWriteOp.source().getType().cast<ShapedType>().getElementType();
     SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
                                   xferWriteOp.indices().end());
+    Value resultTensor;
     auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
       // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
       // `masked` attribute propagates conservatively: if the coarse op didn't
       // need masking, the fine op doesn't either.
-      rewriter.create<vector::TransferWriteOp>(
-          loc, tupleOp.getOperand(index), xferWriteOp.source(), sliceIndices,
+      Operation *write = rewriter.create<vector::TransferWriteOp>(
+          loc, tupleOp.getOperand(index),
+          resultTensor ? resultTensor : xferWriteOp.source(), sliceIndices,
           xferWriteOp.permutation_map(),
           xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
+      if (!write->getResults().empty())
+        resultTensor = write->getResult(0);
     };
-    generateTransferOpSlices(memrefElementType, resultVectorType,
+    generateTransferOpSlices(shapedElementType, resultVectorType,
                              sourceTupleType, sizes, strides, indices, rewriter,
                              createSlice);
 
     // Erase old 'xferWriteOp'.
-    rewriter.eraseOp(xferWriteOp);
+    if (resultTensor)
+      rewriter.replaceOp(xferWriteOp, ArrayRef<Value>(resultTensor));
+    else
+      rewriter.eraseOp(xferWriteOp);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index b676700dae06..d5e9535acb8e 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -58,3 +58,65 @@ func @transfer_readwrite_unroll(%arg0 : memref<4x4xf32>) {
   vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
   return
 }
+
+// CHECK-LABEL: func @transfer_read_unroll_tensor
+//       CHECK:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[C2:.*]] = constant 2 : index
+//       CHECK:   %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
+//  CHECK-NEXT:   return %[[VEC]] : vector<4x4xf32>
+
+func @transfer_read_unroll_tensor(%arg0 : tensor<4x4xf32>) -> vector<4x4xf32> {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : tensor<4x4xf32>, vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @transfer_write_unroll_tensor
+//       CHECK:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[C2:.*]] = constant 2 : index
+//       CHECK:   %[[TUPL:.*]] = vector.extract_slices {{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+//  CHECK-NEXT:   %[[T0:.*]] = vector.tuple_get %[[TUPL]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+//  CHECK-NEXT:   %[[VTW0:.*]] = vector.transfer_write %[[T0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+//  CHECK-NEXT:   %[[T1:.*]] = vector.tuple_get %[[TUPL]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+//  CHECK-NEXT:   %[[VTW1:.*]] = vector.transfer_write %[[T1]], %[[VTW0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+//  CHECK-NEXT:   %[[T2:.*]] = vector.tuple_get %[[TUPL]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+//  CHECK-NEXT:   %[[VTW2:.*]] = vector.transfer_write %[[T2]], %[[VTW1]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+//  CHECK-NEXT:   %[[T3:.*]] = vector.tuple_get %[[TUPL]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
+//  CHECK-NEXT:   %[[VTW3:.*]] = vector.transfer_write %[[T3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+//  CHECK-NEXT:   return %[[VTW3]] : tensor<4x4xf32>
+
+func @transfer_write_unroll_tensor(%arg0 : tensor<4x4xf32>,
+  %arg1 : vector<4x4xf32>) -> tensor<4x4xf32> {
+  %c0 = constant 0 : index
+  %r = vector.transfer_write %arg1, %arg0[%c0, %c0] :
+    vector<4x4xf32>, tensor<4x4xf32>
+  return %r: tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func @transfer_readwrite_unroll_tensor
+//       CHECK:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[C2:.*]] = constant 2 : index
+//       CHECK:   %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTW0:.*]] = vector.transfer_write %[[VTR0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+//  CHECK-NEXT:   %[[VTW1:.*]] = vector.transfer_write %[[VTR1]], %[[VTW0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+//  CHECK-NEXT:   %[[VTW2:.*]] = vector.transfer_write %[[VTR2]], %[[VTW1]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+//  CHECK-NEXT:   %[[VTW3:.*]] = vector.transfer_write %[[VTR3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
+//  CHECK-NEXT:   return %[[VTW3]] : tensor<4x4xf32>
+
+func @transfer_readwrite_unroll_tensor(%arg0 : tensor<4x4xf32>) ->
+  tensor<4x4xf32> {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : tensor<4x4xf32>, vector<4x4xf32>
+  %r = vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32>
+  return %r: tensor<4x4xf32>
+}

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 43a83f04dd30..4a58261c9672 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -530,6 +530,14 @@ func @shape_cast_fold(%arg0 : vector<5x4x2xf32>, %arg1 : vector<3x4x2xf32>)
 //       CHECK:   %[[CMP1:.*]] = cmpf "ult", %[[VT1]], %[[VT5]] : vector<2x2xf32>
 //       CHECK:   %[[CMP2:.*]] = cmpf "ult", %[[VT2]], %[[VT6]] : vector<2x2xf32>
 //       CHECK:   %[[CMP3:.*]] = cmpf "ult", %[[VT3]], %[[VT7]] : vector<2x2xf32>
+//       CHECK:   %[[VT0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//       CHECK:   %[[VT1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//       CHECK:   %[[VT2:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//       CHECK:   %[[VT3:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//       CHECK:   %[[VT4:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//       CHECK:   %[[VT5:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//       CHECK:   %[[VT6:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
+//       CHECK:   %[[VT7:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32>
 //       CHECK:   %[[SEL0:.*]] = select %[[CMP0]], %[[VT0]], %[[VT4]] : vector<2x2xi1>, vector<2x2xf32>
 //       CHECK:   %[[SEL1:.*]] = select %[[CMP1]], %[[VT1]], %[[VT5]] : vector<2x2xi1>, vector<2x2xf32>
 //       CHECK:   %[[SEL2:.*]] = select %[[CMP2]], %[[VT2]], %[[VT6]] : vector<2x2xi1>, vector<2x2xf32>
@@ -544,7 +552,52 @@ func @elementwise_unroll(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) {
   %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
   %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
   %cond = cmpf "ult", %0, %1 : vector<4x4xf32>
-  %2 = select %cond, %0, %1 : vector<4x4xi1>, vector<4x4xf32>
-  vector.transfer_write %2, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
+  // Vector transfer split pattern only support single user right now.
+  %2 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
+  %3 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32>
+  %4 = select %cond, %2, %3 : vector<4x4xi1>, vector<4x4xf32>
+  vector.transfer_write %4, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32>
   return
 }
+
+// Check that vector.transfer read/write are split based on contract unrolling.
+//      CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32>
+
+// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32>
+
+// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
+
+// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
+
+// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[R1]], %[[VTW0]][%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[R2]], %[[VTW1]][%[[C2]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[R3]], %[[VTW2]][%[[C2]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
+// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32>
+
+func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
+                                          %arg1 : tensor<2x4xf32>,
+                                          %arg2 : tensor<4x4xf32>) ->
+  tensor<4x4xf32> {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 :
+    tensor<4x2xf32>, vector<4x2xf32>
+  %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 :
+    tensor<2x4xf32>, vector<2x4xf32>
+  %2 = vector.transfer_read %arg2[%c0, %c0], %cf0 :
+    tensor<4x4xf32>, vector<4x4xf32>
+  %3 = vector.contract #contraction_trait1 %0, %1, %2
+      : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
+  %r = vector.transfer_write %3, %arg2[%c0, %c0]
+      : vector<4x4xf32>, tensor<4x4xf32>
+  return %r : tensor<4x4xf32>
+}

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index f219ef04fce5..572cd1cd68f1 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -28,7 +28,9 @@ struct TestVectorToVectorConversion
     OwningRewritePatternList patterns;
     auto *ctx = &getContext();
     patterns.insert<UnrollVectorPattern>(
-        ctx, UnrollVectorOptions().setNativeShapeFn(getShape));
+        ctx,
+        UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
+            filter));
     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
     populateVectorToVectorTransformationPatterns(patterns, ctx);
     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
@@ -39,13 +41,14 @@ struct TestVectorToVectorConversion
   static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
     if (isa<AddFOp, SelectOp, CmpFOp>(op))
       return SmallVector<int64_t, 4>(2, 2);
-    if (auto transferOp = dyn_cast<VectorTransferOpInterface>(op)) {
-      return SmallVector<int64_t, 4>(transferOp.getVectorType().getRank(), 2);
-    }
     if (isa<vector::ContractionOp>(op))
       return SmallVector<int64_t, 4>(3, 2);
     return llvm::None;
   }
+
+  static LogicalResult filter(Operation *op) {
+    return success(isa<AddFOp, SelectOp, CmpFOp, ContractionOp>(op));
+  }
 };
 
 struct TestVectorSlicesConversion


        


More information about the Mlir-commits mailing list