[Mlir-commits] [mlir] b2505ca - [MLIR] Allow ShapedTypeComponents with attribute in inferReturnTensorTypes

Jacques Pienaar llvmlistbot at llvm.org
Wed Dec 7 17:21:02 PST 2022


Author: smit-hinsu
Date: 2022-12-07T17:20:56-08:00
New Revision: b2505ca2ece3174193f19599a6677b5d0257654a

URL: https://github.com/llvm/llvm-project/commit/b2505ca2ece3174193f19599a6677b5d0257654a
DIFF: https://github.com/llvm/llvm-project/commit/b2505ca2ece3174193f19599a6677b5d0257654a.diff

LOG: [MLIR] Allow ShapedTypeComponents with attribute in inferReturnTensorTypes

Originally, inferReturnTensorTypes didn't support shaped type components
containing an attribute just because there wasn't any motivating use-case.
Removing that limitation and using it to set the encoding attribute for
RankedTensorType.

Updated the existing test to set result attribute based on the first operand,
if available.

Signed-off-by: Smit Hinsu <smittvhinsu at gmail.com>

Differential Revision: https://reviews.llvm.org/D139271

Added: 
    

Modified: 
    mlir/lib/Interfaces/InferTypeOpInterface.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/mlir-tblgen/return-types.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index d703f64910bc3..34200eb0daebf 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -187,15 +187,17 @@ LogicalResult mlir::detail::inferReturnTensorTypes(
                              retComponents)))
     return failure();
   for (const auto &shapeAndType : retComponents) {
-    assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
-    assert(shapeAndType.getElementType() &&
-           "element type required to construct tensor");
-    if (shapeAndType.hasRank())
-      inferredReturnTypes.push_back(RankedTensorType::get(
-          shapeAndType.getDims(), shapeAndType.getElementType()));
-    else
+    Type element_ty = shapeAndType.getElementType();
+    assert(element_ty && "element type required to construct tensor");
+
+    Attribute attr = shapeAndType.getAttribute();
+    if (shapeAndType.hasRank()) {
       inferredReturnTypes.push_back(
-          UnrankedTensorType::get(shapeAndType.getElementType()));
+          RankedTensorType::get(shapeAndType.getDims(), element_ty, attr));
+    } else {
+      assert(attr == nullptr && "attribute not supported");
+      inferredReturnTypes.push_back(UnrankedTensorType::get(element_ty));
+    }
   }
   return success();
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 685be113aec55..5fcfbf26aed26 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1180,7 +1180,11 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
   int64_t dim =
       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
   auto type = IntegerType::get(context, 17);
-  inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
+
+  Attribute encoding;
+  if (auto ranked_ty = sval.dyn_cast<RankedTensorType>())
+    encoding = ranked_ty.getEncoding();
+  inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
   return success();
 }
 

diff  --git a/mlir/test/mlir-tblgen/return-types.mlir b/mlir/test/mlir-tblgen/return-types.mlir
index cf859fe9e8f07..39fb44f27695d 100644
--- a/mlir/test/mlir-tblgen/return-types.mlir
+++ b/mlir/test/mlir-tblgen/return-types.mlir
@@ -3,19 +3,19 @@
 // CHECK-LABEL: testCreateFunctions
 // This function tests invoking the create method with 
diff erent inference
 // methods. The attributes of the ops inside are used to test creation.
-func.func @testCreateFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
+func.func @testCreateFunctions(%arg0 : tensor<10xf32, !test.smpla>, %arg1 : tensor<20xi32>) {
 // CHECK: "test.no_attributes"
-  %good = "test.no_attributes"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  %good = "test.no_attributes"(%arg0, %arg0) : (tensor<10xf32, !test.smpla>, tensor<10xf32, !test.smpla>) -> tensor<10xf32, !test.smpla>
 // CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi17>
+// CHECK-SAME: (tensor<10xf32, !test.smpla>, tensor<10xf32, !test.smpla>) -> tensor<10xi17, !test.smpla>
 // CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10xi17>
+// CHECK-SAME: (tensor<10xf32, !test.smpla>, tensor<20xi32>) -> tensor<10xi17, !test.smpla>
 // CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20xi17>
+// CHECK-SAME: (tensor<20xi32>, tensor<10xf32, !test.smpla>) -> tensor<20xi17>
 // CHECK: "test.op_with_shaped_type_infer_type_if"
 // CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20xi17>
 // CHECK: "test.op_with_infer_type_if"
-// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+// CHECK-SAME: (tensor<10xf32, !test.smpla>, tensor<10xf32, !test.smpla>) -> tensor<10xf32, !test.smpla>
 // CHECK: "test.op_with_infer_type_if"
 // CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20xi32>
   return


        


More information about the Mlir-commits mailing list