[Mlir-commits] [mlir] 5cd074f - [mlir] Add ReifyRankedShapedTypeOpInterface to tosa::TransposeOp (#88890)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 30 04:49:02 PDT 2024


Author: Maya Amrami
Date: 2024-04-30T14:48:58+03:00
New Revision: 5cd074fa57c2a22312f479a9529c0eac10013043

URL: https://github.com/llvm/llvm-project/commit/5cd074fa57c2a22312f479a9529c0eac10013043
DIFF: https://github.com/llvm/llvm-project/commit/5cd074fa57c2a22312f479a9529c0eac10013043.diff

LOG: [mlir] Add ReifyRankedShapedTypeOpInterface to tosa::TransposeOp (#88890)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/MemRef/resolve-dim-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index dde17e2dc8924d..138deb1e773b30 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1502,7 +1502,7 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
 
   let hasFolder = 1;
   let hasVerifier = 1;
-  
+
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
     /// Method used by InferTypeOpInterface.
@@ -1652,7 +1652,7 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
 
   let hasFolder = 1;
   let hasVerifier = 1;
-  
+
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
@@ -1708,7 +1708,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
 //===----------------------------------------------------------------------===//
 // Operator: transpose
 //===----------------------------------------------------------------------===//
-def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose"> {
+def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
+                [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
   let summary = "Transpose operator";
 
   let description = [{
@@ -1835,9 +1836,9 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
 
     | Mode                     | Input   | Output  |
     |--------------------------|---------|---------|
-    | signed 8 to bool         | int8    | Boolean | 
-    | signed 16 to bool        | int16   | Boolean | 
-    | signed 32 to bool        | int32   | Boolean | 
+    | signed 8 to bool         | int8    | Boolean |
+    | signed 16 to bool        | int16   | Boolean |
+    | signed 32 to bool        | int32   | Boolean |
     | bool to 8                | Boolean | int8    |
     | bool to 16               | Boolean | int16   |
     | bool to 32               | Boolean | int32   |
@@ -1851,8 +1852,8 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
     | float to signed 16       | float   | int16   |
     | signed 8 to float        | int8    | float   |
     | signed 16 to float       | int16   | float   |
-    | float 32 to float 64     | float32 | float64 | 
-    | float 64 to float 32     | float64 | float32 | 
+    | float 32 to float 64     | float32 | float64 |
+    | float 64 to float 32     | float64 | float32 |
   }];
 
   let arguments = (ins

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 10e6016a1ed431..99b0db14c1427c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1119,6 +1119,32 @@ LogicalResult tosa::TransposeOp::verify() {
   return success();
 }
 
+LogicalResult TransposeOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+
+  SmallVector<int64_t> transposePerms;
+  if (getConstantPerms(transposePerms).failed())
+    return failure();
+
+  Value input = getInput1();
+  auto inputType = input.getType().cast<TensorType>();
+
+  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
+  for (auto dim : transposePerms) {
+    int64_t dimInInput = transposePerms[dim];
+    if (inputType.isDynamicDim(dimInInput))
+      returnedDims[dim] =
+          builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
+              .getResult();
+    else
+      returnedDims[dim] =
+          builder.getIndexAttr(inputType.getDimSize(dimInInput));
+  }
+
+  reifiedReturnShapes.emplace_back(std::move(returnedDims));
+  return success();
+}
+
 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     GatherOp::Adaptor adaptor,

diff  --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 18e9a9d02e1081..40f88de01b8bd7 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -25,3 +25,31 @@ func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
   %0 = tensor.dim %alloc, %idx : tensor<?x?xf32>
   return %0 : index
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @dynamic_dim_of_transpose_op(
+//  CHECK-SAME:                                   %[[arg:.*]]: tensor<1x2x?x8xi8>) -> index {
+//  CHECK-NEXT:           %[[c2:.*]] = arith.constant 2
+//  CHECK-NEXT:           tensor.dim %[[arg]], %[[c2]]
+//  CHECK-NEXT:           return
+func.func @dynamic_dim_of_transpose_op(%arg0: tensor<1x2x?x8xi8>) -> index {
+  %0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+  %1 = tosa.transpose %arg0, %0 : (tensor<1x2x?x8xi8>, tensor<4xi32>) -> tensor<1x8x2x?xi8>
+  %c3 = arith.constant 3 : index
+  %dim = tensor.dim %1, %c3 : tensor<1x8x2x?xi8>
+  return %dim : index
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @static_dim_of_transpose_op(
+//  CHECK:           arith.constant 100 : index
+//  CHECK:           return
+func.func @static_dim_of_transpose_op(%arg0: tensor<1x100x?x8xi8>) -> index {
+  %0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+  %1 = tosa.transpose %arg0, %0 : (tensor<1x100x?x8xi8>, tensor<4xi32>) -> tensor<1x8x100x?xi8>
+  %c2 = arith.constant 2 : index
+  %dim = tensor.dim %1, %c2 : tensor<1x8x100x?xi8>
+  return %dim : index
+}


        


More information about the Mlir-commits mailing list