[Mlir-commits] [mlir] 35dab90 - [linalg] When removing noop linalg.generics, check that inserting a cast is valid
Benjamin Kramer
llvmlistbot at llvm.org
Tue Mar 29 14:06:19 PDT 2022
Author: Benjamin Kramer
Date: 2022-03-29T23:05:54+02:00
New Revision: 35dab904c09b58f061c303b40b394c909ba84db6
URL: https://github.com/llvm/llvm-project/commit/35dab904c09b58f061c303b40b394c909ba84db6
DIFF: https://github.com/llvm/llvm-project/commit/35dab904c09b58f061c303b40b394c909ba84db6.diff
LOG: [linalg] When removing noop linalg.generics, check that inserting a cast is valid
linalg.generic can also take scalars instead of tensors, which
tensor.cast doesn't support. We don't have an easy way to cast between
scalars and tensors so just keep the linalg.generic in those cases.
Differential Revision: https://reviews.llvm.org/D122575
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d62192d38edba..c72b52b4e1f08 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -836,9 +836,13 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
sparse_tensor::getSparseTensorEncoding(resultType))
returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
genericOp.getLoc(), resultType, returnedArg);
- else
+ else {
+ if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
+ resultType))
+ return failure();
returnedArg = rewriter.create<tensor::CastOp>(
genericOp.getLoc(), resultType, returnedArg);
+ }
}
returnedArgs.push_back(returnedArg);
}
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index eee6ebc907563..56ce26778c971 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -175,6 +175,24 @@ func @remove_no_op_mismatched_types(%arg0 : tensor<?x?x?xf32>)
// -----
+#map = affine_map<() -> ()>
+func @cant_fold_to_tensor_cast(%arg0 : f32) -> tensor<f32> {
+ %out = linalg.init_tensor [] : tensor<f32>
+ %g = linalg.generic {
+ indexing_maps = [#map, #map],
+ iterator_types = []
+ } ins(%arg0 : f32)
+ outs(%out : tensor<f32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32):
+ linalg.yield %arg2 : f32
+ } -> (tensor<f32>)
+ return %g : tensor<f32>
+}
+// CHECK-LABEL: func @cant_fold_to_tensor_cast
+// CHECK: linalg.generic
+
+// -----
+
#map = affine_map<(d0, d1) -> (d0, d1)>
func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list