[Mlir-commits] [mlir] f100bed - [mlir][linalg] Insert a cast for identity linalg.generics when the types don't match

Benjamin Kramer llvmlistbot at llvm.org
Tue Jan 18 14:44:38 PST 2022


Author: Benjamin Kramer
Date: 2022-01-18T23:44:14+01:00
New Revision: f100bedb036276f6a1c73b6467fe7585c1b20292

URL: https://github.com/llvm/llvm-project/commit/f100bedb036276f6a1c73b6467fe7585c1b20292
DIFF: https://github.com/llvm/llvm-project/commit/f100bedb036276f6a1c73b6467fe7585c1b20292.diff

LOG: [mlir][linalg] Insert a cast for identity linalg.generics when the types don't match

This can happen when the result has different dynamic dimensions than
the input.

Differential Revision: https://reviews.llvm.org/D117498

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a74ee5a9c459..00a120fb4518 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -857,12 +857,19 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
     // Get the argument number of the returned values. That is the operand
     // number to use for replacing uses of this operation.
     SmallVector<Value> returnedArgs;
-    for (Value yieldVal : yieldOp.values()) {
-      auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
+    for (const auto &yieldVal : llvm::enumerate(yieldOp.values())) {
+      auto yieldArg = yieldVal.value().dyn_cast<BlockArgument>();
       if (!yieldArg || yieldArg.getOwner() != &body)
         return failure();
       unsigned argumentNumber = yieldArg.getArgNumber();
-      returnedArgs.push_back(genericOp->getOperand(argumentNumber));
+      Value returnedArg = genericOp->getOperand(argumentNumber);
+      Type resultType = genericOp->getResult(yieldVal.index()).getType();
+      // The input can have a 
diff erent type than the result, e.g. a dynamic
+      // input dimension can be turned into a static output dimension.
+      if (returnedArg.getType() != resultType)
+        returnedArg = rewriter.create<tensor::CastOp>(genericOp.getLoc(),
+                                                      resultType, returnedArg);
+      returnedArgs.push_back(returnedArg);
     }
     if (returnedArgs.size() != genericOp->getNumResults())
       return failure();

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index a6913d6f06e2..0f1853dc324a 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -179,6 +179,27 @@ func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>)
 
 // -----
 
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @remove_no_op_mismatched_types(%arg0 : tensor<?x?x?xf32>)
+  -> tensor<1x2x3xf32> {
+  %out = linalg.init_tensor [1, 2, 3] : tensor<1x2x3xf32>
+  %g = linalg.generic {
+    indexing_maps = [#map, #map],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } ins(%arg0 : tensor<?x?x?xf32>)
+    outs(%out : tensor<1x2x3xf32>) {
+  ^bb0(%arg2 : f32, %arg3 : f32):
+    linalg.yield %arg2 : f32
+  } -> (tensor<1x2x3xf32>)
+  return %g : tensor<1x2x3xf32>
+}
+// CHECK-LABEL: func @remove_no_op_mismatched_types
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+//       CHECK:     %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<1x2x3xf32>
+//       CHECK:     return %[[CAST]]
+
+// -----
+
 #map = affine_map<(d0, d1) -> (d0, d1)>
 func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list