[Mlir-commits] [mlir] [mlir][tosa] Change 'shape' of RESHAPE from attribute to input shape … (PR #125789)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 4 16:19:37 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: TatWai Chong (tatwaichong)
<details>
<summary>Changes</summary>
The shape operand is changed to input shape type since V1.0
Change-Id: I508cc1d67e9b017048b3f29fecf202cb7d707110
---
Patch is 115.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125789.diff
25 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1)
- (modified) mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h (+3)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+3-2)
- (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+7-1)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+6-2)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+23-9)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp (+17-20)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+7-5)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+17-6)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp (+9-2)
- (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+12-7)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+2-1)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+10-5)
- (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+72-38)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+37-20)
- (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+2-1)
- (modified) mlir/test/Dialect/Tosa/inlining.mlir (+2-1)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+33-12)
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+2-1)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+16-8)
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir (+25-12)
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+29-15)
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+42-26)
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+27-14)
- (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+44-33)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 8ede271cc56a8a..869ab913a715ad 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1621,7 +1621,7 @@ def Tosa_ReshapeOp : Tosa_InferTensorTypeOp<"reshape"> {
let arguments = (ins
Tosa_Tensor:$input1,
- DenseI64ArrayAttr:$new_shape
+ Tosa_Shape:$shape
);
let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index 78a8828855437e..88c21629286525 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -230,8 +230,11 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
}
// Computes shape value using tosa const_shape op.
+Value getTosaConstShape(ImplicitLocOpBuilder &builder,
+ llvm::ArrayRef<int64_t> shape);
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape);
+
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
bool getConstShapeValue(Operation *op,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b0eb2d6cbc30b6..edb04010d53fd9 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1952,9 +1952,10 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
});
+ auto shapeValue = getTosaConstShape(
+ rewriter, loc, mlir::tosa::convertFromMlirShape(resultTy.getShape()));
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
- op, resultTy, genericOp.getResult(0),
- rewriter.getDenseI64ArrayAttr(resultTy.getShape()));
+ op, resultTy, genericOp.getResult(0), shapeValue);
return success();
}
};
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index c4b787d5c865b0..fdb8b1e1471a73 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -235,7 +236,12 @@ class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
return rewriter.notifyMatchFailure(reshape.getLoc(),
"expected input type to be tensor");
}
- auto newShape = reshape.getNewShape();
+
+ llvm::SmallVector<int64_t> newShape;
+ if (!tosa::getConstShapeValue(reshape.getShape().getDefiningOp(),
+ newShape)) {
+ return failure();
+ }
// Infer all intermediate types
auto inputType = inferReshapeInputType(input, newShape);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9d36947b4352bb..229719f5ef84d4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -180,7 +180,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, op.getType(), op.getInput1(),
- rewriter.getDenseI64ArrayAttr(newShape));
+ getTosaConstShape(rewriter, op.getLoc(), newShape));
return success();
}
};
@@ -948,8 +948,12 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (!getInput1().hasOneUse())
return {};
+ llvm::SmallVector<int64_t> shapeVec;
+ if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeVec))
+ return {};
+
return operand.reshape(
- llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
+ llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
}
return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e8b28906135edf..f88c6df8e2b458 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1309,8 +1309,16 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
- llvm::SmallVector<int64_t> newShapeValue =
- convertToMlirShape(adaptor.getNewShape());
+ llvm::SmallVector<int64_t> newShapeValue;
+ if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(),
+ newShapeValue)) {
+ auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
+ SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+ inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+ return success();
+ } else {
+ newShapeValue = convertToMlirShape(newShapeValue);
+ }
// We cannot infer from the total number of elements so we must take the
// shape attribute as exact.
@@ -1346,13 +1354,19 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
TensorType inputType = getInput1().getType();
RankedTensorType outputType = getType();
- if ((int64_t)getNewShape().size() != outputType.getRank())
+ SmallVector<int64_t> shapeValues;
+ if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
+ // skip following checks if shape is not constant
+ return mlir::success();
+ }
+
+ if ((int64_t)shapeValues.size() != outputType.getRank())
return emitOpError() << "new shape does not match result rank";
for (auto [newShapeDim, outputShapeDim] :
- zip(getNewShape(), outputType.getShape())) {
- if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
- newShapeDim != outputShapeDim)
+ zip(shapeValues, outputType.getShape())) {
+ if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
+ outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
return emitOpError() << "new shape is inconsistent with result shape";
if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
@@ -1371,10 +1385,10 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
}
int64_t newShapeElementsNum = std::accumulate(
- getNewShape().begin(), getNewShape().end(), 1LL,
+ shapeValues.begin(), shapeValues.end(), 1LL,
[](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
bool isStaticNewShape =
- llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
+ llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
(!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
return emitOpError() << "cannot reshape " << inputElementsNum
@@ -1382,7 +1396,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
}
}
- int missingDims = llvm::count(getNewShape(), -1);
+ int missingDims = llvm::count(shapeValues, -1);
if (missingDims > 1)
return emitOpError() << "expected at most one target dimension to be -1";
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 7d3deae3330afe..04e8ad31cf2e2e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -20,12 +20,6 @@ using namespace mlir::tosa;
namespace {
-SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
- return to_vector(llvm::map_range(shape, [](int64_t dim) {
- return ShapedType::isDynamic(dim) ? -1 : dim;
- }));
-}
-
struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
explicit Conv2DIsFullyConnected(MLIRContext *context)
: OpRewritePattern(context) {}
@@ -98,12 +92,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
auto revisedInputShapeType =
RankedTensorType::get(revisedInputShape, inputType.getElementType());
- auto reshapedInput = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedInputShapeType, input,
- rewriter.getDenseI64ArrayAttr(
- convertFromMlirShape(revisedInputShape)))
- .getResult();
+ auto revisedInputShapeValue = getTosaConstShape(
+ rewriter, op.getLoc(), convertFromMlirShape(revisedInputShape));
+ auto reshapedInput =
+ rewriter
+ .create<tosa::ReshapeOp>(op.getLoc(), revisedInputShapeType, input,
+ revisedInputShapeValue)
+ .getResult();
// Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
@@ -111,12 +106,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
auto revisedWeightShapeType = RankedTensorType::get(
revisedWeightShape,
dyn_cast<RankedTensorType>(weight.getType()).getElementType());
- auto reshapedWeight = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedWeightShapeType, weight,
- rewriter.getDenseI64ArrayAttr(
- convertFromMlirShape(revisedWeightShape)))
- .getResult();
+ auto revisedWeightShapeValue = getTosaConstShape(
+ rewriter, op.getLoc(), convertFromMlirShape(revisedWeightShape));
+ auto reshapedWeight =
+ rewriter
+ .create<tosa::ReshapeOp>(op.getLoc(), revisedWeightShapeType,
+ weight, revisedWeightShapeValue)
+ .getResult();
// Perform a fully connected network over the reshaped input and weight.
llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
@@ -149,9 +145,10 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
// Reshape output to [N, IH, IW, OC].
llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
inputShape[2], weightShape[0]};
+ auto outputShapeValue = getTosaConstShape(
+ rewriter, op.getLoc(), convertFromMlirShape(outputShape));
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
- op, resultType, fullyConnectedValue,
- rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
+ op, resultType, fullyConnectedValue, outputShapeValue);
return success();
}
};
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index ee857f1998a54d..b26397d0e3ed7a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -55,10 +55,11 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
inputType = RankedTensorType::get(
revisedInputShape,
dyn_cast<RankedTensorType>(input.getType()).getElementType());
+ auto revisedInputShapeValue =
+ getTosaConstShape(rewriter, op.getLoc(), revisedInputShape);
input = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), inputType, input,
- rewriter.getDenseI64ArrayAttr(revisedInputShape))
+ .create<tosa::ReshapeOp>(op.getLoc(), inputType, input,
+ revisedInputShapeValue)
.getResult();
Type inputETy = inputType.getElementType();
@@ -153,9 +154,10 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
auto outputShapeType = RankedTensorType::get(
outputShape,
dyn_cast<RankedTensorType>(input.getType()).getElementType());
+ auto outputShapeValue =
+ getTosaConstShape(rewriter, op->getLoc(), outputShape);
Value outputValue = rewriter.create<tosa::ReshapeOp>(
- op.getLoc(), outputShapeType, mulValue,
- rewriter.getDenseI64ArrayAttr(outputShape));
+ op.getLoc(), outputShapeType, mulValue, outputShapeValue);
Value bias = op.getBias();
if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index ae224671e304f2..69a66c98307e94 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -160,9 +160,11 @@ class TransposeConvStridedConverter
outputChannels, weightHeight / stride[0],
stride[0], weightWidth / stride[1],
stride[1], inputChannels};
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
- rewriter, loc, UnrankedTensorType::get(weightETy), weight,
- rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
+ builder, UnrankedTensorType::get(weightETy), weight,
+ getTosaConstShape(rewriter, loc, weightReshapeDims0));
// Transpose the factored-out stride to the output channels.
Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
@@ -174,12 +176,13 @@ class TransposeConvStridedConverter
transposeWeightVal);
// Collapse the strides and output channels into a single dimension.
- llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
+ llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
outputChannels * stride[0] * stride[1], weightHeight / stride[0],
weightWidth / stride[1], inputChannels};
+
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
- rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
+ getTosaConstShape(rewriter, loc, weightReshapeDims1));
ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
weight = CreateOpAndInferShape<tosa::ReverseOp>(
@@ -258,9 +261,13 @@ class TransposeConvStridedConverter
// Factor striding out of the convolution result.
llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
+
+ auto convReshapeDims0Value =
+ getTosaConstShape(rewriter, loc, convReshapeDims0);
+
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
- rewriter.getDenseI64ArrayAttr(convReshapeDims0));
+ convReshapeDims0Value);
// Transpose the factored-out stride to the output channels.
Value transposeConvVal = rewriter.create<tosa::ConstOp>(
@@ -274,9 +281,13 @@ class TransposeConvStridedConverter
// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
+
+ auto convReshapeDims1Value =
+ getTosaConstShape(rewriter, loc, convReshapeDims1);
+
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
- rewriter.getDenseI64ArrayAttr(convReshapeDims1));
+ convReshapeDims1Value);
// Determine the amount to slice / pad from the result start.
int64_t resultSliceTop = std::max<int64_t>(0, -pad[0]);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 520f283a3ba888..281f0529a5c081 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -402,13 +402,20 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
return std::nullopt;
// Do not insert a TransposeOp, instead we fold the reshape and its attribute.
+ llvm::SmallVector<int64_t> newShape;
+ if (!tosa::getConstShapeValue(reshapeOp.getShape().getDefiningOp(),
+ newShape)) {
+ // this mean shape is not constant
+ return std::nullopt;
+ }
+ ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
auto foldedReshape = rewriter.create<ReshapeOp>(
reshapeOp.getLoc(),
RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
reshapeOutputType.getElementType()),
reshapeOp.getInput1(),
- rewriter.getDenseI64ArrayAttr(
- applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms)));
+ getTosaConstShape(builder, applyTOSAPermutation(llvm::ArrayRef(newShape),
+ hoistedPerms)));
return foldedReshape->getResult(0);
}
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 62b0bc1857e395..8ab12d038849f4 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -145,10 +145,10 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
llvm::cast<RankedTensorType>(lowerTensorValue.getType());
auto reshapeOutputType = RankedTensorType::get(
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
+ auto reshapeOutputShapeValue = getTosaConstShape(builder, reshapeOutputShape);
auto reshapeLower = builder.create<tosa::ReshapeOp>(
- reshapeOutputType, lowerTensorValue,
- builder.getDenseI64ArrayAttr(reshapeOutputShape));
+ reshapeOutputType, lowerTensorValue, reshapeOutputShapeValue);
if (input1Rank > input2Rank) {
input1 = higherTensorValue;
@@ -161,15 +161,20 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
return success();
}
-Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
+Value mlir::tosa::getTosaConstShape(ImplicitLocOpBuilder &builder,
llvm::ArrayRef<int64_t> shape) {
- auto attr = rewriter.getIndexTensorAttr(shape);
- auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
- mlir::Operation *mlir_op =
- rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
+ auto attr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
+ auto type = mlir::tosa::shapeType::get(builder.getContext(), shape.size());
+ mlir::Operation *mlir_op = builder.create<tosa::ConstShapeOp>(type, attr);
return mlir_op->getResult(0);
}
+Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc,
+ llvm::ArrayRef<int64_t> shape) {
+ ImplicitLocOpBuilder builder(loc, rewriter);
+ return getTosaConstShape(builder, shape);
+}
+
SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
return to_vector(llvm::map_range(shape, [](int64_t dim) {
return ShapedType::isDynamic(dim) ? -1 : dim;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 75b48f2b06d899..460e207d62de6a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -24,7 +24,8 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
%reduce = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<10x10xf32>) -> tensor<10x1xf32>
%1 = tosa.add %reduce, %arg1 : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%0 = tosa.add %1, %arg2 : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
- %2 = tosa.reshape %0 {new_shape = array<i64: 10, 10>} : (tensor<...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/125789
More information about the Mlir-commits
mailing list