[Mlir-commits] [mlir] [mlir][ArmSME] Add rewrite to lift illegal vector.transposes to memory (PR #80170)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Feb 1 02:51:09 PST 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/80170

>From b92efad93455ac13f0a66ab8ac85802ed6a84372 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 31 Jan 2024 17:33:38 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Add rewrite to lift illegal
 vector.transposes to memory

When unrolling the reduction dimension of something like a matmul for
SME, you can end up with transposed reads of illegal types, like so:

```mlir
%illegalRead = vector.transfer_read %memref[%a, %b]
                : memref<?x?xf32>, vector<[8]x4xf32>
%legalType = vector.transpose %illegalRead, [1, 0]
                : vector<[8]x4xf32> to vector<4x[8]xf32>
```

Here the `vector<[8]x4xf32>` is an illegal type, there's no way to
lower a scalable vector of fixed vectors. However, as the final type
`vector<4x[8]xf32>` is legal, we can instead lift the transpose to
memory (producing a strided memref), and eliminate all the illegal
types. This is shown below.

```mlir
%readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
                : memref<?x?xf32> to memref<?x?xf32>
%transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
                : memref<?x?xf32> to memref<?x?xf32>
%legalType = vector.transfer_read %transpose[%c0, %c0]
                : memref<?x?xf32>, vector<4x[8]xf32>
```
---
 .../ArmSME/Transforms/VectorLegalization.cpp  | 163 +++++++++++++++++-
 .../Dialect/ArmSME/vector-legalization.mlir   |  75 ++++++++
 2 files changed, 235 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 85ec53c2618aa..a3db2d2395528 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -7,8 +7,6 @@
 //===----------------------------------------------------------------------===//
 //
 // This pass legalizes vector operations so they can be lowered to ArmSME.
-// Currently, this only implements the decomposition of vector operations that
-// use vector sizes larger than an SME tile, into multiple SME-sized operations.
 //
 // Note: In the context of this pass 'tile' always refers to an SME tile.
 //
@@ -19,6 +17,7 @@
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
@@ -35,6 +34,10 @@ using namespace mlir::arm_sme;
 
 namespace {
 
+//===----------------------------------------------------------------------===//
+// Decomposition of vector operations larger than an SME tile
+//===----------------------------------------------------------------------===//
+
 // Common match failure reasons.
 static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
     "op vector size is not multiple of SME tiles");
@@ -338,13 +341,166 @@ struct LegalizeTransferWriteOpsByDecomposition
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ArmSME-specific fixup canonicalizations/folds
+//===----------------------------------------------------------------------===//
+
+/// Lifts an illegal vector.transpose and vector.transfer_read to a
+/// memref.subview + memref.transpose, followed by a legal read.
+///
+/// 'Illegal' here means a leading scalable dimension and a fixed trailing
+/// dimension, which has no valid lowering.
+///
+/// The memref.transpose is metadata-only transpose that produces a strided
+/// memref, which eventually becomes a loop reading individual elements.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %illegalRead = vector.transfer_read %memref[%a, %b]
+///                  : memref<?x?xf32>, vector<[8]x4xf32>
+///  %legalType = vector.transpose %illegalRead, [1, 0]
+///                  : vector<[8]x4xf32> to vector<4x[8]xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
+///                  : memref<?x?xf32> to memref<?x?xf32>
+///  %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
+///                  : memref<?x?xf32> to memref<?x?xf32>
+///  %legalType = vector.transfer_read %transpose[%c0, %c0]
+///                  : memref<?x?xf32>, vector<4x[8]xf32>
+///  ```
+struct LiftIllegalVectorTransposeToMemory
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  static bool isIllegalVectorType(VectorType vType) {
+    bool seenFixedDim = false;
+    for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+      seenFixedDim |= !scalableFlag;
+      if (seenFixedDim && scalableFlag)
+        return true;
+    }
+    return false;
+  }
+
+  static Value getExtensionSource(Operation *op) {
+    if (auto signExtend = dyn_cast<arith::ExtSIOp>(op))
+      return signExtend.getIn();
+    if (auto zeroExtend = dyn_cast<arith::ExtUIOp>(op))
+      return zeroExtend.getIn();
+    if (auto floatExtend = dyn_cast<arith::ExtFOp>(op))
+      return floatExtend.getIn();
+    return {};
+  }
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = transposeOp.getSourceVectorType();
+    auto resultType = transposeOp.getResultVectorType();
+    if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(
+          transposeOp, "expected transpose from illegal type to legal type");
+
+    Value maybeRead = transposeOp.getVector();
+    auto *transposeSourceOp = maybeRead.getDefiningOp();
+    Operation *extendOp = nullptr;
+    if (Value extendSource = getExtensionSource(transposeSourceOp)) {
+      maybeRead = extendSource;
+      extendOp = transposeSourceOp;
+    }
+
+    auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
+    if (!illegalRead)
+      return rewriter.notifyMatchFailure(
+          transposeOp,
+          "expected source to be (possibility extended) transfer_read");
+
+    if (!illegalRead.getPermutationMap().isIdentity())
+      return rewriter.notifyMatchFailure(
+          illegalRead, "expected read to have identity permutation map");
+
+    auto loc = transposeOp.getLoc();
+    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+    // Create a subview that matches the size of the illegal read vector type.
+    auto readType = illegalRead.getVectorType();
+    auto readSizes = llvm::map_to_vector(
+        llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
+        [&](auto dim) -> Value {
+          auto [size, isScalable] = dim;
+          auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
+          if (!isScalable)
+            return dimSize;
+          auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+          return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
+        });
+    SmallVector<Value> strides(readType.getRank(), Value(one));
+    auto readSubview = rewriter.create<memref::SubViewOp>(
+        loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
+        strides);
+
+    // Apply the transpose to all values/attributes of the transfer_read.
+    // The mask.
+    Value mask = illegalRead.getMask();
+    if (mask) {
+      // Note: The transpose for the mask should fold into the
+      // vector.create_mask/constant_mask op, which will then become legal.
+      mask = rewriter.create<vector::TransposeOp>(loc, mask,
+                                                  transposeOp.getPermutation());
+    }
+    // The source memref.
+    mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
+        transposeOp.getPermutation(), getContext());
+    auto transposedSubview = rewriter.create<memref::TransposeOp>(
+        loc, readSubview, AffineMapAttr::get(transposeMap));
+    ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
+    // The `in_bounds` attribute.
+    if (inBoundsAttr) {
+      SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
+                                            inBoundsAttr.end());
+      applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
+      inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
+    }
+
+    VectorType legalReadType =
+        VectorType::Builder(resultType)
+            .setElementType(illegalRead.getVectorType().getElementType());
+    // Note: The indices are all zero as the subview is already offset.
+    SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
+    Value legalRead = rewriter.create<vector::TransferReadOp>(
+        loc, legalReadType, transposedSubview, readIndices,
+        illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
+        inBoundsAttr);
+
+    // Replace the transpose with the new read, extending the result if
+    // necessary.
+    rewriter.replaceOp(transposeOp, [&]() -> Value {
+      if (!extendOp)
+        return legalRead;
+      if (isa<arith::ExtSIOp>(extendOp))
+        return rewriter.create<arith::ExtSIOp>(loc, resultType, legalRead);
+      if (isa<arith::ExtUIOp>(extendOp))
+        return rewriter.create<arith::ExtUIOp>(loc, resultType, legalRead);
+      if (isa<arith::ExtFOp>(extendOp))
+        return rewriter.create<arith::ExtFOp>(loc, resultType, legalRead);
+      return legalRead;
+    }());
+
+    return success();
+  }
+};
+
 struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
     auto *context = &getContext();
     OneToNTypeConverter converter;
     RewritePatternSet patterns(context);
-
     converter.addConversion([](Type type) { return type; });
     converter.addConversion(
         [](VectorType vectorType,
@@ -358,6 +514,7 @@ struct VectorLegalizationPass
           return success();
         });
 
+    patterns.add<LiftIllegalVectorTransposeToMemory>(context);
     // Note: High benefit to ensure masked outer products are lowered first.
     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
         converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index a20abeefedcfd..2317930d3d061 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -266,3 +266,78 @@ func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: m
   vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_no_mask(
+// CHECK-SAME:                                            %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                            %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                            %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory_no_mask(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
+  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[LEGAL_READ:.*]]  = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+  // CHECK-NEXT: return %[[LEGAL_READ]]
+  %pad = arith.constant 0.0 : f32
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xf32>, vector<[8]x4xf32>
+  %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+  return %legalType : vector<4x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory(
+// CHECK-SAME:                                    %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %dim0: index, %dim1: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
+  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
+  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[LEGAL_READ:.*]]  = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]], %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+  // CHECK-NEXT: return %[[LEGAL_READ]]
+  %pad = arith.constant 0.0 : f32
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x4xi1>
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad, %mask : memref<?x?xf32>, vector<[8]x4xf32>
+  %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+  return %legalType : vector<4x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_arith_extop(
+// CHECK-SAME:                                                     %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                     %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                     %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>)
+func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: index, %memref: memref<?x?xi8>) -> vector<4x[8]xi32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xi8> to memref<?x4xi8, strided<[?, 1], offset: ?>>
+  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xi8, strided<[?, 1], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xi8, strided<[?, ?], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_I8]] : memref<?x?xi8, strided<[?, ?], offset: ?>>, vector<4x[8]xi8>
+  // CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32>
+  // CHECK-NEXT: return %[[EXT_TYPE]]
+  %pad = arith.constant 0 : i8
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xi8>, vector<[8]x4xi8>
+  %extRead = arith.extsi %illegalRead : vector<[8]x4xi8> to vector<[8]x4xi32>
+  %legalType = vector.transpose %extRead, [1, 0] : vector<[8]x4xi32> to vector<4x[8]xi32>
+  return %legalType : vector<4x[8]xi32>
+}

>From ab8c3d740e8fd04b998081be66e13bda81f7a758 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 1 Feb 2024 10:49:51 +0000
Subject: [PATCH 2/2] Fixups

---
 .../ArmSME/Transforms/VectorLegalization.cpp  | 34 ++++------
 .../Dialect/ArmSME/vector-legalization.mlir   | 68 +++++++++----------
 2 files changed, 47 insertions(+), 55 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index a3db2d2395528..f960578e69ec6 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -388,12 +388,8 @@ struct LiftIllegalVectorTransposeToMemory
   }
 
   static Value getExtensionSource(Operation *op) {
-    if (auto signExtend = dyn_cast<arith::ExtSIOp>(op))
-      return signExtend.getIn();
-    if (auto zeroExtend = dyn_cast<arith::ExtUIOp>(op))
-      return zeroExtend.getIn();
-    if (auto floatExtend = dyn_cast<arith::ExtFOp>(op))
-      return floatExtend.getIn();
+    if (isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
+      return op->getOperand(0);
     return {};
   }
 
@@ -405,6 +401,7 @@ struct LiftIllegalVectorTransposeToMemory
       return rewriter.notifyMatchFailure(
           transposeOp, "expected transpose from illegal type to legal type");
 
+    // Look through extend for transfer_read.
     Value maybeRead = transposeOp.getVector();
     auto *transposeSourceOp = maybeRead.getDefiningOp();
     Operation *extendOp = nullptr;
@@ -417,7 +414,7 @@ struct LiftIllegalVectorTransposeToMemory
     if (!illegalRead)
       return rewriter.notifyMatchFailure(
           transposeOp,
-          "expected source to be (possibility extended) transfer_read");
+          "expected source to be (possibly extended) transfer_read");
 
     if (!illegalRead.getPermutationMap().isIdentity())
       return rewriter.notifyMatchFailure(
@@ -444,8 +441,8 @@ struct LiftIllegalVectorTransposeToMemory
         loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
         strides);
 
-    // Apply the transpose to all values/attributes of the transfer_read.
-    // The mask.
+    // Apply the transpose to all values/attributes of the transfer_read:
+    // - The mask
     Value mask = illegalRead.getMask();
     if (mask) {
       // Note: The transpose for the mask should fold into the
@@ -453,13 +450,13 @@ struct LiftIllegalVectorTransposeToMemory
       mask = rewriter.create<vector::TransposeOp>(loc, mask,
                                                   transposeOp.getPermutation());
     }
-    // The source memref.
+    // - The source memref
     mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
         transposeOp.getPermutation(), getContext());
     auto transposedSubview = rewriter.create<memref::TransposeOp>(
         loc, readSubview, AffineMapAttr::get(transposeMap));
     ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
-    // The `in_bounds` attribute.
+    // - The `in_bounds` attribute
     if (inBoundsAttr) {
       SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
                                             inBoundsAttr.end());
@@ -472,22 +469,17 @@ struct LiftIllegalVectorTransposeToMemory
             .setElementType(illegalRead.getVectorType().getElementType());
     // Note: The indices are all zero as the subview is already offset.
     SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
-    Value legalRead = rewriter.create<vector::TransferReadOp>(
+    auto legalRead = rewriter.create<vector::TransferReadOp>(
         loc, legalReadType, transposedSubview, readIndices,
         illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
         inBoundsAttr);
 
     // Replace the transpose with the new read, extending the result if
     // necessary.
-    rewriter.replaceOp(transposeOp, [&]() -> Value {
-      if (!extendOp)
-        return legalRead;
-      if (isa<arith::ExtSIOp>(extendOp))
-        return rewriter.create<arith::ExtSIOp>(loc, resultType, legalRead);
-      if (isa<arith::ExtUIOp>(extendOp))
-        return rewriter.create<arith::ExtUIOp>(loc, resultType, legalRead);
-      if (isa<arith::ExtFOp>(extendOp))
-        return rewriter.create<arith::ExtFOp>(loc, resultType, legalRead);
+    rewriter.replaceOp(transposeOp, [&]() -> Operation * {
+      if (extendOp)
+        return rewriter.create(loc, extendOp->getName().getIdentifier(),
+                               Value(legalRead), resultType);
       return legalRead;
     }());
 
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 2317930d3d061..31ec7c5e97232 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -269,11 +269,11 @@ func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: m
 
 // -----
 
-// CHECK-LABEL: @lift_illegal_transpose_to_memory_no_mask(
-// CHECK-SAME:                                            %[[INDEXA:[a-z0-9]+]]: index,
-// CHECK-SAME:                                            %[[INDEXB:[a-z0-9]+]]: index,
-// CHECK-SAME:                                            %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
-func.func @lift_illegal_transpose_to_memory_no_mask(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+// CHECK-LABEL: @lift_illegal_transpose_to_memory(
+// CHECK-SAME:                                    %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
   // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
   // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
   // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
@@ -292,23 +292,17 @@ func.func @lift_illegal_transpose_to_memory_no_mask(%a: index, %b: index, %memre
 
 // -----
 
-// CHECK-LABEL: @lift_illegal_transpose_to_memory(
-// CHECK-SAME:                                    %[[INDEXA:[a-z0-9]+]]: index,
-// CHECK-SAME:                                    %[[INDEXB:[a-z0-9]+]]: index,
-// CHECK-SAME:                                    %[[DIM0:[a-z0-9]+]]: index,
-// CHECK-SAME:                                    %[[DIM1:[a-z0-9]+]]: index,
-// CHECK-SAME:                                    %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
-func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %dim0: index, %dim1: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
-  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-  // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
-  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
-  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
-  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
-  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
-  // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
-  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
-  // CHECK-NEXT: %[[LEGAL_READ:.*]]  = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]], %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_mask(
+// CHECK-SAME:                                              %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME:                                              %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME:                                              %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>
+func.func @lift_illegal_transpose_to_memory_with_mask(%dim0: index, %dim1: index, %memref: memref<?x?xf32>, %a: index, %b: index) -> vector<4x[8]xf32> {
+  // CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
+  // CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
+  // CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
+  // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
+  // CHECK:     %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
+  // CHECK-SAME:                       %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
   // CHECK-NEXT: return %[[LEGAL_READ]]
   %pad = arith.constant 0.0 : f32
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x4xi1>
@@ -320,19 +314,12 @@ func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %dim0: index,
 // -----
 
 // CHECK-LABEL: @lift_illegal_transpose_to_memory_with_arith_extop(
-// CHECK-SAME:                                                     %[[INDEXA:[a-z0-9]+]]: index,
-// CHECK-SAME:                                                     %[[INDEXB:[a-z0-9]+]]: index,
-// CHECK-SAME:                                                     %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>)
+// CHECK-SAME:                                                     %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>
 func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: index, %memref: memref<?x?xi8>) -> vector<4x[8]xi32> {
-  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-  // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
-  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
-  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
-  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xi8> to memref<?x4xi8, strided<[?, 1], offset: ?>>
-  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xi8, strided<[?, 1], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
-  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xi8, strided<[?, ?], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
-  // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_I8]] : memref<?x?xi8, strided<[?, ?], offset: ?>>, vector<4x[8]xi8>
+  // CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
+  // CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
+  // CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
+  // CHECK:     %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
   // CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32>
   // CHECK-NEXT: return %[[EXT_TYPE]]
   %pad = arith.constant 0 : i8
@@ -341,3 +328,16 @@ func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: inde
   %legalType = vector.transpose %extRead, [1, 0] : vector<[8]x4xi32> to vector<4x[8]xi32>
   return %legalType : vector<4x[8]xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_in_bounds_attr
+func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+  // CHECK: vector.transfer_read
+  // CHECK-SAME: in_bounds = [true, false]
+  // CHECK-NOT: in_bounds = [false, true]
+  %pad = arith.constant 0.0 : f32
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[8]x4xf32>
+  %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+  return %legalType : vector<4x[8]xf32>
+}



More information about the Mlir-commits mailing list