[Mlir-commits] [mlir] [mlir] Add direct vectorization lowering for `tensor.pack` ops (PR #78660)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 1 15:16:19 PST 2024
https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/78660
>From 2fee1229aa88f412c47290a00e5d48753491642e Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 5 Jan 2024 13:50:50 -0500
Subject: [PATCH 1/7] [mlir] Add vectorization support for tensor.pack
---
.../TransformOps/LinalgTransformOps.cpp | 2 +-
.../Linalg/Transforms/Vectorization.cpp | 151 ++++++++++++++++++
2 files changed, 152 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 6431bbd25396a..585fd14b40d76 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3152,7 +3152,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
- if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
+ if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0707625819d1a..f42e85c68f84b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -19,10 +19,14 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
@@ -30,7 +34,9 @@
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#include <type_traits>
@@ -1393,6 +1399,121 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
+/// Given a tensor::PackOp, return the permutation from the "tiled"
+/// shape to the "packed" shape, defined as the following:
+/// The "packed" shape is the same as the `dest` shape of the pack op.
+/// The "tiled" shape is a permutation of the `dest` shape such that
+/// each outer dimension is in the original `source` order, and the
+/// inner_tile dimensions immediately follow their corresponding outer
+/// dimension.
+/// i.e. for the following tensor.pack:
+/// ```mlir
+/// %pack = tensor.pack %0 padding_value(%1)
+/// outer_dims_perm = [0, 2, 1]
+/// inner_dims_pos = [2, 1]
+/// inner_tiles = [16, 2]
+/// into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
+/// ```
+/// The "packed" shape is `32x1x4x16x2`
+/// The "tiled" shape is `32x(4x2)x(1x16)`
+static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
+ auto innerTiles = packOp.getInnerTiles();
+ int64_t srcRank = packOp.getSourceRank();
+ auto innerDimsPos = packOp.getInnerDimsPos();
+ if (innerDimsPos.empty())
+ innerDimsPos = to_vector(llvm::seq<int64_t>(innerTiles.size()));
+ auto outerDimsPerm = packOp.getOuterDimsPerm();
+ if (outerDimsPerm.empty())
+ outerDimsPerm = to_vector(llvm::seq<int64_t>(srcRank));
+ auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
+ int64_t srcIdx;
+ if (idx >= srcRank)
+ srcIdx = innerDimsPos[idx - srcRank];
+ else
+ srcIdx = outerDimsPerm[idx];
+ int64_t tiledIdx = srcIdx;
+ for (int64_t pos : innerDimsPos)
+ if (pos < srcIdx)
+ tiledIdx++;
+ if (idx >= srcRank)
+ tiledIdx++;
+ return tiledIdx;
+ };
+ SmallVector<int64_t> perm;
+ for (int i = 0; i < packOp.getDestRank(); i++)
+ perm.push_back(packedIdxToTiledIdx(i));
+ return perm;
+}
+
+/// Given a tensor::PackOp, return the "tiled" `dest` shape as described
+/// above in `getTiledShapeToPackedShapePerm`.
+static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
+ auto perm = getTiledShapeToPackedShapePerm(packOp);
+ auto destShape = packOp.getDestType().getShape();
+ return applyPermutation(destShape, invertPermutationVector(perm));
+}
+
+///
+static LogicalResult
+vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ auto padValue = packOp.getPaddingValue();
+ Location loc = packOp.getLoc();
+ int64_t inputRank = inputVectorSizes.size();
+ int64_t outputRank = packOp.getDestRank();
+ auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
+ auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(packOp);
+
+ ReifiedRankedShapedTypeDims reifiedReturnShapes;
+ LogicalResult status =
+ cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
+ .reifyResultShapes(rewriter, reifiedReturnShapes);
+ (void)status; // prevent unused variable warning on non-assert builds
+ assert(succeeded(status) && "failed to reify result shapes");
+ auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
+ padValue.getType());
+ SmallVector<OpFoldResult> mixedSourceDims =
+ tensor::getMixedSizes(rewriter, loc, packOp.getSource());
+ Value mask =
+ rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto transferReadOp = rewriter.create<vector::TransferReadOp>(
+ loc,
+ /*vectorType=*/vectorType,
+ /*source=*/packOp.getSource(),
+ /*indices=*/SmallVector<Value>(inputRank, zero),
+ /*padding=*/padValue,
+ /*inBounds=*/SmallVector<bool>(inputRank, true));
+ auto maskedOp = cast<vector::MaskOp>(
+ mlir::vector::maskOperation(rewriter, transferReadOp, mask));
+ // ShapeCast
+ auto tiledPackShape = getTiledPackShape(packOp);
+ auto tiledPackType = VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
+ auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedOp->getResult(0));
+ auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
+ auto transposeOp = rewriter.create<vector::TransposeOp>(loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
+ Operation *write = rewriter.create<vector::TransferWriteOp>(
+ loc,
+ /*vector=*/transposeOp->getResult(0),
+ /*source=*/emptyOp,
+ /*indices=*/SmallVector<Value>(outputRank, zero),
+ /*inBounds=*/SmallVector<bool>(outputRank, true));
+ // bool needMaskForWrite = llvm::any_of(
+ // llvm::zip_equal(inputVectorSizes, packOp.getResultType().getShape()),
+ // [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+ // if (needMaskForWrite) {
+ // Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
+ // loc, maskType, reifiedReturnShapes[0]);
+ // write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
+ // }
+ newResults.push_back(write->getResult(0));
+ return success();
+}
+
/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
/// and (3) all-zero lowPad to
/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
@@ -1585,6 +1706,30 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
return success();
}
+static LogicalResult
+vectorizePackOpPrecondition(tensor::PackOp packOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ auto padValue = packOp.getPaddingValue();
+ if (!padValue) {
+ LDBG("pad value is not constant: " << packOp << "\n");
+ return failure();
+ }
+
+ ArrayRef<int64_t> resultTensorShape = packOp.getSourceType().getShape();
+ if (failed(isValidMaskedInputVector(resultTensorShape, inputVectorSizes)))
+ return failure();
+
+ if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
+ std::optional<int64_t> res = getConstantIntValue(v);
+ return !res.has_value();
+ })) {
+ LDBG("inner_tiles must be constant: " << packOp << "\n");
+ return failure();
+ }
+
+ return success();
+}
+
static LogicalResult
vectorizePadOpPrecondition(tensor::PadOp padOp,
ArrayRef<int64_t> inputVectorSizes) {
@@ -1644,6 +1789,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::PadOp>([&](auto padOp) {
return vectorizePadOpPrecondition(padOp, inputVectorSizes);
})
+ .Case<tensor::PackOp>([&](auto packOp) {
+ return vectorizePackOpPrecondition(packOp, inputVectorSizes);
+ })
.Default([](auto) { return failure(); });
}
@@ -1732,6 +1880,9 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
results);
})
+ .Case<tensor::PackOp>([&](auto packOp) {
+ return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, results);
+ })
.Default([](auto) { return failure(); });
if (failed(vectorizeResult)) {
>From 0b12182bc2881b8ec30ed666cc0a4957f0ea709f Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 18 Jan 2024 20:12:13 -0500
Subject: [PATCH 2/7] Support pack with no padding value
---
.../Linalg/Transforms/Vectorization.cpp | 22 ++++++++-----------
1 file changed, 9 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f42e85c68f84b..d0e3b7f4e8028 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1458,16 +1458,20 @@ static LogicalResult
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
- auto padValue = packOp.getPaddingValue();
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(packOp);
+
Location loc = packOp.getLoc();
+ auto padValue = packOp.getPaddingValue();
+ if (!padValue) {
+ padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
+ }
int64_t inputRank = inputVectorSizes.size();
int64_t outputRank = packOp.getDestRank();
auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(packOp);
-
ReifiedRankedShapedTypeDims reifiedReturnShapes;
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
@@ -1502,14 +1506,6 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
/*source=*/emptyOp,
/*indices=*/SmallVector<Value>(outputRank, zero),
/*inBounds=*/SmallVector<bool>(outputRank, true));
- // bool needMaskForWrite = llvm::any_of(
- // llvm::zip_equal(inputVectorSizes, packOp.getResultType().getShape()),
- // [](auto it) { return std::get<0>(it) != std::get<1>(it); });
- // if (needMaskForWrite) {
- // Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
- // loc, maskType, reifiedReturnShapes[0]);
- // write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
- // }
newResults.push_back(write->getResult(0));
return success();
}
@@ -1710,7 +1706,7 @@ static LogicalResult
vectorizePackOpPrecondition(tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
- if (!padValue) {
+ if (padValue && getConstantIntValue(padValue) != std::nullopt) {
LDBG("pad value is not constant: " << packOp << "\n");
return failure();
}
>From b9d35372349ccae8da8d6d7da68f8607e46fc1a9 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 18 Jan 2024 21:11:49 -0500
Subject: [PATCH 3/7] add tests
---
mlir/test/Dialect/Linalg/vectorization.mlir | 61 +++++++++++++++++++++
1 file changed, 61 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index d5fb0cbb9c723..af1c1337224fa 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -501,6 +501,67 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<4x1x16x2xf32>) -> tensor<4x1x16x2xf32> {
+ %pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor<?x?xf32> -> tensor<4x1x16x2xf32>
+ return %pack : tensor<4x1x16x2xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x16x2xf32>
+// CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<8x16xi1>
+// CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
+// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[cst]]
+// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
+// CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
+// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<8x16xf32> to vector<4x2x1x16xf32>
+// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
+// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
+// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<4x1x16x2xf32>
+// CHECK: return %[[write]] : tensor<4x1x16x2xf32>
+
+// -----
+
+func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+ %pack = tensor.pack %arg0 inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
+ return %pack : tensor<32x4x1x16x2xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [32, 8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
+// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
+// CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c8]], %[[c16]] : vector<32x8x16xi1>
+// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
+// CHECK-SAME: {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32>
+// CHECK-SAME: } : vector<32x8x16xi1> -> vector<32x8x16xf32>
+// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]]
+// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
+
+// -----
+
func.func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
outs(%C: memref<?x?xf32>)
>From 7d4f716e1a62ebce0435ee730f736e0d2b5476ad Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 18 Jan 2024 21:26:08 -0500
Subject: [PATCH 4/7] clang
---
.../Linalg/Transforms/Vectorization.cpp | 35 +++++++++++--------
1 file changed, 20 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d0e3b7f4e8028..37829fbeb79f7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1408,15 +1408,16 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
/// dimension.
/// i.e. for the following tensor.pack:
/// ```mlir
-/// %pack = tensor.pack %0 padding_value(%1)
-/// outer_dims_perm = [0, 2, 1]
-/// inner_dims_pos = [2, 1]
-/// inner_tiles = [16, 2]
+/// %pack = tensor.pack %0 padding_value(%1)
+/// outer_dims_perm = [0, 2, 1]
+/// inner_dims_pos = [2, 1]
+/// inner_tiles = [16, 2]
/// into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
/// ```
/// The "packed" shape is `32x1x4x16x2`
/// The "tiled" shape is `32x(4x2)x(1x16)`
-static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
+static SmallVector<int64_t>
+getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
auto innerTiles = packOp.getInnerTiles();
int64_t srcRank = packOp.getSourceRank();
auto innerDimsPos = packOp.getInnerDimsPos();
@@ -1425,7 +1426,7 @@ static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp
auto outerDimsPerm = packOp.getOuterDimsPerm();
if (outerDimsPerm.empty())
outerDimsPerm = to_vector(llvm::seq<int64_t>(srcRank));
- auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
+ auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
int64_t srcIdx;
if (idx >= srcRank)
srcIdx = innerDimsPos[idx - srcRank];
@@ -1440,7 +1441,7 @@ static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp
return tiledIdx;
};
SmallVector<int64_t> perm;
- for (int i = 0; i < packOp.getDestRank(); i++)
+ for (int i = 0; i < packOp.getDestRank(); i++)
perm.push_back(packedIdxToTiledIdx(i));
return perm;
}
@@ -1453,11 +1454,11 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
return applyPermutation(destShape, invertPermutationVector(perm));
}
-///
+///
static LogicalResult
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
- ArrayRef<int64_t> inputVectorSizes,
- SmallVectorImpl<Value> &newResults) {
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);
@@ -1496,10 +1497,13 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
mlir::vector::maskOperation(rewriter, transferReadOp, mask));
// ShapeCast
auto tiledPackShape = getTiledPackShape(packOp);
- auto tiledPackType = VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
- auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedOp->getResult(0));
+ auto tiledPackType =
+ VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
+ auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(
+ loc, tiledPackType, maskedOp->getResult(0));
auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
- auto transposeOp = rewriter.create<vector::TransposeOp>(loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
+ auto transposeOp = rewriter.create<vector::TransposeOp>(
+ loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
Operation *write = rewriter.create<vector::TransferWriteOp>(
loc,
/*vector=*/transposeOp->getResult(0),
@@ -1704,7 +1708,7 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
static LogicalResult
vectorizePackOpPrecondition(tensor::PackOp packOp,
- ArrayRef<int64_t> inputVectorSizes) {
+ ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
if (padValue && getConstantIntValue(padValue) != std::nullopt) {
LDBG("pad value is not constant: " << packOp << "\n");
@@ -1877,7 +1881,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
results);
})
.Case<tensor::PackOp>([&](auto packOp) {
- return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, results);
+ return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
+ results);
})
.Default([](auto) { return failure(); });
>From 70d37052e34374bc16f84eb05d7e7e2d21110c6e Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 19 Jan 2024 16:25:58 -0500
Subject: [PATCH 5/7] Use result shape pack vector sizes, clean up
---
.../Linalg/Transforms/Vectorization.cpp | 175 +++++++++++-------
mlir/test/Dialect/Linalg/vectorization.mlir | 26 +--
2 files changed, 119 insertions(+), 82 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 37829fbeb79f7..78c8a62933324 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -24,6 +24,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
@@ -1454,7 +1455,73 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
return applyPermutation(destShape, invertPermutationVector(perm));
}
-///
+/// Create a masked TransferReadOp from `source` with shape `readShape`.
+static vector::MaskOp createMaskedTransferRead(OpBuilder &builder, Location loc,
+ Value source,
+ ArrayRef<int64_t> readShape,
+ Value padValue) {
+ auto maskType = VectorType::get(readShape, builder.getI1Type());
+ auto vectorType = VectorType::get(readShape, padValue.getType());
+ SmallVector<OpFoldResult> mixedSourceDims =
+ tensor::getMixedSizes(builder, loc, source);
+ Value mask =
+ builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+ auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ int64_t readRank = readShape.size();
+ auto transferReadOp = builder.create<vector::TransferReadOp>(
+ loc,
+ /*vectorType=*/vectorType,
+ /*source=*/source,
+ /*indices=*/SmallVector<Value>(readRank, zero),
+ /*padding=*/padValue,
+ /*inBounds=*/SmallVector<bool>(readRank, true));
+ return cast<vector::MaskOp>(
+ mlir::vector::maskOperation(builder, transferReadOp, mask));
+}
+
+/// Given an input, the mixed destSizes, and the vector sizes for vectorization,
+/// create an empty destination tensor and create a TransferWriteOp from the
+/// input to the empty tensor. If the destination shape is not the same as the
+/// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
+/// mask for the write.
+static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
+ Value input,
+ SmallVector<OpFoldResult> destSizes,
+ ArrayRef<int64_t> inputVectorSizes) {
+ auto inputType = cast<VectorType>(input.getType());
+ Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
+ inputType.getElementType());
+ int64_t rank = cast<ShapedType>(dest.getType()).getRank();
+ auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Operation *write = builder.create<vector::TransferWriteOp>(
+ loc,
+ /*vector=*/input,
+ /*source=*/dest,
+ /*indices=*/SmallVector<Value>(rank, zero),
+ /*inBounds=*/SmallVector<bool>(rank, true));
+ auto destShape = cast<ShapedType>(dest.getType()).getShape();
+ bool needMaskForWrite =
+ llvm::any_of(llvm::zip(inputVectorSizes, destShape),
+ [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+ if (needMaskForWrite) {
+ SmallVector<int64_t> writeMaskShape;
+ writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
+ writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
+ destShape.end());
+ auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
+ Value maskForWrite =
+ builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
+ write = mlir::vector::maskOperation(builder, write, maskForWrite);
+ }
+ return write;
+}
+
+/// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
+/// padding value into
+/// transfer_write_in_bounds(
+/// transpose(
+/// shape_cast(
+/// transfer_read_masked(pack_source, pad_value))))
static LogicalResult
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes,
@@ -1468,48 +1535,41 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
padValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
}
- int64_t inputRank = inputVectorSizes.size();
- int64_t outputRank = packOp.getDestRank();
- auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
- auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
-
ReifiedRankedShapedTypeDims reifiedReturnShapes;
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
.reifyResultShapes(rewriter, reifiedReturnShapes);
(void)status; // prevent unused variable warning on non-assert builds
assert(succeeded(status) && "failed to reify result shapes");
- auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
- padValue.getType());
- SmallVector<OpFoldResult> mixedSourceDims =
- tensor::getMixedSizes(rewriter, loc, packOp.getSource());
- Value mask =
- rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
- auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto transferReadOp = rewriter.create<vector::TransferReadOp>(
- loc,
- /*vectorType=*/vectorType,
- /*source=*/packOp.getSource(),
- /*indices=*/SmallVector<Value>(inputRank, zero),
- /*padding=*/padValue,
- /*inBounds=*/SmallVector<bool>(inputRank, true));
- auto maskedOp = cast<vector::MaskOp>(
- mlir::vector::maskOperation(rewriter, transferReadOp, mask));
- // ShapeCast
- auto tiledPackShape = getTiledPackShape(packOp);
- auto tiledPackType =
- VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
+
+ // Create masked TransferReadOp
+ SmallVector<int64_t> inputShape(inputVectorSizes);
+ auto innerTiles = packOp.getStaticInnerTiles();
+ auto innerDimsPos = packOp.getInnerDimsPos();
+ auto outerDimsPerm = packOp.getOuterDimsPerm();
+ if (!outerDimsPerm.empty())
+ applyPermutationToVector(inputShape,
+ invertPermutationVector(outerDimsPerm));
+ for (auto [idx, size] : enumerate(innerTiles))
+ inputShape[innerDimsPos[idx]] *= size;
+ auto maskedOp = createMaskedTransferRead(rewriter, loc, packOp.getSource(),
+ inputShape, padValue);
+
+ // Create ShapeCastOp
+ auto tiledPackType = VectorType::get(getTiledPackShape(packOp),
+ packOp.getDestType().getElementType());
auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(
loc, tiledPackType, maskedOp->getResult(0));
+
+ // Create TransposeOp
auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
auto transposeOp = rewriter.create<vector::TransposeOp>(
- loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
- Operation *write = rewriter.create<vector::TransferWriteOp>(
- loc,
- /*vector=*/transposeOp->getResult(0),
- /*source=*/emptyOp,
- /*indices=*/SmallVector<Value>(outputRank, zero),
- /*inBounds=*/SmallVector<bool>(outputRank, true));
+ loc, shapeCastOp.getResult(), tiledShapeToPackedShapePerm);
+
+ // Create TransferWriteOp
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
+ reifiedReturnShapes[0], inputVectorSizes);
newResults.push_back(write->getResult(0));
return success();
}
@@ -1523,9 +1583,6 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
SmallVectorImpl<Value> &newResults) {
auto padValue = padOp.getConstantPaddingValue();
Location loc = padOp.getLoc();
- int64_t rank = inputVectorSizes.size();
- auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
- auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
// transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
OpBuilder::InsertionGuard g(rewriter);
@@ -1537,36 +1594,11 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
.reifyResultShapes(rewriter, reifiedReturnShapes);
(void)status; // prevent unused variable warning on non-assert builds
assert(succeeded(status) && "failed to reify result shapes");
- auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
- padValue.getType());
- SmallVector<OpFoldResult> mixedSourceDims =
- tensor::getMixedSizes(rewriter, loc, padOp.getSource());
- Value mask =
- rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
- auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto transferReadOp = rewriter.create<vector::TransferReadOp>(
- loc,
- /*vectorType=*/vectorType,
- /*source=*/padOp.getSource(),
- /*indices=*/SmallVector<Value>(rank, zero),
- /*padding=*/padValue,
- /*inBounds=*/SmallVector<bool>(rank, true));
- auto maskedOp = cast<vector::MaskOp>(
- mlir::vector::maskOperation(rewriter, transferReadOp, mask));
- Operation *write = rewriter.create<vector::TransferWriteOp>(
- loc,
- /*vector=*/maskedOp->getResult(0),
- /*source=*/emptyOp,
- /*indices=*/SmallVector<Value>(rank, zero),
- /*inBounds=*/SmallVector<bool>(rank, true));
- bool needMaskForWrite = llvm::any_of(
- llvm::zip_equal(inputVectorSizes, padOp.getResultType().getShape()),
- [](auto it) { return std::get<0>(it) != std::get<1>(it); });
- if (needMaskForWrite) {
- Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
- loc, maskType, reifiedReturnShapes[0]);
- write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
- }
+ auto maskedOp = createMaskedTransferRead(rewriter, loc, padOp.getSource(),
+ inputVectorSizes, padValue);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, maskedOp->getResult(0),
+ reifiedReturnShapes[0], inputVectorSizes);
newResults.push_back(write->getResult(0));
return success();
}
@@ -1710,18 +1742,19 @@ static LogicalResult
vectorizePackOpPrecondition(tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
- if (padValue && getConstantIntValue(padValue) != std::nullopt) {
+ if (padValue && !getConstantIntValue(padValue).has_value()) {
LDBG("pad value is not constant: " << packOp << "\n");
return failure();
}
- ArrayRef<int64_t> resultTensorShape = packOp.getSourceType().getShape();
- if (failed(isValidMaskedInputVector(resultTensorShape, inputVectorSizes)))
+ ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
+ if (failed(isValidMaskedInputVector(
+ resultTensorShape.take_front(packOp.getSourceRank()),
+ inputVectorSizes)))
return failure();
if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
- std::optional<int64_t> res = getConstantIntValue(v);
- return !res.has_value();
+ return !getConstantIntValue(v).has_value();
})) {
LDBG("inner_tiles must be constant: " << packOp << "\n");
return failure();
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index af1c1337224fa..ed9a8eb9183bd 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -426,7 +426,6 @@ func.func @test_masked_vectorize_pad(
{
// CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32>
// CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
// CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
// CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
@@ -435,7 +434,9 @@ func.func @test_masked_vectorize_pad(
// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]]
// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
// CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
- // CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_2]], %[[c0_2]]]
+ // CHECK-DAG: %[[c0_3:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32>
+ // CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_3]], %[[c0_3]]]
// CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<2x4xf32>
%cst = arith.constant 42.43 : f32
%c0 = arith.constant 0 : index
@@ -467,7 +468,6 @@ func.func @test_masked_vectorize_dynamic_pad(
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[res_d0:.+]] = affine.apply #[[MAP]]()
// CHECK-DAG: %[[res_d1:.+]] = affine.apply #[[MAP]]()
- // CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[res_d0]], %[[res_d1]]) : tensor<?x?xf32>
// CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
// CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
// CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
@@ -476,9 +476,11 @@ func.func @test_masked_vectorize_dynamic_pad(
// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]]
// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
// CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
+ // CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[res_d0]], %[[res_d1]]) : tensor<?x?xf32>
+ // CHECK-DAG: %[[c0_3:.*]] = arith.constant 0 : index
// CHECK: %[[mask_2:.*]] = vector.create_mask %[[res_d0]], %[[res_d1]] : vector<2x4xi1>
// CHECK: %[[masked_write:.*]] = vector.mask %[[mask_2]] {
- // CHECK-SAME: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_2]], %[[c0_2]]]
+ // CHECK-SAME: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_3]], %[[c0_3]]]
// CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<?x?xf32>
// CHECK: return %[[masked_write]] : tensor<?x?xf32>
%cst = arith.constant 42.43 : f32
@@ -508,7 +510,7 @@ func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<4x1
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [8, 16] : !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
transform.yield
}
}
@@ -517,15 +519,16 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor<?x?xf32>
// CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor<?x?xf32>
-// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x16x2xf32>
// CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<8x16xi1>
-// CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
-// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[cst]]
+// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_1]], %[[c0_1]]], %[[cst]]
// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
// CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<8x16xf32> to vector<4x2x1x16xf32>
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
+// CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x16x2xf32>
// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<4x1x16x2xf32>
// CHECK: return %[[write]] : tensor<4x1x16x2xf32>
@@ -539,7 +542,7 @@ func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<32x4x1x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [32, 8, 16] : !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [32, 4, 1] : !transform.any_op
transform.yield
}
}
@@ -547,7 +550,6 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
-// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
// CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c8]], %[[c16]] : vector<32x8x16xi1>
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
@@ -556,7 +558,9 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: } : vector<32x8x16xi1> -> vector<32x8x16xf32>
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
-// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]]
+// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
+// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
>From c776290f312705feb9d7c17356bfc01ba6086a01 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 19 Jan 2024 17:25:50 -0500
Subject: [PATCH 6/7] add dynamic test
---
.../Linalg/Transforms/Vectorization.cpp | 10 +++--
mlir/test/Dialect/Linalg/vectorization.mlir | 40 +++++++++++++++++++
2 files changed, 46 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 78c8a62933324..2961e8cbee7a1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1442,16 +1442,16 @@ getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
return tiledIdx;
};
SmallVector<int64_t> perm;
- for (int i = 0; i < packOp.getDestRank(); i++)
+ for (size_t i = 0; i < packOp.getDestRank(); i++)
perm.push_back(packedIdxToTiledIdx(i));
return perm;
}
/// Given a tensor::PackOp, return the "tiled" `dest` shape as described
/// above in `getTiledShapeToPackedShapePerm`.
-static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
+static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
+ ArrayRef<int64_t> destShape) {
auto perm = getTiledShapeToPackedShapePerm(packOp);
- auto destShape = packOp.getDestType().getShape();
return applyPermutation(destShape, invertPermutationVector(perm));
}
@@ -1556,7 +1556,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
inputShape, padValue);
// Create ShapeCastOp
- auto tiledPackType = VectorType::get(getTiledPackShape(packOp),
+ SmallVector<int64_t> destShape(inputVectorSizes);
+ destShape.append(innerTiles.begin(), innerTiles.end());
+ auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
packOp.getDestType().getElementType());
auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(
loc, tiledPackType, maskedOp->getResult(0));
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index ed9a8eb9183bd..d9546f6da38a3 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -566,6 +566,46 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @test_vectorize_dynamic_result_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
+ %pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
+ return %pack : tensor<?x?x16x2xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+ transform.yield
+ }
+}
+// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor<?x?x16x2xf32>
+// CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor<?x?x16x2xf32>
+// CHECK-DAG: %[[c0_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1_0:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[d0_0:.*]] = tensor.dim {{.*}} %[[c0_0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[d1_0:.*]] = tensor.dim {{.*}} %[[c1_0]] : tensor<?x?xf32>
+// CHECK: %[[mask:.*]] = vector.create_mask %[[d0_0]], %[[d1_0]] : vector<8x16xi1>
+// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
+// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_1]], %[[c0_1]]], %[[cst]]
+// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
+// CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
+// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<8x16xf32> to vector<4x2x1x16xf32>
+// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
+// CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[d0]], %[[d1]]) : tensor<?x?x16x2xf32>
+// CHECK: %[[mask_0:.*]] = vector.create_mask %[[d0]], %[[d1]], %[[c16]], %[[c2]] : vector<4x1x16x2xi1>
+// CHECK: %[[masked_write:.*]] = vector.mask %[[mask_0]] {
+// CHECK-SAME: vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
+// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<?x?x16x2xf32>
+// CHECK: return %[[masked_write]] : tensor<?x?x16x2xf32>
+
+// -----
+
func.func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
outs(%C: memref<?x?xf32>)
>From 9acc090c0069d8a09b640a1ae107cd521339dddd Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 1 Feb 2024 18:14:20 -0500
Subject: [PATCH 7/7] address comments
---
.../include/mlir/Dialect/Tensor/Utils/Utils.h | 8 +
.../Dialect/Linalg/Transforms/Transforms.cpp | 36 +---
.../Linalg/Transforms/Vectorization.cpp | 160 ++++++++----------
mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 29 ++++
mlir/test/Dialect/Linalg/vectorization.mlir | 95 +++++------
5 files changed, 164 insertions(+), 164 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index 04b4de4a33a52..fe9b16cb44b3d 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -32,6 +32,14 @@ FailureOr<RankedTensorType>
computeTransposedType(RankedTensorType rankedTensorType,
ArrayRef<int64_t> transposeVector);
+/// Given a tensor::PackOp, compute the permutation vector to shuffle the
+/// packed shape into the shape before any outer or inner permutations have
+/// been applied.
+/// i.e. for a pack from an ABCD layout to an ABCDba:
+/// The packed shape would be ABCDba.
+/// The pre-permutation shape would be AaBbCD.
+SmallVector<int64_t> getPackInverseDestPermutation(PackOp packOp);
+
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
/// source tensor or inserts the source tensor into a destination tensor with
/// the same shape.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 02bc3e672bf7a..596b7c50c1e4e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -233,31 +233,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
rewriter.setInsertionPoint(packOp);
// 2. Compute the permutation vector to shuffle packed shape into the shape
- // before any outer or inner permutations have been applied. The permutation
- // can be obtained from two permutations:
- // a) Compute the permutation vector to move the last `numPackedDims` into
- // the `innerPosDims` of a shape of rank `packedRank`.
- // b) Compute the permutation vector to move outer dims if the pack op
- // has outer_dims_perm.
- // Apply (b) permutation on (a) permutation to get the final permutation.
- int64_t numPackedDims = packOp.getInnerDimsPos().size();
- int64_t packedRank = packedTensorType.getRank();
- auto lastDims = llvm::to_vector(
- llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
+ // before any outer or inner permutations have been applied.
PackingMetadata packingMetadata = computePackingMetadata(
packedTensorType.getRank(), packOp.getInnerDimsPos());
- SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
- packedRank, lastDims, packingMetadata.insertPositions);
-
- SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
- ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
- if (!outerPerm.empty())
- applyPermutationToVector(outerPos, outerPerm);
- SmallVector<int64_t> outerPositionPerm = computePermutationVector(
- packedRank, packingMetadata.outerPositions, outerPos);
-
- SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
- applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
+ SmallVector<int64_t> packedToStripMinedShapePerm =
+ tensor::getPackInverseDestPermutation(packOp);
// 3. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
@@ -304,10 +284,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
DBGS() << "packedShape: ");
DBGSNL();
- llvm::interleaveComma(outerPositionPerm, DBGS() << "outerPositionPerm: ");
- DBGSNL(); llvm::interleaveComma(innerPositionsPerm,
- DBGS() << "innerPositionsPerm: ");
- DBGSNL();
llvm::interleaveComma(packedToStripMinedShapePerm,
DBGS() << "packedToStripMinedShapePerm: ");
DBGSNL(); llvm::interleaveComma(
@@ -332,9 +308,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
auto emptyOp =
rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
// Offsets.
- SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
+ rewriter.getIndexAttr(0));
// Strides.
- SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> ones(packOp.getDestRank(),
+ rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2961e8cbee7a1..7e7de846d9954 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -1400,74 +1401,26 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
-/// Given a tensor::PackOp, return the permutation from the "tiled"
-/// shape to the "packed" shape, defined as the following:
-/// The "packed" shape is the same as the `dest` shape of the pack op.
-/// The "tiled" shape is a permutation of the `dest` shape such that
-/// each outer dimension is in the original `source` order, and the
-/// inner_tile dimensions immediately follow their corresponding outer
-/// dimension.
-/// i.e. for the following tensor.pack:
-/// ```mlir
-/// %pack = tensor.pack %0 padding_value(%1)
-/// outer_dims_perm = [0, 2, 1]
-/// inner_dims_pos = [2, 1]
-/// inner_tiles = [16, 2]
-/// into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
-/// ```
-/// The "packed" shape is `32x1x4x16x2`
-/// The "tiled" shape is `32x(4x2)x(1x16)`
-static SmallVector<int64_t>
-getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
- auto innerTiles = packOp.getInnerTiles();
- int64_t srcRank = packOp.getSourceRank();
- auto innerDimsPos = packOp.getInnerDimsPos();
- if (innerDimsPos.empty())
- innerDimsPos = to_vector(llvm::seq<int64_t>(innerTiles.size()));
- auto outerDimsPerm = packOp.getOuterDimsPerm();
- if (outerDimsPerm.empty())
- outerDimsPerm = to_vector(llvm::seq<int64_t>(srcRank));
- auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
- int64_t srcIdx;
- if (idx >= srcRank)
- srcIdx = innerDimsPos[idx - srcRank];
- else
- srcIdx = outerDimsPerm[idx];
- int64_t tiledIdx = srcIdx;
- for (int64_t pos : innerDimsPos)
- if (pos < srcIdx)
- tiledIdx++;
- if (idx >= srcRank)
- tiledIdx++;
- return tiledIdx;
- };
- SmallVector<int64_t> perm;
- for (size_t i = 0; i < packOp.getDestRank(); i++)
- perm.push_back(packedIdxToTiledIdx(i));
- return perm;
-}
-
-/// Given a tensor::PackOp, return the "tiled" `dest` shape as described
-/// above in `getTiledShapeToPackedShapePerm`.
+/// Given a tensor::PackOp, return the `dest` shape before any packing
+/// permutations.
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
ArrayRef<int64_t> destShape) {
- auto perm = getTiledShapeToPackedShapePerm(packOp);
- return applyPermutation(destShape, invertPermutationVector(perm));
+ return applyPermutation(destShape,
+ tensor::getPackInverseDestPermutation(packOp));
}
-/// Create a masked TransferReadOp from `source` with shape `readShape`.
-static vector::MaskOp createMaskedTransferRead(OpBuilder &builder, Location loc,
- Value source,
- ArrayRef<int64_t> readShape,
- Value padValue) {
+/// Create a TransferReadOp from `source` with static shape `readShape`. If the
+/// vector type for the read is not the same as the type of `source`, then a
+/// mask is created on the read.
+static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
+ Value source, ArrayRef<int64_t> readShape,
+ Value padValue) {
+ assert(llvm::none_of(readShape,
+ [](int64_t s) { return s == ShapedType::kDynamic; }));
auto maskType = VectorType::get(readShape, builder.getI1Type());
auto vectorType = VectorType::get(readShape, padValue.getType());
- SmallVector<OpFoldResult> mixedSourceDims =
- tensor::getMixedSizes(builder, loc, source);
- Value mask =
- builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
- auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
int64_t readRank = readShape.size();
+ auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
auto transferReadOp = builder.create<vector::TransferReadOp>(
loc,
/*vectorType=*/vectorType,
@@ -1475,8 +1428,20 @@ static vector::MaskOp createMaskedTransferRead(OpBuilder &builder, Location loc,
/*indices=*/SmallVector<Value>(readRank, zero),
/*padding=*/padValue,
/*inBounds=*/SmallVector<bool>(readRank, true));
- return cast<vector::MaskOp>(
- mlir::vector::maskOperation(builder, transferReadOp, mask));
+ auto sourceShape = llvm::dyn_cast<ShapedType>(source.getType()).getShape();
+ if (sourceShape.size() == readShape.size() &&
+ llvm::all_of(llvm::zip_equal(readShape, sourceShape), [](auto it) {
+ return std::get<0>(it) != ShapedType::kDynamic &&
+ std::get<0>(it) == std::get<1>(it);
+ })) {
+ return transferReadOp;
+ }
+ SmallVector<OpFoldResult> mixedSourceDims =
+ tensor::getMixedSizes(builder, loc, source);
+ Value mask =
+ builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+ return mlir::vector::maskOperation(builder, transferReadOp, mask)
+ ->getResult(0);
}
/// Given an input, the mixed destSizes, and the vector sizes for vectorization,
@@ -1500,9 +1465,14 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
/*indices=*/SmallVector<Value>(rank, zero),
/*inBounds=*/SmallVector<bool>(rank, true));
auto destShape = cast<ShapedType>(dest.getType()).getShape();
- bool needMaskForWrite =
- llvm::any_of(llvm::zip(inputVectorSizes, destShape),
- [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+ assert(llvm::none_of(
+ destShape.drop_front(inputVectorSizes.size()),
+ [](int64_t size) { return size == ShapedType::kDynamic; }) &&
+ "Only dims aligned with inputVectorSizes may be dynamic");
+ bool needMaskForWrite = llvm::any_of(
+ llvm::zip_equal(inputVectorSizes,
+ destShape.take_front(inputVectorSizes.size())),
+ [](auto it) { return std::get<0>(it) != std::get<1>(it); });
if (needMaskForWrite) {
SmallVector<int64_t> writeMaskShape;
writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
@@ -1517,11 +1487,28 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
}
/// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
-/// padding value into
-/// transfer_write_in_bounds(
-/// transpose(
-/// shape_cast(
-/// transfer_read_masked(pack_source, pad_value))))
+/// padding value into:
+/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
+/// As in the following example:
+/// ```mlir
+/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
+/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
+/// ```
+/// This pack would be vectorized to:
+/// ```mlir
+/// %load = vector.mask %mask {
+/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
+/// {in_bounds = [true, true, true]} :
+/// tensor<32x7x16xf32>, vector<32x8x16xf32>
+/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
+/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
+/// to vector<32x4x2x1x16xf32>
+/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
+/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+/// %write = vector.transfer_write %transpose,
+/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
+/// {in_bounds = [true, true, true, true, true]}
+/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
static LogicalResult
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes,
@@ -1539,10 +1526,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
.reifyResultShapes(rewriter, reifiedReturnShapes);
- (void)status; // prevent unused variable warning on non-assert builds
+ (void)status; // prevent unused variable warning on non-assert builds.
assert(succeeded(status) && "failed to reify result shapes");
- // Create masked TransferReadOp
+ // Create masked TransferReadOp.
SmallVector<int64_t> inputShape(inputVectorSizes);
auto innerTiles = packOp.getStaticInnerTiles();
auto innerDimsPos = packOp.getInnerDimsPos();
@@ -1552,23 +1539,24 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
invertPermutationVector(outerDimsPerm));
for (auto [idx, size] : enumerate(innerTiles))
inputShape[innerDimsPos[idx]] *= size;
- auto maskedOp = createMaskedTransferRead(rewriter, loc, packOp.getSource(),
+ auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(),
inputShape, padValue);
- // Create ShapeCastOp
+ // Create ShapeCastOp.
SmallVector<int64_t> destShape(inputVectorSizes);
destShape.append(innerTiles.begin(), innerTiles.end());
auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
packOp.getDestType().getElementType());
- auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(
- loc, tiledPackType, maskedOp->getResult(0));
+ auto shapeCastOp =
+ rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
- // Create TransposeOp
- auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
+ // Create TransposeOp.
+ auto destPermutation =
+ invertPermutationVector(tensor::getPackInverseDestPermutation(packOp));
auto transposeOp = rewriter.create<vector::TransposeOp>(
- loc, shapeCastOp.getResult(), tiledShapeToPackedShapePerm);
+ loc, shapeCastOp.getResult(), destPermutation);
- // Create TransferWriteOp
+ // Create TransferWriteOp.
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
reifiedReturnShapes[0], inputVectorSizes);
@@ -1596,11 +1584,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
.reifyResultShapes(rewriter, reifiedReturnShapes);
(void)status; // prevent unused variable warning on non-assert builds
assert(succeeded(status) && "failed to reify result shapes");
- auto maskedOp = createMaskedTransferRead(rewriter, loc, padOp.getSource(),
+ auto maskedRead = createReadOrMaskedRead(rewriter, loc, padOp.getSource(),
inputVectorSizes, padValue);
- Operation *write =
- createWriteOrMaskedWrite(rewriter, loc, maskedOp->getResult(0),
- reifiedReturnShapes[0], inputVectorSizes);
+ Operation *write = createWriteOrMaskedWrite(
+ rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes);
newResults.push_back(write->getResult(0));
return success();
}
@@ -1740,11 +1727,12 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
return success();
}
+/// TODO: Use a matcher to check for a constant padding value.
static LogicalResult
vectorizePackOpPrecondition(tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
- if (padValue && !getConstantIntValue(padValue).has_value()) {
+ if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
LDBG("pad value is not constant: " << packOp << "\n");
return failure();
}
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 24cbceb3d1179..f20008a1ed2b2 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -73,6 +73,35 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
return transposedTensorType;
}
+SmallVector<int64_t>
+mlir::tensor::getPackInverseDestPermutation(PackOp packOp) {
+ // The permutation can be obtained from two permutations:
+ // a) Compute the permutation vector to move the last `numPackedDims` into
+ // the `innerPosDims` of a shape of rank `packedRank`.
+ // b) Compute the permutation vector to move outer dims if the pack op
+ // has outer_dims_perm.
+ // Apply (b) permutation on (a) permutation to get the final permutation.
+ int64_t numPackedDims = packOp.getInnerDimsPos().size();
+ int64_t packedRank = packOp.getDestType().getRank();
+ auto lastDims = llvm::to_vector(
+ llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
+ PackingMetadata packingMetadata = computePackingMetadata(
+ packOp.getDestType().getRank(), packOp.getInnerDimsPos());
+ SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
+ packedRank, lastDims, packingMetadata.insertPositions);
+
+ SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
+ ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
+ if (!outerPerm.empty())
+ applyPermutationToVector(outerPos, outerPerm);
+ SmallVector<int64_t> outerPositionPerm = computePermutationVector(
+ packedRank, packingMetadata.outerPositions, outerPos);
+
+ SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
+ applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
+ return packInverseDestPermutation;
+}
+
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
llvm::SmallBitVector droppedDims = op.getDroppedDims();
int64_t srcDim = 0;
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index d9546f6da38a3..5d1bef478ee98 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -426,17 +426,17 @@ func.func @test_masked_vectorize_pad(
{
// CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[c0_0:.*]] = arith.constant 0 : index
// CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
// CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
// CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
- // CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
- // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]]
+ // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_0]], %[[c0_0]]], %[[c42]]
// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
// CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
- // CHECK-DAG: %[[c0_3:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32>
- // CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_3]], %[[c0_3]]]
+ // CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_1]], %[[c0_1]]]
// CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<2x4xf32>
%cst = arith.constant 42.43 : f32
%c0 = arith.constant 0 : index
@@ -468,10 +468,10 @@ func.func @test_masked_vectorize_dynamic_pad(
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[res_d0:.+]] = affine.apply #[[MAP]]()
// CHECK-DAG: %[[res_d1:.+]] = affine.apply #[[MAP]]()
+ // CHECK: %[[c0_2:.*]] = arith.constant 0 : index
// CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
// CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
// CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
- // CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[c42]]
// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
@@ -503,58 +503,46 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<4x1x16x2xf32>) -> tensor<4x1x16x2xf32> {
- %pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor<?x?xf32> -> tensor<4x1x16x2xf32>
- return %pack : tensor<4x1x16x2xf32>
+func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
+ %pack = tensor.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32>
+ return %pack : tensor<4x1x32x16x2xf32>
}
+// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[read:.*]] = vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
+// CHECK-SAME: {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32>
+// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [1, 3, 0, 4, 2] : vector<32x4x2x1x16xf32> to vector<4x1x32x16x2xf32>
+// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x32x16x2xf32>
+// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
+// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<4x1x32x16x2xf32>, tensor<4x1x32x16x2xf32>
+// CHECK: return %[[write]] : tensor<4x1x32x16x2xf32>
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [4, 1, 32] : !transform.any_op
transform.yield
}
}
-// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor<?x?xf32>
-// CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor<?x?xf32>
-// CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<8x16xi1>
-// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
-// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
-// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_1]], %[[c0_1]]], %[[cst]]
-// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
-// CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
-// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<8x16xf32> to vector<4x2x1x16xf32>
-// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
-// CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x16x2xf32>
-// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
-// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<4x1x16x2xf32>
-// CHECK: return %[[write]] : tensor<4x1x16x2xf32>
// -----
-func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
- %pack = tensor.pack %arg0 inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
+func.func @test_vectorize_padded_pack(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+ %pad = arith.constant 0.000000e+00 : f32
+ %pack = tensor.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
return %pack : tensor<32x4x1x16x2xf32>
}
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [32, 4, 1] : !transform.any_op
- transform.yield
- }
-}
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
-// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
-// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
-// CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c8]], %[[c16]] : vector<32x8x16xi1>
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
+// CHECK-DAG: %[[c7:.*]] = arith.constant 7 : index
+// CHECK-DAG: %[[c15:.*]] = arith.constant 15 : index
+// CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c7]], %[[c15]] : vector<32x8x16xi1>
// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
-// CHECK-SAME: {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32>
+// CHECK-SAME: {in_bounds = [true, true, true]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
// CHECK-SAME: } : vector<32x8x16xi1> -> vector<32x8x16xf32>
// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
@@ -564,30 +552,31 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
-// -----
-
-func.func @test_vectorize_dynamic_result_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
- %pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
- return %pack : tensor<?x?x16x2xf32>
-}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [32, 4, 1] : !transform.any_op
transform.yield
}
}
+
+// -----
+
+func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
+ %pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
+ return %pack : tensor<?x?x16x2xf32>
+}
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor<?x?x16x2xf32>
// CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor<?x?x16x2xf32>
+// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c0_0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1_0:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[d0_0:.*]] = tensor.dim {{.*}} %[[c0_0]] : tensor<?x?xf32>
// CHECK-DAG: %[[d1_0:.*]] = tensor.dim {{.*}} %[[c1_0]] : tensor<?x?xf32>
// CHECK: %[[mask:.*]] = vector.create_mask %[[d0_0]], %[[d1_0]] : vector<8x16xi1>
-// CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_1]], %[[c0_1]]], %[[cst]]
// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
@@ -604,6 +593,14 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<?x?x16x2xf32>
// CHECK: return %[[masked_write]] : tensor<?x?x16x2xf32>
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
+ transform.yield
+ }
+}
+
// -----
func.func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
More information about the Mlir-commits
mailing list