[Mlir-commits] [mlir] [mlir] Add ReifyRankedShapedTypeOpInterface to tosa::TransposeOp (PR #88890)
Maya Amrami
llvmlistbot at llvm.org
Tue Apr 16 06:10:28 PDT 2024
https://github.com/amrami created https://github.com/llvm/llvm-project/pull/88890
None
>From 657873c655e05862cb412e83da13b2b9279bfd99 Mon Sep 17 00:00:00 2001
From: Maya Amrami <mayaam88 at gmail.com>
Date: Sun, 14 Apr 2024 11:41:29 +0300
Subject: [PATCH] [mlir] Add ReifyRankedShapedTypeOpInterface to
tosa::TransposeOp
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 17 +++++------
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 26 +++++++++++++++++
mlir/test/Dialect/MemRef/resolve-dim-ops.mlir | 28 +++++++++++++++++++
3 files changed, 63 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 306e4a43952088..231a2aca079c6e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1501,7 +1501,7 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
let hasFolder = 1;
let hasVerifier = 1;
-
+
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
/// Method used by InferTypeOpInterface.
@@ -1651,7 +1651,7 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
let hasFolder = 1;
let hasVerifier = 1;
-
+
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
@@ -1707,7 +1707,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
//===----------------------------------------------------------------------===//
// Operator: transpose
//===----------------------------------------------------------------------===//
-def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose"> {
+def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let summary = "Transpose operator";
let description = [{
@@ -1834,9 +1835,9 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
| Mode | Input | Output |
|--------------------------|---------|---------|
- | signed 8 to bool | int8 | Boolean |
- | signed 16 to bool | int16 | Boolean |
- | signed 32 to bool | int32 | Boolean |
+ | signed 8 to bool | int8 | Boolean |
+ | signed 16 to bool | int16 | Boolean |
+ | signed 32 to bool | int32 | Boolean |
| bool to 8 | Boolean | int8 |
| bool to 16 | Boolean | int16 |
| bool to 32 | Boolean | int32 |
@@ -1850,8 +1851,8 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
| float to signed 16 | float | int16 |
| signed 8 to float | int8 | float |
| signed 16 to float | int16 | float |
- | float 32 to float 64 | float32 | float64 |
- | float 64 to float 32 | float64 | float32 |
+ | float 32 to float 64 | float32 | float64 |
+ | float 64 to float 32 | float64 | float32 |
}];
let arguments = (ins
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e06ac9a27ae4cc..d93e262ac596b6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1119,6 +1119,32 @@ LogicalResult tosa::TransposeOp::verify() {
return success();
}
+LogicalResult TransposeOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+
+ SmallVector<int64_t> transposePerms;
+ if (getConstantPerms(transposePerms).failed())
+ return failure();
+
+ Value input = getInput1();
+ auto inputType = input.getType().cast<TensorType>();
+
+ SmallVector<OpFoldResult> returnedDims(inputType.getRank());
+ for (auto dim : transposePerms) {
+ int64_t dimInInput = transposePerms[dim];
+ if (inputType.isDynamicDim(dimInInput))
+ returnedDims[dim] =
+ builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
+ .getResult();
+ else
+ returnedDims[dim] =
+ builder.getIndexAttr(inputType.getDimSize(dimInInput));
+ }
+
+ reifiedReturnShapes.emplace_back(std::move(returnedDims));
+ return success();
+}
+
LogicalResult tosa::GatherOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
GatherOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 18e9a9d02e1081..40f88de01b8bd7 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -25,3 +25,31 @@ func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
%0 = tensor.dim %alloc, %idx : tensor<?x?xf32>
return %0 : index
}
+
+// -----
+
+// CHECK-LABEL: func.func @dynamic_dim_of_transpose_op(
+// CHECK-SAME: %[[arg:.*]]: tensor<1x2x?x8xi8>) -> index {
+// CHECK-NEXT: %[[c2:.*]] = arith.constant 2
+// CHECK-NEXT: tensor.dim %[[arg]], %[[c2]]
+// CHECK-NEXT: return
+func.func @dynamic_dim_of_transpose_op(%arg0: tensor<1x2x?x8xi8>) -> index {
+ %0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<1x2x?x8xi8>, tensor<4xi32>) -> tensor<1x8x2x?xi8>
+ %c3 = arith.constant 3 : index
+ %dim = tensor.dim %1, %c3 : tensor<1x8x2x?xi8>
+ return %dim : index
+}
+
+// -----
+
+// CHECK-LABEL: func.func @static_dim_of_transpose_op(
+// CHECK: arith.constant 100 : index
+// CHECK: return
+func.func @static_dim_of_transpose_op(%arg0: tensor<1x100x?x8xi8>) -> index {
+ %0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<1x100x?x8xi8>, tensor<4xi32>) -> tensor<1x8x100x?xi8>
+ %c2 = arith.constant 2 : index
+ %dim = tensor.dim %1, %c2 : tensor<1x8x100x?xi8>
+ return %dim : index
+}
More information about the Mlir-commits
mailing list