[Mlir-commits] [mlir] b0fc712 - [mlir][Linalg] Disable const -> linalg.generic when fused op is illegal.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 12 10:16:14 PDT 2021
Author: MaheshRavishankar
Date: 2021-04-12T10:15:54-07:00
New Revision: b0fc712b14ff5f9fbf56e6605fd6ae48ab017ec8
URL: https://github.com/llvm/llvm-project/commit/b0fc712b14ff5f9fbf56e6605fd6ae48ab017ec8
DIFF: https://github.com/llvm/llvm-project/commit/b0fc712b14ff5f9fbf56e6605fd6ae48ab017ec8.diff
LOG: [mlir][Linalg] Disable const -> linalg.generic when fused op is illegal.
Fusing a constant with a linalg.generic operation can result in the
fused operation being illegal since the loop bound computation
fails. Avoid such fusions.
Differential Revision: https://reviews.llvm.org/D100272
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/Dialect/Linalg/fusion-tensor.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 713de7b22c4b7..a404cbd560f73 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -1103,6 +1103,12 @@ class FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index()));
+ // Check if the operation shapes to loops map is computable.
+ if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
+ return rewriter.notifyMatchFailure(
+ linalgOp, "fused op loop bound computation failed");
+ }
+
// The operands list is same as the linalgOp with the argument for
// constant index dropped.
SmallVector<Value, 4> fusedOperands(linalgOp.getInputs());
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 7983fe19a95a3..00d0995a25f6e 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -678,3 +678,26 @@ func @index_op(%arg0: tensor<1x8xindex>, %arg1: tensor<1x8xindex>) -> tensor<1x8
} -> tensor<1x8xindex>
return %1 : tensor<1x8xindex>
}
+
+// -----
+
+// CHECK-LABEL: func @no_fuse_constant_with_reduction
+func @no_fuse_constant_with_reduction() -> tensor<3xf32>
+{
+ // CHECK: %[[CONST:.+]] = constant {{.+}} : tensor<3x2xf32>
+ // CHECK: %[[RESULT:.+]] = linalg.generic
+ // CHECK-SAME: ins(%[[CONST]] : tensor<3x2xf32>)
+ // CHECK: return %[[RESULT]]
+ %three = constant dense<3.0> : tensor<3x2xf32>
+ %init = linalg.init_tensor [3] : tensor<3xf32>
+ %result = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%three : tensor<3x2xf32>) outs(%init : tensor<3xf32>) {
+ ^bb0(%arg0 : f32, %arg1 : f32):
+ %0 = addf %arg0, %arg1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<3xf32>
+ return %result : tensor<3xf32>
+}
More information about the Mlir-commits
mailing list