[Mlir-commits] [mlir] [mlir] Add direct vectorization lowering for `tensor.pack` ops (PR #78660)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 18 18:17:37 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: None (Max191)
<details>
<summary>Changes</summary>
This PR adds a direct vectorization lowering of `tensor.pack` into `mask(vector.transfer_read)`->`vector.shape_cast`->`vector.transpose`->`vector.transfer_write`
---
Patch is 34.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78660.diff
9 Files Affected:
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+51-1)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+1-29)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp (-1)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+1-1)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+147)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+7-68)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir (+10)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+85)
- (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+61)
``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 678081837b81382..b4f18d57404cc29 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1052,6 +1052,55 @@ class PointwiseConverter : public OpRewritePattern<SrcOp> {
}
};
+class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
+public:
+ using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::TransposeOp op,
+ PatternRewriter &rewriter) const final {
+ DenseIntElementsAttr perms;
+ if (!matchPattern(op.getPerms(), m_Constant(&perms))) {
+ return rewriter.notifyMatchFailure(op, "unmatched permutation tensor");
+ }
+
+ auto loc = op.getLoc();
+ auto input = op->getOperand(0);
+ auto resultTy = cast<ShapedType>(op.getType());
+
+ SmallVector<Value> dynDims;
+ dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
+
+ SmallVector<AffineExpr, 2> inputExprs;
+ inputExprs.resize(resultTy.getRank());
+ for (const auto &permutation : llvm::enumerate(perms.getValues<APInt>())) {
+ auto index = permutation.index();
+ auto value = permutation.value().getZExtValue();
+ if (!resultTy.hasRank() || resultTy.isDynamicDim(index)) {
+ dynDims[index] = rewriter.create<tensor::DimOp>(loc, input, value);
+ }
+ inputExprs[value] = rewriter.getAffineDimExpr(index);
+ }
+
+ SmallVector<Value> filteredDims = condenseValues(dynDims);
+
+ auto emptyTensor = rewriter.create<tensor::EmptyOp>(
+ loc, resultTy.getShape(), resultTy.getElementType(), filteredDims);
+
+ SmallVector<AffineMap, 2> affineMaps = {
+ AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
+ rewriter.getContext()),
+ rewriter.getMultiDimIdentityMap(resultTy.getRank())};
+
+ rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+ op, resultTy, op.getInput1(), ValueRange{emptyTensor}, affineMaps,
+ getNParallelLoopsAttrs(resultTy.getRank()),
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
+ });
+ return success();
+ }
+};
+
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
public:
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -2408,6 +2457,7 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
ReverseConverter,
RFFT2dConverter,
TableConverter,
- TileConverter>(patterns->getContext());
+ TileConverter,
+ TransposeConverter>(patterns->getContext());
// clang-format on
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 8dc2d27bd545ff8..b3fbc7dd0b22c19 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -19,7 +19,6 @@
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
@@ -985,31 +984,6 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
}
};
-class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
-public:
- using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tosa::TransposeOp op,
- PatternRewriter &rewriter) const final {
- SmallVector<int64_t> constantPerms;
- if (failed(op.getConstantPerms(constantPerms)))
- return failure();
-
- Location loc = op.getLoc();
- // The verifier should have made sure we have a valid permutation tensor.
- assert(isPermutationVector(constantPerms) && "Expected valid permutation");
- SmallVector<OpFoldResult> inputSizes =
- tensor::getMixedSizes(rewriter, loc, op.getInput1());
- auto permutedSizes =
- applyPermutation<OpFoldResult>(inputSizes, constantPerms);
-
- auto permutedInit = rewriter.create<tensor::EmptyOp>(
- loc, permutedSizes, op.getInput1().getType().getElementType());
- rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
- op, op.getInput1(), permutedInit, constantPerms);
- return success();
- }
-};
} // namespace
void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
@@ -1030,8 +1004,6 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
MatMulConverter,
MaxPool2dConverter,
AvgPool2dConverter,
- FullyConnectedConverter,
- TransposeConverter
- >(patterns->getContext());
+ FullyConnectedConverter>(patterns->getContext());
// clang-format on
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 096969391e51b9d..5312dc164c26c5e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -60,7 +60,6 @@ struct TosaToLinalgNamed
target.addIllegalOp<tosa::AvgPool2dOp>();
target.addIllegalOp<tosa::MatMulOp>();
target.addIllegalOp<tosa::FullyConnectedOp>();
- target.addIllegalOp<tosa::TransposeOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5254aac976f462d..2e58eb3376a1c8e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3134,7 +3134,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
// TODO: Check that the correct number of vectorSizes was provided.
for (Operation *target : targets) {
- if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
+ if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5d99951ef09a92b..b56289b560272d0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -19,10 +19,14 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
@@ -30,7 +34,9 @@
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#include <type_traits>
@@ -1393,6 +1399,117 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}
+/// Given a tensor::PackOp, return the permutation from the "tiled"
+/// shape to the "packed" shape, defined as the following:
+/// The "packed" shape is the same as the `dest` shape of the pack op.
+/// The "tiled" shape is a permutation of the `dest` shape such that
+/// each outer dimension is in the original `source` order, and the
+/// inner_tile dimensions immediately follow their corresponding outer
+/// dimension.
+/// i.e. for the following tensor.pack:
+/// ```mlir
+/// %pack = tensor.pack %0 padding_value(%1)
+/// outer_dims_perm = [0, 2, 1]
+/// inner_dims_pos = [2, 1]
+/// inner_tiles = [16, 2]
+/// into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
+/// ```
+/// The "packed" shape is `32x1x4x16x2`
+/// The "tiled" shape is `32x(4x2)x(1x16)`
+static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
+ auto innerTiles = packOp.getInnerTiles();
+ int64_t srcRank = packOp.getSourceRank();
+ auto innerDimsPos = packOp.getInnerDimsPos();
+ if (innerDimsPos.empty())
+ innerDimsPos = to_vector(llvm::seq<int64_t>(innerTiles.size()));
+ auto outerDimsPerm = packOp.getOuterDimsPerm();
+ if (outerDimsPerm.empty())
+ outerDimsPerm = to_vector(llvm::seq<int64_t>(srcRank));
+ auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
+ int64_t srcIdx;
+ if (idx >= srcRank)
+ srcIdx = innerDimsPos[idx - srcRank];
+ else
+ srcIdx = outerDimsPerm[idx];
+ int64_t tiledIdx = srcIdx;
+ for (int64_t pos : innerDimsPos)
+ if (pos < srcIdx)
+ tiledIdx++;
+ if (idx >= srcRank)
+ tiledIdx++;
+ return tiledIdx;
+ };
+ SmallVector<int64_t> perm;
+ for (int i = 0; i < packOp.getDestRank(); i++)
+ perm.push_back(packedIdxToTiledIdx(i));
+ return perm;
+}
+
+/// Given a tensor::PackOp, return the "tiled" `dest` shape as described
+/// above in `getTiledShapeToPackedShapePerm`.
+static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
+ auto perm = getTiledShapeToPackedShapePerm(packOp);
+ auto destShape = packOp.getDestType().getShape();
+ return applyPermutation(destShape, invertPermutationVector(perm));
+}
+
+///
+static LogicalResult
+vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(packOp);
+
+ Location loc = packOp.getLoc();
+ auto padValue = packOp.getPaddingValue();
+ if (!padValue) {
+ padValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
+ }
+ int64_t inputRank = inputVectorSizes.size();
+ int64_t outputRank = packOp.getDestRank();
+ auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
+ auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+
+ ReifiedRankedShapedTypeDims reifiedReturnShapes;
+ LogicalResult status =
+ cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
+ .reifyResultShapes(rewriter, reifiedReturnShapes);
+ (void)status; // prevent unused variable warning on non-assert builds
+ assert(succeeded(status) && "failed to reify result shapes");
+ auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
+ padValue.getType());
+ SmallVector<OpFoldResult> mixedSourceDims =
+ tensor::getMixedSizes(rewriter, loc, packOp.getSource());
+ Value mask =
+ rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto transferReadOp = rewriter.create<vector::TransferReadOp>(
+ loc,
+ /*vectorType=*/vectorType,
+ /*source=*/packOp.getSource(),
+ /*indices=*/SmallVector<Value>(inputRank, zero),
+ /*padding=*/padValue,
+ /*inBounds=*/SmallVector<bool>(inputRank, true));
+ auto maskedOp = cast<vector::MaskOp>(
+ mlir::vector::maskOperation(rewriter, transferReadOp, mask));
+ // ShapeCast
+ auto tiledPackShape = getTiledPackShape(packOp);
+ auto tiledPackType = VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
+ auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedOp->getResult(0));
+ auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
+ auto transposeOp = rewriter.create<vector::TransposeOp>(loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
+ Operation *write = rewriter.create<vector::TransferWriteOp>(
+ loc,
+ /*vector=*/transposeOp->getResult(0),
+ /*source=*/emptyOp,
+ /*indices=*/SmallVector<Value>(outputRank, zero),
+ /*inBounds=*/SmallVector<bool>(outputRank, true));
+ newResults.push_back(write->getResult(0));
+ return success();
+}
+
/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
/// and (3) all-zero lowPad to
/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
@@ -1585,6 +1702,30 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
return success();
}
+static LogicalResult
+vectorizePackOpPrecondition(tensor::PackOp packOp,
+ ArrayRef<int64_t> inputVectorSizes) {
+ auto padValue = packOp.getPaddingValue();
+ if (padValue && getConstantIntValue(padValue) != std::nullopt) {
+ LDBG("pad value is not constant: " << packOp << "\n");
+ return failure();
+ }
+
+ ArrayRef<int64_t> resultTensorShape = packOp.getSourceType().getShape();
+ if (failed(isValidMaskedInputVector(resultTensorShape, inputVectorSizes)))
+ return failure();
+
+ if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
+ std::optional<int64_t> res = getConstantIntValue(v);
+ return !res.has_value();
+ })) {
+ LDBG("inner_tiles must be constant: " << packOp << "\n");
+ return failure();
+ }
+
+ return success();
+}
+
static LogicalResult
vectorizePadOpPrecondition(tensor::PadOp padOp,
ArrayRef<int64_t> inputVectorSizes) {
@@ -1644,6 +1785,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::PadOp>([&](auto padOp) {
return vectorizePadOpPrecondition(padOp, inputVectorSizes);
})
+ .Case<tensor::PackOp>([&](auto packOp) {
+ return vectorizePackOpPrecondition(packOp, inputVectorSizes);
+ })
.Default([](auto) { return failure(); });
}
@@ -1732,6 +1876,9 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
results);
})
+ .Case<tensor::PackOp>([&](auto packOp) {
+ return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, results);
+ })
.Default([](auto) { return failure(); });
if (failed(vectorizeResult)) {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 6616ea7cf699fa5..aa010e759a0f201 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -88,8 +88,7 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
// CHECK-LABEL: @fully_connected
func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
- // CHECK: %[[TRANSPOSEDINIT:.+]] = tensor.empty() : tensor<3x6xf32>
- // CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xf32>) outs(%[[TRANSPOSEDINIT]] : tensor<3x6xf32>) permutation = [1, 0]
+ // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<5x6xf32>) {
@@ -111,7 +110,7 @@ func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2
// CHECK-LABEL: @quantized_fully_connected
func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
- // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x6xi8>) permutation = [1, 0]
+ // CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xi8>, tensor<2xi64>) -> tensor<3x6xi8>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xi32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs(%[[INIT]] : tensor<5x6xi32>) {
@@ -137,7 +136,7 @@ func.func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %c0 : tensor<?x3xf32>
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
- // CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x6xf32>) permutation = [1, 0]
+ // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x6xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<?x6xf32>) {
@@ -378,7 +377,7 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
// CHECK-LABEL: @conv2d_i8
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
- // HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x1x1x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0]
+ // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
// CHECK: arith.extsi
@@ -399,7 +398,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
// CHECK-LABEL: @conv2d_f32
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
- // HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x3x27x28xf32>) permutation = [1, 2, 3, 0]
+ // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) {
@@ -678,7 +677,7 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
// CHECK-LABEL: @conv3d_f32
func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
// CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
- // CHECK-DAG: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xf32>) permutation = [1, 2, 3, 4, 0]
+ // CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
@@ -702,7 +701,7 @@ func.func @conv3...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/78660
More information about the Mlir-commits
mailing list