[Mlir-commits] [mlir] [mlir][vector] Add scalable lowering for `transfer_write(transpose)` (PR #101353)

Benjamin Maxwell llvmlistbot at llvm.org
Mon Aug 12 05:35:26 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/101353

>From cad7ec3baa7f206a1be4573ada74b90319caa88b Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 31 Jul 2024 15:44:36 +0000
Subject: [PATCH 1/6] [mlir][vector] Add scalable lowering for
 `transfer_write(transpose)`

This specifically handles the case of a transpose from a vector type
like `vector<8x[4]xf32>` to `vector<[4]x8xf32>`. Such transposes occur
fairly frequently when scalably vectorizing `linalg.generic`s. There is
no direct lowering for these (as types like `vector<[4]x8xf32>` cannot
be represented in LLVM-IR). However, if the only use of the transpose is
a write, then it is possible to lower the `transfer_write(transpose)`
as a VLA loop.

Example:

```mlir
%transpose = vector.transpose %vec, [1, 0]
   : vector<4x[4]xf32> to vector<[4]x4xf32>
vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
   : vector<[4]x4xf32>,  memref<?x?xf32>
```

Becomes:

```mlir
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
%1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
%2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
%3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
%vscale = vector.vscale
%c4_vscale = arith.muli %vscale, %c4 : index
scf.for %idx = %c0 to %c4_vscale step %c1 {
  %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
  %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
  %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
  %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
  %slice_i = affine.apply #map(%idx)[%i]
  %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
  vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
    : vector<4xf32>, memref<?x?xf32>
}
```
---
 mlir/include/mlir/Conversion/Passes.td        |   4 +-
 .../mlir/Conversion/VectorToSCF/VectorToSCF.h |   8 +
 .../Conversion/VectorToSCF/VectorToSCF.cpp    | 187 +++++++++++++++++-
 .../VectorToSCF/tensor-transfer-ops.mlir      |  15 +-
 .../Conversion/VectorToSCF/vector-to-scf.mlir | 126 +++++++++++-
 5 files changed, 328 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b5bb2f42f2961c..7bde9e490e4f4e 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1300,7 +1300,9 @@ def ConvertVectorToSCF : Pass<"convert-vector-to-scf"> {
     Option<"targetRank", "target-rank", "unsigned", /*default=*/"1",
            "Target vector rank to which transfer ops should be lowered">,
     Option<"lowerTensors", "lower-tensors", "bool", /*default=*/"false",
-           "Lower transfer ops that operate on tensors">
+           "Lower transfer ops that operate on tensors">,
+    Option<"lowerScalable", "lower-scalable", "bool", /*default=*/"false",
+           "Add scalable vector specific lowerings (that introduce loops)">
   ];
 }
 
diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
index 1c834b6c690830..e0ef67c39a1013 100644
--- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
+++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
@@ -69,6 +69,14 @@ struct VectorTransferToSCFOptions {
     unroll = u;
     return *this;
   }
+  /// Enable scalable vector specific lowerings (which introduce loops). These
+  /// work alongside fullUnroll (which unrolls until the first scalable
+  /// dimension).
+  bool lowerScalable = false;
+  VectorTransferToSCFOptions enableLowerScalable(bool enable = true) {
+    lowerScalable = enable;
+    return *this;
+  }
 };
 
 /// Collect a set of patterns to convert from the Vector dialect to SCF + func.
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 19f02297bfbb71..c2fb8b6161ec32 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/Pass/Pass.h"
@@ -987,6 +988,185 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
   }
 };
 
+/// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
+/// and ConstantMaskOp.
+template <typename VscaleConstantBuilder>
+static FailureOr<SmallVector<OpFoldResult>>
+getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
+  if (!mask)
+    return SmallVector<OpFoldResult>{};
+  if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
+    return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
+      return OpFoldResult(dimSize);
+    });
+  }
+  if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
+    int dimIdx = 0;
+    VectorType maskType = constantMask.getVectorType();
+    auto indexType = IndexType::get(mask.getContext());
+    return llvm::map_to_vector(
+        constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
+          // A scalable dim in a constant_mask means vscale x dimSize.
+          if (maskType.getScalableDims()[dimIdx++])
+            return OpFoldResult(createVscaleMultiple(dimSize));
+          return OpFoldResult(IntegerAttr::get(indexType, dimSize));
+        });
+  }
+  return failure();
+}
+
+/// Scalable vector lowering of transfer_write(transpose). This lowering only
+/// supports rank 2 (scalable) vectors, but can be used in in conjunction with
+/// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
+/// unrolls until the first scalable dimension.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %transpose = vector.transpose %vec, [1, 0]
+///    : vector<4x[4]xf32> to vector<[4]x4xf32>
+/// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
+///    : vector<[4]x4xf32>,  memref<?x?xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %c1 = arith.constant 1 : index
+/// %c4 = arith.constant 4 : index
+/// %c0 = arith.constant 0 : index
+/// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %vscale = vector.vscale
+/// %c4_vscale = arith.muli %vscale, %c4 : index
+/// scf.for %idx = %c0 to %c4_vscale step %c1 {
+///   %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
+///   %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
+///   %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
+///   %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
+///   %slice_i = affine.apply #map(%idx)[%i]
+///   %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
+///   vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
+///     : vector<4xf32>, memref<?x?xf32>
+/// }
+/// ```
+struct ScalableTransposeTransferWriteConversion
+    : VectorToSCFPattern<vector::TransferWriteOp> {
+  using VectorToSCFPattern::VectorToSCFPattern;
+
+  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const override {
+    if (isTensorOp(writeOp) && !options.lowerTensors) {
+      return rewriter.notifyMatchFailure(
+          writeOp, "lowering tensor transfers is disabled");
+    }
+
+    auto vector = writeOp.getVector();
+    auto vectorType = vector.getType();
+    auto scalableFlags = vectorType.getScalableDims();
+    if (scalableFlags != ArrayRef<bool>{true, false}) {
+      return rewriter.notifyMatchFailure(
+          writeOp, "expected vector of form vector<[*]x*xty>");
+    }
+
+    auto permutationMap = writeOp.getPermutationMap();
+    if (!permutationMap.isIdentity()) {
+      return rewriter.notifyMatchFailure(
+          writeOp, "non-identity permutations are unsupported (lower first)");
+    }
+
+    if (!writeOp.isDimInBounds(0)) {
+      return rewriter.notifyMatchFailure(
+          writeOp, "out-of-bounds dims are unsupported (use masking)");
+    }
+
+    auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
+    if (!transposeOp ||
+        transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {
+      return rewriter.notifyMatchFailure(writeOp, "source not transpose");
+    }
+
+    auto loc = writeOp.getLoc();
+    auto createVscaleMultiple =
+        vector::makeVscaleConstantBuilder(rewriter, loc);
+
+    auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
+    if (failed(maskDims)) {
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "failed to resolve mask dims");
+    }
+
+    int64_t fixedDimSize = vectorType.getDimSize(1);
+    auto fixedDimOffsets = llvm::seq(fixedDimSize);
+
+    // Extract all slices from the source of the transpose.
+    auto transposeSource = transposeOp.getVector();
+    SmallVector<Value> transposeSourceSlices =
+        llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
+          return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx);
+        });
+
+    // Loop bounds and step.
+    auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto ub =
+        maskDims->empty()
+            ? Value(createVscaleMultiple(vectorType.getDimSize(0)))
+            : vector::getAsValues(rewriter, loc, maskDims->front()).front();
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+    // Generate a new mask for the slice.
+    VectorType sliceType = VectorType::Builder(vectorType).dropDim(0);
+    Value sliceMask = nullptr;
+    if (!maskDims->empty()) {
+      sliceMask = rewriter.create<vector::CreateMaskOp>(
+          loc, sliceType.clone(rewriter.getI1Type()),
+          ArrayRef<OpFoldResult>(*maskDims).drop_front());
+    }
+
+    ValueRange initLoopArgs =
+        isTensorOp(writeOp) ? writeOp.getSource() : ValueRange{};
+    auto result = rewriter.create<scf::ForOp>(
+        loc, lb, ub, step, initLoopArgs,
+        [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
+          // Indices for the new transfer op.
+          SmallVector<Value, 8> xferIndices;
+          getXferIndices(b, writeOp, iv, xferIndices);
+
+          // Extract a transposed slice from the source vector.
+          SmallVector<Value> transposeElements =
+              llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
+                return b.create<vector::ExtractOp>(
+                    loc, transposeSourceSlices[idx], iv);
+              });
+          auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType,
+                                                           transposeElements);
+
+          // Create the transfer_write for the slice.
+          Value dest =
+              loopIterArgs.empty() ? writeOp.getSource() : loopIterArgs.front();
+          auto newWriteOp = b.create<vector::TransferWriteOp>(
+              loc, sliceVec, dest, xferIndices,
+              ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
+          if (sliceMask)
+            newWriteOp.getMaskMutable().assign(sliceMask);
+
+          // Yield from the loop.
+          b.create<scf::YieldOp>(loc, loopIterArgs.empty()
+                                          ? ValueRange{}
+                                          : newWriteOp.getResult());
+        });
+
+    if (isTensorOp(writeOp))
+      rewriter.replaceOp(writeOp, result);
+    else
+      rewriter.eraseOp(writeOp);
+
+    return success();
+  }
+};
+
 } // namespace lowering_n_d
 
 namespace lowering_n_d_unrolled {
@@ -1503,7 +1683,10 @@ void mlir::populateVectorToSCFConversionPatterns(
                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
         patterns.getContext(), options);
   }
-
+  if (options.lowerScalable) {
+    patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
+        patterns.getContext(), options);
+  }
   if (options.targetRank == 1) {
     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
@@ -1522,6 +1705,7 @@ struct ConvertVectorToSCFPass
     this->fullUnroll = options.unroll;
     this->targetRank = options.targetRank;
     this->lowerTensors = options.lowerTensors;
+    this->lowerScalable = options.lowerScalable;
   }
 
   void runOnOperation() override {
@@ -1529,6 +1713,7 @@ struct ConvertVectorToSCFPass
     options.unroll = fullUnroll;
     options.targetRank = targetRank;
     options.lowerTensors = lowerTensors;
+    options.lowerScalable = lowerScalable;
 
     // Lower permutation maps first.
     RewritePatternSet lowerTransferPatterns(&getContext());
diff --git a/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir
index dac8e018f845ff..c542b79d5c80e4 100644
--- a/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir
+++ b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{lower-tensors=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{lower-tensors=true lower-scalable=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s
 
 // CHECK-LABEL: func @transfer_read_2d(
 //       CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<vector<4x9xf32>>
@@ -36,3 +36,16 @@ func.func @transfer_write_2d(%A : tensor<?x?xf32>, %vec : vector<2x3xf32>,
   return %t : tensor<?x?xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @scalable_transpose_store
+//  CHECK-SAME: %[[TENSOR:[a-z0-9]+]]: tensor<?x?xf32>
+//       CHECK: %[[RESULT:.*]] = scf.for {{.*}} iter_args(%[[ITER_ARG:.*]] = %[[TENSOR]]) -> (tensor<?x?xf32>)
+//       CHECK:   %[[WRITE_SLICE:.*]] = vector.transfer_write %{{.*}} %[[ITER_ARG]]
+//       CHECK:   scf.yield %[[WRITE_SLICE]]
+//       CHECK: return %[[RESULT]]
+func.func @scalable_transpose_store(%vec: vector<4x[4]xf32>, %dest: tensor<?x?xf32>, %i: index, %j: index) -> tensor<?x?xf32> {
+  %transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  %result = vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x4xf32>,  tensor<?x?xf32>
+  return %result : tensor<?x?xf32>
+}
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 3f4e70a6835af5..f7ded02d83fe5f 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf))" -split-input-file -allow-unregistered-dialect | FileCheck %s
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{full-unroll=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=FULL-UNROLL
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{full-unroll=true lower-scalable=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=FULL-UNROLL
 // RUN: mlir-opt %s "-convert-vector-to-scf=full-unroll target-rank=0" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=TARGET-RANK-ZERO
 
 // CHECK-LABEL: func @vector_transfer_ops_0d(
@@ -661,10 +661,10 @@ func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[
 // CHECK:           memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref<vector<3x[4]xi1>>
 // CHECK:           %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
 // CHECK:           %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
-// CHECK:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
-// CHECK:             %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
-// CHECK:             %[[READ_SLICE:.*]] = vector.transfer_read %[[ARG]]{{\[}}%[[VAL_11]], %[[C0]]], %[[PADDING]], %[[MASK_SLICE]] {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32>
-// CHECK:             memref.store %[[READ_SLICE]], %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
+// CHECK:           scf.for %[[VSCALE:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK:             %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xi1>>
+// CHECK:             %[[READ_SLICE:.*]] = vector.transfer_read %[[ARG]]{{\[}}%[[VSCALE]], %[[C0]]], %[[PADDING]], %[[MASK_SLICE]] {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32>
+// CHECK:             memref.store %[[READ_SLICE]], %[[UNPACK_VECTOR]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xf32>>
 // CHECK:           }
 // CHECK:           %[[RESULT:.*]] = memref.load %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
 // CHECK:           return %[[RESULT]] : vector<3x[4]xf32>
@@ -695,10 +695,10 @@ func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memr
 // CHECK:           memref.store %[[VEC]], %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
 // CHECK:           %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
 // CHECK:           %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
-// CHECK:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
-// CHECK:             %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
-// CHECK:             %[[VECTOR_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
-// CHECK:             vector.transfer_write %[[MASK_SLICE]], %[[MEMREF]]{{\[}}%[[VAL_11]], %[[C0]]], %[[VECTOR_SLICE]] {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32>
+// CHECK:           scf.for %[[VSCALE:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK:             %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_VECTOR]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xf32>>
+// CHECK:             %[[VECTOR_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xi1>>
+// CHECK:             vector.transfer_write %[[MASK_SLICE]], %[[MEMREF]]{{\[}}%[[VSCALE]], %[[C0]]], %[[VECTOR_SLICE]] {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32>
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -803,3 +803,111 @@ func.func @unroll_transfer_write_target_rank_zero(%vec : vector<2xi32>) {
 // TARGET-RANK-ZERO: %[[EXTRACTED2:.*]] = vector.extract {{.*}} : i32 from vector<2xi32>
 // TARGET-RANK-ZERO: %[[BROADCASTED2:.*]] = vector.broadcast %[[EXTRACTED2]] : i32 to vector<i32>
 // TARGET-RANK-ZERO: vector.transfer_write %[[BROADCASTED2]], %[[ALLOC]]{{.*}} : vector<i32>, memref<4xi32>
+
+// -----
+
+func.func @scalable_transpose_store_unmasked(%vec: vector<4x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+  %transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x4xf32>,  memref<?x?xf32>
+  return
+}
+// FULL-UNROLL: #[[$SLICE_MAP:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// FULL-UNROLL-LABEL:   func.func @scalable_transpose_store_unmasked(
+// FULL-UNROLL-SAME:                                                 %[[VEC:.*]]: vector<4x[4]xf32>,
+// FULL-UNROLL-SAME:                                                 %[[DEST:.*]]: memref<?x?xf32>,
+// FULL-UNROLL-SAME:                                                 %[[I:.*]]: index,
+// FULL-UNROLL-SAME:                                                 %[[J:.*]]: index)
+// FULL-UNROLL-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// FULL-UNROLL-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// FULL-UNROLL-DAG:       %[[C4:.*]] = arith.constant 4 : index
+// FULL-UNROLL:           %[[SLICE_0:.*]] = vector.extract %[[VEC]][0] : vector<[4]xf32> from vector<4x[4]xf32>
+// FULL-UNROLL:           %[[SLICE_1:.*]] = vector.extract %[[VEC]][1] : vector<[4]xf32> from vector<4x[4]xf32>
+// FULL-UNROLL:           %[[SLICE_2:.*]] = vector.extract %[[VEC]][2] : vector<[4]xf32> from vector<4x[4]xf32>
+// FULL-UNROLL:           %[[SLICE_3:.*]] = vector.extract %[[VEC]][3] : vector<[4]xf32> from vector<4x[4]xf32>
+// FULL-UNROLL:           %[[VSCALE:.*]] = vector.vscale
+// FULL-UNROLL:           %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+// FULL-UNROLL:           scf.for %[[VAL_13:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
+// FULL-UNROLL:             %[[SLICE_I:.*]] = affine.apply #[[$SLICE_MAP]](%[[VAL_13]]){{\[}}%[[I]]]
+// FULL-UNROLL:             %[[ELEM_0:.*]] = vector.extract %[[SLICE_0]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
+// FULL-UNROLL:             %[[ELEM_1:.*]] = vector.extract %[[SLICE_1]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
+// FULL-UNROLL:             %[[ELEM_2:.*]] = vector.extract %[[SLICE_2]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
+// FULL-UNROLL:             %[[ELEM_3:.*]] = vector.extract %[[SLICE_3]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
+// FULL-UNROLL:             %[[TRANSPOSE_SLICE:.*]] = vector.from_elements %[[ELEM_0]], %[[ELEM_1]], %[[ELEM_2]], %[[ELEM_3]] : vector<4xf32>
+// FULL-UNROLL:             vector.transfer_write %[[TRANSPOSE_SLICE]], %[[DEST]]{{\[}}%[[SLICE_I]], %[[J]]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+// FULL-UNROLL:           }
+// FULL-UNROLL:           return
+
+// -----
+
+func.func @scalable_transpose_store_dynamic_mask(%vec: vector<4x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index, %a: index, %b: index) {
+  %transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  %mask = vector.create_mask %a, %b : vector<[4]x4xi1>
+  vector.transfer_write %transpose, %dest[%i, %j], %mask {in_bounds = [true, true]} : vector<[4]x4xf32>,  memref<?x?xf32>
+  return
+}
+// FULL-UNROLL-LABEL:   func.func @scalable_transpose_store_dynamic_mask(
+// FULL-UNROLL-SAME:                                                     %{{.*}}, %[[A:.*]]: index, %[[B:.*]]: index)
+// FULL-UNROLL:           %[[SLICE_MASK:.*]] = vector.create_mask %[[B]] : vector<4xi1>
+// FULL-UNROLL:           scf.for %{{.*}} to %[[A]]
+// FULL-UNROLL:             vector.transfer_write {{.*}}, %[[SLICE_MASK]]
+
+// -----
+
+func.func @scalable_transpose_store_constant_mask(%vec: vector<4x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+  %transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  %mask = vector.constant_mask [4, 3] : vector<[4]x4xi1>
+  vector.transfer_write %transpose, %dest[%i, %j], %mask {in_bounds = [true, true]} : vector<[4]x4xf32>,  memref<?x?xf32>
+  return
+}
+// FULL-UNROLL-LABEL:   func.func @scalable_transpose_store_constant_mask
+// FULL-UNROLL:           %[[C3:.*]] = arith.constant 3 : index
+// FULL-UNROLL:           %[[C4:.*]] = arith.constant 4 : index
+// FULL-UNROLL:           %[[VSCALE:.*]] = vector.vscale
+// FULL-UNROLL:           %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+// FULL-UNROLL:           %[[SLICE_MASK:.*]] = vector.create_mask %[[C3]] : vector<4xi1>
+// FULL-UNROLL:           scf.for %{{.*}} to %[[C4_VSCALE]]
+// FULL-UNROLL:             vector.transfer_write {{.*}}, %[[SLICE_MASK]]
+
+// -----
+
+/// Unsupported transpose.
+func.func @negative_scalable_transpose_store_0(%vec: vector<[4]x4xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+  %transpose = vector.transpose %vec, [1, 0] : vector<[4]x4xf32> to vector<4x[4]xf32>
+  vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]} : vector<4x[4]xf32>,  memref<?x?xf32>
+  return
+}
+// FULL-UNROLL-LABEL: @negative_scalable_transpose_store_0
+// FULL-UNROLL-NOT:   scf.for
+
+// -----
+
+/// Non-identity permutation map (should be lowered first).
+func.func @negative_scalable_transpose_store_1(%vec: vector<4x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+  %transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true], permutation_map = affine_map<(d0,d1) -> (d1, d0)> } : vector<[4]x4xf32>,  memref<?x?xf32>
+  return
+}
+// FULL-UNROLL-LABEL: @negative_scalable_transpose_store_1
+// FULL-UNROLL-NOT:   scf.for
+
+
+// -----
+
+/// Out-of-bounds dim.
+func.func @negative_scalable_transpose_store_2(%vec: vector<4x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+  %transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [false, true]} : vector<[4]x4xf32>,  memref<?x?xf32>
+  return
+}
+// FULL-UNROLL-LABEL: @negative_scalable_transpose_store_2
+// FULL-UNROLL-NOT:   scf.for
+
+// -----
+
+/// Source not a vector.transpose.
+func.func @negative_scalable_transpose_store_3(%vec: vector<[4]x4xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+  vector.transfer_write %vec, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x4xf32>,  memref<?x?xf32>
+  return
+}
+// FULL-UNROLL-LABEL: @negative_scalable_transpose_store_3
+// FULL-UNROLL-NOT:   scf.for

>From ba251ee91035cf2abf89922d1706eb04a656ca32 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 2 Aug 2024 10:24:31 +0000
Subject: [PATCH 2/6] Fixups

---
 mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp  |  2 +-
 .../Conversion/VectorToSCF/vector-to-scf.mlir    | 16 ++++++++--------
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index c2fb8b6161ec32..140b5d2ba6aecd 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1068,7 +1068,7 @@ struct ScalableTransposeTransferWriteConversion
     auto scalableFlags = vectorType.getScalableDims();
     if (scalableFlags != ArrayRef<bool>{true, false}) {
       return rewriter.notifyMatchFailure(
-          writeOp, "expected vector of form vector<[*]x*xty>");
+          writeOp, "expected vector of the form vector<[N]xMxty>");
     }
 
     auto permutationMap = writeOp.getPermutationMap();
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index f7ded02d83fe5f..aede6624661c69 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -661,10 +661,10 @@ func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[
 // CHECK:           memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref<vector<3x[4]xi1>>
 // CHECK:           %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
 // CHECK:           %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
-// CHECK:           scf.for %[[VSCALE:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
-// CHECK:             %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xi1>>
-// CHECK:             %[[READ_SLICE:.*]] = vector.transfer_read %[[ARG]]{{\[}}%[[VSCALE]], %[[C0]]], %[[PADDING]], %[[MASK_SLICE]] {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32>
-// CHECK:             memref.store %[[READ_SLICE]], %[[UNPACK_VECTOR]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xf32>>
+// CHECK:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK:             %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
+// CHECK:             %[[READ_SLICE:.*]] = vector.transfer_read %[[ARG]]{{\[}}%[[VAL_11]], %[[C0]]], %[[PADDING]], %[[MASK_SLICE]] {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32>
+// CHECK:             memref.store %[[READ_SLICE]], %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
 // CHECK:           }
 // CHECK:           %[[RESULT:.*]] = memref.load %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
 // CHECK:           return %[[RESULT]] : vector<3x[4]xf32>
@@ -695,10 +695,10 @@ func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memr
 // CHECK:           memref.store %[[VEC]], %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
 // CHECK:           %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
 // CHECK:           %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
-// CHECK:           scf.for %[[VSCALE:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
-// CHECK:             %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_VECTOR]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xf32>>
-// CHECK:             %[[VECTOR_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xi1>>
-// CHECK:             vector.transfer_write %[[MASK_SLICE]], %[[MEMREF]]{{\[}}%[[VSCALE]], %[[C0]]], %[[VECTOR_SLICE]] {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32>
+// CHECK:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK:             %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
+// CHECK:             %[[VECTOR_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
+// CHECK:             vector.transfer_write %[[MASK_SLICE]], %[[MEMREF]]{{\[}}%[[VAL_11]], %[[C0]]], %[[VECTOR_SLICE]] {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32>
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }

>From f040900d3be0678a82b93a212a80885025f11b94 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 7 Aug 2024 16:16:06 +0000
Subject: [PATCH 3/6] Fix UAF

---
 mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 140b5d2ba6aecd..b4f0fa1038afc8 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1125,8 +1125,8 @@ struct ScalableTransposeTransferWriteConversion
           ArrayRef<OpFoldResult>(*maskDims).drop_front());
     }
 
-    ValueRange initLoopArgs =
-        isTensorOp(writeOp) ? writeOp.getSource() : ValueRange{};
+    Value initDest = isTensorOp(writeOp) ? writeOp.getSource() : Value{};
+    ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
     auto result = rewriter.create<scf::ForOp>(
         loc, lb, ub, step, initLoopArgs,
         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {

>From 85bb34336c20c9cd83a33492a4b4e381f127de02 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 8 Aug 2024 10:08:38 +0000
Subject: [PATCH 4/6] Fixups

---
 mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp           | 8 ++++----
 mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir | 4 ++--
 mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir       | 2 --
 3 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index b4f0fa1038afc8..8547556565d594 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1016,7 +1016,7 @@ getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
 }
 
 /// Scalable vector lowering of transfer_write(transpose). This lowering only
-/// supports rank 2 (scalable) vectors, but can be used in in conjunction with
+/// supports rank 2 (scalable) vectors, but can be used in conjunction with
 /// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
 /// unrolls until the first scalable dimension.
 ///
@@ -1063,9 +1063,9 @@ struct ScalableTransposeTransferWriteConversion
           writeOp, "lowering tensor transfers is disabled");
     }
 
-    auto vector = writeOp.getVector();
-    auto vectorType = vector.getType();
-    auto scalableFlags = vectorType.getScalableDims();
+    Value vector = writeOp.getVector();
+    VectorType vectorType = writeOp.getVectorType();
+    ArrayRef<bool> scalableFlags = vectorType.getScalableDims();
     if (scalableFlags != ArrayRef<bool>{true, false}) {
       return rewriter.notifyMatchFailure(
           writeOp, "expected vector of the form vector<[N]xMxty>");
diff --git a/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir
index c542b79d5c80e4..6ec74f6b32db94 100644
--- a/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir
+++ b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir
@@ -44,8 +44,8 @@ func.func @transfer_write_2d(%A : tensor<?x?xf32>, %vec : vector<2x3xf32>,
 //       CHECK:   %[[WRITE_SLICE:.*]] = vector.transfer_write %{{.*}} %[[ITER_ARG]]
 //       CHECK:   scf.yield %[[WRITE_SLICE]]
 //       CHECK: return %[[RESULT]]
-func.func @scalable_transpose_store(%vec: vector<4x[4]xf32>, %dest: tensor<?x?xf32>, %i: index, %j: index) -> tensor<?x?xf32> {
+func.func @scalable_transpose_store(%vec: vector<4x[4]xf32>, %A: tensor<?x?xf32>, %base1: index, %base2: index) -> tensor<?x?xf32> {
   %transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
-  %result = vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x4xf32>,  tensor<?x?xf32>
+  %result = vector.transfer_write %transpose, %A[%base1, %base2] {in_bounds = [true, true]} : vector<[4]x4xf32>,  tensor<?x?xf32>
   return %result : tensor<?x?xf32>
 }
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index aede6624661c69..d7620b74089925 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -834,8 +834,6 @@ func.func @scalable_transpose_store_unmasked(%vec: vector<4x[4]xf32>, %dest: mem
 // FULL-UNROLL:             %[[ELEM_3:.*]] = vector.extract %[[SLICE_3]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
 // FULL-UNROLL:             %[[TRANSPOSE_SLICE:.*]] = vector.from_elements %[[ELEM_0]], %[[ELEM_1]], %[[ELEM_2]], %[[ELEM_3]] : vector<4xf32>
 // FULL-UNROLL:             vector.transfer_write %[[TRANSPOSE_SLICE]], %[[DEST]]{{\[}}%[[SLICE_I]], %[[J]]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
-// FULL-UNROLL:           }
-// FULL-UNROLL:           return
 
 // -----
 

>From d70d7969d77066d4d19ae18565bdc6e66a438659 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 9 Aug 2024 14:54:00 +0000
Subject: [PATCH 5/6] Comments

---
 mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 8547556565d594..f62be1ec5000ee 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1063,8 +1063,10 @@ struct ScalableTransposeTransferWriteConversion
           writeOp, "lowering tensor transfers is disabled");
     }
 
-    Value vector = writeOp.getVector();
     VectorType vectorType = writeOp.getVectorType();
+
+    // Note: By comparing the scalable dims to an ArrayRef of length two this
+    // implicitly checks the rank (is also two).
     ArrayRef<bool> scalableFlags = vectorType.getScalableDims();
     if (scalableFlags != ArrayRef<bool>{true, false}) {
       return rewriter.notifyMatchFailure(
@@ -1077,11 +1079,15 @@ struct ScalableTransposeTransferWriteConversion
           writeOp, "non-identity permutations are unsupported (lower first)");
     }
 
+    // Note: This pattern is only lowering the leading dimension (to a loop),
+    // so we only check if the leading dimension is in bounds. The in-bounds
+    // attribute for the trailing dimension will be propagated.
     if (!writeOp.isDimInBounds(0)) {
       return rewriter.notifyMatchFailure(
           writeOp, "out-of-bounds dims are unsupported (use masking)");
     }
 
+    Value vector = writeOp.getVector();
     auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
     if (!transposeOp ||
         transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {

>From 7b506081b4f7acd20abd997968c3078adc7fe390 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 12 Aug 2024 12:33:34 +0000
Subject: [PATCH 6/6] Add checkLowerTensors helper

---
 .../Conversion/VectorToSCF/VectorToSCF.cpp    | 47 ++++++++++---------
 1 file changed, 26 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index f62be1ec5000ee..3a4dc806efe976 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -45,6 +45,18 @@ namespace {
 /// Attribute name used for labeling transfer ops during progressive lowering.
 static const char kPassLabel[] = "__vector_to_scf_lowering__";
 
+/// Return true if this transfer op operates on a source tensor.
+static bool isTensorOp(VectorTransferOpInterface xferOp) {
+  if (isa<RankedTensorType>(xferOp.getShapedType())) {
+    if (isa<vector::TransferWriteOp>(xferOp)) {
+      // TransferWriteOps on tensors have a result.
+      assert(xferOp->getNumResults() > 0);
+    }
+    return true;
+  }
+  return false;
+}
+
 /// Patterns that inherit from this struct have access to
 /// VectorTransferToSCFOptions.
 template <typename OpTy>
@@ -53,6 +65,15 @@ struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
                               VectorTransferToSCFOptions opt)
       : OpRewritePattern<OpTy>(context), options(opt) {}
 
+  LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
+                                  PatternRewriter &rewriter) const {
+    if (isTensorOp(xferOp) && !options.lowerTensors) {
+      return rewriter.notifyMatchFailure(
+          xferOp, "lowering tensor transfers is disabled");
+    }
+    return success();
+  }
+
   VectorTransferToSCFOptions options;
 };
 
@@ -258,19 +279,6 @@ static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
 }
 
-/// Return true if this transfer op operates on a source tensor.
-template <typename OpTy>
-static bool isTensorOp(OpTy xferOp) {
-  if (isa<RankedTensorType>(xferOp.getShapedType())) {
-    if (xferOp.getOperationName() == TransferWriteOp::getOperationName()) {
-      // TransferWriteOps on tensors have a result.
-      assert(xferOp->getNumResults() > 0);
-    }
-    return true;
-  }
-  return false;
-}
-
 namespace lowering_n_d {
 
 /// Helper data structure for data and mask buffers.
@@ -1058,10 +1066,8 @@ struct ScalableTransposeTransferWriteConversion
 
   LogicalResult matchAndRewrite(TransferWriteOp writeOp,
                                 PatternRewriter &rewriter) const override {
-    if (isTensorOp(writeOp) && !options.lowerTensors) {
-      return rewriter.notifyMatchFailure(
-          writeOp, "lowering tensor transfers is disabled");
-    }
+    if (failed(checkLowerTensors(writeOp, rewriter)))
+      return failure();
 
     VectorType vectorType = writeOp.getVectorType();
 
@@ -1286,9 +1292,8 @@ struct UnrollTransferReadConversion
     if (xferOp.getVectorType().getRank() <= options.targetRank)
       return rewriter.notifyMatchFailure(
           xferOp, "vector rank is less or equal to target rank");
-    if (isTensorOp(xferOp) && !options.lowerTensors)
-      return rewriter.notifyMatchFailure(
-          xferOp, "transfers operating on tensors are excluded");
+    if (failed(checkLowerTensors(xferOp, rewriter)))
+      return failure();
     // Transfer ops that modify the element type are not supported atm.
     if (xferOp.getVectorType().getElementType() !=
         xferOp.getShapedType().getElementType())
@@ -1424,7 +1429,7 @@ struct UnrollTransferWriteConversion
     if (inputVectorTy.getRank() <= options.targetRank)
       return failure();
 
-    if (isTensorOp(xferOp) && !options.lowerTensors)
+    if (failed(checkLowerTensors(xferOp, rewriter)))
       return failure();
     // Transfer ops that modify the element type are not supported atm.
     if (inputVectorTy.getElementType() !=



More information about the Mlir-commits mailing list