[Mlir-commits] [mlir] [mlir][ArmSME] Add rewrite to lift illegal vector.transposes to memory (PR #80170)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Feb 1 02:52:52 PST 2024


================
@@ -338,13 +341,166 @@ struct LegalizeTransferWriteOpsByDecomposition
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ArmSME-specific fixup canonicalizations/folds
+//===----------------------------------------------------------------------===//
+
+/// Lifts an illegal vector.transpose and vector.transfer_read to a
+/// memref.subview + memref.transpose, followed by a legal read.
+///
+/// 'Illegal' here means a leading scalable dimension and a fixed trailing
+/// dimension, which has no valid lowering.
+///
+/// The memref.transpose is metadata-only transpose that produces a strided
+/// memref, which eventually becomes a loop reading individual elements.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %illegalRead = vector.transfer_read %memref[%a, %b]
+///                  : memref<?x?xf32>, vector<[8]x4xf32>
+///  %legalType = vector.transpose %illegalRead, [1, 0]
+///                  : vector<[8]x4xf32> to vector<4x[8]xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
+///                  : memref<?x?xf32> to memref<?x?xf32>
+///  %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
+///                  : memref<?x?xf32> to memref<?x?xf32>
+///  %legalType = vector.transfer_read %transpose[%c0, %c0]
+///                  : memref<?x?xf32>, vector<4x[8]xf32>
+///  ```
+struct LiftIllegalVectorTransposeToMemory
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  static bool isIllegalVectorType(VectorType vType) {
+    bool seenFixedDim = false;
+    for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+      seenFixedDim |= !scalableFlag;
+      if (seenFixedDim && scalableFlag)
+        return true;
+    }
+    return false;
+  }
+
+  static Value getExtensionSource(Operation *op) {
+    if (auto signExtend = dyn_cast<arith::ExtSIOp>(op))
+      return signExtend.getIn();
+    if (auto zeroExtend = dyn_cast<arith::ExtUIOp>(op))
+      return zeroExtend.getIn();
+    if (auto floatExtend = dyn_cast<arith::ExtFOp>(op))
+      return floatExtend.getIn();
+    return {};
+  }
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = transposeOp.getSourceVectorType();
+    auto resultType = transposeOp.getResultVectorType();
+    if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(
+          transposeOp, "expected transpose from illegal type to legal type");
+
+    Value maybeRead = transposeOp.getVector();
+    auto *transposeSourceOp = maybeRead.getDefiningOp();
+    Operation *extendOp = nullptr;
+    if (Value extendSource = getExtensionSource(transposeSourceOp)) {
+      maybeRead = extendSource;
+      extendOp = transposeSourceOp;
+    }
+
+    auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
+    if (!illegalRead)
+      return rewriter.notifyMatchFailure(
+          transposeOp,
+          "expected source to be (possibility extended) transfer_read");
+
+    if (!illegalRead.getPermutationMap().isIdentity())
+      return rewriter.notifyMatchFailure(
+          illegalRead, "expected read to have identity permutation map");
+
+    auto loc = transposeOp.getLoc();
+    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+    // Create a subview that matches the size of the illegal read vector type.
+    auto readType = illegalRead.getVectorType();
+    auto readSizes = llvm::map_to_vector(
+        llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
+        [&](auto dim) -> Value {
+          auto [size, isScalable] = dim;
+          auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
+          if (!isScalable)
+            return dimSize;
+          auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+          return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
+        });
+    SmallVector<Value> strides(readType.getRank(), Value(one));
+    auto readSubview = rewriter.create<memref::SubViewOp>(
+        loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
+        strides);
+
+    // Apply the transpose to all values/attributes of the transfer_read.
+    // The mask.
+    Value mask = illegalRead.getMask();
+    if (mask) {
+      // Note: The transpose for the mask should fold into the
+      // vector.create_mask/constant_mask op, which will then become legal.
+      mask = rewriter.create<vector::TransposeOp>(loc, mask,
+                                                  transposeOp.getPermutation());
+    }
+    // The source memref.
+    mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
+        transposeOp.getPermutation(), getContext());
+    auto transposedSubview = rewriter.create<memref::TransposeOp>(
+        loc, readSubview, AffineMapAttr::get(transposeMap));
+    ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
+    // The `in_bounds` attribute.
----------------
MacDue wrote:

This is meant to follow on from the comment at the top, I've added punctuation to make that clearer. 

```
Apply the transpose to all values/attributes of the transfer_read:
// - The mask
...
// - The source memref
...
```

https://github.com/llvm/llvm-project/pull/80170


More information about the Mlir-commits mailing list