[Mlir-commits] [mlir] ca2ac2b - [MLIR][Linalg] Handle Attribute in InitTensorOp
Lorenzo Chelini
llvmlistbot at llvm.org
Mon Jan 17 02:43:30 PST 2022
Author: Lorenzo Chelini
Date: 2022-01-17T11:43:19+01:00
New Revision: ca2ac2bb14624257d9bc0a53919dbbc5f447c7fc
URL: https://github.com/llvm/llvm-project/commit/ca2ac2bb14624257d9bc0a53919dbbc5f447c7fc
DIFF: https://github.com/llvm/llvm-project/commit/ca2ac2bb14624257d9bc0a53919dbbc5f447c7fc.diff
LOG: [MLIR][Linalg] Handle Attribute in InitTensorOp
In some cases, the result of an initTensorOp may have an attribute.
However, the Attribute was not passed to `inferResultType`, failing the
verifier. Therefore, propagate the Attribute to `inferResultType`.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D117192
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index a5c7561981927..631f7ca16fbde 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -67,7 +67,8 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
// Infer the shape of the result tensor given the static shapes
// and element type of the result tensor.
- static Type inferResultType(ArrayRef<int64_t> staticSizes, Type elementType);
+ static Type inferResultType(ArrayRef<int64_t> staticSizes, Type elementType,
+ Attribute encoding = {});
// Return true if the size of the tensor is dynamic at `idx`
bool isDynamicSize(unsigned idx) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 190de342f5c7a..2de5e23829b4d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -906,8 +906,8 @@ static LogicalResult verify(InitTensorOp op) {
return op->emitError("expected ")
<< resultType.getRank() << " sizes values";
- Type expectedType =
- InitTensorOp::inferResultType(staticSizes, resultType.getElementType());
+ Type expectedType = InitTensorOp::inferResultType(
+ staticSizes, resultType.getElementType(), resultType.getEncoding());
if (resultType != expectedType) {
return op.emitError("specified type ")
<< resultType << " does not match the inferred type "
@@ -917,8 +917,8 @@ static LogicalResult verify(InitTensorOp op) {
}
Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
- Type elementType) {
- return RankedTensorType::get(staticSizes, elementType);
+ Type elementType, Attribute encoding) {
+ return RankedTensorType::get(staticSizes, elementType, encoding);
}
namespace {
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index f9559cbd5b96b..8e7ddd66832e1 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -435,15 +435,18 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x
// -----
+#attr = {"foo"}
func @init_tensor(%arg0 : index, %arg1 : index)
{
%0 = linalg.init_tensor [3, 42] : tensor<3x42xf32>
%1 = linalg.init_tensor [4, %arg0, %arg1, 5] : tensor<4x?x?x5xf32>
+ %2 = linalg.init_tensor [2, 2] : tensor<2x2xf32, #attr>
return
}
// CHECK-LABEL: func @init_tensor
// CHECK: linalg.init_tensor [3, 42] : tensor<3x42xf32>
// CHECK: linalg.init_tensor [4, %{{.*}}, %{{.*}}, 5] : tensor<4x?x?x5xf32>
+// CHECK: linalg.init_tensor [2, 2] : tensor<2x2xf32, {foo}>
// -----
More information about the Mlir-commits
mailing list