[Mlir-commits] [mlir] [mlir][Vector] Add fold transpose(shape_cast) -> shape_cast (PR #73951)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 30 07:12:10 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
This folds transpose(shape_cast) into a new shape_cast, when the transpose just permutes a unit dim from the result of the shape_cast.
Example:
```
%0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
%1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
```
Folds to:
```
vector.shape_cast %vec : vector<[4]xf32> to vector<1x[4]xf32>
```
This is an (alternate) fix for lowering matmuls to ArmSME.
---
Full diff: https://github.com/llvm/llvm-project/pull/73951.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+44-1)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c462b23e1133fc9..cf006adaee72a25 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5548,12 +5548,55 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
+/// Folds transpose(shape_cast) into a new shape_cast, when the transpose just
+/// permutes a unit dim from the result of the shape_cast.
+class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TransposeOp transpOp,
+ PatternRewriter &rewriter) const override {
+ Value transposeSrc = transpOp.getVector();
+ auto shapeCastOp = transposeSrc.getDefiningOp<vector::ShapeCastOp>();
+ if (!shapeCastOp)
+ return failure();
+
+ auto sourceType = transpOp.getSourceVectorType();
+ auto resultType = transpOp.getResultVectorType();
+
+ auto filterUnitDims = [](VectorType type) {
+ return llvm::make_filter_range(
+ llvm::zip_equal(type.getShape(), type.getScalableDims()),
+ [&](auto dim) {
+ auto [size, isScalble] = dim;
+ return size != 1 || isScalble;
+ });
+ };
+
+ auto sourceWithoutUnitDims = filterUnitDims(sourceType);
+ auto resultWithoutUnitDims = filterUnitDims(sourceType);
+
+ // If this transpose just permutes a unit dim, then we can fold it into the
+ // shape_cast.
+ for (auto [srcDim, resDim] :
+ llvm::zip_equal(sourceWithoutUnitDims, resultWithoutUnitDims)) {
+ if (srcDim != resDim)
+ return failure();
+ }
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resultType,
+ shapeCastOp.getSource());
+
+ return success();
+ };
+};
+
} // namespace
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
- TransposeFolder, FoldTransposeSplat>(context);
+ TransposeFolder, FoldTransposeSplat, FoldTransposeShapeCast>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d341..6bfb477ecf97285 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -67,6 +67,18 @@ func.func @create_mask_transpose_to_transposed_create_mask(
// -----
+// CHECK-LABEL: transposed_unit_dim_shape_cast_to_shape_cast
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @transposed_unit_dim_shape_cast_to_shape_cast(%vec: vector<[4]xf32>) -> vector<1x[4]xf32> {
+ // CHECK: vector.shape_cast %[[VEC]] : vector<[4]xf32> to vector<1x[4]xf32>
+ // CHECK-NOT: vector.transpose
+ %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
+ %1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+ return %1 : vector<1x[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: extract_from_create_mask
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
``````````
</details>
https://github.com/llvm/llvm-project/pull/73951
More information about the Mlir-commits
mailing list