[Mlir-commits] [mlir] 93bbcff - [mlir][Transform] Make FuseIntoContainingOp support rank-reducing extract slices

Nicolas Vasilache llvmlistbot at llvm.org
Mon Dec 12 12:55:16 PST 2022


Author: Nicolas Vasilache
Date: 2022-12-12T12:55:08-08:00
New Revision: 93bbcffc7e9dbb30d5cd9002bc136dd3d7df950d

URL: https://github.com/llvm/llvm-project/commit/93bbcffc7e9dbb30d5cd9002bc136dd3d7df950d
DIFF: https://github.com/llvm/llvm-project/commit/93bbcffc7e9dbb30d5cd9002bc136dd3d7df950d.diff

LOG: [mlir][Transform] Make FuseIntoContainingOp support rank-reducing extract slices

This fixes an issue where rank-reducing + fusion would not interop properly.

Differential Revision: https://reviews.llvm.org/D139844

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 55f3695c4d1b8..7088b0c012f06 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -435,6 +435,15 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
     /// Return the dimensions of the source that are dropped in the
     /// result when the result is rank-reduced.
     llvm::SmallBitVector getDroppedDims();
+
+    /// Given a `value`, asserted to be of RankedTensorType, build an
+    /// ExtractSliceOp that results in a rank-reducing extract to the desired
+    /// tensor shape and return the new value created.
+    /// If the shape of `value` is already the `desiredShape`, just return
+    /// `value`.
+    /// If the shape of `value` cannot be rank-reduced to `desiredShape`, fail.
+    static FailureOr<Value> rankReduceIfNeeded(
+      OpBuilder &b, Location loc, Value value, ArrayRef<int64_t> desiredShape);
   }];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c8dd269029026..dc6baddadf433 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -299,7 +300,14 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
 
   // Replace the extract op.
   Operation *fusedOp = tiledProducer->getDefiningOp();
-  rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
+  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
+      rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
+      sliceOpToTile->getResult(0)
+          .getType()
+          .cast<RankedTensorType>()
+          .getShape());
+  assert(succeeded(maybeRankReduced) && "unexpected shape");
+  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
   return fusedOp;
 }
 
@@ -399,7 +407,14 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
 
   // Replace the extract op.
   Operation *fusedOp = tiledProducer->getDefiningOp();
-  rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
+  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
+      rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
+      sliceOpToTile->getResult(0)
+          .getType()
+          .cast<RankedTensorType>()
+          .getShape());
+  assert(succeeded(maybeRankReduced) && "unexpected shape");
+  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
 
   // Replace the use in containingOp.
   rewriter.updateRootInPlace(containingOp, [&]() {

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f5cb9fec082f8..3578db643c98b 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -17,7 +17,9 @@
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Support/MathExtras.h"
@@ -1754,6 +1756,23 @@ llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
   return droppedDims;
 }
 
+FailureOr<Value>
+ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
+                                   ArrayRef<int64_t> desiredShape) {
+  auto sourceTensorType = value.getType().dyn_cast<RankedTensorType>();
+  assert(sourceTensorType && "not a ranked tensor type");
+  auto sourceShape = sourceTensorType.getShape();
+  if (sourceShape.equals(desiredShape))
+    return value;
+  auto maybeRankReductionMask =
+      mlir::computeRankReductionMask(sourceShape, desiredShape);
+  if (!maybeRankReductionMask)
+    return failure();
+  return createCanonicalRankReducingExtractSliceOp(
+      b, loc, value,
+      RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
+}
+
 LogicalResult ExtractSliceOp::reifyResultShapes(
     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
   reifiedReturnShapes.resize(1);
@@ -2375,7 +2394,6 @@ struct InsertSliceOpSourceCastInserter final
         insertSliceOp, cast, insertSliceOp.getDest(),
         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
         insertSliceOp.getMixedStrides());
-    cast.getDefiningOp()->getParentOfType<ModuleOp>().dump();
     return success();
   }
 };
@@ -2475,8 +2493,7 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
 
   SmallVector<int64_t, 4> inferredShape;
   for (auto i : llvm::seq<unsigned>(0, rank)) {
-    if (sourceType.isDynamicDim(i) ||
-        staticLow[i] == ShapedType::kDynamic ||
+    if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
         staticHigh[i] == ShapedType::kDynamic) {
       inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
                                                   : resultShape[i]);
@@ -2525,8 +2542,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
   // This will grow staticLow and staticHigh with 1 value. If the config is
   // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
   // value as well.
-  dispatchIndexOpFoldResults(low, dynamicLow, staticLow,
-                             ShapedType::kDynamic);
+  dispatchIndexOpFoldResults(low, dynamicLow, staticLow, ShapedType::kDynamic);
   dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh,
                              ShapedType::kDynamic);
   if (!resultType) {

diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 141e8f59b5a21..7424e08f5338c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -96,6 +96,52 @@ module {
 
 // -----
 
+module {
+  func.func @foo(%0: tensor<f32>) -> tensor<f32> {
+    return %0: tensor<f32>
+  }
+
+  // CHECK-LABEL: func.func @fuse_tileable_op_rank_reducing
+  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
+  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<?xf32>
+  func.func @fuse_tileable_op_rank_reducing(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
+    %cst = arith.constant 4.200000e+01 : f32
+    %c0 = arith.constant 0 : index
+    %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
+    %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
+
+    // CHECK: scf.foreach_thread {{.*}} -> (tensor<?xf32>) {
+    %2 = scf.foreach_thread (%arg3) in (%d0) shared_outs(%o = %0) -> (tensor<?xf32>) {
+      %5 = tensor.extract_slice %o[%arg3] [1] [1] : tensor<?xf32> to tensor<f32>
+      
+      // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [1] [1] : tensor<?xf32> to tensor<1xf32>
+      // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<1xf32>) -> tensor<1xf32>
+      // CHECK: tensor.extract_slice %{{.*}}[0] [1] [1] : tensor<1xf32> to tensor<f32>
+      // CHECK: func.call @foo(%{{.*}}) : (tensor<f32>) -> tensor<f32>
+      %7 = func.call @foo(%5) : (tensor<f32>) -> tensor<f32>
+
+      scf.foreach_thread.perform_concurrently {
+      // CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [1] [1] : tensor<f32> into tensor<?xf32>
+        tensor.parallel_insert_slice %7 into %o[%arg3] [1] [1] : tensor<f32> into tensor<?xf32>
+      }
+    }
+    // CHECK: }
+    func.return %2 : tensor<?xf32>
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.fill"]} in %arg1
+    %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1
+
+    // linalg.fill is tileable. The op is tiled and fused.
+    transform.structured.fuse_into_containing_op %0 into %1
+  }
+}
+
+// -----
+
 #map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
 #map1 = affine_map<(d0)[s0] -> (d0 * s0)>
 #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>


        


More information about the Mlir-commits mailing list