[Mlir-commits] [mlir] c118fdc - [mlir] Remove incorrect folding for SubTensorInsertOp

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 3 13:58:23 PST 2021


Author: MaheshRavishankar
Date: 2021-03-03T13:58:05-08:00
New Revision: c118fdcd5970d66abf9cc3f0d09544b269b01cc8

URL: https://github.com/llvm/llvm-project/commit/c118fdcd5970d66abf9cc3f0d09544b269b01cc8
DIFF: https://github.com/llvm/llvm-project/commit/c118fdcd5970d66abf9cc3f0d09544b269b01cc8.diff

LOG: [mlir] Remove incorrect folding for SubTensorInsertOp

The SubTensorInsertOp has a requirement that dest type and result
type match. Just folding the tensor.cast operation violates this and
creates verification errors during canonicalization. Also fix other
canonicalization methods that werent inserting casts properly.

Differential Revision: https://reviews.llvm.org/D97800

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 46e5780e151f..4a2999aeaa37 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2427,7 +2427,19 @@ struct FoldTensorCastOp : public RewritePattern {
     // Clone op.
     Operation *newOp =
         linalgOp.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
-    rewriter.replaceOp(op, newOp->getResults());
+    SmallVector<Value, 4> replacements;
+    replacements.reserve(newOp->getNumResults());
+    for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
+      Value oldResult = std::get<0>(result);
+      Value newResult = std::get<1>(result);
+      if (newResult.getType() != oldResult.getType()) {
+        replacements.push_back(rewriter.create<tensor::CastOp>(
+            op->getLoc(), oldResult.getType(), newResult));
+      } else {
+        replacements.push_back(newResult);
+      }
+    }
+    rewriter.replaceOp(op, replacements);
 
     return success();
   }

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 536d71d89d4f..7a71f09adf63 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3305,7 +3305,11 @@ static void replaceWithNewOp(PatternRewriter &rewriter, SubViewOp op,
 
 static void replaceWithNewOp(PatternRewriter &rewriter, SubTensorOp op,
                              SubTensorOp newOp) {
-  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
+  Value replacement = newOp.getResult();
+  if (replacement.getType() != op.getType())
+    replacement =
+        rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), replacement);
+  rewriter.replaceOp(op, replacement);
 }
 
 /// Pattern to rewrite a subview op with constant arguments.
@@ -3789,11 +3793,10 @@ void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
 }
 
 OpFoldResult SubTensorInsertOp::fold(ArrayRef<Attribute>) {
-  if (getSourceType() == getType() &&
+  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
+      getSourceType() == getType() &&
       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
     return this->source();
-  if (succeeded(tensor::foldTensorCast(*this)))
-    return this->source();
   return OpFoldResult();
 }
 
@@ -3847,9 +3850,9 @@ struct SubTensorInsertOpCastFolder final
     : public OpRewritePattern<SubTensorInsertOp> {
   using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(SubTensorInsertOp subTensorOp,
+  LogicalResult matchAndRewrite(SubTensorInsertOp subTensorInsertOp,
                                 PatternRewriter &rewriter) const override {
-    if (llvm::any_of(subTensorOp.getOperands(), [](Value operand) {
+    if (llvm::any_of(subTensorInsertOp.getOperands(), [](Value operand) {
           return matchPattern(operand, m_ConstantIndex());
         }))
       return failure();
@@ -3860,21 +3863,25 @@ struct SubTensorInsertOpCastFolder final
         return llvm::None;
       return castOp.source();
     };
-    Optional<Value> sourceCastSource = getSourceOfCastOp(subTensorOp.source());
-    Optional<Value> destCastSource = getSourceOfCastOp(subTensorOp.dest());
-    if (!sourceCastSource && !destCastSource &&
-        subTensorOp.dest().getType() == subTensorOp.getResult().getType())
+    Optional<Value> sourceCastSource =
+        getSourceOfCastOp(subTensorInsertOp.source());
+    Optional<Value> destCastSource =
+        getSourceOfCastOp(subTensorInsertOp.dest());
+    if (!sourceCastSource && !destCastSource)
       return failure();
 
-    auto newOp = rewriter.create<SubTensorInsertOp>(
-        subTensorOp.getLoc(),
-        (sourceCastSource ? *sourceCastSource : subTensorOp.source()),
-        (destCastSource ? *destCastSource : subTensorOp.dest()),
-        subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(),
-        subTensorOp.getMixedStrides());
+    Value replacement = rewriter.create<SubTensorInsertOp>(
+        subTensorInsertOp.getLoc(),
+        (sourceCastSource ? *sourceCastSource : subTensorInsertOp.source()),
+        (destCastSource ? *destCastSource : subTensorInsertOp.dest()),
+        subTensorInsertOp.getMixedOffsets(), subTensorInsertOp.getMixedSizes(),
+        subTensorInsertOp.getMixedStrides());
 
-    rewriter.replaceOpWithNewOp<tensor::CastOp>(subTensorOp,
-                                                subTensorOp.getType(), newOp);
+    if (replacement.getType() != subTensorInsertOp.getType()) {
+      replacement = rewriter.create<tensor::CastOp>(
+          subTensorInsertOp.getLoc(), subTensorInsertOp.getType(), replacement);
+    }
+    rewriter.replaceOp(subTensorInsertOp, replacement);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 2fb5eb3086e6..f2f3a44169e8 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -767,3 +767,25 @@ func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
 //      CHECK:   %[[D0:.+]] = dim %[[ARG0]], %[[C4]]
 //      CHECK:   %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
 //      CHECK:   return %[[C5]], %[[D1]]
+
+// -----
+
+func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
+    %arg3 : index) -> tensor<?x?xf32> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c21 = constant 21 : index
+  %c42 = constant 42 : index
+  %0 = linalg.init_tensor [%c21, %c42] : tensor<?x?xf32>
+  %1 = linalg.fill(%0, %arg1) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
+  %2 = dim %arg0, %c0 : tensor<?x?xf32>
+  %3 = dim %arg0, %c1 : tensor<?x?xf32>
+  %4 = subtensor_insert %arg0 into %1[%arg2, %arg3] [%2, %3] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+  return %4 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @propogate_casts
+//       CHECK:   %[[INIT:.+]] = linalg.init_tensor [21, 42]
+//       CHECK:   %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}})
+//       CHECK:   %[[INSERTED:.+]] = subtensor_insert %{{.+}} into %[[FILL]]
+//       CHECK:   %[[RESULT:.+]] = tensor.cast %[[INSERTED]]
+//       CHECK:   return %[[RESULT]]


        


More information about the Mlir-commits mailing list