[Mlir-commits] [mlir] b2505ca - [MLIR] Allow ShapedTypeComponents with attribute in inferReturnTensorTypes
Jacques Pienaar
llvmlistbot at llvm.org
Wed Dec 7 17:21:02 PST 2022
Author: smit-hinsu
Date: 2022-12-07T17:20:56-08:00
New Revision: b2505ca2ece3174193f19599a6677b5d0257654a
URL: https://github.com/llvm/llvm-project/commit/b2505ca2ece3174193f19599a6677b5d0257654a
DIFF: https://github.com/llvm/llvm-project/commit/b2505ca2ece3174193f19599a6677b5d0257654a.diff
LOG: [MLIR] Allow ShapedTypeComponents with attribute in inferReturnTensorTypes
Originally, inferReturnTensorTypes didn't support shaped type components
containing an attribute just because there wasn't any motivating use-case.
Removing that limitation and using it to set the encoding attribute for
RankedTensorType.
Updated the existing test to set result attribute based on the first operand,
if available.
Signed-off-by: Smit Hinsu <smittvhinsu at gmail.com>
Differential Revision: https://reviews.llvm.org/D139271
Added:
Modified:
mlir/lib/Interfaces/InferTypeOpInterface.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/mlir-tblgen/return-types.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index d703f64910bc3..34200eb0daebf 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -187,15 +187,17 @@ LogicalResult mlir::detail::inferReturnTensorTypes(
retComponents)))
return failure();
for (const auto &shapeAndType : retComponents) {
- assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
- assert(shapeAndType.getElementType() &&
- "element type required to construct tensor");
- if (shapeAndType.hasRank())
- inferredReturnTypes.push_back(RankedTensorType::get(
- shapeAndType.getDims(), shapeAndType.getElementType()));
- else
+ Type element_ty = shapeAndType.getElementType();
+ assert(element_ty && "element type required to construct tensor");
+
+ Attribute attr = shapeAndType.getAttribute();
+ if (shapeAndType.hasRank()) {
inferredReturnTypes.push_back(
- UnrankedTensorType::get(shapeAndType.getElementType()));
+ RankedTensorType::get(shapeAndType.getDims(), element_ty, attr));
+ } else {
+ assert(attr == nullptr && "attribute not supported");
+ inferredReturnTypes.push_back(UnrankedTensorType::get(element_ty));
+ }
}
return success();
}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 685be113aec55..5fcfbf26aed26 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1180,7 +1180,11 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
int64_t dim =
sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
auto type = IntegerType::get(context, 17);
- inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
+
+ Attribute encoding;
+ if (auto ranked_ty = sval.dyn_cast<RankedTensorType>())
+ encoding = ranked_ty.getEncoding();
+ inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
return success();
}
diff --git a/mlir/test/mlir-tblgen/return-types.mlir b/mlir/test/mlir-tblgen/return-types.mlir
index cf859fe9e8f07..39fb44f27695d 100644
--- a/mlir/test/mlir-tblgen/return-types.mlir
+++ b/mlir/test/mlir-tblgen/return-types.mlir
@@ -3,19 +3,19 @@
// CHECK-LABEL: testCreateFunctions
// This function tests invoking the create method with
diff erent inference
// methods. The attributes of the ops inside are used to test creation.
-func.func @testCreateFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
+func.func @testCreateFunctions(%arg0 : tensor<10xf32, !test.smpla>, %arg1 : tensor<20xi32>) {
// CHECK: "test.no_attributes"
- %good = "test.no_attributes"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+ %good = "test.no_attributes"(%arg0, %arg0) : (tensor<10xf32, !test.smpla>, tensor<10xf32, !test.smpla>) -> tensor<10xf32, !test.smpla>
// CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi17>
+// CHECK-SAME: (tensor<10xf32, !test.smpla>, tensor<10xf32, !test.smpla>) -> tensor<10xi17, !test.smpla>
// CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10xi17>
+// CHECK-SAME: (tensor<10xf32, !test.smpla>, tensor<20xi32>) -> tensor<10xi17, !test.smpla>
// CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20xi17>
+// CHECK-SAME: (tensor<20xi32>, tensor<10xf32, !test.smpla>) -> tensor<20xi17>
// CHECK: "test.op_with_shaped_type_infer_type_if"
// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20xi17>
// CHECK: "test.op_with_infer_type_if"
-// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+// CHECK-SAME: (tensor<10xf32, !test.smpla>, tensor<10xf32, !test.smpla>) -> tensor<10xf32, !test.smpla>
// CHECK: "test.op_with_infer_type_if"
// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20xi32>
return
More information about the Mlir-commits
mailing list