[Mlir-commits] [mlir] ca2ac2b - [MLIR][Linalg] Handle Attribute in InitTensorOp

Lorenzo Chelini llvmlistbot at llvm.org
Mon Jan 17 02:43:30 PST 2022


Author: Lorenzo Chelini
Date: 2022-01-17T11:43:19+01:00
New Revision: ca2ac2bb14624257d9bc0a53919dbbc5f447c7fc

URL: https://github.com/llvm/llvm-project/commit/ca2ac2bb14624257d9bc0a53919dbbc5f447c7fc
DIFF: https://github.com/llvm/llvm-project/commit/ca2ac2bb14624257d9bc0a53919dbbc5f447c7fc.diff

LOG: [MLIR][Linalg] Handle Attribute in InitTensorOp

In some cases, the result of an initTensorOp may have an attribute.
However, the Attribute was not passed to `inferResultType`, failing the
verifier. Therefore, propagate the Attribute to `inferResultType`.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D117192

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index a5c7561981927..631f7ca16fbde 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -67,7 +67,8 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
 
     // Infer the shape of the result tensor given the static shapes
     // and element type of the result tensor.
-    static Type inferResultType(ArrayRef<int64_t> staticSizes, Type elementType);
+    static Type inferResultType(ArrayRef<int64_t> staticSizes, Type elementType,
+                                Attribute encoding = {});
 
     // Return true if the size of the tensor is dynamic at `idx`
     bool isDynamicSize(unsigned idx) {

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 190de342f5c7a..2de5e23829b4d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -906,8 +906,8 @@ static LogicalResult verify(InitTensorOp op) {
     return op->emitError("expected ")
            << resultType.getRank() << " sizes values";
 
-  Type expectedType =
-      InitTensorOp::inferResultType(staticSizes, resultType.getElementType());
+  Type expectedType = InitTensorOp::inferResultType(
+      staticSizes, resultType.getElementType(), resultType.getEncoding());
   if (resultType != expectedType) {
     return op.emitError("specified type ")
            << resultType << " does not match the inferred type "
@@ -917,8 +917,8 @@ static LogicalResult verify(InitTensorOp op) {
 }
 
 Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
-                                   Type elementType) {
-  return RankedTensorType::get(staticSizes, elementType);
+                                   Type elementType, Attribute encoding) {
+  return RankedTensorType::get(staticSizes, elementType, encoding);
 }
 
 namespace {

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index f9559cbd5b96b..8e7ddd66832e1 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -435,15 +435,18 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x
 
 // -----
 
+#attr = {"foo"}
 func @init_tensor(%arg0 : index, %arg1 : index)
 {
   %0 = linalg.init_tensor [3, 42] : tensor<3x42xf32>
   %1 = linalg.init_tensor [4, %arg0, %arg1, 5] : tensor<4x?x?x5xf32>
+  %2 = linalg.init_tensor [2, 2] : tensor<2x2xf32, #attr>
   return
 }
 // CHECK-LABEL: func @init_tensor
 //       CHECK:   linalg.init_tensor [3, 42] : tensor<3x42xf32>
 //       CHECK:   linalg.init_tensor [4, %{{.*}}, %{{.*}}, 5] : tensor<4x?x?x5xf32>
+//       CHECK:   linalg.init_tensor [2, 2] : tensor<2x2xf32, {foo}>
 
 // -----
 


        


More information about the Mlir-commits mailing list