[Mlir-commits] [mlir] d7bc3b7 - [mlir][Linalg] Add missing check to canonicalization of GenericOp that are identity ops.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 15 13:55:51 PST 2021


Author: MaheshRavishankar
Date: 2021-01-15T13:55:35-08:00
New Revision: d7bc3b7ce23b664d6620cdc32370a8614523ca2f

URL: https://github.com/llvm/llvm-project/commit/d7bc3b7ce23b664d6620cdc32370a8614523ca2f
DIFF: https://github.com/llvm/llvm-project/commit/d7bc3b7ce23b664d6620cdc32370a8614523ca2f.diff

LOG: [mlir][Linalg] Add missing check to canonicalization of GenericOp that are identity ops.

The operantion is an identity if the values yielded by the operation
is the argument of the basic block of that operation. Add this missing check.

Differential Revision: https://reviews.llvm.org/D94819

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 30a6b9c0c371..fa98ed0cfbc9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2276,13 +2276,15 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
     SmallVector<Value, 4> returnedArgs;
     for (Value yieldVal : yieldOp.values()) {
       auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
-      if (!yieldArg)
+      if (!yieldArg || yieldArg.getOwner() != &body)
         return failure();
       unsigned argumentNumber = yieldArg.getArgNumber();
       if (argumentNumber < numIndexArgs)
         return failure();
       returnedArgs.push_back(op->getOperand(argumentNumber - numIndexArgs));
     }
+    if (returnedArgs.size() != genericOp.getOperation()->getNumResults())
+      return failure();
     rewriter.replaceOp(genericOp, returnedArgs);
     return success();
   }

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index ca7f82c1b254..cc00b98d376c 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -615,3 +615,56 @@ func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>)
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 //  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 //       CHECK:     return %[[ARG1]], %[[ARG0]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %cst = constant 1.000000e+00 : f32
+  %0 = dim %arg0, %c0 : tensor<?x?xf32>
+  %1 = dim %arg0, %c1 : tensor<?x?xf32>
+  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+  br ^bb1(%cst : f32)
+
+^bb1(%arg1 : f32):
+  %3 = linalg.generic
+    {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
+    ins(%arg0 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) {
+    ^bb0(%arg2: f32, %arg3 : f32):
+      linalg.yield %arg1 : f32
+    } -> tensor<?x?xf32>
+  return %3 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @keep_not_noop
+//       CHECK:   %[[RESULT:.+]] = linalg.generic
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func @keep_not_noop(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>)
+  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %cst = constant 1.000000e+00 : f32
+  %0 = dim %arg0, %c0 : tensor<?x?xf32>
+  %1 = dim %arg0, %c1 : tensor<?x?xf32>
+  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+  br ^bb1(%cst : f32)
+
+^bb1(%arg2 : f32):
+  %3:2 = linalg.generic
+    {indexing_maps = [#map, #map, #map, #map],
+     iterator_types = ["parallel", "parallel"]}
+    ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4 : f32, %arg5 : f32, %arg6 : f32):
+      linalg.yield %arg2, %arg4 : f32, f32
+    } -> tensor<?x?xf32>, tensor<?x?xf32>
+  return %3#0, %3#1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @keep_not_noop
+//       CHECK:   %[[RESULT:.+]]:2 = linalg.generic
+//       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1


        


More information about the Mlir-commits mailing list