[Mlir-commits] [mlir] 6052b17 - [mlir][tensor] Add dim(expand_shape/collapse_shape) folding
Matthias Springer
llvmlistbot at llvm.org
Tue Nov 22 08:36:38 PST 2022
Author: Matthias Springer
Date: 2022-11-22T17:34:49+01:00
New Revision: 6052b17aabec2db8ad255eca5632cb128363c604
URL: https://github.com/llvm/llvm-project/commit/6052b17aabec2db8ad255eca5632cb128363c604
DIFF: https://github.com/llvm/llvm-project/commit/6052b17aabec2db8ad255eca5632cb128363c604.diff
LOG: [mlir][tensor] Add dim(expand_shape/collapse_shape) folding
Differential Revision: https://reviews.llvm.org/D138487
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 7af19a7f6bb4e..14060075b2340 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1051,7 +1051,10 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
}]>
];
- let extraClassDeclaration = commonExtraClassDeclaration;
+ let extraClassDeclaration = commonExtraClassDeclaration # [{
+ int64_t getCorrespondingSourceDim(int64_t resultDim);
+ }];
+
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index bf54d46065a73..e53879b618cc7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -908,9 +908,7 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
}
Optional<int64_t> DimOp::getConstantIndex() {
- if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
- return constantOp.getValue().cast<IntegerAttr>().getInt();
- return {};
+ return getConstantIntValue(getIndex());
}
Speculation::Speculatability DimOp::getSpeculatability() {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index c5d7e42493af4..826c69e23f048 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
@@ -379,9 +380,7 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
}
Optional<int64_t> DimOp::getConstantIndex() {
- if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
- return constantOp.getValue().cast<IntegerAttr>().getInt();
- return {};
+ return getConstantIntValue(getIndex());
}
Speculation::Speculatability DimOp::getSpeculatability() {
@@ -1302,6 +1301,15 @@ void ExpandShapeOp::getAsmResultNames(
setNameFn(getResult(), "expanded");
}
+int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
+ assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
+ "invalid resultDim");
+ for (const auto &it : llvm::enumerate(getReassociationIndices()))
+ if (llvm::find(it.value(), resultDim) != it.value().end())
+ return it.index();
+ llvm_unreachable("could not find reassociation group");
+}
+
SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
return getSymbolLessAffineMaps(getReassociationExprs());
}
@@ -1470,6 +1478,87 @@ struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
}
};
+struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dimOp,
+ PatternRewriter &rewriter) const override {
+ auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
+ if (!expandShapeOp)
+ return failure();
+
+ // Only constant dimension values are supported.
+ Optional<int64_t> dim = dimOp.getConstantIndex();
+ if (!dim.has_value())
+ return failure();
+
+ // Skip static dims. These are folded to constant ops.
+ TensorType resultType = expandShapeOp.getResultType();
+ if (!resultType.isDynamicDim(*dim))
+ return failure();
+
+ // Find reassociation group that contains this result dimension.
+ int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
+
+ // `dim` is the only dynamic dimension in `group`. (Otherwise, the
+ // ExpandShapeOp would be ambiguous.)
+ int64_t product = 1;
+ ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
+ for (int64_t d : grp) {
+ if (d != dim) {
+ assert(!resultType.isDynamicDim(d) && "expected static dim");
+ product *= resultType.getDimSize(d);
+ }
+ }
+
+ // result dim size = src dim size / (product(other dims in reassoc group))
+ Value srcDimSz =
+ rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
+ AffineExpr expr;
+ bindSymbols(dimOp.getContext(), expr);
+ rewriter.replaceOpWithNewOp<AffineApplyOp>(dimOp, expr.floorDiv(product),
+ srcDimSz);
+ return success();
+ }
+};
+
+struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dimOp,
+ PatternRewriter &rewriter) const override {
+ auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
+ if (!collapseShapeOp)
+ return failure();
+
+ // Only constant dimension values are supported.
+ Optional<int64_t> dim = dimOp.getConstantIndex();
+ if (!dim.has_value())
+ return failure();
+
+ // Skip static dims. These are folded to constant ops.
+ TensorType resultType = collapseShapeOp.getResultType();
+ if (!resultType.isDynamicDim(*dim))
+ return failure();
+
+ // Get reassociation group of the result dimension.
+ ReassociationIndices group =
+ collapseShapeOp.getReassociationIndices()[*dim];
+
+ // result dim size = product(dims in reassoc group)
+ SmallVector<Value> srcDimSizes;
+ SmallVector<AffineExpr> syms;
+ AffineExpr product;
+ for (const auto &it : llvm::enumerate(group)) {
+ srcDimSizes.push_back(rewriter.create<DimOp>(
+ dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
+ syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
+ product = product ? product * syms.back() : syms.back();
+ }
+ rewriter.replaceOpWithNewOp<AffineApplyOp>(dimOp, product, srcDimSizes);
+ return success();
+ }
+};
} // namespace
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -1477,7 +1566,8 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
FoldReshapeWithConstant<ExpandShapeOp>,
- FoldReshapeWithFromElements<ExpandShapeOp>>(context);
+ FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
+ FoldDimOfCollapseShape>(context);
}
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 99e31c7c35964..c9e662f969d74 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1628,3 +1628,41 @@ func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tens
%r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
return %r: tensor<2xf32>
}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
+// CHECK-LABEL: func @dim_of_expand_shape(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
+// CHECK: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
+// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
+// CHECK: return %[[apply]]
+func.func @dim_of_expand_shape(%t: tensor<?x?xf32>) -> index {
+ %c2 = arith.constant 2 : index
+ %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]]
+ : tensor<?x?xf32> into tensor<?x1x?x5x1x8xf32>
+ %1 = tensor.dim %0, %c2 : tensor<?x1x?x5x1x8xf32>
+ return %1 : index
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) * 7)>
+// CHECK-LABEL: func @dim_of_collapse_shape(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x7x?xf32>
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[dim1:.*]] = tensor.dim %[[t]], %[[c1]]
+// CHECK-DAG: %[[dim2:.*]] = tensor.dim %[[t]], %[[c2]]
+// CHECK-DAG: %[[dim4:.*]] = tensor.dim %[[t]], %[[c4]]
+// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim1]], %[[dim2]], %[[dim4]]]
+// CHECK: return %[[apply]]
+func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
+ %c1 = arith.constant 1 : index
+ %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]
+ : tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
+ %1 = tensor.dim %0, %c1 : tensor<?x?xf32>
+ return %1 : index
+}
More information about the Mlir-commits
mailing list