[Mlir-commits] [mlir] d12fa33 - [mlir] Add a TensorLoadToMemref canonicalization

Nicolas Vasilache llvmlistbot at llvm.org
Fri Feb 19 01:47:02 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-19T09:38:33Z
New Revision: d12fa33d736d60d419f86b4ec5f3e77e602d4b1e

URL: https://github.com/llvm/llvm-project/commit/d12fa33d736d60d419f86b4ec5f3e77e602d4b1e
DIFF: https://github.com/llvm/llvm-project/commit/d12fa33d736d60d419f86b4ec5f3e77e602d4b1e.diff

LOG: [mlir] Add a TensorLoadToMemref canonicalization

A folder of `tensor_load + tensor_to_memref` exists but it only applies when
source and destination memref types are the same.

This revision adds a canonicalize `tensor_load + tensor_to_memref` to `memref_cast`
when type mismatches prevent folding to kick in.

Differential Revision: https://reviews.llvm.org/D97038

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 084d3fdfb2bf..046033cc7f9d 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3838,11 +3838,34 @@ struct TensorCastToMemref : public OpRewritePattern<TensorToMemrefOp> {
     return success();
   }
 };
+
+/// Canonicalize tensor_load + tensor_to_memref to memref_cast when type
+/// mismatches prevent `TensorToMemrefOp::fold` to kick in.
+struct TensorLoadToMemref : public OpRewritePattern<TensorToMemrefOp> {
+  using OpRewritePattern<TensorToMemrefOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TensorToMemrefOp tensorToMemRef,
+                                PatternRewriter &rewriter) const final {
+    auto tensorLoad = tensorToMemRef.tensor().getDefiningOp<TensorLoadOp>();
+    // Bail unless we have a tensor_load + tensor_to_memref with 
diff erent
+    // types. `TensorToMemrefOp::fold` handles the same type case.
+    if (!tensorLoad ||
+        tensorLoad.memref().getType() == tensorToMemRef.getType())
+      return failure();
+    // If types are not cast-compatible, bail.
+    if (!MemRefCastOp::areCastCompatible(tensorLoad.memref().getType(),
+                                         tensorToMemRef.getType()))
+      return failure();
+    rewriter.replaceOpWithNewOp<MemRefCastOp>(
+        tensorToMemRef, tensorToMemRef.getType(), tensorLoad.memref());
+    return success();
+  }
+};
 } // namespace
 
 void TensorToMemrefOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<TensorCastToMemref>(context);
+  results.insert<TensorCastToMemref, TensorLoadToMemref>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 5c437ae3dda4..ff5ca24f7587 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt %s -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -canonicalize --split-input-file | FileCheck %s
+
+// -----
 
 // Test case: Basic folding of tensor_load(tensor_to_memref(t)) -> t
 // CHECK-LABEL:   func @tensor_load_of_tensor_to_memref(
@@ -10,6 +12,8 @@ func @tensor_load_of_tensor_to_memref(%arg0: tensor<?xf32>) -> tensor<?xf32> {
   return %1 : tensor<?xf32>
 }
 
+// -----
+
 // Test case: Basic folding of tensor_to_memref(tensor_load(m)) -> m
 // CHECK-LABEL:   func @tensor_to_memref_of_tensor_load(
 // CHECK-SAME:                                          %[[MEMREF:.*]]: memref<?xf32>) -> memref<?xf32> {
@@ -20,7 +24,11 @@ func @tensor_to_memref_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
   return %1 : memref<?xf32>
 }
 
+// -----
+
 // Test case: If the memrefs are not the same type, don't fold them.
+// Test case: If the memrefs are not cast-compatible (e.g. 
diff erent address space),
+// don't canonicalize them either.
 // CHECK-LABEL:   func @no_fold_tensor_to_memref_of_tensor_load(
 // CHECK-SAME:                                                  %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>) -> memref<?xf32, 7> {
 // CHECK:           %[[TENSOR:.*]] = tensor_load %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2>
@@ -32,6 +40,28 @@ func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref
   return %1 : memref<?xf32, 7>
 }
 
+// -----
+
+// CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)>
+// CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+
+// Test case: If the memrefs are cast-compatible, canonicalize.
+// CHECK-LABEL: func @canonicalize_tensor_to_memref_of_tensor_load(
+//  CHECK-SAME:   %[[M:.*]]: memref<?xf32, #[[$OFF_3]]>) -> memref<?xf32, #[[$OFF_UNK]]> {
+//   CHECK-NOT:   tensor_load
+//   CHECK-NOT:   tensor_to_memref
+//       CHECK:   %[[R:.*]] = memref_cast %[[M]] : memref<?xf32, #[[$OFF_3]]> to memref<?xf32, #[[$OFF_UNK]]>
+//       CHECK:   return %[[R]]
+func @canonicalize_tensor_to_memref_of_tensor_load(%arg0: memref<?xf32, offset: 3, strides: [1]>)
+  -> memref<?xf32, offset: ?, strides: [1]>
+{
+  %0 = tensor_load %arg0 : memref<?xf32, offset: 3, strides: [1]>
+  %1 = tensor_to_memref %0 : memref<?xf32, offset: ?, strides: [1]>
+  return %1 : memref<?xf32, offset: ?, strides: [1]>
+}
+
+// -----
+
 // Test case: Basic folding of dim(tensor_load(m)) -> dim(m).
 // CHECK-LABEL: func @dim_of_tensor_load(
 //  CHECK-SAME:     %[[MEMREF:[0-9a-z]*]]: memref<?xf32>
@@ -45,6 +75,8 @@ func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
   return %1 : index
 }
 
+// -----
+
 // Test case: Folding of load(tensor_to_memref(%v, %idxs))
 //            -> tensor.extract(%v, %idx)
 // CHECK-LABEL: func @load_from_tensor_to_memref(
@@ -59,6 +91,8 @@ func @load_from_tensor_to_memref(%arg0: index, %arg1: index, %arg2: tensor<?x?xf
   return %1 : f32
 }
 
+// -----
+
 // Test case: Folding of dim(tensor.generate %idx) -> %idx
 // CHECK-LABEL: func @dim_of_tensor.generate(
 //  CHECK-SAME:     %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
@@ -74,6 +108,8 @@ func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index {
   return %1 : index
 }
 
+// -----
+
 // Test case: Folding of comparisons with equal operands.
 // CHECK-LABEL: @cmpi_equal_operands
 //   CHECK-DAG:   %[[T:.*]] = constant true
@@ -96,6 +132,8 @@ func @cmpi_equal_operands(%arg0: i64)
       : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
 }
 
+// -----
+
 // Test case: Folding of dim(memref_reshape %v %shp, %idx) -> load %shp[%idx]
 // CHECK-LABEL: func @dim_of_memref_reshape(
 //  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
@@ -116,6 +154,8 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
   return %1 : index
 }
 
+// -----
+
 // Test case: Folding dim(tensor.cast %0, %idx) -> dim %0, %idx
 // CHECK-LABEL: func @fold_dim_of_tensor.cast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
@@ -132,6 +172,8 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
   return %1, %2: index, index
 }
 
+// -----
+
 // CHECK-LABEL: func @tensor_cast_to_memref
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<4x6x16x32xi8>
 //       CHECK:   %[[M:.+]] = tensor_to_memref %[[ARG0]] : memref<4x6x16x32xi8>
@@ -144,6 +186,8 @@ func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
   return %1 : memref<?x?x16x32xi8>
 }
 
+// -----
+
 // CHECK-LABEL: func @subview_of_memcast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
 //       CHECK:   %[[S:.+]] = subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
@@ -158,6 +202,8 @@ func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
   return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
 }
 
+// -----
+
 // CHECK-LABEL: func @trivial_subtensor
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
 //   CHECK-NOT:   subtensor
@@ -167,6 +213,8 @@ func @trivial_subtensor(%arg0 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
   return %0 : tensor<4x6x16x32xi8>
 }
 
+// -----
+
 // CHECK-LABEL: func @trivial_subtensor_insert
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
 //   CHECK-NOT:   subtensor
@@ -176,6 +224,8 @@ func @trivial_subtensor_insert(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x
   return %0 : tensor<4x6x16x32xi8>
 }
 
+// -----
+
 // CHECK-LABEL: func @rank_reducing_tensor_of_cast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
 //       CHECK:   %[[S:.+]] = subtensor %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>


        


More information about the Mlir-commits mailing list