[Mlir-commits] [mlir] 3f89e33 - [mlir] add pad_tensor(tensor.cast) -> pad_tensor canonicalizer

Alex Zinenko llvmlistbot at llvm.org
Fri Sep 24 03:03:54 PDT 2021


Author: Alex Zinenko
Date: 2021-09-24T12:03:47+02:00
New Revision: 3f89e339bb185726a2a3eb127ac59c813b52c6fe

URL: https://github.com/llvm/llvm-project/commit/3f89e339bb185726a2a3eb127ac59c813b52c6fe
DIFF: https://github.com/llvm/llvm-project/commit/3f89e339bb185726a2a3eb127ac59c813b52c6fe.diff

LOG: [mlir] add pad_tensor(tensor.cast) -> pad_tensor canonicalizer

This canonicalization pattern complements the tensor.cast(pad_tensor) one in
propagating constant type information when possible. It contributes to the
feasibility of pad hoisting.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110343

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
    mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
    mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index f374a0613f7a1..e8df979ac9cfb 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -53,6 +53,10 @@ SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
 namespace mlir {
 namespace tensor {
 
+/// Returns true if `target` is a ranked tensor type that preserves static
+/// information available in the `source` ranked tensor type.
+bool preservesStaticInformation(Type source, Type target);
+
 /// Determines whether tensor::CastOp casts to a more dynamic version of the
 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
 /// implement canonicalization patterns for ops in 
diff erent dialects that may

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 75e4a1c91bcda..dfa4df8f58d69 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1482,11 +1482,41 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
     return success();
   }
 };
+
+// Fold CastOp using the result of PadTensorOp back into the latter if it adds
+// static information.
+struct FoldTargetTensorCast : public OpRewritePattern<PadTensorOp> {
+  using OpRewritePattern<PadTensorOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
+                                PatternRewriter &rewriter) const override {
+    if (!padTensorOp.result().hasOneUse())
+      return failure();
+    auto tensorCastOp =
+        dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
+    if (!tensorCastOp)
+      return failure();
+    if (!tensor::preservesStaticInformation(padTensorOp.result().getType(),
+                                            tensorCastOp.dest().getType()))
+      return failure();
+
+    auto replacementOp = rewriter.create<PadTensorOp>(
+        padTensorOp.getLoc(), tensorCastOp.dest().getType(),
+        padTensorOp.source(), padTensorOp.low(), padTensorOp.high(),
+        padTensorOp.static_low(), padTensorOp.static_high());
+    replacementOp.region().takeBody(padTensorOp.region());
+
+    rewriter.replaceOp(padTensorOp, replacementOp.result());
+    rewriter.replaceOp(tensorCastOp, replacementOp.result());
+    return success();
+  }
+};
 } // namespace
 
 void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   results.add<FoldStaticZeroPadding, FoldSourceTensorCast>(context);
+  results.add<FoldTargetTensorCast>(context);
 }
 
 /// Return the padding value of the PadTensorOp if it constant. In this context,

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 14ce6c104d44f..2a55223ca9a8b 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -31,6 +31,34 @@ Operation *TensorDialect::materializeConstant(OpBuilder &builder,
 // CastOp
 //===----------------------------------------------------------------------===//
 
+/// Returns true if `target` is a ranked tensor type that preserves static
+/// information available in the `source` ranked tensor type.
+bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
+  auto sourceType = source.dyn_cast<RankedTensorType>();
+  auto targetType = target.dyn_cast<RankedTensorType>();
+
+  // Requires RankedTensorType.
+  if (!sourceType || !targetType)
+    return false;
+
+  // Requires same elemental type.
+  if (sourceType.getElementType() != targetType.getElementType())
+    return false;
+
+  // Requires same rank.
+  if (sourceType.getRank() != targetType.getRank())
+    return false;
+
+  // If cast is towards more static sizes along any dimension, don't fold.
+  for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
+    if (!ShapedType::isDynamic(std::get<0>(t)) &&
+        ShapedType::isDynamic(std::get<1>(t)))
+      return false;
+  }
+
+  return true;
+}
+
 /// Determines whether tensor::CastOp casts to a more dynamic version of the
 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
 /// implement canonicalization patterns for ops in 
diff erent dialects that may
@@ -57,30 +85,10 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
   if (!castOp)
     return false;
 
-  RankedTensorType sourceType =
-      castOp.source().getType().dyn_cast<RankedTensorType>();
-  RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
-
-  // Requires RankedTensorType.
-  if (!sourceType || !resultType)
-    return false;
-
-  // Requires same elemental type.
-  if (sourceType.getElementType() != resultType.getElementType())
-    return false;
-
-  // Requires same rank.
-  if (sourceType.getRank() != resultType.getRank())
-    return false;
-
-  // If cast is towards more static sizes along any dimension, don't fold.
-  for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
-    if (ShapedType::isDynamic(std::get<0>(t)) &&
-        !ShapedType::isDynamic(std::get<1>(t)))
-      return false;
-  }
-
-  return true;
+  // Can fold if the source of cast has at least as much static information as
+  // its results.
+  return preservesStaticInformation(castOp.getType(),
+                                    castOp.source().getType());
 }
 
 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index fce08a1e04dca..42d640a60246c 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -696,6 +696,39 @@ func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
 
 // -----
 
+// CHECK-LABEL: @cast_of_pad_more_static
+func @cast_of_pad_more_static(%arg0: tensor<?x?xf32>, %padding: index) -> tensor<32x32xf32> {
+  %cst = constant 0.000000e+00 : f32
+  // CHECK: %[[PAD:.*]] = linalg.pad_tensor
+  // CHECK: tensor<?x?xf32> to tensor<32x32xf32>
+  %padded = linalg.pad_tensor %arg0 low[%padding, %padding] high[0, 0] {
+  ^bb0(%arg1: index, %arg2: index):
+    linalg.yield %cst : f32
+  } : tensor<?x?xf32> to tensor<?x?xf32>
+  // CHECK-NOT: tensor.cast
+  %casted = tensor.cast %padded : tensor<?x?xf32> to tensor<32x32xf32>
+  // CHECK: return %[[PAD]]
+  return %casted : tensor<32x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @cast_of_pad_less_static
+func @cast_of_pad_less_static(%arg0: tensor<32x?x?xf32>, %padding: index) -> tensor<?x32x32xf32> {
+  %cst = constant 0.000000e+00 : f32
+  // CHECK: linalg.pad_tensor
+  %padded = linalg.pad_tensor %arg0 low[%padding, %padding, %padding] high[0, 0, 0] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index):
+    linalg.yield %cst : f32
+  } : tensor<32x?x?xf32> to tensor<32x?x?xf32>
+  // CHECK: %[[CAST:.*]] = tensor.cast
+  %casted = tensor.cast %padded : tensor<32x?x?xf32> to tensor<?x32x32xf32>
+  // CHECK: return %[[CAST]]
+  return %casted : tensor<?x32x32xf32>
+}
+
+// -----
+
 func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
     %arg3 : index) -> tensor<?x?xf32> {
   %c0 = constant 0 : index

diff  --git a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
index 13f12d83133be..b00855581efb7 100644
--- a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
+++ b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
@@ -140,8 +140,7 @@ func @static_mixed_data_low_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
 //       CHECK:   } else {
 //       CHECK:     %[[SUBTENSOR:.*]] = tensor.extract_slice %[[ARG0]][%{{.*}}, 4] [%{{.*}}, 1] [1, 1] : tensor<?x5xf32> to tensor<?x1xf32>
 //       CHECK:     %[[PADTENSOR:.*]] = linalg.pad_tensor %[[SUBTENSOR]] low[0, 0] high[%{{.*}}, 3]
-//       CHECK:     %[[CAST:.*]] = tensor.cast %[[PADTENSOR]] : tensor<?x4xf32> to tensor<3x4xf32>
-//       CHECK:     scf.yield %[[CAST]]
+//       CHECK:     scf.yield %[[PADTENSOR]]
 //       CHECK:   }
 //       CHECK:   return %[[RESULT]]
 func @dynamic_high_pad(%arg0 : tensor<?x5xf32>, %h1: index, %pad : f32) -> tensor<3x4xf32> {

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index c1a761ce1c425..4aef50e6c96ba 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -289,7 +289,6 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
 //     CHECK:       else
 //     CHECK:         tensor.extract_slice
 //     CHECK:         linalg.pad_tensor
-//     CHECK:         tensor.cast
 //     CHECK:       tensor.extract_slice
 //     CHECK:       tensor.extract_slice
 //     CHECK:       linalg.generic

diff  --git a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
index 20615f27d1442..5556699cdae66 100644
--- a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
+++ b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
@@ -111,8 +111,7 @@ func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
 //       TILE1:     else
 //       TILE1:       %[[SLICE:.*]] = tensor.extract_slice %arg0[0, %{{.*}}] [7, %{{.*}}] [1, 1] : tensor<7x9xf32> to tensor<7x?xf32>
 //       TILE1:       %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]] low[0, 0] high[7, %{{.*}}]
-//       TILE1:       %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<14x?xf32> to tensor<14x3xf32>
-//       TILE1:       scf.yield %[[CAST]] : tensor<14x3xf32>
+//       TILE1:       scf.yield %[[PAD]] : tensor<14x3xf32>
 //       TILE1:     %[[R3:.*]] = tensor.insert_slice %[[R2]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32>
 //       TILE1:     scf.yield %[[R3]] : tensor<14x15xf32>
 //       TILE1:   return %[[RESULT]] : tensor<14x15xf32>


        


More information about the Mlir-commits mailing list