[Mlir-commits] [mlir] [mlir][ArmSME] Rewrite illegal `shape_casts` to `vector.transpose` ops (PR #82985)

Cullen Rhodes llvmlistbot at llvm.org
Tue Feb 27 07:23:46 PST 2024


================
@@ -556,6 +559,59 @@ struct LiftIllegalVectorTransposeToMemory
   }
 };
 
+/// A rewrite to turn unit dim transpose-like vector.shape_casts into
+/// vector.transposes. The shape_cast has to be from an illegal vector type to a
+/// legal one (as defined by isLegalVectorType).
+///
+/// The reasoning for this is if we've got to this pass and we still have
+/// shape_casts of illegal types, then they likely will not cancel out. Turning
+/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
+/// eliminate them.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+///  ```
+struct ConvertIllegalShapeCastOpsToTransposes
+    : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = shapeCastOp.getSourceVectorType();
+    auto resultType = shapeCastOp.getResultVectorType();
+    if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(shapeCastOp,
+                                         kMatchFailureNotIllegalToLegal);
+
+    // Note: If we know the that is source is an illegal vector type (and 2D)
+    // then dim 0 is scalable and dim 1 is fixed.
----------------
c-rhodes wrote:

```suggestion
    // Note: If we know that source is a legal vector type (and 2D)
    // then dim 0 is scalable and dim 1 is fixed.
```
?

https://github.com/llvm/llvm-project/pull/82985


More information about the Mlir-commits mailing list