[Mlir-commits] [mlir] a99e06a - [mlir][Linalg] Avoid generating illegal operations during elementwise fusion.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 20 23:44:53 PST 2022
Author: MaheshRavishankar
Date: 2022-01-20T23:43:50-08:00
New Revision: a99e06aa869b44588a18a423f58e0ab30c292d8e
URL: https://github.com/llvm/llvm-project/commit/a99e06aa869b44588a18a423f58e0ab30c292d8e
DIFF: https://github.com/llvm/llvm-project/commit/a99e06aa869b44588a18a423f58e0ab30c292d8e.diff
LOG: [mlir][Linalg] Avoid generating illegal operations during elementwise fusion.
In some cases, fusion can produce illegal operations if after fusion
the range of some of the loops cannot be computed from shapes of its
operands. Check for this case and abort the fusion if this happens.
Differential Revision: https://reviews.llvm.org/D117602
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 33286258543e5..be34ef8bbd625 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -318,6 +318,13 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
consumer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr);
+ if (!fusedOp.getShapesToLoopsMap()) {
+ // Fused op has invalid indexing maps. Typically this means something is off
+ // in the input, but going ahead here would result in verification errors.
+ // So cleanup and abort.
+ rewriter.eraseOp(fusedOp);
+ return llvm::None;
+ }
// Construct an AffineMap from consumer loops to producer loops.
// consumer loop -> tensor index
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 6ae9e15543e1c..3f68820b18cc7 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -945,3 +945,33 @@ func @no_fusion_missing_reduction_shape(%arg0: tensor<f32>, %arg1: index) -> ten
} -> tensor<?xf32>
return %8 : tensor<?xf32>
}
+
+// -----
+
+func @illegal_fusion(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tensor<5000xi32> {
+ %c1_i32 = arith.constant 1 : i32
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ outs(%arg0 : tensor<5000xi64>) {
+ ^bb0(%arg3: i64): // no predecessors
+ %22 = linalg.index 0 : index
+ %23 = arith.index_cast %22 : index to i64
+ linalg.yield %23 : i64
+ } -> tensor<5000xi64>
+ %1 = linalg.init_tensor [5000] : tensor<5000xi32>
+ %2 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0 : tensor<5000xi64>) outs(%1 : tensor<5000xi32>) {
+ ^bb0(%arg3: i64, %arg5: i32): // no predecessors
+ %22 = arith.index_cast %arg3 : i64 to index
+ %23 = tensor.extract %arg1[%22] : tensor<5000xi32>
+ linalg.yield %23 : i32
+ } -> tensor<5000xi32>
+ return %2 : tensor<5000xi32>
+}
+// CHECK-LABEL: func @illegal_fusion(
+// CHECK: %[[PRODUCER:.+]] = linalg.generic
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[PRODUCER]]
More information about the Mlir-commits
mailing list