[Mlir-commits] [mlir] da8778e - [mlir][vector] Add pattern to drop unit dims from vector.transpose (#102017)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 8 03:42:30 PDT 2024
Author: Benjamin Maxwell
Date: 2024-08-08T11:42:27+01:00
New Revision: da8778e499d8049ac68c2e152941a38ff2bc9fb2
URL: https://github.com/llvm/llvm-project/commit/da8778e499d8049ac68c2e152941a38ff2bc9fb2
DIFF: https://github.com/llvm/llvm-project/commit/da8778e499d8049ac68c2e152941a38ff2bc9fb2.diff
LOG: [mlir][vector] Add pattern to drop unit dims from vector.transpose (#102017)
Example:
BEFORE:
```mlir
%transpose = vector.transpose %vector, [3, 0, 1, 2]
: vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
```
AFTER:
```mlir
%dropDims = vector.shape_cast %vector
: vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%transpose = vector.transpose %0, [1, 0]
: vector<4x[4]xf32> to vector<[4]x4xf32>
%restoreDims = vector.shape_cast %transpose
: vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
```
Added:
Modified:
mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 40e04b76593a0e..5f32aca88a2734 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -120,6 +120,11 @@ inline auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
};
}
+/// Returns a range over the dims (size and scalability) of a VectorType.
+inline auto getDims(VectorType vType) {
+ return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
+}
+
/// A wrapper for getMixedSizes for vector.transfer_read and
/// vector.transfer_write Ops (for source and destination, respectively).
///
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6777e589795c8e..55c1c6bad9f2a4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1720,6 +1720,72 @@ struct DropUnitDimFromElementwiseOps final
}
};
+/// A pattern to drop unit dims from vector.transpose.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %transpose = vector.transpose %vector, [3, 0, 1, 2]
+/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %dropDims = vector.shape_cast %vector
+/// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+/// %transpose = vector.transpose %0, [1, 0]
+/// : vector<4x[4]xf32> to vector<[4]x4xf32>
+/// %restoreDims = vector.shape_cast %transpose
+/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+/// ```
+struct DropUnitDimsFromTransposeOp final
+ : OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = op.getSourceVectorType();
+ VectorType sourceTypeWithoutUnitDims =
+ dropNonScalableUnitDimFromType(sourceType);
+
+ if (sourceType == sourceTypeWithoutUnitDims)
+ return failure();
+
+ // Construct a map from dimIdx -> number of dims dropped before dimIdx.
+ auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
+ SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
+ int64_t droppedDims = 0;
+ for (auto [i, dim] : llvm::enumerate(sourceDims)) {
+ droppedDimsBefore[i] = droppedDims;
+ if (dim == std::make_tuple(1, false))
+ ++droppedDims;
+ }
+
+ // Drop unit dims from transpose permutation.
+ ArrayRef<int64_t> perm = op.getPermutation();
+ SmallVector<int64_t> newPerm;
+ for (int64_t idx : perm) {
+ if (sourceDims[idx] == std::make_tuple(1, false))
+ continue;
+ newPerm.push_back(idx - droppedDimsBefore[idx]);
+ }
+
+ Location loc = op.getLoc();
+ // Drop the unit dims via shape_cast.
+ auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
+ loc, sourceTypeWithoutUnitDims, op.getVector());
+ // Create the new transpose.
+ auto tranposeWithoutUnitDims =
+ rewriter.create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
+ // Restore the unit dims via shape cast.
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ op, op.getResultVectorType(), tranposeWithoutUnitDims);
+
+ return failure();
+ }
+};
+
/// Pattern to eliminate redundant zero-constants added to reduction operands.
/// It's enough for there to be one initial zero value, so we can eliminate the
/// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1924,8 +1990,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
- patterns.getContext(), benefit);
+ patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
+ ShapeCastOpFolder>(patterns.getContext(), benefit);
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 9d16aa46a9f2ad..937dbf22bb713f 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -700,3 +700,47 @@ func.func @negative_out_of_bound_transfer_write(
}
// CHECK: func.func @negative_out_of_bound_transfer_write
// CHECK-NOT: memref.collapse_shape
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// [Pattern: DropUnitDimsFromTransposeOp]
+/// TODO: Move to a dedicated file - there's no "flattening" in the following tests
+///----------------------------------------------------------------------------------------
+
+func.func @transpose_with_internal_unit_dims(%vec: vector<1x1x4x[4]xf32>) -> vector<[4]x1x1x4xf32> {
+ %res = vector.transpose %vec, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+ return %res : vector<[4]x1x1x4xf32>
+}
+
+// CHECK-LABEL: func.func @transpose_with_internal_unit_dims(
+// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
+// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<[4]x1x1x4xf32>
+
+// -----
+
+func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> vector<1x1x4x2x[1]xf32>
+{
+ %res = vector.transpose %vec, [4, 1, 3, 2, 0] : vector<[1]x1x2x4x1xf32> to vector<1x1x4x2x[1]xf32>
+ return %res: vector<1x1x4x2x[1]xf32>
+}
+
+// CHECK-LABEL: func.func @transpose_with_scalable_unit_dims(
+// CHECK-SAME: %[[VEC:.*]]: vector<[1]x1x2x4x1xf32>)
+// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %[[VEC]] : vector<[1]x1x2x4x1xf32> to vector<[1]x2x4xf32>
+// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[DROP_DIMS]], [2, 1, 0] : vector<[1]x2x4xf32> to vector<4x2x[1]xf32>
+// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %[[TRANSPOSE]] : vector<4x2x[1]xf32> to vector<1x1x4x2x[1]xf32>
+// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<1x1x4x2x[1]xf32>
+
+// -----
+
+func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> {
+ %res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
+ return %res : vector<4x3x2xf32>
+}
+
+// CHECK-LABEL: func.func @negative_transpose_with_no_unit_dims
+// CHECK-NOT: vector.shape_cast
More information about the Mlir-commits
mailing list