[Mlir-commits] [mlir] [mlir][sparse] avoid incompatible linalg fuse-into-consumer (PR #86752)
Aart Bik
llvmlistbot at llvm.org
Tue Mar 26 17:04:16 PDT 2024
https://github.com/aartbik created https://github.com/llvm/llvm-project/pull/86752
This fixes an "infinite" loop bug, where the incoming IR was repeatedly rewritten while adding identical cast operations. The test for compatible types should include the notion of an encoding. If it differs, then a
naive fusion into the consumer is invalid.
>From 4e8f4eccea1f7d9eb2c18c1d96b3adace7bffb9d Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 26 Mar 2024 17:02:57 -0700
Subject: [PATCH] [mlir][sparse] avoid incompatible linalg fuse-into-consumer
This fixes an "infinite" loop bug, where the incoming
IR was repeatedly rewritten while adding identical cast
operations. The test for compatible types should include
the notion of an encoding. If it differs, then a
naive fusion into the consumer is invalid.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 ++
.../SparseTensor/no_fold_into_consumer.mlir | 47 +++++++++++++++++++
2 files changed, 51 insertions(+)
create mode 100644 mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index dc8843aa4e1e13..38a9ad60bb7948 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -276,6 +276,10 @@ bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
if (sourceType.getRank() != targetType.getRank())
return false;
+ // Requires same encoding.
+ if (sourceType.getEncoding() != targetType.getEncoding())
+ return false;
+
// If cast is towards more static sizes along any dimension, don't fold.
for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
if (!ShapedType::isDynamic(std::get<0>(t)) &&
diff --git a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
new file mode 100644
index 00000000000000..bbc7f397e793fe
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt %s --canonicalize --pre-sparsification-rewrite | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+#sparse = #sparse_tensor.encoding<{
+ map = (d0, d1, d2) ->
+ (d0 : compressed(nonunique),
+ d1 : singleton(nonunique, soa),
+ d2 : singleton(soa)),
+ posWidth = 64,
+ crdWidth = 64
+}>
+
+
+module {
+ //
+ // This IR should not end up in an infinite loop trying to fold
+ // the linalg producer into the tensor cast consumer (even though
+ // static sizes can fold, the different encodings cannot). The
+ // cast was sloppy to begin with (but it has been observed by
+ // external sources) and can be easily repaired by the sparsifier.
+ //
+ // CHECK-LABEL: func @avoid_fold
+ // CHECK: arith.constant
+ // CHECK: tensor.empty()
+ // CHECK: linalg.generic
+ // CHECK: sparse_tensor.convert
+ // CHECK: return
+ //
+ func.func @avoid_fold(%0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
+ %1 = tensor.empty() : tensor<10x20x30xf64>
+ %2 = linalg.generic { indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ }
+ ins (%0 : tensor<10x20x30xf64, #sparse>)
+ outs(%1 : tensor<10x20x30xf64>) {
+ ^bb0(%in: f64, %out: f64):
+ %cst = arith.constant 0.000000e+00 : f64
+ %4 = arith.cmpf ugt, %in, %cst : f64
+ %5 = arith.select %4, %in, %cst : f64
+ linalg.yield %5 : f64
+ } -> tensor<10x20x30xf64>
+ %cast = tensor.cast %2 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #sparse>
+ return %cast : tensor<10x20x30xf64, #sparse>
+ }
+}
+
More information about the Mlir-commits
mailing list