[Mlir-commits] [mlir] [mlir][ArmSME] Add rewrite to handle unsupported SVE transposes via SME/ZA (PR #98620)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Jul 25 08:37:41 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/98620

>From 6475f49f32d71375abf23924f72277508a7caa8f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 12 Jul 2024 11:16:49 +0000
Subject: [PATCH 1/5] [mlir][ArmSME] Add rewrite to handle unsupported SVE
 transposes via SME/ZA

This adds a workaround rewrite that allows stores of unsupported SVE
transposes such as:

```mlir
%tr = vector.transpose %vec, [1, 0]
  : vector<2x[4]xf32> to vector<[4]x2xf32>
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]}
  : vector<[4]x2xf32>,  memref<?x?xf32>
```

To use SME tiles, which are possible to lower (when SME is available):

```mlir
// Insert vector<2x[4]xf32> into an SME tile:
%0 = arm_sme.get_tile : vector<[4]x[4]xf32>
%1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
%2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
%3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
%4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
// Store the tile with a transpose + mask:
%c4_vscale = arith.muli %vscale, %c4 : index
%mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
   {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
   : vector<[4]x[4]xf32>, memref<?x?xf32>
```
---
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |   3 +-
 .../Dialect/ArmSME/Transforms/CMakeLists.txt  |   1 +
 .../ArmSME/Transforms/VectorLegalization.cpp  | 170 ++++++++++++++++--
 .../Dialect/ArmSME/vector-legalization.mlir   |  71 ++++++++
 4 files changed, 233 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index dfd64f995546a..921234daad1f1 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -202,7 +202,8 @@ def VectorLegalization
     "func::FuncDialect",
     "arm_sme::ArmSMEDialect",
     "vector::VectorDialect",
-    "arith::ArithDialect"
+    "arith::ArithDialect",
+    "index::IndexDialect"
   ];
 }
 
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 600f2ecdb51bc..8f9b5080e82db 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
   MLIRFuncDialect
   MLIRLLVMCommonConversion
   MLIRVectorDialect
+  MLIRIndexDialect
   MLIRSCFDialect
   MLIRSCFTransforms
   MLIRFuncTransforms
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 96dad6518fec8..028e2327e2a4f 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -18,6 +18,8 @@
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -140,11 +142,11 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
 auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
                          VectorType smeTileType,
                          bool transposeIndices = false) {
-  assert(isMultipleOfSMETileVectorType(type) &&
-         "`type` not multiple of SME tiles");
   return llvm::map_range(
-      StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
-                                              smeTileType.getDimSize(1)}),
+      StaticTileOffsetRange(
+          type.getShape(),
+          {std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
+           std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
       [=](auto indices) {
         int row = int(indices[0]);
         int col = int(indices[1]);
@@ -374,6 +376,14 @@ struct LegalizeTransferWriteOpsByDecomposition
   }
 };
 
+auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
+  Value vscale = rewriter.create<vector::VectorScaleOp>(loc);
+  return [loc, vscale, &rewriter](int64_t multiplier) {
+    return rewriter.create<arith::MulIOp>(
+        loc, vscale, rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
+  };
+}
+
 /// Legalize a multi-tile transfer_write as a single store loop. This is done as
 /// part of type decomposition as at this level we know each tile write is
 /// disjoint, but that information is lost after decomposition (without analysis
@@ -440,12 +450,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
                                          kMatchFailureUnsupportedMaskOp);
 
     auto loc = writeOp.getLoc();
-    auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
-    auto createVscaleMultiple = [&](int64_t multiplier) {
-      return rewriter.create<arith::MulIOp>(
-          loc, vscale,
-          rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
-    };
+    auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
 
     // Get SME tile and slice types.
     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
@@ -775,6 +780,148 @@ struct ConvertIllegalShapeCastOpsToTransposes
   }
 };
 
+/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
+/// the ZA state. This workaround rewrite to support these transposes when ZA is
+/// available.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %transpose = vector.transpose %vec, [1, 0]
+///     : vector<2x[4]xf32> to vector<[4]x2xf32>
+///  vector.transfer_write %transpose, %dest[%y, %x]
+///     : vector<[4]x2xf32>,  memref<?x?xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///   %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
+///   %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
+///   %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
+///   %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
+///   %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
+///   %c4_vscale = arith.muli %vscale, %c4 : index
+///   %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
+///   vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
+///      {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
+///      : vector<[4]x[4]xf32>, memref<?x?xf32>
+///  ```
+///
+/// Values larger than a single tile are supported via decomposition.
+struct LowerIllegalTransposeStoreViaZA
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const override {
+    if (!isSupportedMaskOp(writeOp.getMask()))
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureUnsupportedMaskOp);
+
+    auto permutationMap = writeOp.getPermutationMap();
+    if (!permutationMap.isIdentity())
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureNonPermutationMap);
+
+    auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
+    if (!transposeOp)
+      return failure();
+
+    auto sourceType = transposeOp.getSourceVectorType();
+    auto resultType = transposeOp.getResultVectorType();
+
+    if (resultType.getRank() != 2)
+      return rewriter.notifyMatchFailure(transposeOp, "not rank 2");
+
+    if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(
+          transposeOp, "not illegal/unsupported SVE transpose");
+
+    auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
+    VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);
+
+    if (sourceType.getDimSize(0) <= 1 ||
+        sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
+      return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
+
+    auto loc = writeOp.getLoc();
+    auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
+
+    auto transposeMap = AffineMapAttr::get(
+        AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));
+
+    // Note: We need to use `get_tile` as there's no vector-level `undef`.
+    Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
+    Value destTensorOrMemref = writeOp.getSource();
+    auto numSlicesPerTile =
+        std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
+    auto numSlices =
+        rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
+    for (auto [index, smeTile] : llvm::enumerate(
+             decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
+      // 1. _Deliberately_ drop a scalable dimension and insert a fixed number
+      // of slices from the source type into the SME tile. Without checking
+      // vscale (and emitting multiple implementations) we can't make use of the
+      // rows of the tile after 1*vscale rows.
+      Value tile = undefTile;
+      for (int d = 0, e = numSlicesPerTile; d < e; ++d) {
+        Value vector = rewriter.create<vector::ExtractOp>(
+            loc, transposeOp.getVector(),
+            rewriter.getIndexAttr(d + smeTile.row));
+        if (vector.getType() != smeSliceType) {
+          vector = rewriter.create<vector::ScalableExtractOp>(
+              loc, smeSliceType, vector, smeTile.col);
+        }
+        tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d);
+      }
+
+      // 2. Transpose the tile position.
+      auto transposedRow = createVscaleMultiple(smeTile.col);
+      auto transposedCol =
+          rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row);
+
+      // 3. Compute mask for tile store.
+      Value maskRows;
+      Value maskCols;
+      if (auto mask = writeOp.getMask()) {
+        auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
+        maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0),
+                                                  transposedRow);
+        maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1),
+                                                  transposedCol);
+        maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices);
+      } else {
+        maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
+        maskCols = numSlices;
+      }
+      auto subMask = rewriter.create<vector::CreateMaskOp>(
+          loc, smeTileType.clone(rewriter.getI1Type()),
+          ValueRange{maskRows, maskCols});
+
+      // 4. Emit a transposed tile write.
+      auto writeIndices = writeOp.getIndices();
+      Value destRow =
+          rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]);
+      Value destCol =
+          rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
+      auto smeWrite = rewriter.create<vector::TransferWriteOp>(
+          loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
+          transposeMap, subMask, writeOp.getInBounds().value_or(ArrayAttr{}));
+
+      if (writeOp.hasPureTensorSemantics())
+        destTensorOrMemref = smeWrite.getResult();
+    }
+
+    if (writeOp.hasPureTensorSemantics())
+      rewriter.replaceOp(writeOp, destTensorOrMemref);
+    else
+      rewriter.eraseOp(writeOp);
+
+    return success();
+  }
+};
+
 struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
@@ -796,7 +943,8 @@ struct VectorLegalizationPass
 
     patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
                  LiftIllegalVectorTransposeToMemory,
-                 ConvertIllegalShapeCastOpsToTransposes>(context);
+                 ConvertIllegalShapeCastOpsToTransposes,
+                 LowerIllegalTransposeStoreViaZA>(context);
     // Note: These two patterns are added with a high benefit to ensure:
     //  - Masked outer products are handled before unmasked ones
     //  - Multi-tile writes are lowered as a store loop (if possible)
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 71d80bc16ea12..951b29b6e3805 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -544,3 +544,74 @@ func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
   %0 = arith.constant dense<42> : vector<[8]x[8]xi32>
   return %0 : vector<[8]x[8]xi32>
 }
+
+// -----
+
+// CHECK: #[[$TRANSPOSE_MAP_0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: @transpose_store_scalable_via_za(
+// CHECK-SAME:                                   %[[VEC:.*]]: vector<2x[4]xf32>
+// CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xf32>,
+// CHECK-SAME:                                   %[[I:.*]]: index,
+// CHECK-SAME:                                   %[[J:.*]]: index)
+func.func @transpose_store_scalable_via_za(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-NEXT: %[[INIT:.*]] = arm_sme.get_tile : vector<[4]x[4]xf32>
+  // CHECK-NEXT: %[[V0:.*]] = vector.extract %[[VEC]][0] : vector<[4]xf32> from vector<2x[4]xf32>
+  // CHECK-NEXT: %[[R0:.*]] = vector.insert %[[V0]], %[[INIT]] [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK-NEXT: %[[V1:.*]] = vector.extract %[[VEC]][1] : vector<[4]xf32> from vector<2x[4]xf32>
+  // CHECK-NEXT: %[[RES:.*]] = vector.insert %[[V1]], %[[R0]] [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK-NEXT: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C4_VSCALE]], %[[C2]] : vector<[4]x[4]xi1>
+  // CHECK-NEXT: vector.transfer_write %[[RES]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {in_bounds = [true, true], permutation_map = #[[$TRANSPOSE_MAP_0]]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  %tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
+  vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x2xf32>,  memref<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK: @transpose_store_scalable_via_za_masked(
+// CHECK-SAME:                                    %[[A:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[B:[a-z0-9]+]]: index)
+func.func @transpose_store_scalable_via_za_masked(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %a: index, %b: index) {
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[MIN:.*]] = index.mins %[[B]], %[[C2]]
+  // CHECK: %[[MASK:.*]] = vector.create_mask %[[A]], %[[MIN]] : vector<[4]x[4]xi1>
+  // CHECK: vector.transfer_write {{.*}} %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  %c0 = arith.constant 0 : index
+  %mask = vector.create_mask %a, %b : vector<[4]x2xi1>
+  %tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
+  vector.transfer_write %tr, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[4]x2xf32>,  memref<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK: @transpose_store_scalable_via_za_multi_tile(
+// CHECK-SAME:                                       %[[VEC:.*]]: vector<8x[4]xf32>
+// CHECK-SAME:                                       %[[DEST:.*]]: memref<?x?xf32>,
+// CHECK-SAME:                                       %[[I:.*]]: index,
+// CHECK-SAME:                                       %[[J:.*]]: index)
+func.func @transpose_store_scalable_via_za_multi_tile(%vec: vector<8x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+  // CHECK: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK: %[[VSCALE:.*]] = vector.vscale
+
+  // <skip 3x other extract+insert chain>
+  // CHECK: %[[V3:.*]] = vector.extract %[[VEC]][3] : vector<[4]xf32> from vector<8x[4]xf32>
+  // CHECK: %[[TILE_0:.*]] = vector.insert %[[V3]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK: %[[MASK:.*]] = vector.create_mask %c4_vscale, %c4 : vector<[4]x[4]xi1>
+  // CHECK: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+
+  // <skip 3x other extract+insert chain>
+  // CHECK: %[[V7:.*]] = vector.extract %arg0[7] : vector<[4]xf32> from vector<8x[4]xf32>
+  // CHECK: %[[TILE_1:.*]] = vector.insert %[[V7]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK: %[[J_OFFSET:.*]] = arith.addi %[[J]], %[[C4]] : index
+  // CHECK: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[I]], %[[J_OFFSET]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  %tr = vector.transpose %vec, [1, 0] : vector<8x[4]xf32> to vector<[4]x8xf32>
+  vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x8xf32>,  memref<?x?xf32>
+  return
+}

>From f847f4576035c2c1c75d58d0b3fcaea896bbdec0 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 25 Jul 2024 11:26:34 +0000
Subject: [PATCH 2/5] Review fixups

---
 .../ArmSME/Transforms/VectorLegalization.cpp  |  6 +--
 .../Dialect/ArmSME/vector-legalization.mlir   | 47 +++++++++++++++----
 2 files changed, 42 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 028e2327e2a4f..6f2118ad4359b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -803,7 +803,7 @@ struct ConvertIllegalShapeCastOpsToTransposes
 ///   %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
 ///   %c4_vscale = arith.muli %vscale, %c4 : index
 ///   %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
-///   vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
+///   vector.transfer_write %4, %dest[%y, %x], %mask
 ///      {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
 ///      : vector<[4]x[4]xf32>, memref<?x?xf32>
 ///  ```
@@ -832,7 +832,7 @@ struct LowerIllegalTransposeStoreViaZA
     auto resultType = transposeOp.getResultVectorType();
 
     if (resultType.getRank() != 2)
-      return rewriter.notifyMatchFailure(transposeOp, "not rank 2");
+      return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2");
 
     if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
       return rewriter.notifyMatchFailure(
@@ -865,7 +865,7 @@ struct LowerIllegalTransposeStoreViaZA
       // vscale (and emitting multiple implementations) we can't make use of the
       // rows of the tile after 1*vscale rows.
       Value tile = undefTile;
-      for (int d = 0, e = numSlicesPerTile; d < e; ++d) {
+      for (int d = 0; d < numSlicesPerTile; ++d) {
         Value vector = rewriter.create<vector::ExtractOp>(
             loc, transposeOp.getVector(),
             rewriter.getIndexAttr(d + smeTile.row));
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 951b29b6e3805..40511ec100f0a 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -573,9 +573,9 @@ func.func @transpose_store_scalable_via_za(%vec: vector<2x[4]xf32>, %dest: memre
 
 // -----
 
-// CHECK: @transpose_store_scalable_via_za_masked(
-// CHECK-SAME:                                    %[[A:[a-z0-9]+]]: index,
-// CHECK-SAME:                                    %[[B:[a-z0-9]+]]: index)
+// CHECK-LABEL: @transpose_store_scalable_via_za_masked(
+// CHECK-SAME:                                          %[[A:[a-z0-9]+]]: index,
+// CHECK-SAME:                                          %[[B:[a-z0-9]+]]: index)
 func.func @transpose_store_scalable_via_za_masked(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %a: index, %b: index) {
   // CHECK: %[[C2:.*]] = arith.constant 2 : index
   // CHECK: %[[MIN:.*]] = index.mins %[[B]], %[[C2]]
@@ -590,11 +590,11 @@ func.func @transpose_store_scalable_via_za_masked(%vec: vector<2x[4]xf32>, %dest
 
 // -----
 
-// CHECK: @transpose_store_scalable_via_za_multi_tile(
-// CHECK-SAME:                                       %[[VEC:.*]]: vector<8x[4]xf32>
-// CHECK-SAME:                                       %[[DEST:.*]]: memref<?x?xf32>,
-// CHECK-SAME:                                       %[[I:.*]]: index,
-// CHECK-SAME:                                       %[[J:.*]]: index)
+// CHECK-LABEL: @transpose_store_scalable_via_za_multi_tile(
+// CHECK-SAME:                                              %[[VEC:.*]]: vector<8x[4]xf32>
+// CHECK-SAME:                                              %[[DEST:.*]]: memref<?x?xf32>,
+// CHECK-SAME:                                              %[[I:.*]]: index,
+// CHECK-SAME:                                              %[[J:.*]]: index)
 func.func @transpose_store_scalable_via_za_multi_tile(%vec: vector<8x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
   // CHECK: %[[C4:.*]] = arith.constant 4 : index
   // CHECK: %[[VSCALE:.*]] = vector.vscale
@@ -615,3 +615,34 @@ func.func @transpose_store_scalable_via_za_multi_tile(%vec: vector<8x[4]xf32>, %
   vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x8xf32>,  memref<?x?xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @transpose_store_scalable_via_za_multi_tile_with_scalable_extracts
+func.func @transpose_store_scalable_via_za_multi_tile_with_scalable_extracts(%vec: vector<2x[8]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+  // <check extracts from lower 4 x vscale of %vec>
+  // CHECK: vector.scalable.extract
+  // CHECK: %[[ROW_2_LOWER:.*]] = vector.scalable.extract %{{.*}}[0] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK: %[[TILE_0:.*]] = vector.insert %[[ROW_2_LOWER]], %{{.*}}[1] : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK: vector.transfer_write %[[TILE_0]], %{{.*}}[%[[I:.[a-z0-9]+]], %[[J:[a-z0-9]+]]]
+
+  // <check extracts from upper 4 x vscale of %vec>
+  // CHECK: vector.scalable.extract
+  // CHECK: %[[ROW_2_UPPER:.*]] = vector.scalable.extract %{{.*}}[4] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK: %[[TILE_0:.*]] = vector.insert %[[ROW_2_UPPER]], %{{.*}}[1] : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK: %[[I_OFFSET:.*]] = arith.addi %c4_vscale, %[[I]] : index
+  // CHECK: vector.transfer_write %[[TILE_0]], %{{.*}}[%[[I_OFFSET]], %[[J]]]
+  %tr = vector.transpose %vec, [1, 0] : vector<2x[8]xf32> to vector<[8]x2xf32>
+  vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[8]x2xf32>,  memref<?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @negative_transpose_store_scalable_via_za__bad_source_shape
+// CHECK-NOT: arm_sme.get_tile
+func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vector<2x[7]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+  %tr = vector.transpose %vec, [1, 0] : vector<2x[7]xf32> to vector<[7]x2xf32>
+  vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>,  memref<?x?xf32>
+  return
+}

>From 81e8c079073e5b986324649e6e0efd85790011c1 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 25 Jul 2024 13:27:38 +0000
Subject: [PATCH 3/5] Move builder to utils

---
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   | 19 +++++++++++++++++++
 .../ArmSME/Transforms/VectorLegalization.cpp  | 15 +++++----------
 .../Dialect/ArmSME/vector-legalization.mlir   |  4 ++--
 3 files changed, 26 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 9c83acc76e77a..40e04b76593a0 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
 #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -101,6 +102,24 @@ bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
 std::optional<StaticTileOffsetRange>
 createUnrollIterator(VectorType vType, int64_t targetRank = 1);
 
+/// Returns a functor (int64_t -> Value) which returns a constant vscale
+/// multiple.
+///
+/// Example:
+/// ```c++
+/// auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
+/// auto c4Vscale = createVscaleMultiple(4); // 4 * vector.vscale
+/// ```
+inline auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
+  Value vscale = nullptr;
+  return [loc, vscale, &rewriter](int64_t multiplier) mutable {
+    if (!vscale)
+      vscale = rewriter.create<vector::VectorScaleOp>(loc);
+    return rewriter.create<arith::MulIOp>(
+        loc, vscale, rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
+  };
+}
+
 /// A wrapper for getMixedSizes for vector.transfer_read and
 /// vector.transfer_write Ops (for source and destination, respectively).
 ///
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 6f2118ad4359b..ba7bda099b9a3 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
 
 #define DEBUG_TYPE "arm-sme-vector-legalization"
@@ -376,14 +377,6 @@ struct LegalizeTransferWriteOpsByDecomposition
   }
 };
 
-auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
-  Value vscale = rewriter.create<vector::VectorScaleOp>(loc);
-  return [loc, vscale, &rewriter](int64_t multiplier) {
-    return rewriter.create<arith::MulIOp>(
-        loc, vscale, rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
-  };
-}
-
 /// Legalize a multi-tile transfer_write as a single store loop. This is done as
 /// part of type decomposition as at this level we know each tile write is
 /// disjoint, but that information is lost after decomposition (without analysis
@@ -450,7 +443,8 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
                                          kMatchFailureUnsupportedMaskOp);
 
     auto loc = writeOp.getLoc();
-    auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
+    auto createVscaleMultiple =
+        vector::makeVscaleConstantBuilder(rewriter, loc);
 
     // Get SME tile and slice types.
     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
@@ -846,7 +840,8 @@ struct LowerIllegalTransposeStoreViaZA
       return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
 
     auto loc = writeOp.getLoc();
-    auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
+    auto createVscaleMultiple =
+        vector::makeVscaleConstantBuilder(rewriter, loc);
 
     auto transposeMap = AffineMapAttr::get(
         AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 40511ec100f0a..7651f75d54620 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -557,12 +557,12 @@ func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
 func.func @transpose_store_scalable_via_za(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
   // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
   // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
   // CHECK-NEXT: %[[INIT:.*]] = arm_sme.get_tile : vector<[4]x[4]xf32>
   // CHECK-NEXT: %[[V0:.*]] = vector.extract %[[VEC]][0] : vector<[4]xf32> from vector<2x[4]xf32>
   // CHECK-NEXT: %[[R0:.*]] = vector.insert %[[V0]], %[[INIT]] [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
   // CHECK-NEXT: %[[V1:.*]] = vector.extract %[[VEC]][1] : vector<[4]xf32> from vector<2x[4]xf32>
   // CHECK-NEXT: %[[RES:.*]] = vector.insert %[[V1]], %[[R0]] [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK-NEXT: %[[VSCALE:.*]] = vector.vscale
   // CHECK-NEXT: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
   // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C4_VSCALE]], %[[C2]] : vector<[4]x[4]xi1>
   // CHECK-NEXT: vector.transfer_write %[[RES]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {in_bounds = [true, true], permutation_map = #[[$TRANSPOSE_MAP_0]]} : vector<[4]x[4]xf32>, memref<?x?xf32>
@@ -597,11 +597,11 @@ func.func @transpose_store_scalable_via_za_masked(%vec: vector<2x[4]xf32>, %dest
 // CHECK-SAME:                                              %[[J:.*]]: index)
 func.func @transpose_store_scalable_via_za_multi_tile(%vec: vector<8x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
   // CHECK: %[[C4:.*]] = arith.constant 4 : index
-  // CHECK: %[[VSCALE:.*]] = vector.vscale
 
   // <skip 3x other extract+insert chain>
   // CHECK: %[[V3:.*]] = vector.extract %[[VEC]][3] : vector<[4]xf32> from vector<8x[4]xf32>
   // CHECK: %[[TILE_0:.*]] = vector.insert %[[V3]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK: %[[VSCALE:.*]] = vector.vscale
   // CHECK: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
   // CHECK: %[[MASK:.*]] = vector.create_mask %c4_vscale, %c4 : vector<[4]x[4]xi1>
   // CHECK: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>

>From 9ed32f9c9c775553e82efa4d9157680f51c61360 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 25 Jul 2024 15:12:48 +0000
Subject: [PATCH 4/5] Fixups

---
 mlir/test/Dialect/ArmSME/vector-legalization.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 7651f75d54620..458906a187982 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -618,8 +618,8 @@ func.func @transpose_store_scalable_via_za_multi_tile(%vec: vector<8x[4]xf32>, %
 
 // -----
 
-// CHECK-LABEL: @transpose_store_scalable_via_za_multi_tile_with_scalable_extracts
-func.func @transpose_store_scalable_via_za_multi_tile_with_scalable_extracts(%vec: vector<2x[8]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+// CHECK-LABEL: @transpose_store_scalable_via_za_multi_tile_wide
+func.func @transpose_store_scalable_via_za_multi_tile_wide(%vec: vector<2x[8]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
   // <check extracts from lower 4 x vscale of %vec>
   // CHECK: vector.scalable.extract
   // CHECK: %[[ROW_2_LOWER:.*]] = vector.scalable.extract %{{.*}}[0] : vector<[4]xf32> from vector<[8]xf32>

>From a37d6daae5956e0236db2cee28c971cca232b8e1 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 25 Jul 2024 15:37:22 +0000
Subject: [PATCH 5/5] Rebase

---
 mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index ba7bda099b9a3..004428bd343e2 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -902,7 +902,7 @@ struct LowerIllegalTransposeStoreViaZA
           rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
       auto smeWrite = rewriter.create<vector::TransferWriteOp>(
           loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
-          transposeMap, subMask, writeOp.getInBounds().value_or(ArrayAttr{}));
+          transposeMap, subMask, writeOp.getInBounds());
 
       if (writeOp.hasPureTensorSemantics())
         destTensorOrMemref = smeWrite.getResult();



More information about the Mlir-commits mailing list