[Mlir-commits] [mlir] f100bed - [mlir][linalg] Insert a cast for identity linalg.generics when the types don't match
Benjamin Kramer
llvmlistbot at llvm.org
Tue Jan 18 14:44:38 PST 2022
Author: Benjamin Kramer
Date: 2022-01-18T23:44:14+01:00
New Revision: f100bedb036276f6a1c73b6467fe7585c1b20292
URL: https://github.com/llvm/llvm-project/commit/f100bedb036276f6a1c73b6467fe7585c1b20292
DIFF: https://github.com/llvm/llvm-project/commit/f100bedb036276f6a1c73b6467fe7585c1b20292.diff
LOG: [mlir][linalg] Insert a cast for identity linalg.generics when the types don't match
This can happen when the result has different dynamic dimensions than
the input.
Differential Revision: https://reviews.llvm.org/D117498
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a74ee5a9c459..00a120fb4518 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -857,12 +857,19 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
// Get the argument number of the returned values. That is the operand
// number to use for replacing uses of this operation.
SmallVector<Value> returnedArgs;
- for (Value yieldVal : yieldOp.values()) {
- auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
+ for (const auto &yieldVal : llvm::enumerate(yieldOp.values())) {
+ auto yieldArg = yieldVal.value().dyn_cast<BlockArgument>();
if (!yieldArg || yieldArg.getOwner() != &body)
return failure();
unsigned argumentNumber = yieldArg.getArgNumber();
- returnedArgs.push_back(genericOp->getOperand(argumentNumber));
+ Value returnedArg = genericOp->getOperand(argumentNumber);
+ Type resultType = genericOp->getResult(yieldVal.index()).getType();
+ // The input can have a
diff erent type than the result, e.g. a dynamic
+ // input dimension can be turned into a static output dimension.
+ if (returnedArg.getType() != resultType)
+ returnedArg = rewriter.create<tensor::CastOp>(genericOp.getLoc(),
+ resultType, returnedArg);
+ returnedArgs.push_back(returnedArg);
}
if (returnedArgs.size() != genericOp->getNumResults())
return failure();
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index a6913d6f06e2..0f1853dc324a 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -179,6 +179,27 @@ func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>)
// -----
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @remove_no_op_mismatched_types(%arg0 : tensor<?x?x?xf32>)
+ -> tensor<1x2x3xf32> {
+ %out = linalg.init_tensor [1, 2, 3] : tensor<1x2x3xf32>
+ %g = linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ } ins(%arg0 : tensor<?x?x?xf32>)
+ outs(%out : tensor<1x2x3xf32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32):
+ linalg.yield %arg2 : f32
+ } -> (tensor<1x2x3xf32>)
+ return %g : tensor<1x2x3xf32>
+}
+// CHECK-LABEL: func @remove_no_op_mismatched_types
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<1x2x3xf32>
+// CHECK: return %[[CAST]]
+
+// -----
+
#map = affine_map<(d0, d1) -> (d0, d1)>
func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list