[Mlir-commits] [mlir] [mlir][ArmSME] Add rewrite to handle unsupported SVE transposes via SME/ZA (PR #98620)
Benjamin Maxwell
llvmlistbot at llvm.org
Fri Jul 12 04:36:53 PDT 2024
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/98620
This adds a workaround rewrite that allows stores of unsupported SVE transposes such as:
```mlir
%tr = vector.transpose %vec, [1, 0]
: vector<2x[4]xf32> to vector<[4]x2xf32>
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]}
: vector<[4]x2xf32>, memref<?x?xf32>
```
To use SME tiles, which are possible to lower (when SME is available):
```mlir
// Insert vector<2x[4]xf32> into an SME tile:
%0 = arm_sme.get_tile : vector<[4]x[4]xf32>
%1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
%2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
%3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
%4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
// Store the tile with a transpose + mask:
%c4_vscale = arith.muli %vscale, %c4 : index
%mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
: vector<[4]x[4]xf32>, memref<?x?xf32>
```
>From 4e689a555d092b44efd55362528b401980ac9526 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 12 Jul 2024 11:16:49 +0000
Subject: [PATCH] [mlir][ArmSME] Add rewrite to handle unsupported SVE
transposes via SME/ZA
This adds a workaround rewrite that allows stores of unsupported SVE
transposes such as:
```mlir
%tr = vector.transpose %vec, [1, 0]
: vector<2x[4]xf32> to vector<[4]x2xf32>
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]}
: vector<[4]x2xf32>, memref<?x?xf32>
```
To use SME tiles, which are possible to lower (when SME is available):
```mlir
// Insert vector<2x[4]xf32> into an SME tile:
%0 = arm_sme.get_tile : vector<[4]x[4]xf32>
%1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
%2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
%3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
%4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
// Store the tile with a transpose + mask:
%c4_vscale = arith.muli %vscale, %c4 : index
%mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
: vector<[4]x[4]xf32>, memref<?x?xf32>
```
---
.../mlir/Dialect/ArmSME/Transforms/Passes.td | 3 +-
.../Dialect/ArmSME/Transforms/CMakeLists.txt | 1 +
.../ArmSME/Transforms/VectorLegalization.cpp | 170 ++++++++++++++++--
.../Dialect/ArmSME/vector-legalization.mlir | 71 ++++++++
4 files changed, 233 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index dfd64f995546a..921234daad1f1 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -202,7 +202,8 @@ def VectorLegalization
"func::FuncDialect",
"arm_sme::ArmSMEDialect",
"vector::VectorDialect",
- "arith::ArithDialect"
+ "arith::ArithDialect",
+ "index::IndexDialect"
];
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 600f2ecdb51bc..8f9b5080e82db 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
MLIRFuncDialect
MLIRLLVMCommonConversion
MLIRVectorDialect
+ MLIRIndexDialect
MLIRSCFDialect
MLIRSCFTransforms
MLIRFuncTransforms
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 96dad6518fec8..028e2327e2a4f 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -18,6 +18,8 @@
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -140,11 +142,11 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
VectorType smeTileType,
bool transposeIndices = false) {
- assert(isMultipleOfSMETileVectorType(type) &&
- "`type` not multiple of SME tiles");
return llvm::map_range(
- StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
- smeTileType.getDimSize(1)}),
+ StaticTileOffsetRange(
+ type.getShape(),
+ {std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
+ std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
[=](auto indices) {
int row = int(indices[0]);
int col = int(indices[1]);
@@ -374,6 +376,14 @@ struct LegalizeTransferWriteOpsByDecomposition
}
};
+auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
+ Value vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ return [loc, vscale, &rewriter](int64_t multiplier) {
+ return rewriter.create<arith::MulIOp>(
+ loc, vscale, rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
+ };
+}
+
/// Legalize a multi-tile transfer_write as a single store loop. This is done as
/// part of type decomposition as at this level we know each tile write is
/// disjoint, but that information is lost after decomposition (without analysis
@@ -440,12 +450,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
kMatchFailureUnsupportedMaskOp);
auto loc = writeOp.getLoc();
- auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
- auto createVscaleMultiple = [&](int64_t multiplier) {
- return rewriter.create<arith::MulIOp>(
- loc, vscale,
- rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
- };
+ auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
// Get SME tile and slice types.
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
@@ -775,6 +780,148 @@ struct ConvertIllegalShapeCastOpsToTransposes
}
};
+/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
+/// the ZA state. This workaround rewrite to support these transposes when ZA is
+/// available.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %transpose = vector.transpose %vec, [1, 0]
+/// : vector<2x[4]xf32> to vector<[4]x2xf32>
+/// vector.transfer_write %transpose, %dest[%y, %x]
+/// : vector<[4]x2xf32>, memref<?x?xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
+/// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
+/// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
+/// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
+/// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
+/// %c4_vscale = arith.muli %vscale, %c4 : index
+/// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
+/// vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
+/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
+/// : vector<[4]x[4]xf32>, memref<?x?xf32>
+/// ```
+///
+/// Values larger than a single tile are supported via decomposition.
+struct LowerIllegalTransposeStoreViaZA
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+ PatternRewriter &rewriter) const override {
+ if (!isSupportedMaskOp(writeOp.getMask()))
+ return rewriter.notifyMatchFailure(writeOp,
+ kMatchFailureUnsupportedMaskOp);
+
+ auto permutationMap = writeOp.getPermutationMap();
+ if (!permutationMap.isIdentity())
+ return rewriter.notifyMatchFailure(writeOp,
+ kMatchFailureNonPermutationMap);
+
+ auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
+ if (!transposeOp)
+ return failure();
+
+ auto sourceType = transposeOp.getSourceVectorType();
+ auto resultType = transposeOp.getResultVectorType();
+
+ if (resultType.getRank() != 2)
+ return rewriter.notifyMatchFailure(transposeOp, "not rank 2");
+
+ if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
+ return rewriter.notifyMatchFailure(
+ transposeOp, "not illegal/unsupported SVE transpose");
+
+ auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
+ VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);
+
+ if (sourceType.getDimSize(0) <= 1 ||
+ sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
+ return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
+
+ auto loc = writeOp.getLoc();
+ auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
+
+ auto transposeMap = AffineMapAttr::get(
+ AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));
+
+ // Note: We need to use `get_tile` as there's no vector-level `undef`.
+ Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
+ Value destTensorOrMemref = writeOp.getSource();
+ auto numSlicesPerTile =
+ std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
+ auto numSlices =
+ rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
+ for (auto [index, smeTile] : llvm::enumerate(
+ decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
+ // 1. _Deliberately_ drop a scalable dimension and insert a fixed number
+ // of slices from the source type into the SME tile. Without checking
+ // vscale (and emitting multiple implementations) we can't make use of the
+ // rows of the tile after 1*vscale rows.
+ Value tile = undefTile;
+ for (int d = 0, e = numSlicesPerTile; d < e; ++d) {
+ Value vector = rewriter.create<vector::ExtractOp>(
+ loc, transposeOp.getVector(),
+ rewriter.getIndexAttr(d + smeTile.row));
+ if (vector.getType() != smeSliceType) {
+ vector = rewriter.create<vector::ScalableExtractOp>(
+ loc, smeSliceType, vector, smeTile.col);
+ }
+ tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d);
+ }
+
+ // 2. Transpose the tile position.
+ auto transposedRow = createVscaleMultiple(smeTile.col);
+ auto transposedCol =
+ rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row);
+
+ // 3. Compute mask for tile store.
+ Value maskRows;
+ Value maskCols;
+ if (auto mask = writeOp.getMask()) {
+ auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
+ maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0),
+ transposedRow);
+ maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1),
+ transposedCol);
+ maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices);
+ } else {
+ maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
+ maskCols = numSlices;
+ }
+ auto subMask = rewriter.create<vector::CreateMaskOp>(
+ loc, smeTileType.clone(rewriter.getI1Type()),
+ ValueRange{maskRows, maskCols});
+
+ // 4. Emit a transposed tile write.
+ auto writeIndices = writeOp.getIndices();
+ Value destRow =
+ rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]);
+ Value destCol =
+ rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
+ auto smeWrite = rewriter.create<vector::TransferWriteOp>(
+ loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
+ transposeMap, subMask, writeOp.getInBounds().value_or(ArrayAttr{}));
+
+ if (writeOp.hasPureTensorSemantics())
+ destTensorOrMemref = smeWrite.getResult();
+ }
+
+ if (writeOp.hasPureTensorSemantics())
+ rewriter.replaceOp(writeOp, destTensorOrMemref);
+ else
+ rewriter.eraseOp(writeOp);
+
+ return success();
+ }
+};
+
struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
@@ -796,7 +943,8 @@ struct VectorLegalizationPass
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
- ConvertIllegalShapeCastOpsToTransposes>(context);
+ ConvertIllegalShapeCastOpsToTransposes,
+ LowerIllegalTransposeStoreViaZA>(context);
// Note: These two patterns are added with a high benefit to ensure:
// - Masked outer products are handled before unmasked ones
// - Multi-tile writes are lowered as a store loop (if possible)
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 71d80bc16ea12..951b29b6e3805 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -544,3 +544,74 @@ func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
%0 = arith.constant dense<42> : vector<[8]x[8]xi32>
return %0 : vector<[8]x[8]xi32>
}
+
+// -----
+
+// CHECK: #[[$TRANSPOSE_MAP_0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: @transpose_store_scalable_via_za(
+// CHECK-SAME: %[[VEC:.*]]: vector<2x[4]xf32>
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[I:.*]]: index,
+// CHECK-SAME: %[[J:.*]]: index)
+func.func @transpose_store_scalable_via_za(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+ // CHECK-NEXT: %[[INIT:.*]] = arm_sme.get_tile : vector<[4]x[4]xf32>
+ // CHECK-NEXT: %[[V0:.*]] = vector.extract %[[VEC]][0] : vector<[4]xf32> from vector<2x[4]xf32>
+ // CHECK-NEXT: %[[R0:.*]] = vector.insert %[[V0]], %[[INIT]] [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK-NEXT: %[[V1:.*]] = vector.extract %[[VEC]][1] : vector<[4]xf32> from vector<2x[4]xf32>
+ // CHECK-NEXT: %[[RES:.*]] = vector.insert %[[V1]], %[[R0]] [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK-NEXT: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+ // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C4_VSCALE]], %[[C2]] : vector<[4]x[4]xi1>
+ // CHECK-NEXT: vector.transfer_write %[[RES]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {in_bounds = [true, true], permutation_map = #[[$TRANSPOSE_MAP_0]]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ %tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
+ vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x2xf32>, memref<?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK: @transpose_store_scalable_via_za_masked(
+// CHECK-SAME: %[[A:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[B:[a-z0-9]+]]: index)
+func.func @transpose_store_scalable_via_za_masked(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %a: index, %b: index) {
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[MIN:.*]] = index.mins %[[B]], %[[C2]]
+ // CHECK: %[[MASK:.*]] = vector.create_mask %[[A]], %[[MIN]] : vector<[4]x[4]xi1>
+ // CHECK: vector.transfer_write {{.*}} %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %mask = vector.create_mask %a, %b : vector<[4]x2xi1>
+ %tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
+ vector.transfer_write %tr, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[4]x2xf32>, memref<?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK: @transpose_store_scalable_via_za_multi_tile(
+// CHECK-SAME: %[[VEC:.*]]: vector<8x[4]xf32>
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[I:.*]]: index,
+// CHECK-SAME: %[[J:.*]]: index)
+func.func @transpose_store_scalable_via_za_multi_tile(%vec: vector<8x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK: %[[VSCALE:.*]] = vector.vscale
+
+ // <skip 3x other extract+insert chain>
+ // CHECK: %[[V3:.*]] = vector.extract %[[VEC]][3] : vector<[4]xf32> from vector<8x[4]xf32>
+ // CHECK: %[[TILE_0:.*]] = vector.insert %[[V3]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+ // CHECK: %[[MASK:.*]] = vector.create_mask %c4_vscale, %c4 : vector<[4]x[4]xi1>
+ // CHECK: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+
+ // <skip 3x other extract+insert chain>
+ // CHECK: %[[V7:.*]] = vector.extract %arg0[7] : vector<[4]xf32> from vector<8x[4]xf32>
+ // CHECK: %[[TILE_1:.*]] = vector.insert %[[V7]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK: %[[J_OFFSET:.*]] = arith.addi %[[J]], %[[C4]] : index
+ // CHECK: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[I]], %[[J_OFFSET]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ %tr = vector.transpose %vec, [1, 0] : vector<8x[4]xf32> to vector<[4]x8xf32>
+ vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x8xf32>, memref<?x?xf32>
+ return
+}
More information about the Mlir-commits
mailing list