[Mlir-commits] [mlir] f4ac950 - Generalize the vector transfer flattening patterns (dyn shapes).
Benoit Jacob
llvmlistbot at llvm.org
Mon Jul 25 08:59:18 PDT 2022
Author: Benoit Jacob
Date: 2022-07-25T15:59:08Z
New Revision: f4ac950957f58c703c347474b358b7a8802d02fe
URL: https://github.com/llvm/llvm-project/commit/f4ac950957f58c703c347474b358b7a8802d02fe
DIFF: https://github.com/llvm/llvm-project/commit/f4ac950957f58c703c347474b358b7a8802d02fe.diff
LOG: Generalize the vector transfer flattening patterns (dyn shapes).
Differential Revision: https://reviews.llvm.org/D130284
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 6cddef218ca63..9125aae4ccb9b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -339,23 +339,71 @@ class TransferWriteDropUnitDimsPattern
}
};
-/// Creates a memref.collapse_shape collapsing all of the dimensions of the
-/// input into a 1D shape.
-// TODO: move helper function
-static Value collapseContiguousRowMajorMemRefTo1D(PatternRewriter &rewriter,
- mlir::Location loc,
- Value input) {
- Value rankReducedInput =
- rankReducingSubviewDroppingUnitDims(rewriter, loc, input);
- ShapedType rankReducedInputType =
- rankReducedInput.getType().cast<ShapedType>();
- if (rankReducedInputType.getRank() == 1)
- return rankReducedInput;
- ReassociationIndices indices;
- for (int i = 0; i < rankReducedInputType.getRank(); ++i)
- indices.push_back(i);
- return rewriter.create<memref::CollapseShapeOp>(
- loc, rankReducedInput, std::array<ReassociationIndices, 1>{indices});
+/// Returns the position of the first inner dimension that has contiguous layout
+/// with at least `requiredContiguousSize` contiguous elements.
+/// When such a dimension is found, the return value satisfies:
+/// 0 <= return_value <= memrefType.getRank() - 1.
+/// When no such dimension is found, the return value is memrefType.getRank().
+static int64_t getContiguousInnerDim(MemRefType memrefType,
+ int64_t requiredContiguousSize) {
+ auto shape = memrefType.getShape();
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ int64_t innerDim = shape.size();
+ if (succeeded(getStridesAndOffset(memrefType, strides, offset))) {
+ int64_t innerSize = 1;
+ while (true) {
+ if (innerDim == 0)
+ break;
+ const int64_t nextDim = innerDim - 1;
+ if (shape[nextDim] == ShapedType::kDynamicSize)
+ break;
+ if (strides[nextDim] != innerSize)
+ break;
+ innerSize *= shape[nextDim];
+ innerDim = nextDim;
+ if (innerSize >= requiredContiguousSize)
+ break;
+ }
+ }
+ return innerDim;
+}
+
+/// Creates a memref.collapse_shape collapsing all inner dimensions of the
+/// input starting at `firstDimToCollapse`.
+static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
+ Value input, int64_t firstDimToCollapse) {
+ ShapedType inputType = input.getType().cast<ShapedType>();
+ if (inputType.getRank() == 1)
+ return input;
+ SmallVector<ReassociationIndices> reassociation;
+ for (int64_t i = 0; i < firstDimToCollapse; ++i)
+ reassociation.push_back(ReassociationIndices{i});
+ ReassociationIndices collapsedIndices;
+ for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
+ collapsedIndices.push_back(i);
+ reassociation.push_back(collapsedIndices);
+ return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
+}
+
+/// Checks that the indices corresponding to dimensions starting at
+/// `firstDimToCollapse` are constant 0, and writes to `outIndices`
+/// the truncated indices where `firstDimToCollapse` is now the innermost dim.
+static LogicalResult
+checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
+ SmallVector<Value> &outIndices) {
+ int64_t rank = indices.size();
+ if (firstDimToCollapse >= rank)
+ return failure();
+ for (int64_t i = firstDimToCollapse; i < rank; ++i) {
+ arith::ConstantIndexOp cst =
+ indices[i].getDefiningOp<arith::ConstantIndexOp>();
+ if (!cst || cst.value() != 0)
+ return failure();
+ }
+ outIndices = indices;
+ outIndices.resize(firstDimToCollapse + 1);
+ return success();
}
/// Rewrites contiguous row-major vector.transfer_read ops by inserting
@@ -379,12 +427,9 @@ class FlattenContiguousRowMajorTransferReadPattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
- if (!isStaticShapeAndContiguousRowMajor(sourceType))
- return failure();
- if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
- // This pattern requires the source to already be rank-reduced.
- return failure();
- if (sourceType.getNumElements() != vectorType.getNumElements())
+ int64_t firstContiguousInnerDim =
+ getContiguousInnerDim(sourceType, vectorType.getNumElements());
+ if (firstContiguousInnerDim >= sourceType.getRank() - 1)
return failure();
// TODO: generalize this pattern, relax the requirements here.
if (transferReadOp.hasOutOfBoundsDim())
@@ -393,19 +438,28 @@ class FlattenContiguousRowMajorTransferReadPattern
return failure();
if (transferReadOp.getMask())
return failure();
- if (llvm::any_of(transferReadOp.getIndices(),
- [](Value v) { return !isZero(v); }))
+ SmallVector<Value> collapsedIndices;
+ if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
+ firstContiguousInnerDim,
+ collapsedIndices)))
return failure();
- Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
- VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
- sourceType.getElementType());
- Value source1d =
- collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
- Value read1d = rewriter.create<vector::TransferReadOp>(
- loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D);
+ Value collapsedSource =
+ collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
+ MemRefType collapsedSourceType =
+ collapsedSource.getType().dyn_cast<MemRefType>();
+ int64_t collapsedRank = collapsedSourceType.getRank();
+ assert(collapsedRank == firstContiguousInnerDim + 1);
+ SmallVector<AffineExpr, 1> dimExprs{
+ getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
+ auto collapsedMap =
+ AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+ VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
+ vectorType.getElementType());
+ vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
+ loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
+ flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
- transferReadOp, vector.getType().cast<VectorType>(), read1d);
+ transferReadOp, vector.getType().cast<VectorType>(), flatRead);
return success();
}
};
@@ -431,12 +485,9 @@ class FlattenContiguousRowMajorTransferWritePattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
- if (!isStaticShapeAndContiguousRowMajor(sourceType))
- return failure();
- if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
- // This pattern requires the source to already be rank-reduced.
- return failure();
- if (sourceType.getNumElements() != vectorType.getNumElements())
+ int64_t firstContiguousInnerDim =
+ getContiguousInnerDim(sourceType, vectorType.getNumElements());
+ if (firstContiguousInnerDim >= sourceType.getRank() - 1)
return failure();
// TODO: generalize this pattern, relax the requirements here.
if (transferWriteOp.hasOutOfBoundsDim())
@@ -445,19 +496,29 @@ class FlattenContiguousRowMajorTransferWritePattern
return failure();
if (transferWriteOp.getMask())
return failure();
- if (llvm::any_of(transferWriteOp.getIndices(),
- [](Value v) { return !isZero(v); }))
+ SmallVector<Value> collapsedIndices;
+ if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
+ firstContiguousInnerDim,
+ collapsedIndices)))
return failure();
- Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
- VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
- sourceType.getElementType());
- Value source1d =
- collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
- Value vector1d =
- rewriter.create<vector::ShapeCastOp>(loc, vectorType1d, vector);
- rewriter.create<vector::TransferWriteOp>(loc, vector1d, source1d,
- ValueRange{c0}, identityMap1D);
+ Value collapsedSource =
+ collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
+ MemRefType collapsedSourceType =
+ collapsedSource.getType().cast<MemRefType>();
+ int64_t collapsedRank = collapsedSourceType.getRank();
+ assert(collapsedRank == firstContiguousInnerDim + 1);
+ SmallVector<AffineExpr, 1> dimExprs{
+ getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
+ auto collapsedMap =
+ AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+ VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
+ vectorType.getElementType());
+ Value flatVector =
+ rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
+ vector::TransferWriteOp flatWrite =
+ rewriter.create<vector::TransferWriteOp>(
+ loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
+ flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
rewriter.eraseOp(transferWriteOp);
return success();
}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 8e15ab48c1750..cd55222dddcd9 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -59,3 +59,48 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
// CHECK: %[[CST:.+]] = arith.constant 0 : i8
// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
// CHECK: return %[[READ]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)>
+
+func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref<?x?x8x4xi8, #map0>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> {
+ %c0_i8 = arith.constant 0 : i8
+ %c0 = arith.constant 0 : index
+ %result = vector.transfer_read %arg0[%arg1, %arg2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x8x4xi8, #map0>, vector<8x4xi8>
+ return %result : vector<8x4xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices
+// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
+// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME: {in_bounds = [true]}
+// CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
+// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
+// CHECK: return %[[VEC2D]] : vector<8x4xi8>
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)>
+
+func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref<?x?x8x4xi8, #map0>, %arg1 : index, %arg2 : index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %dst[%arg1, %arg2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref<?x?x8x4xi8, #map0>
+ return
+}
+
+// CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices
+// CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
+// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
+// CHECK-SAME: {in_bounds = [true]}
+// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
More information about the Mlir-commits
mailing list