[Mlir-commits] [mlir] 5e0ded2 - [mlir][Standard] Canonicalize chains of tensor_cast operations

Stephan Herhut llvmlistbot at llvm.org
Thu Sep 17 07:50:56 PDT 2020


Author: Stephan Herhut
Date: 2020-09-17T16:50:38+02:00
New Revision: 5e0ded268929b87ddf2c5e077c9185554342f602

URL: https://github.com/llvm/llvm-project/commit/5e0ded268929b87ddf2c5e077c9185554342f602
DIFF: https://github.com/llvm/llvm-project/commit/5e0ded268929b87ddf2c5e077c9185554342f602.diff

LOG: [mlir][Standard] Canonicalize chains of tensor_cast operations

Adds a pattern that replaces a chain of two tensor_cast operations by a single tensor_cast operation if doing so will not remove constraints on the shapes.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index b0aa9b9e3c76..2113dfeb4c08 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2997,6 +2997,8 @@ def TensorCastOp : CastOp<"tensor_cast"> {
     /// The result of a tensor_cast is always a tensor.
     TensorType getType() { return getResult().getType().cast<TensorType>(); }
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 0c86c87384d3..c0dc87210a3f 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3163,6 +3163,87 @@ OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
   return impl::foldCastOp(*this);
 }
 
+/// Compute a TensorType that has the joined shape knowledge of the two
+/// given TensorTypes. The element types need to match.
+static TensorType joinShapes(TensorType one, TensorType two) {
+  assert(one.getElementType() == two.getElementType());
+
+  if (!one.hasRank())
+    return two;
+  if (!two.hasRank())
+    return one;
+
+  int64_t rank = one.getRank();
+  if (rank != two.getRank())
+    return {};
+
+  SmallVector<int64_t, 4> join;
+  join.reserve(rank);
+  for (int64_t i = 0; i < rank; ++i) {
+    if (one.isDynamicDim(i)) {
+      join.push_back(two.getDimSize(i));
+      continue;
+    }
+    if (two.isDynamicDim(i)) {
+      join.push_back(one.getDimSize(i));
+      continue;
+    }
+    if (one.getDimSize(i) != two.getDimSize(i))
+      return {};
+    join.push_back(one.getDimSize(i));
+  }
+  return RankedTensorType::get(join, one.getElementType());
+}
+
+namespace {
+
+/// Replaces chains of two tensor_cast operations by a single tensor_cast
+/// operation if doing so does not remove runtime constraints.
+struct ChainedTensorCast : public OpRewritePattern<TensorCastOp> {
+  using OpRewritePattern<TensorCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TensorCastOp tensorCast,
+                                PatternRewriter &rewriter) const final {
+    auto tensorCastOperand =
+        tensorCast.getOperand().getDefiningOp<TensorCastOp>();
+
+    if (!tensorCastOperand)
+      return failure();
+
+    auto sourceType =
+        tensorCastOperand.getOperand().getType().cast<TensorType>();
+    auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
+    auto resultType = tensorCast.getType().cast<TensorType>();
+
+    // We can remove the intermediate cast if joining all three produces the
+    // same result as just joining the source and result shapes.
+    auto firstJoin =
+        joinShapes(joinShapes(sourceType, intermediateType), resultType);
+
+    // The join might not exist if the cast sequence would fail at runtime.
+    if (!firstJoin)
+      return failure();
+
+    // The newJoin always exists if the above join exists, it might just contain
+    // less information. If so, we cannot drop the intermediate cast, as doing
+    // so would remove runtime checks.
+    auto newJoin = joinShapes(sourceType, resultType);
+    if (firstJoin != newJoin)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<TensorCastOp>(tensorCast, resultType,
+                                              tensorCastOperand.getOperand());
+    return success();
+  }
+};
+
+} // namespace
+
+void TensorCastOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<ChainedTensorCast>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Helpers for Tensor[Load|Store]Op
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 320418545893..3603c473a1fd 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1062,3 +1062,51 @@ func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tenso
   return %0 : tensor<3x?x?x7x?xindex>
 }
 
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_ok
+// CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
+func @tensor_cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
+  // CHECK-NEXT: %[[RES:.*]] = tensor_cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
+  %0 = tensor_cast %input : tensor<*xi32> to tensor<4x?xi32>
+  %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
+  // CHECK-NEXT: return %[[RES]]
+  return %1 : tensor<4x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_regain
+// CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
+func @tensor_cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
+  %0 = tensor_cast %input : tensor<4xi32> to tensor<?xi32>
+  %1 = tensor_cast %0 : tensor<?xi32> to tensor<4xi32>
+  // CHECK-NEXT: return %[[IN]]
+  return %1 : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_keep
+// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
+func @tensor_cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
+  // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
+  %0 = tensor_cast %input : tensor<?x?xi32> to tensor<4x?xi32>
+  // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
+  %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
+  // CHECK-NEXT: return %[[C2]]
+  return %1 : tensor<?x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_invalid
+// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
+func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
+  // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
+  %0 = tensor_cast %input : tensor<4x8xi32> to tensor<?x?xi32>
+  // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
+  %1 = tensor_cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
+  // CHECK-NEXT: return %[[C2]]
+  return %1 : tensor<8x4xi32>
+}


        


More information about the Mlir-commits mailing list