[Mlir-commits] [mlir] [mlir][ArmSME] Add rewrite to lift illegal vector.transposes to memory (PR #80170)

Benjamin Maxwell llvmlistbot at llvm.org
Fri Feb 2 02:26:53 PST 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/80170

>From 84d68438dd373662ca66b6bbe62d12c80786db53 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 31 Jan 2024 17:33:38 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Add rewrite to lift illegal
 vector.transposes to memory

When unrolling the reduction dimension of something like a matmul for
SME, you can end up with transposed reads of illegal types, like so:

```mlir
%illegalRead = vector.transfer_read %memref[%a, %b]
                : memref<?x?xf32>, vector<[8]x4xf32>
%legalType = vector.transpose %illegalRead, [1, 0]
                : vector<[8]x4xf32> to vector<4x[8]xf32>
```

Here the `vector<[8]x4xf32>` is an illegal type, there's no way to
lower a scalable vector of fixed vectors. However, as the final type
`vector<4x[8]xf32>` is legal, we can instead lift the transpose to
memory (producing a strided memref), and eliminate all the illegal
types. This is shown below.

```mlir
%readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
                : memref<?x?xf32> to memref<?x?xf32>
%transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
                : memref<?x?xf32> to memref<?x?xf32>
%legalType = vector.transfer_read %transpose[%c0, %c0]
                : memref<?x?xf32>, vector<4x[8]xf32>
```
---
 .../ArmSME/Transforms/VectorLegalization.cpp  | 144 +++++++++++++++++-
 .../Dialect/ArmSME/vector-legalization.mlir   |  75 +++++++++
 2 files changed, 218 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 14b9d8e34da65..e88f82c92eba7 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
@@ -415,6 +416,146 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
   }
 };
 
+/// Lifts an illegal vector.transpose and vector.transfer_read to a
+/// memref.subview + memref.transpose, followed by a legal read.
+///
+/// 'Illegal' here means a leading scalable dimension and a fixed trailing
+/// dimension, which has no valid lowering.
+///
+/// The memref.transpose is metadata-only transpose that produces a strided
+/// memref, which eventually becomes a loop reading individual elements.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %illegalRead = vector.transfer_read %memref[%a, %b]
+///                  : memref<?x?xf32>, vector<[8]x4xf32>
+///  %legalType = vector.transpose %illegalRead, [1, 0]
+///                  : vector<[8]x4xf32> to vector<4x[8]xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
+///                  : memref<?x?xf32> to memref<?x?xf32>
+///  %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
+///                  : memref<?x?xf32> to memref<?x?xf32>
+///  %legalType = vector.transfer_read %transpose[%c0, %c0]
+///                  : memref<?x?xf32>, vector<4x[8]xf32>
+///  ```
+struct LiftIllegalVectorTransposeToMemory
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  static bool isIllegalVectorType(VectorType vType) {
+    bool seenFixedDim = false;
+    for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+      seenFixedDim |= !scalableFlag;
+      if (seenFixedDim && scalableFlag)
+        return true;
+    }
+    return false;
+  }
+
+  static Value getExtensionSource(Operation *op) {
+    if (isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
+      return op->getOperand(0);
+    return {};
+  }
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = transposeOp.getSourceVectorType();
+    auto resultType = transposeOp.getResultVectorType();
+    if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(
+          transposeOp, "expected transpose from illegal type to legal type");
+
+    // Look through extend for transfer_read.
+    Value maybeRead = transposeOp.getVector();
+    auto *transposeSourceOp = maybeRead.getDefiningOp();
+    Operation *extendOp = nullptr;
+    if (Value extendSource = getExtensionSource(transposeSourceOp)) {
+      maybeRead = extendSource;
+      extendOp = transposeSourceOp;
+    }
+
+    auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
+    if (!illegalRead)
+      return rewriter.notifyMatchFailure(
+          transposeOp,
+          "expected source to be (possibly extended) transfer_read");
+
+    if (!illegalRead.getPermutationMap().isIdentity())
+      return rewriter.notifyMatchFailure(
+          illegalRead, "expected read to have identity permutation map");
+
+    auto loc = transposeOp.getLoc();
+    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+    // Create a subview that matches the size of the illegal read vector type.
+    auto readType = illegalRead.getVectorType();
+    auto readSizes = llvm::map_to_vector(
+        llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
+        [&](auto dim) -> Value {
+          auto [size, isScalable] = dim;
+          auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
+          if (!isScalable)
+            return dimSize;
+          auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+          return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
+        });
+    SmallVector<Value> strides(readType.getRank(), Value(one));
+    auto readSubview = rewriter.create<memref::SubViewOp>(
+        loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
+        strides);
+
+    // Apply the transpose to all values/attributes of the transfer_read:
+    // - The mask
+    Value mask = illegalRead.getMask();
+    if (mask) {
+      // Note: The transpose for the mask should fold into the
+      // vector.create_mask/constant_mask op, which will then become legal.
+      mask = rewriter.create<vector::TransposeOp>(loc, mask,
+                                                  transposeOp.getPermutation());
+    }
+    // - The source memref
+    mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
+        transposeOp.getPermutation(), getContext());
+    auto transposedSubview = rewriter.create<memref::TransposeOp>(
+        loc, readSubview, AffineMapAttr::get(transposeMap));
+    ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
+    // - The `in_bounds` attribute
+    if (inBoundsAttr) {
+      SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
+                                            inBoundsAttr.end());
+      applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
+      inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
+    }
+
+    VectorType legalReadType = resultType.clone(readType.getElementType());
+    // Note: The indices are all zero as the subview is already offset.
+    SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
+    auto legalRead = rewriter.create<vector::TransferReadOp>(
+        loc, legalReadType, transposedSubview, readIndices,
+        illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
+        inBoundsAttr);
+
+    // Replace the transpose with the new read, extending the result if
+    // necessary.
+    rewriter.replaceOp(transposeOp, [&]() -> Operation * {
+      if (extendOp)
+        return rewriter.create(loc, extendOp->getName().getIdentifier(),
+                               Value(legalRead), resultType);
+      return legalRead;
+    }());
+
+    return success();
+  }
+};
+
 struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
@@ -434,7 +575,8 @@ struct VectorLegalizationPass
           return success();
         });
 
-    patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks>(context);
+    patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
+                 LiftIllegalVectorTransposeToMemory>(context);
     // Note: High benefit to ensure masked outer products are lowered first.
     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
         converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index a2526db9b4831..d3104c549229b 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -302,3 +302,78 @@ func.func @non_constant_extract_from_vector_create_mask_non_constant(%index: ind
   %extract = vector.extract %mask[%index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
   return %extract : vector<[4]x[4]xi1>
 }
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_no_mask(
+// CHECK-SAME:                                            %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                            %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                            %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory_no_mask(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
+  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[LEGAL_READ:.*]]  = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+  // CHECK-NEXT: return %[[LEGAL_READ]]
+  %pad = arith.constant 0.0 : f32
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xf32>, vector<[8]x4xf32>
+  %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+  return %legalType : vector<4x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory(
+// CHECK-SAME:                                    %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %dim0: index, %dim1: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
+  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
+  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[LEGAL_READ:.*]]  = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]], %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+  // CHECK-NEXT: return %[[LEGAL_READ]]
+  %pad = arith.constant 0.0 : f32
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x4xi1>
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad, %mask : memref<?x?xf32>, vector<[8]x4xf32>
+  %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+  return %legalType : vector<4x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_arith_extop(
+// CHECK-SAME:                                                     %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                     %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                     %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>)
+func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: index, %memref: memref<?x?xi8>) -> vector<4x[8]xi32> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xi8> to memref<?x4xi8, strided<[?, 1], offset: ?>>
+  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xi8, strided<[?, 1], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xi8, strided<[?, ?], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
+  // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_I8]] : memref<?x?xi8, strided<[?, ?], offset: ?>>, vector<4x[8]xi8>
+  // CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32>
+  // CHECK-NEXT: return %[[EXT_TYPE]]
+  %pad = arith.constant 0 : i8
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xi8>, vector<[8]x4xi8>
+  %extRead = arith.extsi %illegalRead : vector<[8]x4xi8> to vector<[8]x4xi32>
+  %legalType = vector.transpose %extRead, [1, 0] : vector<[8]x4xi32> to vector<4x[8]xi32>
+  return %legalType : vector<4x[8]xi32>
+}

>From 08dd0e49545a25fb0087a2232b6b05b73925aec9 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 1 Feb 2024 10:49:51 +0000
Subject: [PATCH 2/2] Fixups

---
 .../Dialect/ArmSME/vector-legalization.mlir   | 68 +++++++++----------
 1 file changed, 34 insertions(+), 34 deletions(-)

diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index d3104c549229b..11888c675f0b0 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -305,11 +305,11 @@ func.func @non_constant_extract_from_vector_create_mask_non_constant(%index: ind
 
 // -----
 
-// CHECK-LABEL: @lift_illegal_transpose_to_memory_no_mask(
-// CHECK-SAME:                                            %[[INDEXA:[a-z0-9]+]]: index,
-// CHECK-SAME:                                            %[[INDEXB:[a-z0-9]+]]: index,
-// CHECK-SAME:                                            %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
-func.func @lift_illegal_transpose_to_memory_no_mask(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+// CHECK-LABEL: @lift_illegal_transpose_to_memory(
+// CHECK-SAME:                                    %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
   // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
   // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
   // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
@@ -328,23 +328,17 @@ func.func @lift_illegal_transpose_to_memory_no_mask(%a: index, %b: index, %memre
 
 // -----
 
-// CHECK-LABEL: @lift_illegal_transpose_to_memory(
-// CHECK-SAME:                                    %[[INDEXA:[a-z0-9]+]]: index,
-// CHECK-SAME:                                    %[[INDEXB:[a-z0-9]+]]: index,
-// CHECK-SAME:                                    %[[DIM0:[a-z0-9]+]]: index,
-// CHECK-SAME:                                    %[[DIM1:[a-z0-9]+]]: index,
-// CHECK-SAME:                                    %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
-func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %dim0: index, %dim1: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
-  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-  // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
-  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
-  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
-  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
-  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
-  // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
-  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
-  // CHECK-NEXT: %[[LEGAL_READ:.*]]  = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]], %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_mask(
+// CHECK-SAME:                                              %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME:                                              %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME:                                              %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>
+func.func @lift_illegal_transpose_to_memory_with_mask(%dim0: index, %dim1: index, %memref: memref<?x?xf32>, %a: index, %b: index) -> vector<4x[8]xf32> {
+  // CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
+  // CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
+  // CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
+  // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
+  // CHECK:     %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
+  // CHECK-SAME:                       %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
   // CHECK-NEXT: return %[[LEGAL_READ]]
   %pad = arith.constant 0.0 : f32
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x4xi1>
@@ -356,19 +350,12 @@ func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %dim0: index,
 // -----
 
 // CHECK-LABEL: @lift_illegal_transpose_to_memory_with_arith_extop(
-// CHECK-SAME:                                                     %[[INDEXA:[a-z0-9]+]]: index,
-// CHECK-SAME:                                                     %[[INDEXB:[a-z0-9]+]]: index,
-// CHECK-SAME:                                                     %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>)
+// CHECK-SAME:                                                     %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>
 func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: index, %memref: memref<?x?xi8>) -> vector<4x[8]xi32> {
-  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-  // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
-  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
-  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
-  // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xi8> to memref<?x4xi8, strided<[?, 1], offset: ?>>
-  // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xi8, strided<[?, 1], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
-  // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xi8, strided<[?, ?], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
-  // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_I8]] : memref<?x?xi8, strided<[?, ?], offset: ?>>, vector<4x[8]xi8>
+  // CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
+  // CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
+  // CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
+  // CHECK:     %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
   // CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32>
   // CHECK-NEXT: return %[[EXT_TYPE]]
   %pad = arith.constant 0 : i8
@@ -377,3 +364,16 @@ func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: inde
   %legalType = vector.transpose %extRead, [1, 0] : vector<[8]x4xi32> to vector<4x[8]xi32>
   return %legalType : vector<4x[8]xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_in_bounds_attr
+func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+  // CHECK: vector.transfer_read
+  // CHECK-SAME: in_bounds = [true, false]
+  // CHECK-NOT: in_bounds = [false, true]
+  %pad = arith.constant 0.0 : f32
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[8]x4xf32>
+  %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+  return %legalType : vector<4x[8]xf32>
+}



More information about the Mlir-commits mailing list