[Mlir-commits] [mlir] 35dab90 - [linalg] When removing noop linalg.generics, check that inserting a cast is valid

Benjamin Kramer llvmlistbot at llvm.org
Tue Mar 29 14:06:19 PDT 2022


Author: Benjamin Kramer
Date: 2022-03-29T23:05:54+02:00
New Revision: 35dab904c09b58f061c303b40b394c909ba84db6

URL: https://github.com/llvm/llvm-project/commit/35dab904c09b58f061c303b40b394c909ba84db6
DIFF: https://github.com/llvm/llvm-project/commit/35dab904c09b58f061c303b40b394c909ba84db6.diff

LOG: [linalg] When removing noop linalg.generics, check that inserting a cast is valid

linalg.generic can also take scalars instead of tensors, which
tensor.cast doesn't support. We don't have an easy way to cast between
scalars and tensors so just keep the linalg.generic in those cases.

Differential Revision: https://reviews.llvm.org/D122575

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d62192d38edba..c72b52b4e1f08 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -836,9 +836,13 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
             sparse_tensor::getSparseTensorEncoding(resultType))
           returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
               genericOp.getLoc(), resultType, returnedArg);
-        else
+        else {
+          if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
+                                                 resultType))
+            return failure();
           returnedArg = rewriter.create<tensor::CastOp>(
               genericOp.getLoc(), resultType, returnedArg);
+        }
       }
       returnedArgs.push_back(returnedArg);
     }

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index eee6ebc907563..56ce26778c971 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -175,6 +175,24 @@ func @remove_no_op_mismatched_types(%arg0 : tensor<?x?x?xf32>)
 
 // -----
 
+#map = affine_map<() -> ()>
+func @cant_fold_to_tensor_cast(%arg0 : f32) -> tensor<f32> {
+  %out = linalg.init_tensor [] : tensor<f32>
+  %g = linalg.generic {
+    indexing_maps = [#map, #map],
+    iterator_types = []
+  } ins(%arg0 : f32)
+    outs(%out : tensor<f32>) {
+  ^bb0(%arg2 : f32, %arg3 : f32):
+    linalg.yield %arg2 : f32
+  } -> (tensor<f32>)
+  return %g : tensor<f32>
+}
+// CHECK-LABEL: func @cant_fold_to_tensor_cast
+//       CHECK:     linalg.generic
+
+// -----
+
 #map = affine_map<(d0, d1) -> (d0, d1)>
 func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list