[Mlir-commits] [mlir] [mlir][ArmSME] Add rewrite to lift illegal vector.transposes to memory (PR #80170)

Cullen Rhodes llvmlistbot at llvm.org
Thu Feb 1 01:15:02 PST 2024


================
@@ -338,13 +341,166 @@ struct LegalizeTransferWriteOpsByDecomposition
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ArmSME-specific fixup canonicalizations/folds
+//===----------------------------------------------------------------------===//
+
+/// Lifts an illegal vector.transpose and vector.transfer_read to a
+/// memref.subview + memref.transpose, followed by a legal read.
+///
+/// 'Illegal' here means a leading scalable dimension and a fixed trailing
+/// dimension, which has no valid lowering.
+///
+/// The memref.transpose is metadata-only transpose that produces a strided
+/// memref, which eventually becomes a loop reading individual elements.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %illegalRead = vector.transfer_read %memref[%a, %b]
+///                  : memref<?x?xf32>, vector<[8]x4xf32>
+///  %legalType = vector.transpose %illegalRead, [1, 0]
+///                  : vector<[8]x4xf32> to vector<4x[8]xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
+///                  : memref<?x?xf32> to memref<?x?xf32>
+///  %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
+///                  : memref<?x?xf32> to memref<?x?xf32>
+///  %legalType = vector.transfer_read %transpose[%c0, %c0]
+///                  : memref<?x?xf32>, vector<4x[8]xf32>
+///  ```
+struct LiftIllegalVectorTransposeToMemory
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  static bool isIllegalVectorType(VectorType vType) {
+    bool seenFixedDim = false;
+    for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+      seenFixedDim |= !scalableFlag;
+      if (seenFixedDim && scalableFlag)
+        return true;
+    }
+    return false;
+  }
+
+  static Value getExtensionSource(Operation *op) {
+    if (auto signExtend = dyn_cast<arith::ExtSIOp>(op))
+      return signExtend.getIn();
+    if (auto zeroExtend = dyn_cast<arith::ExtUIOp>(op))
+      return zeroExtend.getIn();
+    if (auto floatExtend = dyn_cast<arith::ExtFOp>(op))
+      return floatExtend.getIn();
+    return {};
+  }
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = transposeOp.getSourceVectorType();
+    auto resultType = transposeOp.getResultVectorType();
+    if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(
+          transposeOp, "expected transpose from illegal type to legal type");
+
+    Value maybeRead = transposeOp.getVector();
+    auto *transposeSourceOp = maybeRead.getDefiningOp();
+    Operation *extendOp = nullptr;
+    if (Value extendSource = getExtensionSource(transposeSourceOp)) {
+      maybeRead = extendSource;
+      extendOp = transposeSourceOp;
+    }
+
+    auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
+    if (!illegalRead)
+      return rewriter.notifyMatchFailure(
+          transposeOp,
+          "expected source to be (possibility extended) transfer_read");
----------------
c-rhodes wrote:

```suggestion
          "expected source to be (possibly extended) transfer_read");
```

https://github.com/llvm/llvm-project/pull/80170


More information about the Mlir-commits mailing list