[Mlir-commits] [mlir] [mlir] Add direct vectorization lowering for `tensor.pack` ops (PR #78660)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 18 18:24:06 PST 2024
https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/78660
>From 55fadd51d1bddb5e28f88b14a657583ed494ed3e Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 5 Jan 2024 13:50:50 -0500
Subject: [PATCH 1/3] [mlir] Add vectorization support for tensor.pack
---
.../TransformOps/LinalgTransformOps.cpp | 2 +-
.../Linalg/Transforms/Vectorization.cpp | 151 ++++++++++++++++++
2 files changed, 152 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 140bdd1f2db3618..6f6abf56acfd96b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3123,7 +3123,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
- if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
+ if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0610f24ddaf471a..1507eceac8f0b28 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -19,10 +19,14 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
@@ -30,7 +34,9 @@
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#include <type_traits>
@@ -1393,6 +1399,121 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
+/// Given a tensor::PackOp, return the permutation from the "tiled"
+/// shape to the "packed" shape, defined as the following:
+/// The "packed" shape is the same as the `dest` shape of the pack op.
+/// The "tiled" shape is a permutation of the `dest` shape such that
+/// each outer dimension is in the original `source` order, and the
+/// inner_tile dimensions immediately follow their corresponding outer
+/// dimension.
+/// i.e. for the following tensor.pack:
+/// ```mlir
+/// %pack = tensor.pack %0 padding_value(%1)
+/// outer_dims_perm = [0, 2, 1]
+/// inner_dims_pos = [2, 1]
+/// inner_tiles = [16, 2]
+/// into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
+/// ```
+/// The "packed" shape is `32x1x4x16x2`
+/// The "tiled" shape is `32x(4x2)x(1x16)`
+static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
+ auto innerTiles = packOp.getInnerTiles();
+ int64_t srcRank = packOp.getSourceRank();
+ auto innerDimsPos = packOp.getInnerDimsPos();
+ if (innerDimsPos.empty())
+ innerDimsPos = to_vector(llvm::seq<int64_t>(innerTiles.size()));
+ auto outerDimsPerm = packOp.getOuterDimsPerm();
+ if (outerDimsPerm.empty())
+ outerDimsPerm = to_vector(llvm::seq<int64_t>(srcRank));
+ auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
+ int64_t srcIdx;
+ if (idx >= srcRank)
+ srcIdx = innerDimsPos[idx - srcRank];
+ else
+ srcIdx = outerDimsPerm[idx];
+ int64_t tiledIdx = srcIdx;
+ for (int64_t pos : innerDimsPos)
+ if (pos < srcIdx)
+ tiledIdx++;
+ if (idx >= srcRank)
+ tiledIdx++;
+ return tiledIdx;
+ };
+ SmallVector<int64_t> perm;
+ for (int i = 0; i < packOp.getDestRank(); i++)
+ perm.push_back(packedIdxToTiledIdx(i));
+ return perm;
+}
+
+/// Given a tensor::PackOp, return the "tiled" `dest` shape as described
+/// above in `getTiledShapeToPackedShapePerm`.
+static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
+ auto perm = getTiledShapeToPackedShapePerm(packOp);
+ auto destShape = packOp.getDestType().getShape();
+ return applyPermutation(destShape, invertPermutationVector(perm));
+}
+
+///
+static LogicalResult
+vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ auto padValue = packOp.getPaddingValue();
+ Location loc = packOp.getLoc();
+ int64_t inputRank = inputVectorSizes.size();
+ int64_t outputRank = packOp.getDestRank();
+ auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
+ auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(packOp);
+
+ ReifiedRankedShapedTypeDims reifiedReturnShapes;
+ LogicalResult status =
+ cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
+ .reifyResultShapes(rewriter, reifiedReturnShapes);
+ (void)status; // prevent unused variable warning on non-assert builds
+ assert(succeeded(status) && "failed to reify result shapes");
+ auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
+ padValue.getType());
+ SmallVector<OpFoldResult> mixedSourceDims =
+ tensor::getMixedSizes(rewriter, loc, packOp.getSource());
+ Value mask =
+ rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto transferReadOp = rewriter.create<vector::TransferReadOp>(
+ loc,
+ /*vectorType=*/vectorType,
+ /*source=*/packOp.getSource(),
+ /*indices=*/SmallVector<Value>(inputRank, zero),
+ /*padding=*/padValue,
+ /*inBounds=*/SmallVector<bool>(inputRank, true));
+ auto maskedOp = cast<vector::MaskOp>(
+ mlir::vector::maskOperation(rewriter, transferReadOp, mask));
+ // ShapeCast
+ auto tiledPackShape = getTiledPackShape(packOp);
+ auto tiledPackType = VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
+ auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedOp->getResult(0));
+ auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
+ auto transposeOp = rewriter.create<vector::TransposeOp>(loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
+ Operation *write = rewriter.create<vector::TransferWriteOp>(
+ loc,
+ /*vector=*/transposeOp->getResult(0),
+ /*source=*/emptyOp,
+ /*indices=*/SmallVector<Value>(outputRank, zero),
+ /*inBounds=*/SmallVector<bool>(outputRank, true));
+ // bool needMaskForWrite = llvm::any_of(
+ // llvm::zip_equal(inputVectorSizes, packOp.getResultType().getShape()),
+ // [](auto it) { return std::get<0>(it) != std::get<1>(it); });
+ // if (needMaskForWrite) {
+ // Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
+ // loc, maskType, reifiedReturnShapes[0]);
+ // write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
+ // }
+ newResults.push_back(write->getResult(0));
+ return success();
+}
+
/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
/// and (3) all-zero lowPad to
/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
@@ -1585,6 +1706,30 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
return success();
}
+static LogicalResult
+vectorizePackOpPrecondition(tensor::PackOp packOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ auto padValue = packOp.getPaddingValue();
+ if (!padValue) {
+ LDBG("pad value is not constant: " << packOp << "\n");
+ return failure();
+ }
+
+ ArrayRef<int64_t> resultTensorShape = packOp.getSourceType().getShape();
+ if (failed(isValidMaskedInputVector(resultTensorShape, inputVectorSizes)))
+ return failure();
+
+ if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
+ std::optional<int64_t> res = getConstantIntValue(v);
+ return !res.has_value();
+ })) {
+ LDBG("inner_tiles must be constant: " << packOp << "\n");
+ return failure();
+ }
+
+ return success();
+}
+
static LogicalResult
vectorizePadOpPrecondition(tensor::PadOp padOp,
ArrayRef<int64_t> inputVectorSizes) {
@@ -1644,6 +1789,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::PadOp>([&](auto padOp) {
return vectorizePadOpPrecondition(padOp, inputVectorSizes);
})
+ .Case<tensor::PackOp>([&](auto packOp) {
+ return vectorizePackOpPrecondition(packOp, inputVectorSizes);
+ })
.Default([](auto) { return failure(); });
}
@@ -1732,6 +1880,9 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
results);
})
+ .Case<tensor::PackOp>([&](auto packOp) {
+ return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, results);
+ })
.Default([](auto) { return failure(); });
if (failed(vectorizeResult)) {
>From 1ea5d0e88b04233695049d7492dafe2ef8a3cba2 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 18 Jan 2024 20:12:13 -0500
Subject: [PATCH 2/3] Support pack with no padding value
---
.../Linalg/Transforms/Vectorization.cpp | 22 ++++++++-----------
1 file changed, 9 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 1507eceac8f0b28..cc3ee8938c9a9f5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1458,16 +1458,20 @@ static LogicalResult
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
- auto padValue = packOp.getPaddingValue();
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(packOp);
+
Location loc = packOp.getLoc();
+ auto padValue = packOp.getPaddingValue();
+ if (!padValue) {
+ padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
+ }
int64_t inputRank = inputVectorSizes.size();
int64_t outputRank = packOp.getDestRank();
auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(packOp);
-
ReifiedRankedShapedTypeDims reifiedReturnShapes;
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
@@ -1502,14 +1506,6 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
/*source=*/emptyOp,
/*indices=*/SmallVector<Value>(outputRank, zero),
/*inBounds=*/SmallVector<bool>(outputRank, true));
- // bool needMaskForWrite = llvm::any_of(
- // llvm::zip_equal(inputVectorSizes, packOp.getResultType().getShape()),
- // [](auto it) { return std::get<0>(it) != std::get<1>(it); });
- // if (needMaskForWrite) {
- // Value maskForWrite = rewriter.create<vector::CreateMaskOp>(
- // loc, maskType, reifiedReturnShapes[0]);
- // write = mlir::vector::maskOperation(rewriter, write, maskForWrite);
- // }
newResults.push_back(write->getResult(0));
return success();
}
@@ -1710,7 +1706,7 @@ static LogicalResult
vectorizePackOpPrecondition(tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
- if (!padValue) {
+ if (padValue && getConstantIntValue(padValue) != std::nullopt) {
LDBG("pad value is not constant: " << packOp << "\n");
return failure();
}
>From 06f86da2219c52f5222e11553795c44ad057117a Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 18 Jan 2024 21:11:49 -0500
Subject: [PATCH 3/3] add tests
---
mlir/test/Dialect/Linalg/vectorization.mlir | 61 +++++++++++++++++++++
1 file changed, 61 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index d5fb0cbb9c723b0..af1c1337224fa2c 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -501,6 +501,67 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<4x1x16x2xf32>) -> tensor<4x1x16x2xf32> {
+ %pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor<?x?xf32> -> tensor<4x1x16x2xf32>
+ return %pack : tensor<4x1x16x2xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<4x1x16x2xf32>
+// CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<8x16xi1>
+// CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
+// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_2]], %[[c0_2]]], %[[cst]]
+// CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
+// CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
+// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<8x16xf32> to vector<4x2x1x16xf32>
+// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
+// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
+// CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<4x1x16x2xf32>
+// CHECK: return %[[write]] : tensor<4x1x16x2xf32>
+
+// -----
+
+func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+ %pack = tensor.pack %arg0 inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
+ return %pack : tensor<32x4x1x16x2xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [32, 8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
+// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
+// CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c8]], %[[c16]] : vector<32x8x16xi1>
+// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+// CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
+// CHECK-SAME: {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32>
+// CHECK-SAME: } : vector<32x8x16xi1> -> vector<32x8x16xf32>
+// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+// CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+// CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]]
+// CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+// CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
+
+// -----
+
func.func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
outs(%C: memref<?x?xf32>)
More information about the Mlir-commits
mailing list