[Mlir-commits] [mlir] [mlir][ArmSME] Add rewrite to lift illegal vector.transposes to memory (PR #80170)
Cullen Rhodes
llvmlistbot at llvm.org
Thu Feb 1 01:15:03 PST 2024
================
@@ -338,13 +341,166 @@ struct LegalizeTransferWriteOpsByDecomposition
}
};
+//===----------------------------------------------------------------------===//
+// ArmSME-specific fixup canonicalizations/folds
+//===----------------------------------------------------------------------===//
+
+/// Lifts an illegal vector.transpose and vector.transfer_read to a
+/// memref.subview + memref.transpose, followed by a legal read.
+///
+/// 'Illegal' here means a leading scalable dimension and a fixed trailing
+/// dimension, which has no valid lowering.
+///
+/// The memref.transpose is metadata-only transpose that produces a strided
+/// memref, which eventually becomes a loop reading individual elements.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %illegalRead = vector.transfer_read %memref[%a, %b]
+/// : memref<?x?xf32>, vector<[8]x4xf32>
+/// %legalType = vector.transpose %illegalRead, [1, 0]
+/// : vector<[8]x4xf32> to vector<4x[8]xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
+/// : memref<?x?xf32> to memref<?x?xf32>
+/// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
+/// : memref<?x?xf32> to memref<?x?xf32>
+/// %legalType = vector.transfer_read %transpose[%c0, %c0]
+/// : memref<?x?xf32>, vector<4x[8]xf32>
+/// ```
+struct LiftIllegalVectorTransposeToMemory
+ : public OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+ static bool isIllegalVectorType(VectorType vType) {
+ bool seenFixedDim = false;
+ for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+ seenFixedDim |= !scalableFlag;
+ if (seenFixedDim && scalableFlag)
+ return true;
+ }
+ return false;
+ }
+
+ static Value getExtensionSource(Operation *op) {
+ if (auto signExtend = dyn_cast<arith::ExtSIOp>(op))
+ return signExtend.getIn();
+ if (auto zeroExtend = dyn_cast<arith::ExtUIOp>(op))
+ return zeroExtend.getIn();
+ if (auto floatExtend = dyn_cast<arith::ExtFOp>(op))
+ return floatExtend.getIn();
+ return {};
+ }
+
+ LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = transposeOp.getSourceVectorType();
+ auto resultType = transposeOp.getResultVectorType();
+ if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
+ return rewriter.notifyMatchFailure(
+ transposeOp, "expected transpose from illegal type to legal type");
+
+ Value maybeRead = transposeOp.getVector();
+ auto *transposeSourceOp = maybeRead.getDefiningOp();
+ Operation *extendOp = nullptr;
+ if (Value extendSource = getExtensionSource(transposeSourceOp)) {
+ maybeRead = extendSource;
+ extendOp = transposeSourceOp;
+ }
+
+ auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
+ if (!illegalRead)
+ return rewriter.notifyMatchFailure(
+ transposeOp,
+ "expected source to be (possibility extended) transfer_read");
+
+ if (!illegalRead.getPermutationMap().isIdentity())
+ return rewriter.notifyMatchFailure(
+ illegalRead, "expected read to have identity permutation map");
+
+ auto loc = transposeOp.getLoc();
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+ // Create a subview that matches the size of the illegal read vector type.
+ auto readType = illegalRead.getVectorType();
+ auto readSizes = llvm::map_to_vector(
+ llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
+ [&](auto dim) -> Value {
+ auto [size, isScalable] = dim;
+ auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
+ if (!isScalable)
+ return dimSize;
+ auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
+ });
+ SmallVector<Value> strides(readType.getRank(), Value(one));
+ auto readSubview = rewriter.create<memref::SubViewOp>(
+ loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
+ strides);
+
+ // Apply the transpose to all values/attributes of the transfer_read.
+ // The mask.
+ Value mask = illegalRead.getMask();
+ if (mask) {
+ // Note: The transpose for the mask should fold into the
+ // vector.create_mask/constant_mask op, which will then become legal.
+ mask = rewriter.create<vector::TransposeOp>(loc, mask,
+ transposeOp.getPermutation());
+ }
+ // The source memref.
+ mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
+ transposeOp.getPermutation(), getContext());
+ auto transposedSubview = rewriter.create<memref::TransposeOp>(
+ loc, readSubview, AffineMapAttr::get(transposeMap));
+ ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
+ // The `in_bounds` attribute.
+ if (inBoundsAttr) {
+ SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
+ inBoundsAttr.end());
+ applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
+ inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
+ }
----------------
c-rhodes wrote:
this isn't tested
https://github.com/llvm/llvm-project/pull/80170
More information about the Mlir-commits
mailing list