[Mlir-commits] [mlir] 0c4e538 - [mlir][Linalg] Add an InitTensor -> DimOp canonicalization pattern.
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Jul 7 01:46:41 PDT 2021
Author: Nicolas Vasilache
Date: 2021-07-07T08:44:54Z
New Revision: 0c4e538d8fdaf66b4eb9f361f156d068e4f90abd
URL: https://github.com/llvm/llvm-project/commit/0c4e538d8fdaf66b4eb9f361f156d068e4f90abd
DIFF: https://github.com/llvm/llvm-project/commit/0c4e538d8fdaf66b4eb9f361f156d068e4f90abd.diff
LOG: [mlir][Linalg] Add an InitTensor -> DimOp canonicalization pattern.
Differential Revision: https://reviews.llvm.org/D105537
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 18c55f4019cab..62e66a421d5c5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -799,11 +799,30 @@ struct FoldInitTensorWithTensorReshapeOp
return success();
}
};
+
+struct FoldInitTensorWithDimOp : public OpRewritePattern<tensor::DimOp> {
+ using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::DimOp dimOp,
+ PatternRewriter &rewriter) const override {
+ Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
+ auto initTensorOp = dimOp.source().getDefiningOp<linalg::InitTensorOp>();
+ if (!initTensorOp || !maybeConstantIndex)
+ return failure();
+ if (initTensorOp.isDynamicSize(*maybeConstantIndex)) {
+ rewriter.replaceOp(dimOp,
+ initTensorOp.getDynamicSize(*maybeConstantIndex));
+ return success();
+ }
+ rewriter.replaceOpWithNewOp<ConstantIndexOp>(dimOp, *maybeConstantIndex);
+ return success();
+ }
+};
} // namespace
void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldInitTensorWithExtractSliceOp,
+ results.add<FoldInitTensorWithDimOp, FoldInitTensorWithExtractSliceOp,
FoldInitTensorWithTensorReshapeOp<TensorExpandShapeOp>,
FoldInitTensorWithTensorReshapeOp<TensorCollapseShapeOp>,
ReplaceStaticShapeDims>(context);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 350a3cbb6842c..864c79a357561 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -540,13 +540,10 @@ func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
}
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: func @init_tensor_reshape_expansion
-// CHECK-SAME: %[[ARG0:.+]]: index
-// CHECK: %[[C2:.+]] = constant 2
-// CHECK: %[[INIT1:.+]] = linalg.init_tensor [6, 5, %[[ARG0]]]
-// CHECK: %[[D0:.+]] = tensor.dim %[[INIT1]], %[[C2]]
-// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK: %[[INIT2:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
-// CHECK: return %[[INIT2]]
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+// CHECK-NEXT: %[[INIT:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[D]], 7]
+// CHECK-NEXT: return %[[INIT]]
// -----
@@ -558,13 +555,10 @@ func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
}
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
// CHECK: func @init_tensor_reshape_collapse
-// CHECK-SAME: %[[ARG0:.+]]: index
-// CHECK: %[[C4:.+]] = constant 4
-// CHECK: %[[INIT1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[ARG0]], 7]
-// CHECK: %[[D0:.+]] = tensor.dim %[[INIT1]], %[[C4]]
-// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK: %[[INIT2:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
-// CHECK: return %[[INIT2]]
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+// CHECK-NEXT: %[[INIT:.+]] = linalg.init_tensor [6, 5, %[[D]]]
+// CHECK-NEXT: return %[[INIT]]
// -----
@@ -873,3 +867,26 @@ func @pad_static_zero_cast(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<
return %0 : tensor<2x3x4xf32>
}
+// -----
+
+func private @some_use(%i : index, %j : index)
+
+// CHECK-LABEL: func @init_canonicalize
+// CHECK-SAME: %[[I:.*]]: index
+func @init_canonicalize(%i : index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+
+ // CHECK-NOT: init_tensor
+ %0 = linalg.init_tensor [%i, 42] : tensor<?x42xf32>
+
+ // CHECK-NOT: tensor.dim
+ %1 = tensor.dim %0, %c0: tensor<?x42xf32>
+ %2 = tensor.dim %0, %c1: tensor<?x42xf32>
+
+ // CHECK: %[[c42:.*]] = constant 42 : index
+ // CHECK: call @some_use(%[[I]], %[[c42]])
+ call @some_use(%1, %2) : (index, index) -> ()
+
+ return
+}
More information about the Mlir-commits
mailing list