[Mlir-commits] [mlir] a99e06a - [mlir][Linalg] Avoid generating illegal operations during elementwise fusion.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 20 23:44:53 PST 2022


Author: MaheshRavishankar
Date: 2022-01-20T23:43:50-08:00
New Revision: a99e06aa869b44588a18a423f58e0ab30c292d8e

URL: https://github.com/llvm/llvm-project/commit/a99e06aa869b44588a18a423f58e0ab30c292d8e
DIFF: https://github.com/llvm/llvm-project/commit/a99e06aa869b44588a18a423f58e0ab30c292d8e.diff

LOG: [mlir][Linalg] Avoid generating illegal operations during elementwise fusion.

In some cases, fusion can produce illegal operations if after fusion
the range of some of the loops cannot be computed from shapes of its
operands. Check for this case and abort the fusion if this happens.

Differential Revision: https://reviews.llvm.org/D117602

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 33286258543e5..be34ef8bbd625 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -318,6 +318,13 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
       consumer.iterator_types(),
       /*doc=*/nullptr,
       /*library_call=*/nullptr);
+  if (!fusedOp.getShapesToLoopsMap()) {
+    // Fused op has invalid indexing maps. Typically this means something is off
+    // in the input, but going ahead here would result in verification errors.
+    // So cleanup and abort.
+    rewriter.eraseOp(fusedOp);
+    return llvm::None;
+  }
 
   // Construct an AffineMap from consumer loops to producer loops.
   // consumer loop -> tensor index

diff  --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 6ae9e15543e1c..3f68820b18cc7 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -945,3 +945,33 @@ func @no_fusion_missing_reduction_shape(%arg0: tensor<f32>, %arg1: index) -> ten
   } -> tensor<?xf32>
   return %8 : tensor<?xf32>
 }
+
+// -----
+
+func @illegal_fusion(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tensor<5000xi32> {
+  %c1_i32 = arith.constant 1 : i32
+  %0 = linalg.generic {
+        indexing_maps = [affine_map<(d0) -> (d0)>],
+        iterator_types = ["parallel"]}
+        outs(%arg0 : tensor<5000xi64>) {
+        ^bb0(%arg3: i64):  // no predecessors
+          %22 = linalg.index 0 : index
+          %23 = arith.index_cast %22 : index to i64
+          linalg.yield %23 : i64
+        } -> tensor<5000xi64>
+  %1 = linalg.init_tensor [5000] : tensor<5000xi32>
+  %2 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>],
+        iterator_types = ["parallel", "parallel"]}
+        ins(%0 : tensor<5000xi64>) outs(%1 : tensor<5000xi32>) {
+        ^bb0(%arg3: i64, %arg5: i32):  // no predecessors
+          %22 = arith.index_cast %arg3 : i64 to index
+          %23 = tensor.extract %arg1[%22] : tensor<5000xi32>
+          linalg.yield %23 : i32
+        } -> tensor<5000xi32>
+  return %2 : tensor<5000xi32>
+}
+// CHECK-LABEL: func @illegal_fusion(
+//       CHECK:   %[[PRODUCER:.+]] = linalg.generic
+//       CHECK:    linalg.generic
+//   CHECK-SAME:       ins(%[[PRODUCER]]


        


More information about the Mlir-commits mailing list