[Mlir-commits] [mlir] 346b9d1 - [mlir][Linalg] Canonicalize TensorCastOp away when it feeds a LinalgOp.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Oct 5 07:50:18 PDT 2020


Author: Nicolas Vasilache
Date: 2020-10-05T14:48:21Z
New Revision: 346b9d17720a0ccd920cd02b81811a4d2ddc67d6

URL: https://github.com/llvm/llvm-project/commit/346b9d17720a0ccd920cd02b81811a4d2ddc67d6
DIFF: https://github.com/llvm/llvm-project/commit/346b9d17720a0ccd920cd02b81811a4d2ddc67d6.diff

LOG: [mlir][Linalg] Canonicalize TensorCastOp away when it feeds a LinalgOp.

This canonicalization is the counterpart of MemRefCastOp -> LinalgOp but on tensors.

This is needed to properly canonicalize post linalg tiling on tensors.

Differential Revision: https://reviews.llvm.org/D88729

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index f51f7b913027..44c6b77ee404 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -404,6 +404,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return getInitTensors()[i];
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the number of inputs, output buffers and init tensors operands.
+      }],
+      /*retTy=*/"unsigned",
+      /*methodName=*/"getNumShapedOperands",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto range = this->getOperation()->getOperands();
+        return getNumInputsAndOutputBuffers() + $_op.getNumInitTensors();
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the range over inputs, output buffers and init tensors.
@@ -414,7 +427,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         auto range = this->getOperation()->getOperands();
-        return {range.begin(), range.begin() + getNumInputsAndOutputs()};
+        return {range.begin(), range.begin() + getNumShapedOperands()};
       }]
     >,
     InterfaceMethod<
@@ -621,6 +634,27 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       }]
     >
   ];
+
+  let extraClassDeclaration = [{
+    /// Returns all the operands past the inputs, output_buffers and
+    /// init_tensors operands. Asserts that these operands are value types to
+    /// allow transformations like tiling to just use the values when cloning
+    /// `linalgOp`.
+    SmallVector<Value, 4> getAssumedNonShapedOperands() {
+      unsigned numShapedOperands = getNumInputsAndOutputs();
+      unsigned nExtraOperands =
+        getOperation()->getNumOperands() - numShapedOperands;
+      SmallVector<Value, 4> res;
+      res.reserve(nExtraOperands);
+      for (unsigned i = 0; i < nExtraOperands; ++i) {
+        res.push_back(getOperation()->getOperand(numShapedOperands + i));
+        assert((res.back().getType().isSignlessIntOrIndexOrFloat() 
+                || res.back().getType().isa<VectorType>()) &&
+               "expected scalar or vector type");
+      }
+      return res;
+    }
+  }];
 }
 
 #endif // LINALG_IR_STRUCTURED_OPS_INTERFACE

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index fbe735e31cff..409f54384aca 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -350,6 +350,31 @@ ParseResult parseDimAndSymbolList(OpAsmParser &parser,
 /// ```
 bool canFoldIntoConsumerOp(MemRefCastOp castOp);
 
+/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
+/// Determines whether TensorCastOp casts to a more dynamic version of the
+/// source tensor. This is useful to fold a tensor_cast into a consuming op and
+/// implement canonicalization patterns for ops in 
diff erent dialects that may
+/// consume the results of tensor_cast operations. Such foldable tensor_cast
+/// operations are typically inserted as `subtensor` ops and are canonicalized,
+/// to preserve the type compatibility of their uses.
+///
+/// Returns true when all conditions are met:
+/// 1. source and result are ranked tensors with same element type and rank.
+/// 2. the tensor type has more static information than the result
+///
+/// Example:
+/// ```mlir
+///   %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+///   %2 = consumer %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %2 = consumer %0 ... : tensor<8x16xf32> ...
+/// ```
+bool canFoldIntoConsumerOp(TensorCastOp castOp);
+
 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
 /// comparison predicates.
 bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 69c979ae9e38..ab7b599dffba 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -3334,7 +3334,7 @@ def TensorCastOp : CastOp<"tensor_cast"> {
     ```
   }];
 
-  let arguments = (ins AnyTensor);
+  let arguments = (ins AnyTensor:$source);
   let results = (outs AnyTensor);
 
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e9cdb3391f4a..26aa75955e3c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Support/LLVM.h"
 
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
@@ -1498,12 +1499,65 @@ struct EraseDeadLinalgOp : public RewritePattern {
     return failure();
   }
 };
+
+struct FoldTensorCastOp : public RewritePattern {
+  FoldTensorCastOp(PatternBenefit benefit = 1)
+      : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    auto linalgOp = dyn_cast<LinalgOp>(op);
+    if (!linalgOp)
+      return failure();
+
+    // If no operand comes from a TensorCastOp and can be folded then fail.
+    bool hasTensorCastOperand =
+        llvm::any_of(linalgOp.getShapedOperands(), [&](Value v) {
+          if (v.isa<BlockArgument>())
+            return false;
+          auto castOp = v.getDefiningOp<TensorCastOp>();
+          return castOp && canFoldIntoConsumerOp(castOp);
+        });
+    if (!hasTensorCastOperand)
+      return failure();
+
+    SmallVector<Type, 4> newResultTypes;
+    newResultTypes.reserve(op->getNumResults());
+    SmallVector<Value, 4> newOperands;
+    newOperands.reserve(op->getNumOperands());
+    // Inputs may fold.
+    for (Value v : linalgOp.getInputs()) {
+      auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
+      newOperands.push_back(
+          canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
+    }
+    // Output buffers are memrefs, they don't fold.
+    newOperands.append(linalgOp.getOutputBuffers().begin(),
+                       linalgOp.getOutputBuffers().end());
+    // Init tensors may fold, in which case the resultType must also change.
+    for (Value v : linalgOp.getInitTensors()) {
+      auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
+      bool fold = canFoldIntoConsumerOp(tensorCastOp);
+      newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
+      newResultTypes.push_back(newOperands.back().getType());
+    }
+    auto extraOperands = linalgOp.getAssumedNonShapedOperands();
+    newOperands.append(extraOperands.begin(), extraOperands.end());
+    // Clone op.
+    Operation *newOp =
+        linalgOp.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
+    rewriter.replaceOp(op, newOp->getResults());
+
+    return success();
+  }
+};
 } // namespace
 
 #define CANONICALIZERS_AND_FOLDERS(XXX)                                        \
   void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results,     \
                                         MLIRContext *context) {                \
     results.insert<EraseDeadLinalgOp>();                                       \
+    results.insert<FoldTensorCastOp>();                                        \
   }                                                                            \
                                                                                \
   LogicalResult XXX::fold(ArrayRef<Attribute>,                                 \

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index a4d739135aea..f2823c564cce 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3157,6 +3157,60 @@ bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
   return true;
 }
 
+/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
+/// Determines whether TensorCastOp casts to a more dynamic version of the
+/// source tensor. This is useful to fold a tensor_cast into a consuming op and
+/// implement canonicalization patterns for ops in 
diff erent dialects that may
+/// consume the results of tensor_cast operations. Such foldable tensor_cast
+/// operations are typically inserted as `subtensor` ops and are canonicalized,
+/// to preserve the type compatibility of their uses.
+///
+/// Returns true when all conditions are met:
+/// 1. source and result are ranked tensors with same element type and rank.
+/// 2. the tensor type has more static information than the result
+///
+/// Example:
+/// ```mlir
+///   %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
+///   %2 = consumer %1 ... : tensor<?x?xf32> ...
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %2 = consumer %0 ... : tensor<8x16xf32> ...
+/// ```
+bool mlir::canFoldIntoConsumerOp(TensorCastOp castOp) {
+  if (!castOp)
+    return false;
+
+  RankedTensorType sourceType =
+      castOp.source().getType().dyn_cast<RankedTensorType>();
+  RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
+
+  // Requires RankedTensorType.
+  if (!sourceType || !resultType)
+    return false;
+
+  // Requires same elemental type.
+  if (sourceType.getElementType() != resultType.getElementType())
+    return false;
+
+  // Requires same rank.
+  if (sourceType.getRank() != resultType.getRank())
+    return false;
+
+  // If cast is towards more static sizes along any dimension, don't fold.
+  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
+    auto ss = std::get<0>(it), st = std::get<1>(it);
+    if (ss != st)
+      if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
+        return false;
+  }
+
+  return true;
+}
+
 namespace {
 /// Pattern to rewrite a subview op with MemRefCast arguments.
 /// This essentially pushes memref_cast past its consuming subview when

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 5e0890fa4bb5..cf86a97f4fcd 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -259,3 +259,23 @@ func @reshape_splat_constant_float64() -> tensor<2x4x2xf64>
 //       CHECK:   %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf64>
 //   CHECK-NOT:   linalg.tensor_reshape
 //       CHECK:   return %[[CST]]
+
+// -----
+
+// CHECK-LABEL: func @tensor_cast(
+func @tensor_cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
+  -> tensor<3x?xf32>
+{
+  %ta = tensor_cast %a : tensor<3x4xf32> to tensor<?x?xf32>
+  %tb = tensor_cast %b : tensor<4x?xf32> to tensor<?x?xf32>
+  %tc = tensor_cast %c : tensor<3x?xf32> to tensor<?x?xf32>
+
+  //      CHECK:  linalg.matmul ins({{.*}}tensor<3x4xf32>, tensor<4x?xf32>)
+  // CHECK-SAME:    init({{.*}}tensor<3x?xf32>) -> tensor<3x?xf32>
+  %0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>)
+               init(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  %1 = tensor_cast %0 : tensor<?x?xf32> to tensor<3x?xf32>
+
+  return %1: tensor<3x?xf32>
+}


        


More information about the Mlir-commits mailing list