[Mlir-commits] [mlir] e5a315f - [mlir][Linalg] Disallow ops with index semantics in `PushExpandingReshape`.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 25 10:37:38 PST 2022
Author: MaheshRavishankar
Date: 2022-01-25T10:37:30-08:00
New Revision: e5a315f57acf5580aa8819123300d90b4f7a160a
URL: https://github.com/llvm/llvm-project/commit/e5a315f57acf5580aa8819123300d90b4f7a160a
DIFF: https://github.com/llvm/llvm-project/commit/e5a315f57acf5580aa8819123300d90b4f7a160a.diff
LOG: [mlir][Linalg] Disallow ops with index semantics in `PushExpandingReshape`.
This pattern is not written to handle operations with `linalg.index`
operations in its body, i.e. operations that have index semantics.
Differential Revision: https://reviews.llvm.org/D117856
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index be34ef8bbd625..aaa5d4c386208 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -994,7 +994,7 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Only apply to elementwise linalg on tensor.
- if (!genericOp.hasTensorSemantics() ||
+ if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() ||
genericOp.getNumParallelLoops() != genericOp.getNumLoops())
return failure();
// Only support identity output maps. It could be extended to permuations if
diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
index 9e96c98e7850b..0c02ff8c54d1f 100644
--- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -124,3 +124,30 @@ func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
// CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>)
// CHECK: tensor.expand_shape %[[OP]]
// CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32>
+
+// -----
+
+func @generic_op_index_semantics(%A: tensor<?x16xi64>, %B: tensor<16xi64>, %init: tensor<?x112x16xi64>) -> tensor<?x112x16xi64> {
+ %0 = tensor.expand_shape %A [[0, 1], [2]]
+ : tensor<?x16xi64> into tensor<?x112x16xi64>
+ %2 = linalg.generic {indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%0, %B : tensor<?x112x16xi64>, tensor<16xi64>)
+ outs(%init : tensor<?x112x16xi64>) {
+ ^bb0(%arg1: i64, %arg2: i64, %arg3: i64): // no predecessors
+ %index = linalg.index 0 : index
+ %1 = arith.index_cast %index : index to i64
+ %add = arith.addi %arg1, %1 : i64
+ %s = arith.subi %add, %arg2 : i64
+ linalg.yield %s : i64
+ } -> tensor<?x112x16xi64>
+ return %2 : tensor<?x112x16xi64>
+}
+// CHECK: func @generic_op_index_semantics
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x16xi64>
+// CHECK: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[RESHAPE]]
+// CHECK: return %[[RESULT]]
More information about the Mlir-commits
mailing list