[Mlir-commits] [mlir] [mlir] Add ReifyRankedShapedTypeOpInterface to tosa::TransposeOp (PR #88890)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 16 06:10:59 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Maya Amrami (amrami)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/88890.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+9-8) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+26) 
- (modified) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+28) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 306e4a43952088..231a2aca079c6e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1501,7 +1501,7 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
 
   let hasFolder = 1;
   let hasVerifier = 1;
-  
+
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
     /// Method used by InferTypeOpInterface.
@@ -1651,7 +1651,7 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
 
   let hasFolder = 1;
   let hasVerifier = 1;
-  
+
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
@@ -1707,7 +1707,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
 //===----------------------------------------------------------------------===//
 // Operator: transpose
 //===----------------------------------------------------------------------===//
-def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose"> {
+def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
+                [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
   let summary = "Transpose operator";
 
   let description = [{
@@ -1834,9 +1835,9 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
 
     | Mode                     | Input   | Output  |
     |--------------------------|---------|---------|
-    | signed 8 to bool         | int8    | Boolean | 
-    | signed 16 to bool        | int16   | Boolean | 
-    | signed 32 to bool        | int32   | Boolean | 
+    | signed 8 to bool         | int8    | Boolean |
+    | signed 16 to bool        | int16   | Boolean |
+    | signed 32 to bool        | int32   | Boolean |
     | bool to 8                | Boolean | int8    |
     | bool to 16               | Boolean | int16   |
     | bool to 32               | Boolean | int32   |
@@ -1850,8 +1851,8 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
     | float to signed 16       | float   | int16   |
     | signed 8 to float        | int8    | float   |
     | signed 16 to float       | int16   | float   |
-    | float 32 to float 64     | float32 | float64 | 
-    | float 64 to float 32     | float64 | float32 | 
+    | float 32 to float 64     | float32 | float64 |
+    | float 64 to float 32     | float64 | float32 |
   }];
 
   let arguments = (ins
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e06ac9a27ae4cc..d93e262ac596b6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1119,6 +1119,32 @@ LogicalResult tosa::TransposeOp::verify() {
   return success();
 }
 
+LogicalResult TransposeOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+
+  SmallVector<int64_t> transposePerms;
+  if (getConstantPerms(transposePerms).failed())
+    return failure();
+
+  Value input = getInput1();
+  auto inputType = input.getType().cast<TensorType>();
+
+  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
+  for (auto dim : transposePerms) {
+    int64_t dimInInput = transposePerms[dim];
+    if (inputType.isDynamicDim(dimInInput))
+      returnedDims[dim] =
+          builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
+              .getResult();
+    else
+      returnedDims[dim] =
+          builder.getIndexAttr(inputType.getDimSize(dimInInput));
+  }
+
+  reifiedReturnShapes.emplace_back(std::move(returnedDims));
+  return success();
+}
+
 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     GatherOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 18e9a9d02e1081..40f88de01b8bd7 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -25,3 +25,31 @@ func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
   %0 = tensor.dim %alloc, %idx : tensor<?x?xf32>
   return %0 : index
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @dynamic_dim_of_transpose_op(
+//  CHECK-SAME:                                   %[[arg:.*]]: tensor<1x2x?x8xi8>) -> index {
+//  CHECK-NEXT:           %[[c2:.*]] = arith.constant 2
+//  CHECK-NEXT:           tensor.dim %[[arg]], %[[c2]]
+//  CHECK-NEXT:           return
+func.func @dynamic_dim_of_transpose_op(%arg0: tensor<1x2x?x8xi8>) -> index {
+  %0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+  %1 = tosa.transpose %arg0, %0 : (tensor<1x2x?x8xi8>, tensor<4xi32>) -> tensor<1x8x2x?xi8>
+  %c3 = arith.constant 3 : index
+  %dim = tensor.dim %1, %c3 : tensor<1x8x2x?xi8>
+  return %dim : index
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @static_dim_of_transpose_op(
+//  CHECK:           arith.constant 100 : index
+//  CHECK:           return
+func.func @static_dim_of_transpose_op(%arg0: tensor<1x100x?x8xi8>) -> index {
+  %0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+  %1 = tosa.transpose %arg0, %0 : (tensor<1x100x?x8xi8>, tensor<4xi32>) -> tensor<1x8x100x?xi8>
+  %c2 = arith.constant 2 : index
+  %dim = tensor.dim %1, %c2 : tensor<1x8x100x?xi8>
+  return %dim : index
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/88890


More information about the Mlir-commits mailing list