[Mlir-commits] [mlir] fc4485b - Revert "[mlir][ArmSME] Pattern to swap shape_cast(tranpose) with transpose(shape_cast) (#100731)" (#102457)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 9 05:30:01 PDT 2024
Author: Benjamin Maxwell
Date: 2024-08-09T13:29:57+01:00
New Revision: fc4485bf98132c99edf4a0e5612d4309de9b9393
URL: https://github.com/llvm/llvm-project/commit/fc4485bf98132c99edf4a0e5612d4309de9b9393
DIFF: https://github.com/llvm/llvm-project/commit/fc4485bf98132c99edf4a0e5612d4309de9b9393.diff
LOG: Revert "[mlir][ArmSME] Pattern to swap shape_cast(tranpose) with transpose(shape_cast) (#100731)" (#102457)
This reverts commit 88accd9aaa20c6a30661c48cc2ca6dbbdf991ec0.
This change can be dropped in favor of just #102017.
Added:
Modified:
mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
mlir/test/Dialect/ArmSME/vector-legalization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 53df7af00aee88..4968c4fc463d04 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -774,94 +774,6 @@ struct ConvertIllegalShapeCastOpsToTransposes
}
};
-/// Returns an iterator over the dims (inc scalability) of a VectorType.
-static auto getDims(VectorType vType) {
- return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
-}
-
-/// Helper to drop (fixed-size) unit dims from a VectorType.
-static VectorType dropUnitDims(VectorType vType) {
- SmallVector<bool> scalableFlags;
- SmallVector<int64_t> dimSizes;
- for (auto dim : getDims(vType)) {
- if (dim == std::make_tuple(1, false))
- continue;
- auto [size, scalableFlag] = dim;
- dimSizes.push_back(size);
- scalableFlags.push_back(scalableFlag);
- }
- return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
-}
-
-/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
-/// shape_cast only drops unit dimensions.
-///
-/// This simplifies the transpose making it possible for other legalization
-/// rewrites to handle it.
-///
-/// Example:
-///
-/// BEFORE:
-/// ```mlir
-/// %0 = vector.transpose %vector, [3, 0, 1, 2]
-/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
-/// %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
-/// ```
-///
-/// AFTER:
-/// ```mlir
-/// %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
-/// %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
-/// ```
-struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
- PatternRewriter &rewriter) const override {
- auto transposeOp =
- shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
- if (!transposeOp)
- return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
-
- auto resultType = shapeCastOp.getResultVectorType();
- if (resultType.getRank() <= 1)
- return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
-
- if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
- return rewriter.notifyMatchFailure(
- shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
-
- auto transposeSourceVectorType = transposeOp.getSourceVectorType();
- auto transposeSourceDims =
- llvm::to_vector(getDims(transposeSourceVectorType));
-
- // Construct a map from dimIdx -> number of dims dropped before dimIdx.
- SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
- int64_t droppedDims = 0;
- for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
- droppedDimsBefore[i] = droppedDims;
- if (dim == std::make_tuple(1, false))
- ++droppedDims;
- }
-
- // Drop unit dims from transpose permutation.
- auto perm = transposeOp.getPermutation();
- SmallVector<int64_t> newPerm;
- for (int64_t idx : perm) {
- if (transposeSourceDims[idx] == std::make_tuple(1, false))
- continue;
- newPerm.push_back(idx - droppedDimsBefore[idx]);
- }
-
- auto loc = shapeCastOp.getLoc();
- auto newShapeCastOp = rewriter.create<vector::ShapeCastOp>(
- loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector());
- rewriter.replaceOpWithNewOp<vector::TransposeOp>(shapeCastOp,
- newShapeCastOp, newPerm);
- return success();
- }
-};
-
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
/// the ZA state. This workaround rewrite to support these transposes when ZA is
/// available.
@@ -1027,8 +939,7 @@ struct VectorLegalizationPass
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
ConvertIllegalShapeCastOpsToTransposes,
- SwapShapeCastOfTranspose, LowerIllegalTransposeStoreViaZA>(
- context);
+ LowerIllegalTransposeStoreViaZA>(context);
// Note: These two patterns are added with a high benefit to ensure:
// - Masked outer products are handled before unmasked ones
// - Multi-tile writes are lowered as a store loop (if possible)
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index adc02adb6e974c..458906a1879829 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -646,29 +646,3 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref<?x?xf32>
return
}
-
-// -----
-
-// CHECK-LABEL: @swap_shape_cast_of_transpose(
-// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
-func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> {
- // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
- // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
- // CHECK: return %[[TRANSPOSE]]
- %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
- %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
- return %1 : vector<[4]x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after(
-// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
-func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> {
- // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
- // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
- // CHECK: return %[[TRANSPOSE]]
- %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
- %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32>
- return %1 : vector<[4]x4xf32>
-}
More information about the Mlir-commits
mailing list