[Mlir-commits] [mlir] f4ac950 - Generalize the vector transfer flattening patterns (dyn shapes).

Benoit Jacob llvmlistbot at llvm.org
Mon Jul 25 08:59:18 PDT 2022


Author: Benoit Jacob
Date: 2022-07-25T15:59:08Z
New Revision: f4ac950957f58c703c347474b358b7a8802d02fe

URL: https://github.com/llvm/llvm-project/commit/f4ac950957f58c703c347474b358b7a8802d02fe
DIFF: https://github.com/llvm/llvm-project/commit/f4ac950957f58c703c347474b358b7a8802d02fe.diff

LOG: Generalize the vector transfer flattening patterns (dyn shapes).

Differential Revision: https://reviews.llvm.org/D130284

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 6cddef218ca63..9125aae4ccb9b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -339,23 +339,71 @@ class TransferWriteDropUnitDimsPattern
   }
 };
 
-/// Creates a memref.collapse_shape collapsing all of the dimensions of the
-/// input into a 1D shape.
-// TODO: move helper function
-static Value collapseContiguousRowMajorMemRefTo1D(PatternRewriter &rewriter,
-                                                  mlir::Location loc,
-                                                  Value input) {
-  Value rankReducedInput =
-      rankReducingSubviewDroppingUnitDims(rewriter, loc, input);
-  ShapedType rankReducedInputType =
-      rankReducedInput.getType().cast<ShapedType>();
-  if (rankReducedInputType.getRank() == 1)
-    return rankReducedInput;
-  ReassociationIndices indices;
-  for (int i = 0; i < rankReducedInputType.getRank(); ++i)
-    indices.push_back(i);
-  return rewriter.create<memref::CollapseShapeOp>(
-      loc, rankReducedInput, std::array<ReassociationIndices, 1>{indices});
+/// Returns the position of the first inner dimension that has contiguous layout
+/// with at least `requiredContiguousSize` contiguous elements.
+/// When such a dimension is found, the return value satisfies:
+///   0 <= return_value <= memrefType.getRank() - 1.
+/// When no such dimension is found, the return value is memrefType.getRank().
+static int64_t getContiguousInnerDim(MemRefType memrefType,
+                                     int64_t requiredContiguousSize) {
+  auto shape = memrefType.getShape();
+  SmallVector<int64_t> strides;
+  int64_t offset;
+  int64_t innerDim = shape.size();
+  if (succeeded(getStridesAndOffset(memrefType, strides, offset))) {
+    int64_t innerSize = 1;
+    while (true) {
+      if (innerDim == 0)
+        break;
+      const int64_t nextDim = innerDim - 1;
+      if (shape[nextDim] == ShapedType::kDynamicSize)
+        break;
+      if (strides[nextDim] != innerSize)
+        break;
+      innerSize *= shape[nextDim];
+      innerDim = nextDim;
+      if (innerSize >= requiredContiguousSize)
+        break;
+    }
+  }
+  return innerDim;
+}
+
+/// Creates a memref.collapse_shape collapsing all inner dimensions of the
+/// input starting at `firstDimToCollapse`.
+static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
+                               Value input, int64_t firstDimToCollapse) {
+  ShapedType inputType = input.getType().cast<ShapedType>();
+  if (inputType.getRank() == 1)
+    return input;
+  SmallVector<ReassociationIndices> reassociation;
+  for (int64_t i = 0; i < firstDimToCollapse; ++i)
+    reassociation.push_back(ReassociationIndices{i});
+  ReassociationIndices collapsedIndices;
+  for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
+    collapsedIndices.push_back(i);
+  reassociation.push_back(collapsedIndices);
+  return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
+}
+
+/// Checks that the indices corresponding to dimensions starting at
+/// `firstDimToCollapse` are constant 0, and writes to `outIndices`
+/// the truncated indices where `firstDimToCollapse` is now the innermost dim.
+static LogicalResult
+checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
+                                 SmallVector<Value> &outIndices) {
+  int64_t rank = indices.size();
+  if (firstDimToCollapse >= rank)
+    return failure();
+  for (int64_t i = firstDimToCollapse; i < rank; ++i) {
+    arith::ConstantIndexOp cst =
+        indices[i].getDefiningOp<arith::ConstantIndexOp>();
+    if (!cst || cst.value() != 0)
+      return failure();
+  }
+  outIndices = indices;
+  outIndices.resize(firstDimToCollapse + 1);
+  return success();
 }
 
 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
@@ -379,12 +427,9 @@ class FlattenContiguousRowMajorTransferReadPattern
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
-    if (!isStaticShapeAndContiguousRowMajor(sourceType))
-      return failure();
-    if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
-      // This pattern requires the source to already be rank-reduced.
-      return failure();
-    if (sourceType.getNumElements() != vectorType.getNumElements())
+    int64_t firstContiguousInnerDim =
+        getContiguousInnerDim(sourceType, vectorType.getNumElements());
+    if (firstContiguousInnerDim >= sourceType.getRank() - 1)
       return failure();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferReadOp.hasOutOfBoundsDim())
@@ -393,19 +438,28 @@ class FlattenContiguousRowMajorTransferReadPattern
       return failure();
     if (transferReadOp.getMask())
       return failure();
-    if (llvm::any_of(transferReadOp.getIndices(),
-                     [](Value v) { return !isZero(v); }))
+    SmallVector<Value> collapsedIndices;
+    if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
+                                                firstContiguousInnerDim,
+                                                collapsedIndices)))
       return failure();
-    Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
-    VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
-                                              sourceType.getElementType());
-    Value source1d =
-        collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
-    Value read1d = rewriter.create<vector::TransferReadOp>(
-        loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D);
+    Value collapsedSource =
+        collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
+    MemRefType collapsedSourceType =
+        collapsedSource.getType().dyn_cast<MemRefType>();
+    int64_t collapsedRank = collapsedSourceType.getRank();
+    assert(collapsedRank == firstContiguousInnerDim + 1);
+    SmallVector<AffineExpr, 1> dimExprs{
+        getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
+    auto collapsedMap =
+        AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+    VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
+                                                vectorType.getElementType());
+    vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
+        loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
+    flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
-        transferReadOp, vector.getType().cast<VectorType>(), read1d);
+        transferReadOp, vector.getType().cast<VectorType>(), flatRead);
     return success();
   }
 };
@@ -431,12 +485,9 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
-    if (!isStaticShapeAndContiguousRowMajor(sourceType))
-      return failure();
-    if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
-      // This pattern requires the source to already be rank-reduced.
-      return failure();
-    if (sourceType.getNumElements() != vectorType.getNumElements())
+    int64_t firstContiguousInnerDim =
+        getContiguousInnerDim(sourceType, vectorType.getNumElements());
+    if (firstContiguousInnerDim >= sourceType.getRank() - 1)
       return failure();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferWriteOp.hasOutOfBoundsDim())
@@ -445,19 +496,29 @@ class FlattenContiguousRowMajorTransferWritePattern
       return failure();
     if (transferWriteOp.getMask())
       return failure();
-    if (llvm::any_of(transferWriteOp.getIndices(),
-                     [](Value v) { return !isZero(v); }))
+    SmallVector<Value> collapsedIndices;
+    if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
+                                                firstContiguousInnerDim,
+                                                collapsedIndices)))
       return failure();
-    Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
-    VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
-                                              sourceType.getElementType());
-    Value source1d =
-        collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
-    Value vector1d =
-        rewriter.create<vector::ShapeCastOp>(loc, vectorType1d, vector);
-    rewriter.create<vector::TransferWriteOp>(loc, vector1d, source1d,
-                                             ValueRange{c0}, identityMap1D);
+    Value collapsedSource =
+        collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
+    MemRefType collapsedSourceType =
+        collapsedSource.getType().cast<MemRefType>();
+    int64_t collapsedRank = collapsedSourceType.getRank();
+    assert(collapsedRank == firstContiguousInnerDim + 1);
+    SmallVector<AffineExpr, 1> dimExprs{
+        getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
+    auto collapsedMap =
+        AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+    VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
+                                                vectorType.getElementType());
+    Value flatVector =
+        rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
+    vector::TransferWriteOp flatWrite =
+        rewriter.create<vector::TransferWriteOp>(
+            loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
+    flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
     rewriter.eraseOp(transferWriteOp);
     return success();
   }

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 8e15ab48c1750..cd55222dddcd9 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -59,3 +59,48 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
 // CHECK:            %[[CST:.+]] = arith.constant 0 : i8
 // CHECK:            %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
 // CHECK:            return %[[READ]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)>
+
+func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref<?x?x8x4xi8, #map0>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> {
+    %c0_i8 = arith.constant 0 : i8
+    %c0 = arith.constant 0 : index
+    %result = vector.transfer_read %arg0[%arg1, %arg2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x8x4xi8, #map0>, vector<8x4xi8>
+    return %result : vector<8x4xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices
+// CHECK-SAME:    %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
+// CHECK:       %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK:       %[[C0:.+]] = arith.constant 0 : index
+// CHECK:       %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK-SAME:    : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK:       %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// CHECK-SAME:    [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME:    {in_bounds = [true]}
+// CHECK-SAME:    : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
+// CHECK:       %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
+// CHECK:       return %[[VEC2D]] : vector<8x4xi8>
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)>
+
+func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref<?x?x8x4xi8, #map0>, %arg1 : index, %arg2 : index) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %vec, %dst[%arg1, %arg2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref<?x?x8x4xi8, #map0>
+    return
+}
+
+// CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices
+// CHECK-SAME:    %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
+// CHECK:       %[[C0:.+]] = arith.constant 0 : index
+// CHECK:       %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK-SAME:    : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK:       %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
+// CHECK:       vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+// CHECK-SAME:    [%[[ARG2]], %[[ARG3]], %[[C0]]]
+// CHECK-SAME:    {in_bounds = [true]}
+// CHECK-SAME:    : vector<32xi8>, memref<?x?x32xi8, {{.+}}>


        


More information about the Mlir-commits mailing list