[Mlir-commits] [mlir] [mlir][tosa] Change the start and size of slice to tosa shape type (PR #124209)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 23 15:31:19 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Jerry-Ge (Jerry-Ge)
<details>
<summary>Changes</summary>
Update to use getConstShapeValue to collect shape information along the graph.
Change-Id: Ic6fc2341e3bcfbec06a1d08986e26dd08573bd9c
---
Patch is 38.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124209.diff
14 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-2)
- (modified) mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp (+30-4)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+31-12)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+18-4)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+2-2)
- (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir (+4-2)
- (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+7-3)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+48-23)
- (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+7-2)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+6-2)
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+8-5)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+13-2)
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+12-6)
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+30-10)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 2953e006bbe8d1..d2ac206263bb13 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1664,8 +1664,8 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
let arguments = (ins
Tosa_Tensor:$input1,
- DenseI64ArrayAttr:$start,
- DenseI64ArrayAttr:$size
+ Tosa_Shape:$start,
+ Tosa_Shape:$size
);
let results = (outs
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 5aa0269a675cbe..d6319c1c3fd891 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -268,12 +268,28 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
ShapedType resultType = cast<ShapedType>(sliceOp.getType());
if (llvm::isa<UnrankedTensorType>(resultType))
return failure();
+
+ ElementsAttr StartElems;
+ ElementsAttr SizeElems;
+
+ if (!matchPattern(sliceOp.getStart(), m_Constant(&StartElems)))
+ return rewriter.notifyMatchFailure(
+ sliceOp, "start of slice must be a static ranked shape");
+
+ if (!matchPattern(sliceOp.getSize(), m_Constant(&SizeElems)))
+ return rewriter.notifyMatchFailure(
+ sliceOp, "size of slice must be a static ranked shape");
+
+ llvm::SmallVector<int64_t> SliceStarts =
+ llvm::to_vector(StartElems.getValues<int64_t>());
+ llvm::SmallVector<int64_t> SliceSizes =
+ llvm::to_vector(SizeElems.getValues<int64_t>());
+
SmallVector<int64_t> strides, sizes;
- ArrayRef<int64_t> starts = sliceOp.getStart();
strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
SmallVector<Value> dynSizes;
- for (const auto &i : llvm::enumerate(sliceOp.getSize())) {
+ for (const auto &i : llvm::enumerate(SliceSizes)) {
int64_t size = i.value();
size_t index = i.index();
sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
@@ -282,17 +298,27 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
auto offset = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(starts[index]));
+ loc, rewriter.getIndexAttr(SliceStarts[index]));
dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
}
auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
- ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
+ ValueRange({}), rewriter.getDenseI64ArrayAttr(SliceStarts),
rewriter.getDenseI64ArrayAttr(sizes),
rewriter.getDenseI64ArrayAttr(strides));
rewriter.replaceOp(sliceOp, newSliceOp.getResult());
+
+ auto removeIfRedundant = [&](Operation *op) {
+ if (op->getResults().size() == 1 && op->getResult(0).hasOneUse())
+ rewriter.eraseOp(op);
+ };
+
+ // Remove const_shape ops when it no longer has use point.
+ removeIfRedundant(sliceOp.getStart().getDefiningOp());
+ removeIfRedundant(sliceOp.getSize().getDefiningOp());
+
return success();
}
};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index f7a596f1ccb192..7f3b929d276024 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -393,8 +393,21 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
sliceOp, "slice input must be a static ranked tensor");
int32_t axis = concatOp.getAxis();
- llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
- llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
+ DenseElementsAttr StartElems;
+ DenseElementsAttr SizeElems;
+
+ if (!matchPattern(sliceOp.getStart(), m_Constant(&StartElems)))
+ return rewriter.notifyMatchFailure(
+ sliceOp, "start of slice must be a static ranked shape");
+
+ if (!matchPattern(sliceOp.getSize(), m_Constant(&SizeElems)))
+ return rewriter.notifyMatchFailure(
+ sliceOp, "size of slice must be a static ranked shape");
+
+ llvm::SmallVector<int64_t> SliceStarts =
+ llvm::to_vector(StartElems.getValues<int64_t>());
+ llvm::SmallVector<int64_t> SliceSizes =
+ llvm::to_vector(SizeElems.getValues<int64_t>());
// Validate slice on the concatenated axis. Slicing along this
// axis should span only one of the inputs to the concatenate
@@ -406,17 +419,19 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
return rewriter.notifyMatchFailure(
sliceOp, "concat input must be a static ranked tensor");
- if (sliceStart[axis] >= 0 &&
- (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
- replaceWithSlice = rewriter
- .create<tosa::SliceOp>(
- sliceOp.getLoc(), sliceOp.getType(), input,
- rewriter.getDenseI64ArrayAttr(sliceStart),
- rewriter.getDenseI64ArrayAttr(sliceSize))
- .getResult();
+ if (SliceStarts[axis] >= 0 &&
+ (SliceStarts[axis] + SliceSizes[axis]) <= inputType.getDimSize(axis)) {
+ auto start_op =
+ getTosaConstShape(rewriter, sliceOp.getLoc(), SliceStarts);
+ auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), SliceSizes);
+ replaceWithSlice =
+ rewriter
+ .create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(),
+ input, start_op, size_op)
+ .getResult();
break;
}
- sliceStart[axis] -= inputType.getDimSize(axis);
+ SliceStarts[axis] -= inputType.getDimSize(axis);
}
if (!replaceWithSlice)
@@ -963,7 +978,11 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
outputTy.getNumElements() == 1) {
- llvm::SmallVector<uint64_t> indices(getStart());
+ DenseElementsAttr StartElems;
+ if (!matchPattern(getStart(), m_Constant(&StartElems)))
+ return {};
+
+ llvm::SmallVector<uint64_t> indices = llvm::to_vector(StartElems.getValues<uint64_t>());
auto value = operand.getValues<Attribute>()[indices];
return SplatElementsAttr::get(outputTy, value);
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index fdccce60fe1d86..2f7bd75974b79b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -891,8 +891,18 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
SliceOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- auto start = adaptor.getStart();
- auto size = adaptor.getSize();
+
+ Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
+ SmallVector<int64_t> start;
+ SmallVector<int64_t> size;
+
+ if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) ||
+ !tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) {
+ auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
+ SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+ inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+ return success();
+ }
// if size[i] is -1, all remaining elements in dimension i are included
// in the slice, similar to TF.
@@ -933,11 +943,15 @@ LogicalResult tosa::SliceOp::verify() {
if (!inputType)
return success();
- if (static_cast<size_t>(inputType.getRank()) != getStart().size())
+ auto StartShapeRank =
+ llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
+ if (inputType.getRank() != StartShapeRank)
return emitOpError(
"length of start attribute is not equal rank of input shape");
- if (static_cast<size_t>(inputType.getRank()) != getSize().size())
+ auto SizeShapeRank =
+ llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
+ if (inputType.getRank() != SizeShapeRank)
return emitOpError(
"length of size attribute is not equal rank of input shape");
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 1b97f0b245d9ba..807f9cd683bb8c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -302,8 +302,8 @@ class TransposeConvStridedConverter
auto slice = CreateOpAndInferShape<tosa::SliceOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
- rewriter.getDenseI64ArrayAttr(sliceBegin),
- rewriter.getDenseI64ArrayAttr(sliceSize))
+ getTosaConstShape(rewriter, loc, sliceBegin),
+ getTosaConstShape(rewriter, loc, sliceSize))
.getResult();
llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir
index 36eb4d4669b07a..a72d6b333f7ea4 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor-invalid.mlir
@@ -2,7 +2,9 @@
// CHECK-LABEL: @slice_resultType_unranked
func.func @slice_resultType_unranked(%arg0: tensor<?xf32>) -> (tensor<*xf32>) {
+ %0 = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %1 = tosa.const_shape {value = dense<0> : tensor<1xindex>} : () -> !tosa.shape<1>
// expected-error at +1 {{failed to legalize operation 'tosa.slice'}}
- %0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: 0>} : (tensor<?xf32>) -> (tensor<*xf32>)
- return %0 : tensor<*xf32>
+ %2 = tosa.slice %arg0, %0, %1 : (tensor<?xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<*xf32>
+ return %2 : tensor<*xf32>
}
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index 2f11b31aad2307..dd97a872a07b46 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -437,7 +437,9 @@ func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3x
// CHECK-LABEL: func @slice
func.func @slice(%arg0: tensor<6xf32>) ->() {
// CHECK: [[SLICE:%.+]] = tensor.extract_slice %arg0[2] [1] [1]
- %0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: 1>} : (tensor<6xf32>) -> (tensor<1xf32>)
+ %0 = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %1 = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %2 = tosa.slice %arg0, %0, %1 : (tensor<6xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<1xf32>
return
}
@@ -450,8 +452,10 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
// CHECK: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[SUB:.+]] = arith.subi %[[DIM]], %[[C2]]
// CHECK: tensor.extract_slice %arg0[2] [%[[SUB]]] [1]
- %0 = "tosa.slice"(%arg0) {start = array<i64: 2>, size = array<i64: -1>} : (tensor<?xf32>) -> (tensor<?xf32>)
- return %0 : tensor<?xf32>
+ %0 = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %1 = tosa.const_shape {value = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %2 = tosa.slice %arg0, %0, %1 : (tensor<?xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<?xf32>
+ return %2 : tensor<?xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e394188e9a9311..4a8694c4713e4e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -572,18 +572,22 @@ func.func @transpose_canonicalize_strip_quant() -> (tensor<2x1x3x!quant.uniform<
// CHECK-LABEL: @slice_fold
func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
+ %0 = tosa.const_shape {value = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: return %arg0
- %0 = tosa.slice %arg0 { size = array<i64: 3, 4>, start = array<i64: 0, 0>}: (tensor<3x4xf32>) -> tensor<3x4xf32>
- return %0 : tensor<3x4xf32>
+ %3 = tosa.slice %arg0, %0, %1 : (tensor<3x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x4xf32>
+ return %3 : tensor<3x4xf32>
}
// -----
// CHECK-LABEL: @slice_nofold
func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
+ %0 = tosa.const_shape {value = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: tosa.slice
- %0 = tosa.slice %arg0 { size = array<i64: 3, 4>, start = array<i64: 0, 0>}: (tensor<?x4xf32>) -> tensor<?x4xf32>
- return %0 : tensor<?x4xf32>
+ %3 = tosa.slice %arg0, %0, %1 : (tensor<?x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x4xf32>
+ return %3 : tensor<?x4xf32>
}
// -----
@@ -663,9 +667,12 @@ func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x
// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %arg1 : tensor<1x12x12x1xf32>) -> (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) {
%0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) -> tensor<1x12x12x2xf32>
- %1 = tosa.slice %0 {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 0>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
- %2 = tosa.slice %0 {size = array<i64: 1, 12, 12, 1>, start = array<i64: 0, 0, 0, 1>} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32>
- return %1, %2 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
+ %1 = tosa.const_shape {value = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %2 = tosa.const_shape {value = dense<[0, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %3 = tosa.const_shape {value = dense<[1, 12, 12, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %4 = tosa.slice %0, %1, %3 : (tensor<1x12x12x2xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x12x12x1xf32>
+ %5 = tosa.slice %0, %2, %3 : (tensor<1x12x12x2xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x12x12x1xf32>
+ return %4, %5 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
}
// -----
@@ -675,38 +682,56 @@ func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %
// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12xf32>, tensor<1x12x12xf32>
func.func @canonicalize_concat_slice_middle_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x12xf32>, tensor<1x12x12xf32>) {
%0 = tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x24x12xf32>
- %1 = tosa.slice %0 {size = array<i64: 1, 12, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
- %2 = tosa.slice %0 {size = array<i64: 1, 12, 12>, start = array<i64: 0, 12, 0>} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32>
- return %1, %2 : tensor<1x12x12xf32>, tensor<1x12x12xf32>
+ %1 = tosa.const_shape {value = dense<[0, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %2 = tosa.const_shape {value = dense<[0, 12, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %3 = tosa.const_shape {value = dense<[1, 12, 12]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %4 = tosa.slice %0, %1, %3 : (tensor<1x24x12xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x12xf32>
+ %5 = tosa.slice %0, %2, %3 : (tensor<1x24x12xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x12xf32>
+ return %4, %5 : tensor<1x12x12xf32>, tensor<1x12x12xf32>
}
// -----
// CHECK-LABEL: @canonicalize_cross_concat_inputs
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
-// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
-// CHECK: %[[VAL_3:.*]] = tosa.slice %[[VAL_2]] {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
-// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_2]] {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
-// CHECK: return %[[VAL_3]], %[[VAL_4]] : tensor<1x12x15xf32>, tensor<1x12x20xf32>
+// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[1, 12, 20]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 12, 15]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[0, 0, 4]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<0> : tensor<3xindex>}
+// CHECK: %[[VAL_6:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
+// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_6]], %[[VAL_5]], %[[VAL_3]]
+// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_6]], %[[VAL_4]], %[[VAL_2]]
+// CHECK: return %[[VAL_7]], %[[VAL_8]] : tensor<1x12x15xf32>, tensor<1x12x20xf32>
func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x15xf32>, tensor<1x12x20xf32>) {
%0 = tosa.concat %arg0, %arg1 {axis = 2 : i32} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
- %1 = tosa.slice %0 {size = array<i64: 1, 12, 15>, start = array<i64: 0, 0, 0>} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32>
- %2 = tosa.slice %0 {size = array<i64: 1, 12, 20>, start = array<i64: 0, 0, 4>} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32>
- return %1, %2 : tensor<1x12x15xf32>, tensor<1x12x20xf32>
+ %1 = tosa.const_shape {value = dense<[0, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %2 = tosa.const_shape {value = dense<[0, 0, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %3 = tosa.const_shape {value = dense<[1, 12, 15]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %4 = tosa.const_shape {value = dense<[1, 12, 20]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %5 = tosa.slice %0, %1, %3 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x15xf32>
+ %6 = tosa.slice %0, %2, %4 : (tensor<1x12x24xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x12x20xf32>
+ return %5, %6 : tensor<1x12x15xf32>, tensor<1x12x20xf32>
}
// -----
// CHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
-// CHECK: %[[VAL_2:.*]] = tosa.slice %[[VAL_0]] {size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>} : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32>
-// CHECK: %[[VAL_3:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 0>} : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
-// CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
+// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[1, 3, 0]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 3, 12]> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<0> : tensor<3xindex>}
+// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[1, 6, 12]> : tensor<3xindex>}
+// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_0]], %[[VAL_4]], %[[VAL_5]]
+// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]
+// CHECK: return %[[VAL_6]], %[[VAL_7]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) {
%0 = tosa.concat %arg0, %arg1 {axis = 2 : i32} : (tensor<1x12x12xf3...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/124209
More information about the Mlir-commits
mailing list