[Mlir-commits] [mlir] 93bbcff - [mlir][Transform] Make FuseIntoContainingOp support rank-reducing extract slices
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Dec 12 12:55:16 PST 2022
Author: Nicolas Vasilache
Date: 2022-12-12T12:55:08-08:00
New Revision: 93bbcffc7e9dbb30d5cd9002bc136dd3d7df950d
URL: https://github.com/llvm/llvm-project/commit/93bbcffc7e9dbb30d5cd9002bc136dd3d7df950d
DIFF: https://github.com/llvm/llvm-project/commit/93bbcffc7e9dbb30d5cd9002bc136dd3d7df950d.diff
LOG: [mlir][Transform] Make FuseIntoContainingOp support rank-reducing extract slices
This fixes an issue where rank-reducing + fusion would not interop properly.
Differential Revision: https://reviews.llvm.org/D139844
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 55f3695c4d1b8..7088b0c012f06 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -435,6 +435,15 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
/// Return the dimensions of the source that are dropped in the
/// result when the result is rank-reduced.
llvm::SmallBitVector getDroppedDims();
+
+ /// Given a `value`, asserted to be of RankedTensorType, build an
+ /// ExtractSliceOp that results in a rank-reducing extract to the desired
+ /// tensor shape and return the new value created.
+ /// If the shape of `value` is already the `desiredShape`, just return
+ /// `value`.
+ /// If the shape of `value` cannot be rank-reduced to `desiredShape`, fail.
+ static FailureOr<Value> rankReduceIfNeeded(
+ OpBuilder &b, Location loc, Value value, ArrayRef<int64_t> desiredShape);
}];
let hasCanonicalizer = 1;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c8dd269029026..dc6baddadf433 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -299,7 +300,14 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
- rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
+ auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
+ rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
+ sliceOpToTile->getResult(0)
+ .getType()
+ .cast<RankedTensorType>()
+ .getShape());
+ assert(succeeded(maybeRankReduced) && "unexpected shape");
+ rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
return fusedOp;
}
@@ -399,7 +407,14 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
- rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
+ auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
+ rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
+ sliceOpToTile->getResult(0)
+ .getType()
+ .cast<RankedTensorType>()
+ .getShape());
+ assert(succeeded(maybeRankReduced) && "unexpected shape");
+ rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
// Replace the use in containingOp.
rewriter.updateRootInPlace(containingOp, [&]() {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f5cb9fec082f8..3578db643c98b 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -17,7 +17,9 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Support/MathExtras.h"
@@ -1754,6 +1756,23 @@ llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
return droppedDims;
}
+FailureOr<Value>
+ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
+ ArrayRef<int64_t> desiredShape) {
+ auto sourceTensorType = value.getType().dyn_cast<RankedTensorType>();
+ assert(sourceTensorType && "not a ranked tensor type");
+ auto sourceShape = sourceTensorType.getShape();
+ if (sourceShape.equals(desiredShape))
+ return value;
+ auto maybeRankReductionMask =
+ mlir::computeRankReductionMask(sourceShape, desiredShape);
+ if (!maybeRankReductionMask)
+ return failure();
+ return createCanonicalRankReducingExtractSliceOp(
+ b, loc, value,
+ RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
+}
+
LogicalResult ExtractSliceOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
reifiedReturnShapes.resize(1);
@@ -2375,7 +2394,6 @@ struct InsertSliceOpSourceCastInserter final
insertSliceOp, cast, insertSliceOp.getDest(),
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
insertSliceOp.getMixedStrides());
- cast.getDefiningOp()->getParentOfType<ModuleOp>().dump();
return success();
}
};
@@ -2475,8 +2493,7 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
SmallVector<int64_t, 4> inferredShape;
for (auto i : llvm::seq<unsigned>(0, rank)) {
- if (sourceType.isDynamicDim(i) ||
- staticLow[i] == ShapedType::kDynamic ||
+ if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
staticHigh[i] == ShapedType::kDynamic) {
inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
: resultShape[i]);
@@ -2525,8 +2542,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
// This will grow staticLow and staticHigh with 1 value. If the config is
// dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
// value as well.
- dispatchIndexOpFoldResults(low, dynamicLow, staticLow,
- ShapedType::kDynamic);
+ dispatchIndexOpFoldResults(low, dynamicLow, staticLow, ShapedType::kDynamic);
dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh,
ShapedType::kDynamic);
if (!resultType) {
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 141e8f59b5a21..7424e08f5338c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -96,6 +96,52 @@ module {
// -----
+module {
+ func.func @foo(%0: tensor<f32>) -> tensor<f32> {
+ return %0: tensor<f32>
+ }
+
+ // CHECK-LABEL: func.func @fuse_tileable_op_rank_reducing
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<?xf32>
+ func.func @fuse_tileable_op_rank_reducing(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+ %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
+ %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
+
+ // CHECK: scf.foreach_thread {{.*}} -> (tensor<?xf32>) {
+ %2 = scf.foreach_thread (%arg3) in (%d0) shared_outs(%o = %0) -> (tensor<?xf32>) {
+ %5 = tensor.extract_slice %o[%arg3] [1] [1] : tensor<?xf32> to tensor<f32>
+
+ // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [1] [1] : tensor<?xf32> to tensor<1xf32>
+ // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<1xf32>) -> tensor<1xf32>
+ // CHECK: tensor.extract_slice %{{.*}}[0] [1] [1] : tensor<1xf32> to tensor<f32>
+ // CHECK: func.call @foo(%{{.*}}) : (tensor<f32>) -> tensor<f32>
+ %7 = func.call @foo(%5) : (tensor<f32>) -> tensor<f32>
+
+ scf.foreach_thread.perform_concurrently {
+ // CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [1] [1] : tensor<f32> into tensor<?xf32>
+ tensor.parallel_insert_slice %7 into %o[%arg3] [1] [1] : tensor<f32> into tensor<?xf32>
+ }
+ }
+ // CHECK: }
+ func.return %2 : tensor<?xf32>
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.fill"]} in %arg1
+ %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1
+
+ // linalg.fill is tileable. The op is tiled and fused.
+ transform.structured.fuse_into_containing_op %0 into %1
+ }
+}
+
+// -----
+
#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
More information about the Mlir-commits
mailing list