[Mlir-commits] [mlir] 92cf9f1 - [mlir][linalg] Cast back to the original type after making linalg.generic outputs more static
Benjamin Kramer
llvmlistbot at llvm.org
Thu Feb 24 04:46:11 PST 2022
Author: Benjamin Kramer
Date: 2022-02-24T13:35:54+01:00
New Revision: 92cf9f14814a5e8308c431095fb2205202445676
URL: https://github.com/llvm/llvm-project/commit/92cf9f14814a5e8308c431095fb2205202445676
DIFF: https://github.com/llvm/llvm-project/commit/92cf9f14814a5e8308c431095fb2205202445676.diff
LOG: [mlir][linalg] Cast back to the original type after making linalg.generic outputs more static
This codepath was entirely untested.
Differential Revision: https://reviews.llvm.org/D120473
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 319a1c318ff8a..a33945a873126 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -991,7 +991,7 @@ struct InferStaticShapeOfOperands : public OpRewritePattern<GenericOp> {
Type oldType = oldResult.getType();
replacements.push_back(
(newType != oldType)
- ? rewriter.create<tensor::CastOp>(loc, newType, newResult)
+ ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
: newResult);
}
rewriter.replaceOp(genericOp, replacements);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 8a3f201f7cc26..c3405887431ff 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -780,3 +780,27 @@ func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
// CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @cast_dest
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<1x?x?xf32>,
+func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2: index, %arg3: index, %arg4: index) -> tensor<?x?x?xf32> {
+ %0 = linalg.init_tensor [%arg2, %arg3, %arg4] : tensor<?x?x?xf32>
+ %1 = tensor.cast %arg1 : tensor<1x?x?xf32> to tensor<?x?x?xf32>
+ %2 = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ } ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<1x?x?xf32>)
+ outs(%0 : tensor<?x?x?xf32>) {
+ ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
+ %3 = arith.subf %arg5, %arg6 : f32
+ linalg.yield %3 : f32
+ } -> tensor<?x?x?xf32>
+ return %2 : tensor<?x?x?xf32>
+// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+// CHECK-SAME: ins(%{{.*}}, %[[ARG1]] : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
+// CHECK-SAME: outs(%{{.*}} : tensor<1x?x?xf32>)
+// CHECK: tensor.cast %[[GENERIC_OP]] : tensor<1x?x?xf32> to tensor<?x?x?xf32>
+}
More information about the Mlir-commits
mailing list