[Mlir-commits] [mlir] [mlir][tosa] Change Transpose perms operand to attribute (PR #128115)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 20 18:59:36 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: Tai Ly (Tai78641)
<details>
<summary>Changes</summary>
This patch changes the perms operand for Tosa Transpose operator to an i32 array attribute
---
Patch is 109.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128115.diff
20 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-5)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+5-11)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+7-24)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+48-93)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+2-10)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp (+1-6)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp (+5-10)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (-10)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+4-12)
- (modified) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+2-4)
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+1-2)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+16-9)
- (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+18-28)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+21-47)
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+1-2)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+3-6)
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+6-12)
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+2-21)
- (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+91-171)
- (modified) mlir/test/Dialect/Tosa/transpose-fold.mlir (+9-18)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3de1c21f40b43..a06e03f831985 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2011,7 +2011,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
let arguments = (ins
Tosa_Tensor:$input1,
- Tosa_Int32Tensor:$perms
+ DenseI32ArrayAttr:$perms
);
let results = (
@@ -2023,10 +2023,6 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];
- let extraClassDeclaration = [{
- LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
- }];
-
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index a8fd536dd2548..42e88ee9026ac 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -329,13 +329,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
- auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
- Value weightPermValue =
- rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+ auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
- weightPermValue);
+ weightPermAttr);
}
}
@@ -353,13 +351,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
- auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
- Value weightPermValue =
- rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+ auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
- weightPermValue);
+ weightPermAttr);
}
// Extract the attributes for convolution.
@@ -970,9 +966,7 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const final {
- SmallVector<int32_t> constantPerms;
- if (failed(op.getConstantPerms(constantPerms)))
- return failure();
+ const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
Location loc = op.getLoc();
// The verifier should have made sure we have a valid TOSA permutation
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..8e2d8662ece8d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -88,13 +88,10 @@ struct ConsolidateTransposeOptimization
return rewriter.notifyMatchFailure(transposeOp,
"input must be transpose operation");
- SmallVector<int32_t> transposePerms, innerTransposePerms;
- if (transposeOp.getConstantPerms(transposePerms).failed())
- return rewriter.notifyMatchFailure(transposeOp,
- "transpose perms must be constant");
- if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
- return rewriter.notifyMatchFailure(
- transposeOp, "inner transpose perms must be constant");
+ const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
+ const llvm::ArrayRef<int32_t> innerTransposePerms =
+ innerTranspose.getPerms();
+
if (transposePerms.size() != innerTransposePerms.size())
return rewriter.notifyMatchFailure(
transposeOp,
@@ -108,15 +105,9 @@ struct ConsolidateTransposeOptimization
for (int i = 0, s = transposePerms.size(); i < s; ++i)
perms[i] = innerTransposePerms[transposePerms[i]];
- auto permsTy =
- RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
- auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
- Value permsValue = rewriter.create<tosa::ConstOp>(transposeOp.getLoc(),
- permsTy, permsAttr);
-
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
transposeOp, transposeOp.getResult().getType(),
- innerTranspose.getInput1(), permsValue);
+ innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
return success();
}
@@ -128,10 +119,6 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
- DenseIntElementsAttr permAttr;
- if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
- return rewriter.notifyMatchFailure(op, "Non-constant permutation");
-
if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
return rewriter.notifyMatchFailure(
op, "Src is from transpose, can compose transposes");
@@ -156,9 +143,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
if (numDynDims > 1)
return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
- SmallVector<int64_t> permValues = llvm::to_vector<6>(
- llvm::map_range(permAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); }));
+ const llvm::ArrayRef<int32_t> permValues = op.getPerms();
SmallVector<int64_t> nonZeroPerms;
nonZeroPerms.reserve(permValues.size());
@@ -1175,9 +1160,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
}
// Transpose is not the identity transpose.
- SmallVector<int32_t> perms;
- if (getConstantPerms(perms).failed())
- return {};
+ const llvm::ArrayRef<int32_t> perms = getPerms();
if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e9c33e1b1bf10..7030dccd693a4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1374,41 +1374,22 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
return mlir::success();
}
-LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
- // Perms must be constants.
- DenseIntElementsAttr permsAttr;
- if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
- return failure();
-
- perms.clear();
- for (auto v : permsAttr.getValues<APInt>())
- perms.push_back(v.getSExtValue());
-
- return success();
-}
-
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TransposeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
- ShapeAdaptor permsShape(adaptor.getPerms().getType());
-
- // We cannot infer anything from a rank-0 "permutation" tensor.
- if (permsShape.hasRank() && permsShape.getRank() == 0)
- return failure();
// If input rank and permutation length is unknown, the output rank is
// unknown.
- if (!inputShape.hasRank() || !permsShape.hasRank() ||
- permsShape.isDynamicDim(0)) {
+ if (!inputShape.hasRank()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
return success();
}
// This would imply the number of permutations does not match the rank of
// the input which is illegal.
- if (permsShape.getDimSize(0) != inputShape.getRank()) {
+ if (adaptor.getPerms().size() != static_cast<size_t>(inputShape.getRank())) {
return failure();
}
@@ -1437,28 +1418,16 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
}
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
- // If the permuations are a constant we can directly determine the output
- // shape.
- DenseIntElementsAttr attr;
- if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
- attr.getType().getRank() == 1) {
- ShapeAdaptor permShape = attr;
- // Constant permutation must be the same length as the input rank.
- if (inputShape.getRank() != permShape.getRank())
- return emitOptionalError(location,
- "constant permutation must be the same length"
- " as the input rank");
-
- // Constant permutation values must be within the input rank.
- for (int i = 0, e = inputShape.getRank(); i < e; i++) {
- if (inputShape.getRank() <= permShape.getDimSize(i))
- return failure();
- }
- outputShape.reserve(inputShape.getRank());
- for (int i = 0, s = inputShape.getRank(); i < s; i++) {
- outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
- }
+ // Constant permutation values must be within the input rank.
+ for (auto i : adaptor.getPerms()) {
+ if (inputShape.getRank() <= i)
+ return failure();
+ }
+
+ outputShape.reserve(inputShape.getRank());
+ for (int i = 0, s = inputShape.getRank(); i < s; i++) {
+ outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
}
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -1467,75 +1436,61 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
LogicalResult tosa::TransposeOp::verify() {
TensorType inputType = getInput1().getType();
- TensorType permType = getPerms().getType();
TensorType outputType = getOutput().getType();
+ const llvm::ArrayRef<int32_t> constantPerms = getPerms();
- if (permType.hasRank() && permType.getRank() != 1)
- return emitOpError()
- << "expected permutation tensor to be rank 1 but got rank "
- << permType.getRank();
- if (inputType.hasRank() && permType.hasRank())
- if (!permType.isDynamicDim(0) &&
- permType.getDimSize(0) != inputType.getRank())
- return emitOpError() << "expected permutation tensor dim 0 to have size "
+ if (inputType.hasRank())
+ if (constantPerms.size() != static_cast<size_t>(inputType.getRank()))
+ return emitOpError() << "expected perms attribute to have size "
<< inputType.getRank()
<< " (input rank) but got size "
- << permType.getDimSize(0);
+ << constantPerms.size();
if (inputType.hasRank() && outputType.hasRank() &&
inputType.getRank() != outputType.getRank())
return emitOpError()
<< "expected input tensor rank to equal result tensor rank";
- if (outputType.hasRank() && permType.hasRank())
- if (!permType.isDynamicDim(0) &&
- permType.getDimSize(0) != outputType.getRank())
- return emitOpError() << "expected permutation tensor dim 0 to have size "
+ if (outputType.hasRank())
+ if (constantPerms.size() != static_cast<size_t>(outputType.getRank()))
+ return emitOpError() << "expected perms attribute to have size "
<< outputType.getRank()
<< " (output rank) but got size "
- << permType.getDimSize(0);
-
- SmallVector<int32_t> constantPerms;
- if (succeeded(getConstantPerms(constantPerms))) {
- // Assert that the permutation tensor has a rank, which means that the
- // rank has been verified above.
- assert(permType.hasRank() &&
- "Unexpectedly found permutation tensor without rank");
- if (!llvm::all_of(constantPerms,
- [&constantPerms](int32_t s) {
- return s >= 0 &&
- static_cast<size_t>(s) < constantPerms.size();
- }) ||
- !isPermutationVector(llvm::to_vector(llvm::map_range(
- constantPerms, [](int32_t v) -> int64_t { return v; }))))
- return emitOpError() << "expected valid permutation tensor";
-
- // Verify that the types of the input and output tensors are properly
- // permuted.
- if (inputType.hasRank() && outputType.hasRank()) {
- assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
- inputType.getRank() == outputType.getRank());
-
- for (auto i = 0; i < outputType.getRank(); i++) {
- if (inputType.isDynamicDim(constantPerms[i]) ||
- outputType.isDynamicDim(i))
- continue;
-
- if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
- return emitOpError()
- << "expected output tensor dim " << i << " to match "
- << "input dim " << constantPerms[i] << " with value of "
- << inputType.getDimSize(constantPerms[i]);
- }
+ << constantPerms.size();
+
+ if (!llvm::all_of(constantPerms,
+ [&constantPerms](int32_t s) {
+ return s >= 0 &&
+ static_cast<size_t>(s) < constantPerms.size();
+ }) ||
+ !isPermutationVector(llvm::to_vector(llvm::map_range(
+ constantPerms, [](int32_t v) -> int64_t { return v; }))))
+ return emitOpError() << "expected valid permutation indices";
+
+ // Verify that the types of the input and output tensors are properly
+ // permuted.
+ if (inputType.hasRank() && outputType.hasRank()) {
+ assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
+ inputType.getRank() == outputType.getRank());
+
+ for (auto i = 0; i < outputType.getRank(); i++) {
+ if (inputType.isDynamicDim(constantPerms[i]) ||
+ outputType.isDynamicDim(i))
+ continue;
+
+ if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
+ return emitOpError()
+ << "expected output tensor dim " << i << " to match "
+ << "input dim " << constantPerms[i] << " with value of "
+ << inputType.getDimSize(constantPerms[i]);
}
}
+
return success();
}
LogicalResult TransposeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
- SmallVector<int32_t> transposePerms;
- if (getConstantPerms(transposePerms).failed())
- return failure();
+ const llvm::ArrayRef<int32_t> transposePerms = getPerms();
Value input = getInput1();
auto inputType = cast<TensorType>(input.getType());
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 26baddcf1dd15..61011b6df4617 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -166,13 +166,9 @@ class TransposeConvStridedConverter
getTosaConstShape(rewriter, loc, weightReshapeDims0));
// Transpose the factored-out stride to the output channels.
- Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
- loc, RankedTensorType::get({6}, rewriter.getI32Type()),
- rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
-
weight = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
- transposeWeightVal);
+ rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
// Collapse the strides and output channels into a single dimension.
llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
@@ -269,13 +265,9 @@ class TransposeConvStridedConverter
convReshapeDims0Value);
// Transpose the factored-out stride to the output channels.
- Value transposeConvVal = rewriter.create<tosa::ConstOp>(
- loc, RankedTensorType::get({6}, rewriter.getI32Type()),
- rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
-
conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
- transposeConvVal);
+ rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 403ac48b91559..43e9507b4d95a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -224,13 +224,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
return failure();
- DenseIntElementsAttr permAttr;
- if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
- return failure();
auto permValues = llvm::map_to_vector(
- // TOSA allows both 32- and 64-bit integer tensors here.
- permAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); });
+ op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
auto inputType = cast<ShapedType>(op.getInput1().getType());
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 64e5c31793f84..d4d8aae8b0316 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -367,9 +367,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
std::optional<Value> TosaReduceTransposes::buildMappedToValue(
TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
- SmallVector<int32_t> perms;
- if (failed(transposeOp.getConstantPerms(perms)) ||
- !areInvolutionTransposes(hoistedPerms, perms))
+ if (!areInvolutionTransposes(hoistedPerms, transposeOp.getPerms()))
return std::nullopt;
return transposeOp.getInput1();
}
@@ -506,14 +504,11 @@ bool TosaReduceTransposes::dependenciesAreValid(
// replaced.
Operation *user = use.getOwner();
if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
- SmallVector<int32_t> otherPerms;
-
// Can later think about cases where transpose -> transpose
// or reshape -> transpose, where the transposes are not necessarily
// the same perms as the hoisted, if implementing a more general
// transform. These could be permitted.
- if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
- !llvm::equal(perms, otherPerms))
+ if (!llvm::equal(perms, otherTranspose.getPerms()))
return false;
} else if (userNotContainedInValidTransposeDependencies(
user, validTransposes, transposeInfo)) {
@@ -607,9 +602,9 @@ void TosaReduceTransposes::runOnOperation() {
!llvm::isa<RankedTensorType>(output.getType()))
return;
- // No transformation when transpose permutation non-constant.
- if (failed(transposeOp.getConstantPerms(perms)))
- return;
+ for (int32_t v : transposeOp.getPerms()) {
+ perms.push_back(v);
+ }
// We let --canonicalize deal with identity transpose.
if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f74a4b4c58b80..f2abb29b4fe66 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -56,15 +56,6 @@ static LogicalResult checkConstantOperandPad(Operation *op) {
return success();
}
-static LogicalResult checkConstantOpe...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/128115
More information about the Mlir-commits
mailing list