[Mlir-commits] [mlir] [mlir][sparse] avoid incompatible linalg fuse-into-consumer (PR #86752)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 26 17:04:44 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-sparse
Author: Aart Bik (aartbik)
<details>
<summary>Changes</summary>
This fixes an "infinite" loop bug, where the incoming IR was repeatedly rewritten while adding identical cast operations. The test for compatible types should include the notion of an encoding. If it differs, then a
naive fusion into the consumer is invalid.
---
Full diff: https://github.com/llvm/llvm-project/pull/86752.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+4)
- (added) mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir (+47)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index dc8843aa4e1e13..38a9ad60bb7948 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -276,6 +276,10 @@ bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
if (sourceType.getRank() != targetType.getRank())
return false;
+ // Requires same encoding.
+ if (sourceType.getEncoding() != targetType.getEncoding())
+ return false;
+
// If cast is towards more static sizes along any dimension, don't fold.
for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
if (!ShapedType::isDynamic(std::get<0>(t)) &&
diff --git a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
new file mode 100644
index 00000000000000..bbc7f397e793fe
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt %s --canonicalize --pre-sparsification-rewrite | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+#sparse = #sparse_tensor.encoding<{
+ map = (d0, d1, d2) ->
+ (d0 : compressed(nonunique),
+ d1 : singleton(nonunique, soa),
+ d2 : singleton(soa)),
+ posWidth = 64,
+ crdWidth = 64
+}>
+
+
+module {
+ //
+ // This IR should not end up in an infinite loop trying to fold
+ // the linalg producer into the tensor cast consumer (even though
+ // static sizes can fold, the different encodings cannot). The
+ // cast was sloppy to begin with (but it has been observed by
+ // external sources) and can be easily repaired by the sparsifier.
+ //
+ // CHECK-LABEL: func @avoid_fold
+ // CHECK: arith.constant
+ // CHECK: tensor.empty()
+ // CHECK: linalg.generic
+ // CHECK: sparse_tensor.convert
+ // CHECK: return
+ //
+ func.func @avoid_fold(%0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
+ %1 = tensor.empty() : tensor<10x20x30xf64>
+ %2 = linalg.generic { indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ }
+ ins (%0 : tensor<10x20x30xf64, #sparse>)
+ outs(%1 : tensor<10x20x30xf64>) {
+ ^bb0(%in: f64, %out: f64):
+ %cst = arith.constant 0.000000e+00 : f64
+ %4 = arith.cmpf ugt, %in, %cst : f64
+ %5 = arith.select %4, %in, %cst : f64
+ linalg.yield %5 : f64
+ } -> tensor<10x20x30xf64>
+ %cast = tensor.cast %2 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #sparse>
+ return %cast : tensor<10x20x30xf64, #sparse>
+ }
+}
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/86752
More information about the Mlir-commits
mailing list