[Mlir-commits] [mlir] [mlir][ArmSME] Rewrite illegal `shape_casts` to `vector.transpose` ops (PR #82985)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Mar 7 08:33:05 PST 2024
================
@@ -556,6 +559,59 @@ struct LiftIllegalVectorTransposeToMemory
}
};
+/// A rewrite to turn unit dim transpose-like vector.shape_casts into
+/// vector.transposes. The shape_cast has to be from an illegal vector type to a
+/// legal one (as defined by isLegalVectorType).
+///
+/// The reasoning for this is if we've got to this pass and we still have
+/// shape_casts of illegal types, then they likely will not cancel out. Turning
+/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
+/// eliminate them.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+/// ```
+struct ConvertIllegalShapeCastOpsToTransposes
+ : public OpRewritePattern<vector::ShapeCastOp> {
+ using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = shapeCastOp.getSourceVectorType();
+ auto resultType = shapeCastOp.getResultVectorType();
+ if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+ return rewriter.notifyMatchFailure(shapeCastOp,
+ kMatchFailureNotIllegalToLegal);
+
+ // Note: If we know the that is source is an illegal vector type (and 2D)
+ // then dim 0 is scalable and dim 1 is fixed.
----------------
MacDue wrote:
Fixed the typos :+1:, but the source vector is an illegal vector type.
https://github.com/llvm/llvm-project/pull/82985
More information about the Mlir-commits
mailing list