[Mlir-commits] [mlir] [mlir][ArmSME] Add rewrite to lift illegal vector.transposes to memory (PR #80170)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Feb 1 02:51:09 PST 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/80170
>From b92efad93455ac13f0a66ab8ac85802ed6a84372 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 31 Jan 2024 17:33:38 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Add rewrite to lift illegal
vector.transposes to memory
When unrolling the reduction dimension of something like a matmul for
SME, you can end up with transposed reads of illegal types, like so:
```mlir
%illegalRead = vector.transfer_read %memref[%a, %b]
: memref<?x?xf32>, vector<[8]x4xf32>
%legalType = vector.transpose %illegalRead, [1, 0]
: vector<[8]x4xf32> to vector<4x[8]xf32>
```
Here the `vector<[8]x4xf32>` is an illegal type, there's no way to
lower a scalable vector of fixed vectors. However, as the final type
`vector<4x[8]xf32>` is legal, we can instead lift the transpose to
memory (producing a strided memref), and eliminate all the illegal
types. This is shown below.
```mlir
%readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
: memref<?x?xf32> to memref<?x?xf32>
%transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
: memref<?x?xf32> to memref<?x?xf32>
%legalType = vector.transfer_read %transpose[%c0, %c0]
: memref<?x?xf32>, vector<4x[8]xf32>
```
---
.../ArmSME/Transforms/VectorLegalization.cpp | 163 +++++++++++++++++-
.../Dialect/ArmSME/vector-legalization.mlir | 75 ++++++++
2 files changed, 235 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 85ec53c2618aa..a3db2d2395528 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -7,8 +7,6 @@
//===----------------------------------------------------------------------===//
//
// This pass legalizes vector operations so they can be lowered to ArmSME.
-// Currently, this only implements the decomposition of vector operations that
-// use vector sizes larger than an SME tile, into multiple SME-sized operations.
//
// Note: In the context of this pass 'tile' always refers to an SME tile.
//
@@ -19,6 +17,7 @@
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
@@ -35,6 +34,10 @@ using namespace mlir::arm_sme;
namespace {
+//===----------------------------------------------------------------------===//
+// Decomposition of vector operations larger than an SME tile
+//===----------------------------------------------------------------------===//
+
// Common match failure reasons.
static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
"op vector size is not multiple of SME tiles");
@@ -338,13 +341,166 @@ struct LegalizeTransferWriteOpsByDecomposition
}
};
+//===----------------------------------------------------------------------===//
+// ArmSME-specific fixup canonicalizations/folds
+//===----------------------------------------------------------------------===//
+
+/// Lifts an illegal vector.transpose and vector.transfer_read to a
+/// memref.subview + memref.transpose, followed by a legal read.
+///
+/// 'Illegal' here means a leading scalable dimension and a fixed trailing
+/// dimension, which has no valid lowering.
+///
+/// The memref.transpose is metadata-only transpose that produces a strided
+/// memref, which eventually becomes a loop reading individual elements.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %illegalRead = vector.transfer_read %memref[%a, %b]
+/// : memref<?x?xf32>, vector<[8]x4xf32>
+/// %legalType = vector.transpose %illegalRead, [1, 0]
+/// : vector<[8]x4xf32> to vector<4x[8]xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
+/// : memref<?x?xf32> to memref<?x?xf32>
+/// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
+/// : memref<?x?xf32> to memref<?x?xf32>
+/// %legalType = vector.transfer_read %transpose[%c0, %c0]
+/// : memref<?x?xf32>, vector<4x[8]xf32>
+/// ```
+struct LiftIllegalVectorTransposeToMemory
+ : public OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+ static bool isIllegalVectorType(VectorType vType) {
+ bool seenFixedDim = false;
+ for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+ seenFixedDim |= !scalableFlag;
+ if (seenFixedDim && scalableFlag)
+ return true;
+ }
+ return false;
+ }
+
+ static Value getExtensionSource(Operation *op) {
+ if (auto signExtend = dyn_cast<arith::ExtSIOp>(op))
+ return signExtend.getIn();
+ if (auto zeroExtend = dyn_cast<arith::ExtUIOp>(op))
+ return zeroExtend.getIn();
+ if (auto floatExtend = dyn_cast<arith::ExtFOp>(op))
+ return floatExtend.getIn();
+ return {};
+ }
+
+ LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = transposeOp.getSourceVectorType();
+ auto resultType = transposeOp.getResultVectorType();
+ if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
+ return rewriter.notifyMatchFailure(
+ transposeOp, "expected transpose from illegal type to legal type");
+
+ Value maybeRead = transposeOp.getVector();
+ auto *transposeSourceOp = maybeRead.getDefiningOp();
+ Operation *extendOp = nullptr;
+ if (Value extendSource = getExtensionSource(transposeSourceOp)) {
+ maybeRead = extendSource;
+ extendOp = transposeSourceOp;
+ }
+
+ auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
+ if (!illegalRead)
+ return rewriter.notifyMatchFailure(
+ transposeOp,
+ "expected source to be (possibility extended) transfer_read");
+
+ if (!illegalRead.getPermutationMap().isIdentity())
+ return rewriter.notifyMatchFailure(
+ illegalRead, "expected read to have identity permutation map");
+
+ auto loc = transposeOp.getLoc();
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+ // Create a subview that matches the size of the illegal read vector type.
+ auto readType = illegalRead.getVectorType();
+ auto readSizes = llvm::map_to_vector(
+ llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
+ [&](auto dim) -> Value {
+ auto [size, isScalable] = dim;
+ auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
+ if (!isScalable)
+ return dimSize;
+ auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
+ });
+ SmallVector<Value> strides(readType.getRank(), Value(one));
+ auto readSubview = rewriter.create<memref::SubViewOp>(
+ loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
+ strides);
+
+ // Apply the transpose to all values/attributes of the transfer_read.
+ // The mask.
+ Value mask = illegalRead.getMask();
+ if (mask) {
+ // Note: The transpose for the mask should fold into the
+ // vector.create_mask/constant_mask op, which will then become legal.
+ mask = rewriter.create<vector::TransposeOp>(loc, mask,
+ transposeOp.getPermutation());
+ }
+ // The source memref.
+ mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
+ transposeOp.getPermutation(), getContext());
+ auto transposedSubview = rewriter.create<memref::TransposeOp>(
+ loc, readSubview, AffineMapAttr::get(transposeMap));
+ ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
+ // The `in_bounds` attribute.
+ if (inBoundsAttr) {
+ SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
+ inBoundsAttr.end());
+ applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
+ inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
+ }
+
+ VectorType legalReadType =
+ VectorType::Builder(resultType)
+ .setElementType(illegalRead.getVectorType().getElementType());
+ // Note: The indices are all zero as the subview is already offset.
+ SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
+ Value legalRead = rewriter.create<vector::TransferReadOp>(
+ loc, legalReadType, transposedSubview, readIndices,
+ illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
+ inBoundsAttr);
+
+ // Replace the transpose with the new read, extending the result if
+ // necessary.
+ rewriter.replaceOp(transposeOp, [&]() -> Value {
+ if (!extendOp)
+ return legalRead;
+ if (isa<arith::ExtSIOp>(extendOp))
+ return rewriter.create<arith::ExtSIOp>(loc, resultType, legalRead);
+ if (isa<arith::ExtUIOp>(extendOp))
+ return rewriter.create<arith::ExtUIOp>(loc, resultType, legalRead);
+ if (isa<arith::ExtFOp>(extendOp))
+ return rewriter.create<arith::ExtFOp>(loc, resultType, legalRead);
+ return legalRead;
+ }());
+
+ return success();
+ }
+};
+
struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
auto *context = &getContext();
OneToNTypeConverter converter;
RewritePatternSet patterns(context);
-
converter.addConversion([](Type type) { return type; });
converter.addConversion(
[](VectorType vectorType,
@@ -358,6 +514,7 @@ struct VectorLegalizationPass
return success();
});
+ patterns.add<LiftIllegalVectorTransposeToMemory>(context);
// Note: High benefit to ensure masked outer products are lowered first.
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index a20abeefedcfd..2317930d3d061 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -266,3 +266,78 @@ func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: m
vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_no_mask(
+// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory_no_mask(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+ // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+ // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
+ // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+ // CHECK-NEXT: return %[[LEGAL_READ]]
+ %pad = arith.constant 0.0 : f32
+ %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xf32>, vector<[8]x4xf32>
+ %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+ return %legalType : vector<4x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory(
+// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %dim0: index, %dim1: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+ // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+ // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
+ // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
+ // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]], %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+ // CHECK-NEXT: return %[[LEGAL_READ]]
+ %pad = arith.constant 0.0 : f32
+ %mask = vector.create_mask %dim0, %dim1 : vector<[8]x4xi1>
+ %illegalRead = vector.transfer_read %memref[%a, %b], %pad, %mask : memref<?x?xf32>, vector<[8]x4xf32>
+ %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+ return %legalType : vector<4x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_arith_extop(
+// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>)
+func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: index, %memref: memref<?x?xi8>) -> vector<4x[8]xi32> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+ // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+ // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xi8> to memref<?x4xi8, strided<[?, 1], offset: ?>>
+ // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xi8, strided<[?, 1], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
+ // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xi8, strided<[?, ?], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
+ // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_I8]] : memref<?x?xi8, strided<[?, ?], offset: ?>>, vector<4x[8]xi8>
+ // CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32>
+ // CHECK-NEXT: return %[[EXT_TYPE]]
+ %pad = arith.constant 0 : i8
+ %illegalRead = vector.transfer_read %memref[%a, %b], %pad : memref<?x?xi8>, vector<[8]x4xi8>
+ %extRead = arith.extsi %illegalRead : vector<[8]x4xi8> to vector<[8]x4xi32>
+ %legalType = vector.transpose %extRead, [1, 0] : vector<[8]x4xi32> to vector<4x[8]xi32>
+ return %legalType : vector<4x[8]xi32>
+}
>From ab8c3d740e8fd04b998081be66e13bda81f7a758 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 1 Feb 2024 10:49:51 +0000
Subject: [PATCH 2/2] Fixups
---
.../ArmSME/Transforms/VectorLegalization.cpp | 34 ++++------
.../Dialect/ArmSME/vector-legalization.mlir | 68 +++++++++----------
2 files changed, 47 insertions(+), 55 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index a3db2d2395528..f960578e69ec6 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -388,12 +388,8 @@ struct LiftIllegalVectorTransposeToMemory
}
static Value getExtensionSource(Operation *op) {
- if (auto signExtend = dyn_cast<arith::ExtSIOp>(op))
- return signExtend.getIn();
- if (auto zeroExtend = dyn_cast<arith::ExtUIOp>(op))
- return zeroExtend.getIn();
- if (auto floatExtend = dyn_cast<arith::ExtFOp>(op))
- return floatExtend.getIn();
+ if (isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
+ return op->getOperand(0);
return {};
}
@@ -405,6 +401,7 @@ struct LiftIllegalVectorTransposeToMemory
return rewriter.notifyMatchFailure(
transposeOp, "expected transpose from illegal type to legal type");
+ // Look through extend for transfer_read.
Value maybeRead = transposeOp.getVector();
auto *transposeSourceOp = maybeRead.getDefiningOp();
Operation *extendOp = nullptr;
@@ -417,7 +414,7 @@ struct LiftIllegalVectorTransposeToMemory
if (!illegalRead)
return rewriter.notifyMatchFailure(
transposeOp,
- "expected source to be (possibility extended) transfer_read");
+ "expected source to be (possibly extended) transfer_read");
if (!illegalRead.getPermutationMap().isIdentity())
return rewriter.notifyMatchFailure(
@@ -444,8 +441,8 @@ struct LiftIllegalVectorTransposeToMemory
loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
strides);
- // Apply the transpose to all values/attributes of the transfer_read.
- // The mask.
+ // Apply the transpose to all values/attributes of the transfer_read:
+ // - The mask
Value mask = illegalRead.getMask();
if (mask) {
// Note: The transpose for the mask should fold into the
@@ -453,13 +450,13 @@ struct LiftIllegalVectorTransposeToMemory
mask = rewriter.create<vector::TransposeOp>(loc, mask,
transposeOp.getPermutation());
}
- // The source memref.
+ // - The source memref
mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
transposeOp.getPermutation(), getContext());
auto transposedSubview = rewriter.create<memref::TransposeOp>(
loc, readSubview, AffineMapAttr::get(transposeMap));
ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
- // The `in_bounds` attribute.
+ // - The `in_bounds` attribute
if (inBoundsAttr) {
SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
inBoundsAttr.end());
@@ -472,22 +469,17 @@ struct LiftIllegalVectorTransposeToMemory
.setElementType(illegalRead.getVectorType().getElementType());
// Note: The indices are all zero as the subview is already offset.
SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
- Value legalRead = rewriter.create<vector::TransferReadOp>(
+ auto legalRead = rewriter.create<vector::TransferReadOp>(
loc, legalReadType, transposedSubview, readIndices,
illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
inBoundsAttr);
// Replace the transpose with the new read, extending the result if
// necessary.
- rewriter.replaceOp(transposeOp, [&]() -> Value {
- if (!extendOp)
- return legalRead;
- if (isa<arith::ExtSIOp>(extendOp))
- return rewriter.create<arith::ExtSIOp>(loc, resultType, legalRead);
- if (isa<arith::ExtUIOp>(extendOp))
- return rewriter.create<arith::ExtUIOp>(loc, resultType, legalRead);
- if (isa<arith::ExtFOp>(extendOp))
- return rewriter.create<arith::ExtFOp>(loc, resultType, legalRead);
+ rewriter.replaceOp(transposeOp, [&]() -> Operation * {
+ if (extendOp)
+ return rewriter.create(loc, extendOp->getName().getIdentifier(),
+ Value(legalRead), resultType);
return legalRead;
}());
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 2317930d3d061..31ec7c5e97232 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -269,11 +269,11 @@ func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: m
// -----
-// CHECK-LABEL: @lift_illegal_transpose_to_memory_no_mask(
-// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index,
-// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index,
-// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
-func.func @lift_illegal_transpose_to_memory_no_mask(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+// CHECK-LABEL: @lift_illegal_transpose_to_memory(
+// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
+func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
@@ -292,23 +292,17 @@ func.func @lift_illegal_transpose_to_memory_no_mask(%a: index, %b: index, %memre
// -----
-// CHECK-LABEL: @lift_illegal_transpose_to_memory(
-// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index,
-// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index,
-// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index,
-// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index,
-// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>)
-func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %dim0: index, %dim1: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
- // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
- // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
- // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
- // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xf32> to memref<?x4xf32, strided<[?, 1], offset: ?>>
- // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
- // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
- // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
- // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_F32]], %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_mask(
+// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xf32>
+func.func @lift_illegal_transpose_to_memory_with_mask(%dim0: index, %dim1: index, %memref: memref<?x?xf32>, %a: index, %b: index) -> vector<4x[8]xf32> {
+ // CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
+ // CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
+ // CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
+ // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM1]], %[[DIM0]] : vector<4x[8]xi1>
+ // CHECK: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
+ // CHECK-SAME: %[[MASK]] : memref<?x?xf32, strided<[?, ?], offset: ?>>, vector<4x[8]xf32>
// CHECK-NEXT: return %[[LEGAL_READ]]
%pad = arith.constant 0.0 : f32
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x4xi1>
@@ -320,19 +314,12 @@ func.func @lift_illegal_transpose_to_memory(%a: index, %b: index, %dim0: index,
// -----
// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_arith_extop(
-// CHECK-SAME: %[[INDEXA:[a-z0-9]+]]: index,
-// CHECK-SAME: %[[INDEXB:[a-z0-9]+]]: index,
-// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>)
+// CHECK-SAME: %[[MEMREF:[a-z0-9]+]]: memref<?x?xi8>
func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: index, %memref: memref<?x?xi8>) -> vector<4x[8]xi32> {
- // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
- // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
- // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
- // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
- // CHECK-NEXT: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]][%[[INDEXA]], %[[INDEXB]]] [%[[C8_VSCALE]], 4] [1, 1] : memref<?x?xi8> to memref<?x4xi8, strided<[?, 1], offset: ?>>
- // CHECK-NEXT: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]] : memref<?x4xi8, strided<[?, 1], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
- // CHECK-NEXT: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]] (d0, d1) -> (d1, d0) : memref<?x?xi8, strided<[?, ?], offset: ?>> to memref<?x?xi8, strided<[?, ?], offset: ?>>
- // CHECK-NEXT: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]][%c0, %c0], %[[C0_I8]] : memref<?x?xi8, strided<[?, ?], offset: ?>>, vector<4x[8]xi8>
+ // CHECK-DAG: %[[READ_SUBVIEW:.*]] = memref.subview %[[MEMREF]]
+ // CHECK-DAG: %[[CAST:.*]] = memref.cast %[[READ_SUBVIEW]]
+ // CHECK-DAG: %[[TRANSPOSE:.*]] = memref.transpose %[[CAST]]
+ // CHECK: %[[LEGAL_READ:.*]] = vector.transfer_read %[[TRANSPOSE]]
// CHECK-NEXT: %[[EXT_TYPE:.*]] = arith.extsi %[[LEGAL_READ]] : vector<4x[8]xi8> to vector<4x[8]xi32>
// CHECK-NEXT: return %[[EXT_TYPE]]
%pad = arith.constant 0 : i8
@@ -341,3 +328,16 @@ func.func @lift_illegal_transpose_to_memory_with_arith_extop(%a: index, %b: inde
%legalType = vector.transpose %extRead, [1, 0] : vector<[8]x4xi32> to vector<4x[8]xi32>
return %legalType : vector<4x[8]xi32>
}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_transpose_to_memory_with_in_bounds_attr
+func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<4x[8]xf32> {
+ // CHECK: vector.transfer_read
+ // CHECK-SAME: in_bounds = [true, false]
+ // CHECK-NOT: in_bounds = [false, true]
+ %pad = arith.constant 0.0 : f32
+ %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[8]x4xf32>
+ %legalType = vector.transpose %illegalRead, [1, 0] : vector<[8]x4xf32> to vector<4x[8]xf32>
+ return %legalType : vector<4x[8]xf32>
+}
More information about the Mlir-commits
mailing list