[Mlir-commits] [mlir] [mlir][vector] Add scalable lowering for `transfer_write(transpose)` (PR #101353)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 31 08:52:21 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
This specifically handles the case of a transpose from a vector type like `vector<8x[4]xf32>` to `vector<[4]x8xf32>`. Such transposes occur fairly frequently when scalably vectorizing `linalg.generic`s. There is no direct lowering for these (as types like `vector<[4]x8xf32>` cannot be represented in LLVM-IR). However, if the only use of the transpose is a write, then it is possible to lower the `transfer_write(transpose)` as a VLA loop.
Example:
```mlir
%transpose = vector.transpose %vec, [1, 0]
: vector<4x[4]xf32> to vector<[4]x4xf32>
vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
: vector<[4]x4xf32>, memref<?x?xf32>
```
Becomes:
```mlir
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
%1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
%2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
%3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
%vscale = vector.vscale
%c4_vscale = arith.muli %vscale, %c4 : index
scf.for %idx = %c0 to %c4_vscale step %c1 {
%4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
%5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
%6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
%7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
%slice_i = affine.apply #map(%idx)[%i]
%slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
: vector<4xf32>, memref<?x?xf32>
}
```
---
Patch is 23.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101353.diff
5 Files Affected:
- (modified) mlir/include/mlir/Conversion/Passes.td (+3-1)
- (modified) mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h (+8)
- (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+186-1)
- (modified) mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir (+14-1)
- (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+117-9)
``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b5bb2f42f2961..7bde9e490e4f4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1300,7 +1300,9 @@ def ConvertVectorToSCF : Pass<"convert-vector-to-scf"> {
Option<"targetRank", "target-rank", "unsigned", /*default=*/"1",
"Target vector rank to which transfer ops should be lowered">,
Option<"lowerTensors", "lower-tensors", "bool", /*default=*/"false",
- "Lower transfer ops that operate on tensors">
+ "Lower transfer ops that operate on tensors">,
+ Option<"lowerScalable", "lower-scalable", "bool", /*default=*/"false",
+ "Add scalable vector specific lowerings (that introduce loops)">
];
}
diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
index 1c834b6c69083..e0ef67c39a101 100644
--- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
+++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
@@ -69,6 +69,14 @@ struct VectorTransferToSCFOptions {
unroll = u;
return *this;
}
+ /// Enable scalable vector specific lowerings (which introduce loops). These
+ /// work alongside fullUnroll (which unrolls until the first scalable
+ /// dimension).
+ bool lowerScalable = false;
+ VectorTransferToSCFOptions enableLowerScalable(bool enable = true) {
+ lowerScalable = enable;
+ return *this;
+ }
};
/// Collect a set of patterns to convert from the Vector dialect to SCF + func.
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 19f02297bfbb7..c2fb8b6161ec3 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -24,6 +24,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Pass/Pass.h"
@@ -987,6 +988,185 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
}
};
+/// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
+/// and ConstantMaskOp.
+template <typename VscaleConstantBuilder>
+static FailureOr<SmallVector<OpFoldResult>>
+getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
+ if (!mask)
+ return SmallVector<OpFoldResult>{};
+ if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
+ return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
+ return OpFoldResult(dimSize);
+ });
+ }
+ if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
+ int dimIdx = 0;
+ VectorType maskType = constantMask.getVectorType();
+ auto indexType = IndexType::get(mask.getContext());
+ return llvm::map_to_vector(
+ constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
+ // A scalable dim in a constant_mask means vscale x dimSize.
+ if (maskType.getScalableDims()[dimIdx++])
+ return OpFoldResult(createVscaleMultiple(dimSize));
+ return OpFoldResult(IntegerAttr::get(indexType, dimSize));
+ });
+ }
+ return failure();
+}
+
+/// Scalable vector lowering of transfer_write(transpose). This lowering only
+/// supports rank 2 (scalable) vectors, but can be used in in conjunction with
+/// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
+/// unrolls until the first scalable dimension.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %transpose = vector.transpose %vec, [1, 0]
+/// : vector<4x[4]xf32> to vector<[4]x4xf32>
+/// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
+/// : vector<[4]x4xf32>, memref<?x?xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %c1 = arith.constant 1 : index
+/// %c4 = arith.constant 4 : index
+/// %c0 = arith.constant 0 : index
+/// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
+/// %vscale = vector.vscale
+/// %c4_vscale = arith.muli %vscale, %c4 : index
+/// scf.for %idx = %c0 to %c4_vscale step %c1 {
+/// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
+/// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
+/// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
+/// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
+/// %slice_i = affine.apply #map(%idx)[%i]
+/// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
+/// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
+/// : vector<4xf32>, memref<?x?xf32>
+/// }
+/// ```
+struct ScalableTransposeTransferWriteConversion
+ : VectorToSCFPattern<vector::TransferWriteOp> {
+ using VectorToSCFPattern::VectorToSCFPattern;
+
+ LogicalResult matchAndRewrite(TransferWriteOp writeOp,
+ PatternRewriter &rewriter) const override {
+ if (isTensorOp(writeOp) && !options.lowerTensors) {
+ return rewriter.notifyMatchFailure(
+ writeOp, "lowering tensor transfers is disabled");
+ }
+
+ auto vector = writeOp.getVector();
+ auto vectorType = vector.getType();
+ auto scalableFlags = vectorType.getScalableDims();
+ if (scalableFlags != ArrayRef<bool>{true, false}) {
+ return rewriter.notifyMatchFailure(
+ writeOp, "expected vector of form vector<[*]x*xty>");
+ }
+
+ auto permutationMap = writeOp.getPermutationMap();
+ if (!permutationMap.isIdentity()) {
+ return rewriter.notifyMatchFailure(
+ writeOp, "non-identity permutations are unsupported (lower first)");
+ }
+
+ if (!writeOp.isDimInBounds(0)) {
+ return rewriter.notifyMatchFailure(
+ writeOp, "out-of-bounds dims are unsupported (use masking)");
+ }
+
+ auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
+ if (!transposeOp ||
+ transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {
+ return rewriter.notifyMatchFailure(writeOp, "source not transpose");
+ }
+
+ auto loc = writeOp.getLoc();
+ auto createVscaleMultiple =
+ vector::makeVscaleConstantBuilder(rewriter, loc);
+
+ auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
+ if (failed(maskDims)) {
+ return rewriter.notifyMatchFailure(writeOp,
+ "failed to resolve mask dims");
+ }
+
+ int64_t fixedDimSize = vectorType.getDimSize(1);
+ auto fixedDimOffsets = llvm::seq(fixedDimSize);
+
+ // Extract all slices from the source of the transpose.
+ auto transposeSource = transposeOp.getVector();
+ SmallVector<Value> transposeSourceSlices =
+ llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
+ return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx);
+ });
+
+ // Loop bounds and step.
+ auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto ub =
+ maskDims->empty()
+ ? Value(createVscaleMultiple(vectorType.getDimSize(0)))
+ : vector::getAsValues(rewriter, loc, maskDims->front()).front();
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+ // Generate a new mask for the slice.
+ VectorType sliceType = VectorType::Builder(vectorType).dropDim(0);
+ Value sliceMask = nullptr;
+ if (!maskDims->empty()) {
+ sliceMask = rewriter.create<vector::CreateMaskOp>(
+ loc, sliceType.clone(rewriter.getI1Type()),
+ ArrayRef<OpFoldResult>(*maskDims).drop_front());
+ }
+
+ ValueRange initLoopArgs =
+ isTensorOp(writeOp) ? writeOp.getSource() : ValueRange{};
+ auto result = rewriter.create<scf::ForOp>(
+ loc, lb, ub, step, initLoopArgs,
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
+ // Indices for the new transfer op.
+ SmallVector<Value, 8> xferIndices;
+ getXferIndices(b, writeOp, iv, xferIndices);
+
+ // Extract a transposed slice from the source vector.
+ SmallVector<Value> transposeElements =
+ llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
+ return b.create<vector::ExtractOp>(
+ loc, transposeSourceSlices[idx], iv);
+ });
+ auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType,
+ transposeElements);
+
+ // Create the transfer_write for the slice.
+ Value dest =
+ loopIterArgs.empty() ? writeOp.getSource() : loopIterArgs.front();
+ auto newWriteOp = b.create<vector::TransferWriteOp>(
+ loc, sliceVec, dest, xferIndices,
+ ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
+ if (sliceMask)
+ newWriteOp.getMaskMutable().assign(sliceMask);
+
+ // Yield from the loop.
+ b.create<scf::YieldOp>(loc, loopIterArgs.empty()
+ ? ValueRange{}
+ : newWriteOp.getResult());
+ });
+
+ if (isTensorOp(writeOp))
+ rewriter.replaceOp(writeOp, result);
+ else
+ rewriter.eraseOp(writeOp);
+
+ return success();
+ }
+};
+
} // namespace lowering_n_d
namespace lowering_n_d_unrolled {
@@ -1503,7 +1683,10 @@ void mlir::populateVectorToSCFConversionPatterns(
lowering_n_d::TransferOpConversion<TransferWriteOp>>(
patterns.getContext(), options);
}
-
+ if (options.lowerScalable) {
+ patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
+ patterns.getContext(), options);
+ }
if (options.targetRank == 1) {
patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
@@ -1522,6 +1705,7 @@ struct ConvertVectorToSCFPass
this->fullUnroll = options.unroll;
this->targetRank = options.targetRank;
this->lowerTensors = options.lowerTensors;
+ this->lowerScalable = options.lowerScalable;
}
void runOnOperation() override {
@@ -1529,6 +1713,7 @@ struct ConvertVectorToSCFPass
options.unroll = fullUnroll;
options.targetRank = targetRank;
options.lowerTensors = lowerTensors;
+ options.lowerScalable = lowerScalable;
// Lower permutation maps first.
RewritePatternSet lowerTransferPatterns(&getContext());
diff --git a/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir
index dac8e018f845f..c542b79d5c80e 100644
--- a/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir
+++ b/mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{lower-tensors=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{lower-tensors=true lower-scalable=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s
// CHECK-LABEL: func @transfer_read_2d(
// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<vector<4x9xf32>>
@@ -36,3 +36,16 @@ func.func @transfer_write_2d(%A : tensor<?x?xf32>, %vec : vector<2x3xf32>,
return %t : tensor<?x?xf32>
}
+// -----
+
+// CHECK-LABEL: func @scalable_transpose_store
+// CHECK-SAME: %[[TENSOR:[a-z0-9]+]]: tensor<?x?xf32>
+// CHECK: %[[RESULT:.*]] = scf.for {{.*}} iter_args(%[[ITER_ARG:.*]] = %[[TENSOR]]) -> (tensor<?x?xf32>)
+// CHECK: %[[WRITE_SLICE:.*]] = vector.transfer_write %{{.*}} %[[ITER_ARG]]
+// CHECK: scf.yield %[[WRITE_SLICE]]
+// CHECK: return %[[RESULT]]
+func.func @scalable_transpose_store(%vec: vector<4x[4]xf32>, %dest: tensor<?x?xf32>, %i: index, %j: index) -> tensor<?x?xf32> {
+ %transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+ %result = vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x4xf32>, tensor<?x?xf32>
+ return %result : tensor<?x?xf32>
+}
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 3f4e70a6835af..f7ded02d83fe5 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf))" -split-input-file -allow-unregistered-dialect | FileCheck %s
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{full-unroll=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=FULL-UNROLL
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{full-unroll=true lower-scalable=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=FULL-UNROLL
// RUN: mlir-opt %s "-convert-vector-to-scf=full-unroll target-rank=0" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=TARGET-RANK-ZERO
// CHECK-LABEL: func @vector_transfer_ops_0d(
@@ -661,10 +661,10 @@ func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[
// CHECK: memref.store %[[MASK]], %[[ALLOCA_MASK]][] : memref<vector<3x[4]xi1>>
// CHECK: %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
// CHECK: %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
-// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
-// CHECK: %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
-// CHECK: %[[READ_SLICE:.*]] = vector.transfer_read %[[ARG]]{{\[}}%[[VAL_11]], %[[C0]]], %[[PADDING]], %[[MASK_SLICE]] {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32>
-// CHECK: memref.store %[[READ_SLICE]], %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
+// CHECK: scf.for %[[VSCALE:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK: %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xi1>>
+// CHECK: %[[READ_SLICE:.*]] = vector.transfer_read %[[ARG]]{{\[}}%[[VSCALE]], %[[C0]]], %[[PADDING]], %[[MASK_SLICE]] {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32>
+// CHECK: memref.store %[[READ_SLICE]], %[[UNPACK_VECTOR]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xf32>>
// CHECK: }
// CHECK: %[[RESULT:.*]] = memref.load %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
// CHECK: return %[[RESULT]] : vector<3x[4]xf32>
@@ -695,10 +695,10 @@ func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memr
// CHECK: memref.store %[[VEC]], %[[ALLOCA_VEC]][] : memref<vector<3x[4]xf32>>
// CHECK: %[[UNPACK_VECTOR:.*]] = vector.type_cast %[[ALLOCA_VEC]] : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
// CHECK: %[[UNPACK_MASK:.*]] = vector.type_cast %[[ALLOCA_MASK]] : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
-// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
-// CHECK: %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_VECTOR]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xf32>>
-// CHECK: %[[VECTOR_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VAL_11]]] : memref<3xvector<[4]xi1>>
-// CHECK: vector.transfer_write %[[MASK_SLICE]], %[[MEMREF]]{{\[}}%[[VAL_11]], %[[C0]]], %[[VECTOR_SLICE]] {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32>
+// CHECK: scf.for %[[VSCALE:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK: %[[MASK_SLICE:.*]] = memref.load %[[UNPACK_VECTOR]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xf32>>
+// CHECK: %[[VECTOR_SLICE:.*]] = memref.load %[[UNPACK_MASK]]{{\[}}%[[VSCALE]]] : memref<3xvector<[4]xi1>>
+// CHECK: vector.transfer_write %[[MASK_SLICE]], %[[MEMREF]]{{\[}}%[[VSCALE]], %[[C0]]], %[[VECTOR_SLICE]] {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32>
// CHECK: }
// CHECK: return
// CHECK: }
@@ -803,3 +803,111 @@ func.func @unroll_transfer_write_target_rank_zero(%vec : vector<2xi32>) {
// TARGET-RANK-ZERO: %[[EXTRACTED2:.*]] = vector.extract {{.*}} : i32 from vector<2xi32>
// TARGET-RANK-ZERO: %[[BROADCASTED2:.*]] = vector.broadcast %[[EXTRACTED2]] : i32 to vector<i32>
// TARGET-RANK-ZERO: vector.transfer_write %[[BROADCASTED2]], %[[ALLOC]]{{.*}} : vector<i32>, memref<4xi32>
+
+// -----
+
+func.func @scalable_transpose_store_unmasked(%vec: vector<4x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+ %transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+ vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x?xf32>
+ return
+}
+// FULL-UNROLL: #[[$SLICE_MAP:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// FULL-UNROLL-LABEL: func.func @scalable_transpose_store_unmasked(
+// FULL-UNROLL-SAME: %[[VEC:.*]]: vector<4x[4]xf32>,
+// FULL-UNROLL-SAME: %[[DEST:.*]]: memref<?x?xf32>,
+// FULL-UNROLL-SAME: %[[I:.*]]: index,
+// FULL-UNROLL-SAME: %[[J:.*]]: index)
+// FULL-UNROLL-DAG: %[[C0:.*]] = arith.constant 0 : index
+// FULL-UNROLL-DAG: %[[C1:.*]] = arith.constant 1 : index
+// FULL-UNROLL-DAG: %[[C4:.*]] = arith.constant 4 : index
+// FULL-UNROLL: %[[SLICE_0:.*]] = vector.extract %[[VEC]][0] : vector<[4]xf32> from vector<4x[4]xf32>
+// FULL-UNROLL: %[[SLICE_1:.*]] = vector.extract %[[VEC]][1] : vector<[4]xf32> from vector<4x[4]xf32>
+// FULL-UNROLL: %[[SLICE_2:.*]] = vector.extract %[[VEC]][2] : vector<[4]xf32> from vector<4x[4]xf32>
+// FULL-UNROLL: %[[SLICE_3:.*]] = vector.extract %[[VEC]][3] : vector<[4]xf32> from vector<4x[4]xf32>
+// FULL-UNROLL: %[[VSCALE:.*]] = vector.vscale
+// FULL-UNROLL: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+// FULL-UNROLL: scf.for %[[VAL_13:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
+// FULL-UNROLL: %[[SLICE_I:.*]] = affine.apply #[[$SLICE_MAP]](%[[VAL_13]]){{\[}}%[[I]]]
+// FULL-UNROLL: %[[ELEM_0:.*]] = vector.extract %[[SLICE_0]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
+// FULL-UNROLL: %[[ELEM_1:.*]] = vector.extract %[[SLICE_1]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
+// FULL-UNROLL: %[[ELEM_2:.*]] = vector.extract %[[SLICE_2]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
+// FULL-UNROLL: %[[ELEM_3:.*]] = vector.extract %[[SLICE_3]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32>
+// FULL-UNROLL: %[[TRANSPOSE_SLICE:.*]] = vector.from_elements %[[ELEM_0]], %[[ELEM_1]], %[[ELEM_2]], %[[ELEM_3]] : vector<4xf32>
+// FULL-UNROLL: vector.transfer_write %[[TRANSPOSE_SLICE]], %[[DEST]]{{\[}}%[[SLICE_I]], %[[J]]] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+// FULL-UNROLL: }
+// FULL-UNROLL: return
+
+// -----
+
+func.func @scalable_transpose_store_dynamic_mask(%vec: vector<4x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index, %a: index, %b: index) {
+ %transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+ %mask = vector.create_mask %a, %b : vector<[4]x4xi1>
+ vector.transfer_write %transpose, %dest[%i, %j], %mask {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x?xf32>
+ return
+}
+// FULL-UNROLL-LABEL: func.func @scalable_transpose_store_dynamic_mask(
+// FULL-UNR...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/101353
More information about the Mlir-commits
mailing list