[Mlir-commits] [mlir] 589eac6 - [mlir] Add canonicalizations for op -> tensor.cast folding.

Mahesh Ravishankar llvmlistbot at llvm.org
Tue Mar 8 10:27:16 PST 2022


Author: Mahesh Ravishankar
Date: 2022-03-08T18:26:55Z
New Revision: 589eac6524d6ba080a51107757c1f356c365d047

URL: https://github.com/llvm/llvm-project/commit/589eac6524d6ba080a51107757c1f356c365d047
DIFF: https://github.com/llvm/llvm-project/commit/589eac6524d6ba080a51107757c1f356c365d047.diff

LOG: [mlir] Add canonicalizations for op -> tensor.cast folding.

A `tensor.cast` consumer can be folded with its producer. This is
beneficial only if the result of the tensor cast is more static than
the source. This patch adds a utility function to check that this is
the case, and adds a couple of canonicalizations patterns that fold an
operation with `tensor.cast` conusmers.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D120950

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 0f896df15119a..be6f8c323cbe7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -85,6 +85,9 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
           });
     }
 
+    // Return both static and dynamic sizes as a list of `OpFoldResult`.
+    SmallVector<OpFoldResult> getMixedSizes();
+
     // Return the Value of the dynamic size of the tensor at dimension
     // `idx`. Asserts that the shape is dynamic at that `idx.
     Value getDynamicSize(unsigned idx) {

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 7b0d62e502e58..e623b8cb03116 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -85,6 +85,28 @@ bool preservesStaticInformation(Type source, Type target);
 /// ```
 bool canFoldIntoConsumerOp(CastOp castOp);
 
+/// Determines whether the tensor::CastOp casts to a more static version of the
+/// source tensor. This is useful to fold into a producing op and implement
+/// canonicaliation patterns with the `tensor.cast` op as the root, but producer
+/// being from 
diff erent dialects. Returns true when all conditions are met:
+/// 1. source and result and ranked tensors with same element type and rank.
+/// 2. the result type has more static information than the source.
+///
+/// Example:
+/// ```mlir
+///   %1 = producer ... : tensor<?x?xf32>
+///   %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
+/// ```
+///
+/// can be canonicalized to :
+///
+/// ```mlir
+///   %2 = producer ... : tensor<8x16xf32>
+/// ```
+/// Not all ops might be canonicalizable this way, but for those that can be,
+/// this method provides a check that it is worth doing the canonicalization.
+bool canFoldIntoProducerOp(CastOp castOp);
+
 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
 /// that can be folded.
 LogicalResult foldTensorCast(Operation *op);

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 53ff45a531049..010695172518c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1072,6 +1072,21 @@ Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
   return RankedTensorType::get(staticSizes, elementType, encoding);
 }
 
+SmallVector<OpFoldResult> InitTensorOp::getMixedSizes() {
+  SmallVector<OpFoldResult> mixedSizes;
+  mixedSizes.reserve(getType().getRank());
+  unsigned dynamicValIndex = 0;
+  for (Attribute attr : static_sizes()) {
+    auto intAttr = attr.cast<IntegerAttr>();
+    if (!ShapedType::isDynamic(intAttr.getInt())) {
+      mixedSizes.push_back(intAttr);
+      continue;
+    }
+    mixedSizes.push_back(sizes()[dynamicValIndex++]);
+  }
+  return mixedSizes;
+}
+
 namespace {
 /// Change the type of the result of a `linalg.init_tensor` by making the result
 /// type statically sized along dimension that in the original operation where
@@ -1193,11 +1208,86 @@ struct FoldInitTensorWithDimOp : public OpRewritePattern<tensor::DimOp> {
     return success();
   }
 };
+
+/// Canonicalize
+///
+/// ```mlir
+///   %0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+///   %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
+/// ```
+///
+/// into
+///
+/// ```mlir
+///   %0 = linalg.init_tensor [4, %d1] : tensor<4x?xf32>
+/// ```
+///
+/// This assumes the input program is correct in terms of its shape. So it
+/// is safe to assume that `%d0` is in fact 4. If that was not the case, the
+/// input program is wrong to begin with, so its undefined behavior anyway (i.e.
+/// this optimization can still triggering without violating program semantics).
+struct FoldInitTensorWithTensorCastOp
+    : public OpRewritePattern<tensor::CastOp> {
+  using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::CastOp castOp,
+                                PatternRewriter &rewriter) const override {
+    if (!canFoldIntoProducerOp(castOp))
+      return failure();
+    auto producer = castOp.source().getDefiningOp<InitTensorOp>();
+    if (!producer)
+      return failure();
+
+    auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
+    ArrayRef<int64_t> resultShape = resultType.getShape();
+    SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
+    SmallVector<OpFoldResult> newMixedSizes;
+    newMixedSizes.reserve(currMixedSizes.size());
+    assert(resultShape.size() == currMixedSizes.size() &&
+           "mismatch in result shape and sizes of init_tensor op");
+    for (auto it : llvm::zip(resultShape, currMixedSizes)) {
+      int64_t newDim = std::get<0>(it);
+      OpFoldResult currDim = std::get<1>(it);
+      // Case 1: The init tensor dim is static. Check that the tensor cast
+      // result dim matches.
+      if (auto attr = currDim.dyn_cast<Attribute>()) {
+        if (ShapedType::isDynamic(newDim) ||
+            newDim != attr.cast<IntegerAttr>().getInt()) {
+          // Something is off, the cast result shape cannot be more dynamic than
+          // the init tensor result shape (enforced by `canFoldIntoProducer`).
+          // Abort for now.
+          return rewriter.notifyMatchFailure(
+              producer, "mismatch in static value of shape of init "
+                        "tensor result and cast result");
+        }
+        newMixedSizes.push_back(attr);
+        continue;
+      }
+
+      // Case 2 : The tensor cast shape is static, but init tensor result shape
+      // is dynamic.
+      if (!ShapedType::isDynamic(newDim)) {
+        newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
+        continue;
+      }
+
+      // Case 3 : The tensor cast shape is dynamic and init tensor result shape
+      // is dynamic. Use the dynamic value from the init tensor op.
+      newMixedSizes.push_back(currDim);
+    }
+
+    rewriter.replaceOpWithNewOp<InitTensorOp>(castOp, newMixedSizes,
+                                              resultType.getElementType());
+    return success();
+  }
+};
+
 } // namespace
 
 void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                MLIRContext *context) {
-  results.add<FoldInitTensorWithDimOp, FoldInitTensorWithExtractSliceOp,
+  results.add<FoldInitTensorWithTensorCastOp, FoldInitTensorWithDimOp,
+              FoldInitTensorWithExtractSliceOp,
               FoldInitTensorWithTensorReshapeOp<tensor::ExpandShapeOp>,
               FoldInitTensorWithTensorReshapeOp<tensor::CollapseShapeOp>,
               ReplaceStaticShapeDims>(context);
@@ -1608,7 +1698,7 @@ struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
   }
 };
 
-struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
+struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
   using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
 
   LogicalResult matchAndRewrite(LinalgOp op,
@@ -1664,6 +1754,63 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
   }
 };
 
+/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
+/// result that is more static than the linalg op.
+struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
+  using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::CastOp castOp,
+                                PatternRewriter &rewriter) const override {
+    if (!tensor::canFoldIntoProducerOp(castOp))
+      return failure();
+    auto linalgOp = castOp.source().getDefiningOp<LinalgOp>();
+    if (!linalgOp)
+      return failure();
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(linalgOp);
+
+    Location loc = linalgOp.getLoc();
+    OpResult resultValue = castOp.source().cast<OpResult>();
+    unsigned resultNumber = resultValue.getResultNumber();
+    auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
+    // Replace the `outs` for the result with a `tensor.cast`. This cast is now
+    // going from a more dynamic shape to a less dynamic shape. If the producer
+    // for this cast, i.e. producer of the out operand, is also an operation
+    // that folds with tensor.cast consumer (like this pattern), the cast will
+    // continue to propagate as far up the stack as it can go.
+    OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
+    Value newOperand =
+        rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
+    SmallVector<Value> newOperands = linalgOp.getInputOperands();
+    SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
+    outputOperands[resultNumber] = newOperand;
+    newOperands.append(outputOperands.begin(), outputOperands.end());
+
+    SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
+                                  linalgOp->result_type_end());
+    resultTypes[resultNumber] = resultType;
+    Operation *newOp = linalgOp.clone(rewriter, loc, resultTypes, newOperands);
+
+    if (!resultValue.hasOneUse()) {
+      SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
+      // Create a tensor.cast operation back to the original type.
+      Value castBack = rewriter.create<tensor::CastOp>(
+          loc, resultValue.getType(), newOp->getResult(resultNumber));
+      results[resultNumber] = castBack;
+      // Replace all uses except the use in the cast op that is matched by the
+      // pattern. Note that this cast is from a more static shape to a more
+      // dynamic shape. These are expected to be pulled into their consumers.
+      rewriter.replaceOpWithIf(linalgOp, results,
+                               [&castOp](OpOperand &use) -> bool {
+                                 return use.getOwner() != castOp.getOperation();
+                               });
+    }
+    rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
+    return success();
+  }
+};
+
 } // namespace
 
 #define LINALGOP_FOLDERS(XXX)                                                  \
@@ -1684,7 +1831,8 @@ LINALGOP_FOLDERS(GenericOp)
 
 void LinalgDialect::getCanonicalizationPatterns(
     RewritePatternSet &results) const {
-  results.add<EraseDeadLinalgOp, FoldTensorCastOp>(getContext());
+  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
+              FoldTensorCastProducerOp>(getContext());
 }
 
 Operation *LinalgDialect::materializeConstant(OpBuilder &builder,

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0c5d182c82d3e..e93d619bce74e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -98,6 +98,33 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
                                     castOp.source().getType());
 }
 
+/// Determines whether the tensor::CastOp casts to a more static version of the
+/// source tensor. This is useful to fold into a producing op and implement
+/// canonicaliation patterns with the `tensor.cast` op as the root, but producer
+/// being from 
diff erent dialects. Returns true when all conditions are met:
+/// 1. source and result and ranked tensors with same element type and rank.
+/// 2. the result type has more static information than the source.
+///
+/// Example:
+/// ```mlir
+///   %1 = producer ... : tensor<?x?xf32>
+///   %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
+/// ```
+///
+/// can be canonicalized to :
+///
+/// ```mlir
+///   %2 = producer ... : tensor<8x16xf32>
+/// ```
+/// Not all ops might be canonicalizable this way, but for those that can be,
+/// this method provides a check that it is worth doing the canonicalization.
+bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
+  if (!castOp)
+    return false;
+  return preservesStaticInformation(castOp.source().getType(),
+                                    castOp.getType());
+}
+
 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
 /// that can be folded.
 LogicalResult mlir::tensor::foldTensorCast(Operation *op) {

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index a2be638fd8078..b24a3e78e32b1 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -244,6 +244,17 @@ func @fold_init_tensor_with_slice
 
 // -----
 
+func @fold_init_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> {
+  %0 = linalg.init_tensor [%arg0, 12] : tensor<?x12xf32>
+  %1 = tensor.cast %0 : tensor<?x12xf32> to tensor<1x12xf32>
+  return %1 : tensor<1x12xf32>
+}
+//      CHECK: func @fold_init_tensor_with_cast(%[[ARG0:.+]]: index)
+//      CHECK:   %[[T0:.+]] = linalg.init_tensor [1, 12] : tensor<1x12xf32>
+//      CHECK:   return %[[T0]] : tensor<1x12xf32>
+
+// -----
+
 #accesses = [
   affine_map<(i, j) -> (i, j)>
 ]
@@ -747,3 +758,23 @@ func @multi_insert_pad_into_fill_mismatch(%input: tensor<7x123x124xf32>, %a: ten
   %2 = tensor.insert_slice %pad into %1   [0, 0, 256]      [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
   return %2: tensor<8x384x384xf32>
 }
+
+// -----
+
+func @fold_linalgop_with_cast_consumer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?x?xf32>) -> (tensor<4x8xf32>, tensor<?x?xf32>) {
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x8xf32>
+  return %1, %0 : tensor<4x8xf32>, tensor<?x?xf32>
+}
+//       CHECK: func @fold_linalgop_with_cast_consumer(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//       CHECK:  %[[OUT_CAST:.+]] = tensor.cast %[[ARG2]] : tensor<?x?xf32> to tensor<4x8xf32>
+//       CHECK:  %[[MATMUL:.+]] = linalg.matmul
+//  CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]] :
+//  CHECK-SAME:      outs(%[[OUT_CAST]] :
+//       CHECK:  %[[RESULT_CAST:.+]] = tensor.cast %[[MATMUL]]
+//       CHECK:  return %[[MATMUL]], %[[RESULT_CAST]]

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index ac94261a153f0..716e38a32e03c 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -47,7 +47,7 @@ func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tens
 //       CHECK: scf.for %[[I:[0-9a-z]*]]
 //       CHECK:   %[[sizeA0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dA0]]]
 //       CHECK:   %[[stA:.*]] = tensor.extract_slice %[[A]][%[[I]], 0] [%[[sizeA0]], %[[dA1]]] [1, 1]  : tensor<?x?xf32> to tensor<?x?xf32>
-//  CHECK-NEXT:   scf.for %[[J:[0-9a-z]*]]
+//       CHECK:   scf.for %[[J:[0-9a-z]*]]
 //  CHECK-NEXT:     scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]]
 //   CHECK-DAG:       %[[stB1:.*]] = tensor.extract_slice %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1]  : tensor<?x?xf32> to tensor<4x3xf32>
 //   CHECK-DAG:       %[[stF:.*]] = tensor.extract_slice %[[RES]][%[[I]], %[[J]]] [2, 3] [1, 1]  : tensor<?x?xf32> to tensor<2x3xf32>
@@ -56,9 +56,9 @@ func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tens
 //       CHECK:       %[[sizeB1:.*]] = affine.min #[[BOUND4_MAP]](%[[K]])[%[[dB1]]]
 //       CHECK:       %[[stB2:.*]] = tensor.extract_slice %[[B]][0, %[[K]]] [%[[dB0]], %[[sizeB1]]] [1, 1]  : tensor<?x?xf32> to tensor<?x?xf32>
 //       CHECK:       %[[stC:.*]] = tensor.extract_slice %[[C]][%[[I]], %[[K]]] [%[[sizeA0]], %[[sizeB1]]] [1, 1]  : tensor<?x?xf32> to tensor<?x?xf32>
-//       CHECK:       %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[stC]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
-//       CHECK:       %[[CAST:.*]] = tensor.cast %[[stD]] : tensor<?x?xf32> to tensor<2x4xf32>
-//  CHECK-NEXT:       %[[stG:.*]] = linalg.matmul ins(%[[CAST]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>)  -> tensor<2x3xf32>
+//   CHECK-DAG:       %[[castC:.+]] = tensor.cast %[[stC]] : tensor<?x?xf32> to tensor<2x4xf32>
+//       CHECK:       %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[castC]] : tensor<2x4xf32>)  -> tensor<2x4xf32>
+//  CHECK-NEXT:       %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>)  -> tensor<2x3xf32>
 //  CHECK-NEXT:       tensor.insert_slice %[[stG]] into %[[RES]][%[[I]], %[[J]]]
 
 // -----


        


More information about the Mlir-commits mailing list