[llvm-branch-commits] [mlir] de03121 - [mlir] Add canonicalization from `tensor_cast` to `dim` op.

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Dec 17 14:54:21 PST 2020


Author: MaheshRavishankar
Date: 2020-12-17T14:45:51-08:00
New Revision: de031216bf1755e61418a1515f2b0db0a9cfeddc

URL: https://github.com/llvm/llvm-project/commit/de031216bf1755e61418a1515f2b0db0a9cfeddc
DIFF: https://github.com/llvm/llvm-project/commit/de031216bf1755e61418a1515f2b0db0a9cfeddc.diff

LOG: [mlir] Add canonicalization from `tensor_cast` to `dim` op.

Fold a `tensor_cast` -> `dim` to take the `dim` of the original tensor.

Differential Revision: https://reviews.llvm.org/D93492

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index b19264ec4972..7ed9cffa8806 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1472,11 +1472,29 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
     return success();
   }
 };
+
+/// Fold dim of a dim of a cast into the the dim of the source of the tensor
+/// cast.
+template <typename CastOpTy>
+struct DimOfCastOp : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>();
+    if (!castOp)
+      return failure();
+    Value newSource = castOp.getOperand();
+    rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
+    return success();
+  }
+};
+
 } // end anonymous namespace.
 
 void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                         MLIRContext *context) {
-  results.insert<DimOfMemRefReshape>(context);
+  results.insert<DimOfMemRefReshape, DimOfCastOp<TensorCastOp>>(context);
 }
 
 // ---------------------------------------------------------------------------

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index d3c781d1290f..af67453a1f3c 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -115,3 +115,19 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
   %1 = dim %0, %c3 : memref<*xf32>
   return %1 : index
 }
+
+// Test case: Folding dim(tensor_cast %0, %idx) -> dim %0, %idx
+// CHECK-LABEL: func @fold_dim_of_tensor_cast
+//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[C4:.+]] = constant 4 : index
+//       CHECK:   %[[T0:.+]] = dim %[[ARG0]], %[[C1]]
+//  CHECK-NEXT:   return %[[C4]], %[[T0]]
+func @fold_dim_of_tensor_cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = tensor_cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
+  %1 = dim %0, %c0 : tensor<?x?xf32>
+  %2 = dim %0, %c1 : tensor<?x?xf32>
+  return %1, %2: index, index
+}


        


More information about the llvm-branch-commits mailing list