[Mlir-commits] [mlir] f44c76d - [mlir][vector] Extend vector transfer unrolling to support permutations and broadcast

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 3 10:48:59 PDT 2021


Author: thomasraoux
Date: 2021-05-03T10:47:02-07:00
New Revision: f44c76d6e919641655615d62ea8b432175571a0b

URL: https://github.com/llvm/llvm-project/commit/f44c76d6e919641655615d62ea8b432175571a0b
DIFF: https://github.com/llvm/llvm-project/commit/f44c76d6e919641655615d62ea8b432175571a0b.diff

LOG: [mlir][vector] Extend vector transfer unrolling to support permutations and broadcast

Differential Revision: https://reviews.llvm.org/D101637

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-unroll.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 75017973d30f..2c8b33795de4 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -516,10 +516,12 @@ static void getVectorElementwiseOpUnrollState(Operation *op,
 
 /// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
 /// calls 'fn' with linear index and indices for each slice.
-static void generateTransferOpSlices(
-    Type shapedElementType, VectorType vectorType, TupleType tupleType,
-    ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<Value> indices,
-    OpBuilder &builder, function_ref<void(unsigned, ArrayRef<Value>)> fn) {
+static void
+generateTransferOpSlices(Type shapedElementType, VectorType vectorType,
+                         TupleType tupleType, ArrayRef<int64_t> sizes,
+                         ArrayRef<int64_t> strides, ArrayRef<Value> indices,
+                         AffineMap permutationMap, OpBuilder &builder,
+                         function_ref<void(unsigned, ArrayRef<Value>)> fn) {
   // Compute strides w.r.t. to slice counts in each dimension.
   auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
   assert(maybeDimSliceCounts.hasValue());
@@ -527,7 +529,6 @@ static void generateTransferOpSlices(
   auto sliceStrides = computeStrides(sliceDimCounts);
 
   int64_t numSlices = tupleType.size();
-  unsigned numSliceIndices = indices.size();
   // Compute 'indexOffset' at which to update 'indices', which is equal
   // to the memref rank (indices.size) minus the effective 'vectorRank'.
   // The effective 'vectorRank', is equal to the rank of the vector type
@@ -545,57 +546,38 @@ static void generateTransferOpSlices(
     assert(vectorRank >= sourceVectorElementType.getRank());
     vectorRank -= sourceVectorElementType.getRank();
   }
-  unsigned indexOffset = numSliceIndices - vectorRank;
-
+  auto isBroadcast = [](AffineExpr expr) {
+    if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
+      return constExpr.getValue() == 0;
+    return false;
+  };
   auto *ctx = builder.getContext();
   for (unsigned i = 0; i < numSlices; ++i) {
     auto vectorOffsets = delinearize(sliceStrides, i);
     auto elementOffsets =
         computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
     // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
-    SmallVector<Value, 4> sliceIndices(numSliceIndices);
-    for (unsigned j = 0; j < numSliceIndices; ++j) {
-      if (j < indexOffset) {
-        sliceIndices[j] = indices[j];
-      } else {
-        auto expr = getAffineDimExpr(0, ctx) +
-                    getAffineConstantExpr(elementOffsets[j - indexOffset], ctx);
-        auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
-        sliceIndices[j] = builder.create<AffineApplyOp>(
-            indices[j].getLoc(), map, ArrayRef<Value>(indices[j]));
-      }
+    SmallVector<Value, 4> sliceIndices(indices.begin(), indices.end());
+    for (auto dim : llvm::enumerate(permutationMap.getResults())) {
+      if (isBroadcast(dim.value()))
+        continue;
+      unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
+      auto expr = getAffineDimExpr(0, ctx) +
+                  getAffineConstantExpr(elementOffsets[dim.index()], ctx);
+      auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
+      sliceIndices[pos] = builder.create<AffineApplyOp>(
+          indices[pos].getLoc(), map, ArrayRef<Value>(indices[pos]));
     }
     // Call 'fn' to generate slice 'i' at 'sliceIndices'.
     fn(i, sliceIndices);
   }
 }
 
-/// Returns true if 'map' is a suffix of an identity affine map, false
-/// otherwise. Example: affine_map<(d0, d1, d2, d3) -> (d2, d3)>
-static bool isIdentitySuffix(AffineMap map) {
-  if (map.getNumDims() < map.getNumResults())
-    return false;
-  ArrayRef<AffineExpr> results = map.getResults();
-  Optional<int> lastPos;
-  for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
-    auto expr = results[i].dyn_cast<AffineDimExpr>();
-    if (!expr)
-      return false;
-    int currPos = static_cast<int>(expr.getPosition());
-    if (lastPos.hasValue() && currPos != lastPos.getValue() + 1)
-      return false;
-    lastPos = currPos;
-  }
-  return true;
-}
-
 /// Unroll transfer_read ops to the given shape and create an aggregate with all
 /// the chunks.
 static Value unrollTransferReadOp(vector::TransferReadOp readOp,
                                   ArrayRef<int64_t> targetShape,
                                   OpBuilder &builder) {
-  if (!isIdentitySuffix(readOp.permutation_map()))
-    return nullptr;
   if (readOp.mask())
     return nullptr;
   auto sourceVectorType = readOp.getVectorType();
@@ -623,7 +605,8 @@ static Value unrollTransferReadOp(vector::TransferReadOp readOp,
         readOp.in_bounds() ? *readOp.in_bounds() : ArrayAttr());
   };
   generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
-                           targetShape, strides, indices, builder, createSlice);
+                           targetShape, strides, indices,
+                           readOp.permutation_map(), builder, createSlice);
 
   // Create tuple of splice transfer read operations.
   Value tupleOp =
@@ -641,8 +624,6 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
                                     ArrayRef<int64_t> targetShape,
                                     SmallVector<Value, 1> &result) {
   auto writeOp = cast<vector::TransferWriteOp>(op);
-  if (!isIdentitySuffix(writeOp.permutation_map()))
-    return failure();
   if (writeOp.mask())
     return failure();
   VectorType sourceVectorType = writeOp.getVectorType();
@@ -671,7 +652,8 @@ mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op,
       resultTensor = write->getResult(0);
   };
   generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType,
-                           targetShape, strides, indices, builder, createSlice);
+                           targetShape, strides, indices,
+                           writeOp.permutation_map(), builder, createSlice);
   if (resultTensor)
     result.push_back(resultTensor);
   return success();
@@ -729,11 +711,6 @@ class SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
     if (readOp.mask())
       return failure();
 
-    // TODO: Support splitting TransferReadOp with non-identity permutation
-    // maps. Repurpose code from MaterializeVectors transformation.
-    if (!isIdentitySuffix(readOp.permutation_map()))
-      return failure();
-
     // Return unless there is only one user, and it is an ExtractSlicesOp.
     Value readResult = readOp.getResult();
     if (!readResult.hasOneUse())
@@ -778,11 +755,6 @@ class SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
     if (writeOp.mask())
       return failure();
 
-    // TODO: Support splitting TransferWriteOp with non-identity permutation
-    // maps. Repurpose code from MaterializeVectors transformation.
-    if (!isIdentitySuffix(writeOp.permutation_map()))
-      return failure();
-
     // Fail to match unless this is writing a vector resulting from an
     // InsertSlicesOp.
     auto insertSlicesOp =
@@ -821,8 +793,8 @@ class SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
         resultTensor = write->getResult(0);
     };
     generateTransferOpSlices(shapedElementType, resultVectorType,
-                             sourceTupleType, sizes, strides, indices, rewriter,
-                             createSlice);
+                             sourceTupleType, sizes, strides, indices,
+                             writeOp.permutation_map(), rewriter, createSlice);
 
     if (resultTensor)
       rewriter.replaceOp(writeOp, ArrayRef<Value>(resultTensor));

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index d63809c3063b..0929031396ec 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns | FileCheck %s
+// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns --split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @transfer_read_unroll
 //       CHECK-DAG:   %[[C2:.*]] = constant 2 : index
@@ -120,3 +120,94 @@ func @transfer_readwrite_unroll_tensor(%arg0 : tensor<4x4xf32>, %arg1 : tensor<4
   %r = vector.transfer_write %0, %arg1[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32>
   return %r: tensor<4x4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_unroll_permutation
+//       CHECK-DAG:   %[[C4:.*]] = constant 4 : index
+//       CHECK-DAG:   %[[C2:.*]] = constant 2 : index
+//       CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x6xf32>
+//  CHECK-NEXT:   return %[[VEC]] : vector<4x6xf32>
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+func @transfer_read_unroll_permutation(%arg0 : memref<6x4xf32>) -> vector<4x6xf32> {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {permutation_map = #map0} : memref<6x4xf32>, vector<4x6xf32>
+  return %0 : vector<4x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_unroll_broadcast
+//       CHECK-DAG:   %[[C2:.*]] = constant 2 : index
+//       CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<6x4xf32>
+//  CHECK-NEXT:   return %[[VEC]] : vector<6x4xf32>
+#map0 = affine_map<(d0, d1) -> (0, d1)>
+func @transfer_read_unroll_broadcast(%arg0 : memref<6x4xf32>) -> vector<6x4xf32> {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {permutation_map = #map0} : memref<6x4xf32>, vector<6x4xf32>
+  return %0 : vector<6x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_unroll_broadcast_permuation
+//       CHECK-DAG:   %[[C4:.*]] = constant 4 : index
+//       CHECK-DAG:   %[[C2:.*]] = constant 2 : index
+//       CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x6xf32>
+//  CHECK-NEXT:   return %[[VEC]] : vector<4x6xf32>
+#map0 = affine_map<(d0, d1) -> (0, d0)>
+func @transfer_read_unroll_broadcast_permuation(%arg0 : memref<6x4xf32>) -> vector<4x6xf32> {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {permutation_map = #map0} : memref<6x4xf32>, vector<4x6xf32>
+  return %0 : vector<4x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_unroll_
diff erent_rank
+//       CHECK-DAG:   %[[C4:.*]] = constant 4 : index
+//       CHECK-DAG:   %[[C2:.*]] = constant 2 : index
+//       CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+//       CHECK:   %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C0]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C2]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]], %[[C4]]], %{{.*}} : memref<?x?x?xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]], %[[VTR4]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
+//  CHECK-NEXT:   %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<6x4xf32>
+//  CHECK-NEXT:   return %[[VEC]] : vector<6x4xf32>
+#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
+func @transfer_read_unroll_
diff erent_rank(%arg0 : memref<?x?x?xf32>) -> vector<6x4xf32> {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %cf0 {permutation_map = #map0} : memref<?x?x?xf32>, vector<6x4xf32>
+  return %0 : vector<6x4xf32>
+}


        


More information about the Mlir-commits mailing list