[Mlir-commits] [mlir] [mlir][ArmSME] Add rewrite to handle unsupported SVE transposes via SME/ZA (PR #98620)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Jul 25 08:09:05 PDT 2024


================
@@ -775,6 +774,149 @@ struct ConvertIllegalShapeCastOpsToTransposes
   }
 };
 
+/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
+/// the ZA state. This workaround rewrite to support these transposes when ZA is
+/// available.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %transpose = vector.transpose %vec, [1, 0]
+///     : vector<2x[4]xf32> to vector<[4]x2xf32>
+///  vector.transfer_write %transpose, %dest[%y, %x]
+///     : vector<[4]x2xf32>,  memref<?x?xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///   %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
+///   %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
+///   %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
+///   %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
+///   %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
+///   %c4_vscale = arith.muli %vscale, %c4 : index
+///   %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
+///   vector.transfer_write %4, %dest[%y, %x], %mask
+///      {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
+///      : vector<[4]x[4]xf32>, memref<?x?xf32>
+///  ```
+///
+/// Values larger than a single tile are supported via decomposition.
+struct LowerIllegalTransposeStoreViaZA
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const override {
+    if (!isSupportedMaskOp(writeOp.getMask()))
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureUnsupportedMaskOp);
+
+    auto permutationMap = writeOp.getPermutationMap();
+    if (!permutationMap.isIdentity())
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureNonPermutationMap);
+
+    auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
+    if (!transposeOp)
+      return failure();
+
+    auto sourceType = transposeOp.getSourceVectorType();
+    auto resultType = transposeOp.getResultVectorType();
+
+    if (resultType.getRank() != 2)
+      return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2");
+
+    if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(
+          transposeOp, "not illegal/unsupported SVE transpose");
+
+    auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
+    VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);
+
+    if (sourceType.getDimSize(0) <= 1 ||
+        sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
+      return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
+
+    auto loc = writeOp.getLoc();
+    auto createVscaleMultiple =
+        vector::makeVscaleConstantBuilder(rewriter, loc);
+
+    auto transposeMap = AffineMapAttr::get(
+        AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));
+
+    // Note: We need to use `get_tile` as there's no vector-level `undef`.
+    Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
+    Value destTensorOrMemref = writeOp.getSource();
+    auto numSlicesPerTile =
+        std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
+    auto numSlices =
+        rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
+    for (auto [index, smeTile] : llvm::enumerate(
+             decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
+      // 1. _Deliberately_ drop a scalable dimension and insert a fixed number
----------------
MacDue wrote:

We can't dynamically index the array of vectors input (or dynamically select an SME tile). These restrictions mean this lowering just targets the lowest common denominator (that is vscale = 1). 

https://github.com/llvm/llvm-project/pull/98620


More information about the Mlir-commits mailing list