[Mlir-commits] [mlir] 66f878c - [mlir][NFC] Remove Standard dialect dependency on MemRef dialect

Matthias Springer llvmlistbot at llvm.org
Mon Jun 21 01:55:51 PDT 2021


Author: Matthias Springer
Date: 2021-06-21T17:55:23+09:00
New Revision: 66f878cee91047162e7913cf9533e5313988175a

URL: https://github.com/llvm/llvm-project/commit/66f878cee91047162e7913cf9533e5313988175a
DIFF: https://github.com/llvm/llvm-project/commit/66f878cee91047162e7913cf9533e5313988175a.diff

LOG: [mlir][NFC] Remove Standard dialect dependency on MemRef dialect

* Remove dependency: Standard --> MemRef
* Add dependencies: GPUToNVVMTransforms --> MemRef, Linalg --> MemRef, MemRef --> Tensor
* Note: The `subtensor_insert_propagate_dest_cast` test case in MemRef/canonicalize.mlir will be moved to Tensor/canonicalize.mlir in a subsequent commit, which moves over the remaining Tensor ops from the Standard dialect to the Tensor dialect.

Differential Revision: https://reviews.llvm.org/D104506

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
    mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 07a378b39d066..9d1e3baad8ee7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -33,7 +33,8 @@ def Linalg_Dialect : Dialect {
   }];
   let cppNamespace = "::mlir::linalg";
   let dependentDialects = [
-    "AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
+    "AffineDialect", "memref::MemRefDialect", "StandardOpsDialect",
+    "tensor::TensorDialect"
   ];
   let hasCanonicalizer = 1;
   let hasOperationAttrVerify = 1;

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index e1167b86e54e7..3ada4fbfdfb18 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_LINALG_LINALGOPS_H_
 
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
index af09671e32518..c5cfdd15c00a8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_LINALG_LINALGTYPES_H_
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dialect.h"

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8841af104a360..8a085d5573c64 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -10,6 +10,7 @@
 #define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
 
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Utils.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/Identifier.h"

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 1f694fc4d22d8..20ff5df914250 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_MEMREF_IR_MEMREF_H_
 #define MLIR_DIALECT_MEMREF_IR_MEMREF_H_
 
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/CastInterfaces.h"

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
index ada1e526ca42e..4e29acf99e31a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
@@ -19,6 +19,7 @@ def MemRef_Dialect : Dialect {
     manipulation ops, which are not strongly associated with any particular
     other dialect or domain abstraction.
   }];
+  let dependentDialects = ["tensor::TensorDialect"];
   let hasConstantMaterializer = 1;
 }
 

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index ee1cc67dee457..a9aa9810d27f5 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -14,7 +14,6 @@
 #ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H
 #define MLIR_DIALECT_STANDARDOPS_IR_OPS_H
 
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 7216e3d2ed5c2..e71d497b7ab66 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -27,9 +27,6 @@ def StandardOps_Dialect : Dialect {
   let name = "std";
   let cppNamespace = "::mlir";
   let hasConstantMaterializer = 1;
-  // TODO: This dependency is needed to handle memref ops in the
-  // canonicalize pass and should be resolved.
-  let dependentDialects = ["memref::MemRefDialect"];
 }
 
 // Base class for Standard dialect ops.

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index ea336dc68f0e2..3e102a0c9966c 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/GPU/Passes.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 227890bc1f661..1c0b0313c0dfc 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -17,6 +17,7 @@
 #include "../PassDetail.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorUtils.h"

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index cfd21dcadd0d7..8a86efd7ef1fe 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -17,6 +17,7 @@
 #include "../PassDetail.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorUtils.h"

diff  --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 01ff9b1974d45..851ec5051a6be 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LoopUtils.h"

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 2b345a42c9a1e..8323dbe23342d 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/StandardOps/Utils/Utils.h"

diff  --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
index 126041a6e9f7f..9ee8075892426 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Transforms/Bufferize.h"
 #include "PassDetail.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Passes.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/SCF/Transforms.h"

diff  --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
index 95bc5457933ee..ea2f97d7d0569 100644
--- a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Transforms/Bufferize.h"
 #include "PassDetail.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
 #include "mlir/Pass/Pass.h"
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 684a97580fda8..c3defdc4c5469 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -43,6 +43,7 @@
 
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 140cd43ede147..ec5fd88e69cde 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -217,3 +217,177 @@ func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
   return
 }
 
+// -----
+
+// Test case: Folding of memref.load(memref.buffer_cast(%v, %idxs))
+//            -> tensor.extract(%v, %idx)
+// CHECK-LABEL: func @load_from_buffer_cast(
+//  CHECK-SAME:     %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+//  CHECK-SAME:     %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
+//       CHECK:   %[[RES:.*]] = tensor.extract %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
+//   CHECK-NOT:   memref.load
+//       CHECK:   return %[[RES]] : f32
+func @load_from_buffer_cast(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {
+  %0 = memref.buffer_cast %arg2 : memref<?x?xf32>
+  %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
+  return %1 : f32
+}
+
+// -----
+
+
+// Test case: Basic folding of memref.dim(memref.tensor_load(m)) -> memref.dim(m).
+// CHECK-LABEL: func @dim_of_tensor_load(
+//  CHECK-SAME:     %[[MEMREF:[0-9a-z]*]]: memref<?xf32>
+//       CHECK:   %[[C0:.*]] = constant 0
+//       CHECK:   %[[D:.*]] = memref.dim %[[MEMREF]], %[[C0]]
+//       CHECK:   return %[[D]] : index
+func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
+  %c0 = constant 0 : index
+  %0 = memref.tensor_load %arg0 : memref<?xf32>
+  %1 = memref.dim %0, %c0 : tensor<?xf32>
+  return %1 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(tensor.generate %idx) -> %idx
+// CHECK-LABEL: func @dim_of_tensor.generate(
+//  CHECK-SAME:     %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+//   CHECK-NOT:   memref.dim
+//       CHECK:   return %[[IDX1]] : index
+func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index {
+  %c3 = constant 3 : index
+  %0 = tensor.generate %arg0, %arg1 {
+  ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %c3 : index
+  } : tensor<2x?x4x?x5xindex>
+  %1 = memref.dim %0, %c3 : tensor<2x?x4x?x5xindex>
+  return %1 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.alloca(%size), %idx) -> %size
+// CHECK-LABEL: func @dim_of_alloca(
+//  CHECK-SAME:     %[[SIZE:[0-9a-z]+]]: index
+//  CHECK-NEXT:   return %[[SIZE]] : index
+func @dim_of_alloca(%size: index) -> index {
+  %0 = memref.alloca(%size) : memref<?xindex>
+  %c0 = constant 0 : index
+  %1 = memref.dim %0, %c0 : memref<?xindex>
+  return %1 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v)
+// CHECK-LABEL: func @dim_of_alloca_with_dynamic_size(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>
+//  CHECK-NEXT:   %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32>
+//  CHECK-NEXT:   return %[[RANK]] : index
+func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
+  %0 = rank %arg0 : memref<*xf32>
+  %1 = memref.alloca(%0) : memref<?xindex>
+  %c0 = constant 0 : index
+  %2 = memref.dim %1, %c0 : memref<?xindex>
+  return %2 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xindex>
+//  CHECK-NEXT:   %[[IDX:.*]] = constant 3
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
+//  CHECK-NEXT:   memref.store
+//   CHECK-NOT:   memref.dim
+//       CHECK:   return %[[DIM]] : index
+func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
+    -> index {
+  %c3 = constant 3 : index
+  %0 = memref.reshape %arg0(%arg1)
+      : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+  // Update the shape to test that he load ends up in the right place.
+  memref.store %c3, %arg1[%c3] : memref<?xindex>
+  %1 = memref.dim %0, %c3 : memref<*xf32>
+  return %1 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_i32(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xi32>
+//  CHECK-NEXT:   %[[IDX:.*]] = constant 3
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
+//  CHECK-NEXT:   %[[CAST:.*]] = index_cast %[[DIM]]
+//   CHECK-NOT:   memref.dim
+//       CHECK:   return %[[CAST]] : index
+func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
+    -> index {
+  %c3 = constant 3 : index
+  %0 = memref.reshape %arg0(%arg1)
+      : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
+  %1 = memref.dim %0, %c3 : memref<*xf32>
+  return %1 : index
+}
+
+// -----
+
+// Test case: Folding memref.dim(tensor.cast %0, %idx) -> memref.dim %0, %idx
+// CHECK-LABEL: func @fold_dim_of_tensor.cast
+//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[C4:.+]] = constant 4 : index
+//       CHECK:   %[[T0:.+]] = memref.dim %[[ARG0]], %[[C1]]
+//  CHECK-NEXT:   return %[[C4]], %[[T0]]
+func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
+  %1 = memref.dim %0, %c0 : tensor<?x?xf32>
+  %2 = memref.dim %0, %c1 : tensor<?x?xf32>
+  return %1, %2: index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_cast_to_memref
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<4x6x16x32xi8>
+//       CHECK:   %[[M:.+]] = memref.buffer_cast %[[ARG0]] : memref<4x6x16x32xi8>
+//       CHECK:   %[[M1:.+]] = memref.cast %[[M]] : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
+//       CHECK:   return %[[M1]] : memref<?x?x16x32xi8>
+func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
+  memref<?x?x16x32xi8> {
+  %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
+  %1 = memref.buffer_cast %0 : memref<?x?x16x32xi8>
+  return %1 : memref<?x?x16x32xi8>
+}
+
+// -----
+
+// TODO: Move this test to Tensor/canonicalize.mlir.
+func @subtensor_insert_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
+    %arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c8 = constant 8 : index
+  %0 = memref.dim %arg0, %c1 : tensor<2x?xi32>
+  %1 = tensor.extract %arg1[] : tensor<i32>
+  %2 = tensor.generate %arg2, %c8 {
+  ^bb0(%arg4: index, %arg5: index):
+    tensor.yield %1 : i32
+  } : tensor<?x?xi32>
+  %3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32>
+  return %3 : tensor<?x?xi32>
+}
+// CHECK-LABEL: func @subtensor_insert_propagate_dest_cast
+//       CHECK:   %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
+//  CHECK-SAME:     tensor<2x?xi32> into tensor<?x8xi32>
+//       CHECK:   %[[CAST:.+]] = tensor.cast %[[UPDATED]]
+//       CHECK:   return %[[CAST]]

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 24db1d295ffc3..aed84cfd16c31 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -1,53 +1,5 @@
 // RUN: mlir-opt %s -canonicalize --split-input-file | FileCheck %s
 
-// Test case: Basic folding of memref.dim(memref.tensor_load(m)) -> memref.dim(m).
-// CHECK-LABEL: func @dim_of_tensor_load(
-//  CHECK-SAME:     %[[MEMREF:[0-9a-z]*]]: memref<?xf32>
-//       CHECK:   %[[C0:.*]] = constant 0
-//       CHECK:   %[[D:.*]] = memref.dim %[[MEMREF]], %[[C0]]
-//       CHECK:   return %[[D]] : index
-func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
-  %c0 = constant 0 : index
-  %0 = memref.tensor_load %arg0 : memref<?xf32>
-  %1 = memref.dim %0, %c0 : tensor<?xf32>
-  return %1 : index
-}
-
-// -----
-
-// Test case: Folding of memref.load(memref.buffer_cast(%v, %idxs))
-//            -> tensor.extract(%v, %idx)
-// CHECK-LABEL: func @load_from_buffer_cast(
-//  CHECK-SAME:     %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
-//  CHECK-SAME:     %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
-//       CHECK:   %[[RES:.*]] = tensor.extract %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
-//   CHECK-NOT:   memref.load
-//       CHECK:   return %[[RES]] : f32
-func @load_from_buffer_cast(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {
-  %0 = memref.buffer_cast %arg2 : memref<?x?xf32>
-  %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
-  return %1 : f32
-}
-
-// -----
-
-// Test case: Folding of memref.dim(tensor.generate %idx) -> %idx
-// CHECK-LABEL: func @dim_of_tensor.generate(
-//  CHECK-SAME:     %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
-//   CHECK-NOT:   memref.dim
-//       CHECK:   return %[[IDX1]] : index
-func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index {
-  %c3 = constant 3 : index
-  %0 = tensor.generate %arg0, %arg1 {
-  ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
-    tensor.yield %c3 : index
-  } : tensor<2x?x4x?x5xindex>
-  %1 = memref.dim %0, %c3 : tensor<2x?x4x?x5xindex>
-  return %1 : index
-}
-
-// -----
-
 // Test case: Folding of comparisons with equal operands.
 // CHECK-LABEL: @cmpi_equal_operands
 //   CHECK-DAG:   %[[T:.*]] = constant true
@@ -72,108 +24,6 @@ func @cmpi_equal_operands(%arg0: i64)
 
 // -----
 
-// Test case: Folding of memref.dim(memref.alloca(%size), %idx) -> %size
-// CHECK-LABEL: func @dim_of_alloca(
-//  CHECK-SAME:     %[[SIZE:[0-9a-z]+]]: index
-//  CHECK-NEXT:   return %[[SIZE]] : index
-func @dim_of_alloca(%size: index) -> index {
-  %0 = memref.alloca(%size) : memref<?xindex>
-  %c0 = constant 0 : index
-  %1 = memref.dim %0, %c0 : memref<?xindex>
-  return %1 : index
-}
-
-// -----
-
-// Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v)
-// CHECK-LABEL: func @dim_of_alloca_with_dynamic_size(
-//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>
-//  CHECK-NEXT:   %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32>
-//  CHECK-NEXT:   return %[[RANK]] : index
-func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
-  %0 = rank %arg0 : memref<*xf32>
-  %1 = memref.alloca(%0) : memref<?xindex>
-  %c0 = constant 0 : index
-  %2 = memref.dim %1, %c0 : memref<?xindex>
-  return %2 : index
-}
-
-// -----
-
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape(
-//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xindex>
-//  CHECK-NEXT:   %[[IDX:.*]] = constant 3
-//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-//  CHECK-NEXT:   memref.store
-//   CHECK-NOT:   memref.dim
-//       CHECK:   return %[[DIM]] : index
-func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
-    -> index {
-  %c3 = constant 3 : index
-  %0 = memref.reshape %arg0(%arg1)
-      : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
-  // Update the shape to test that he load ends up in the right place.
-  memref.store %c3, %arg1[%c3] : memref<?xindex>
-  %1 = memref.dim %0, %c3 : memref<*xf32>
-  return %1 : index
-}
-
-// -----
-
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape_i32(
-//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xi32>
-//  CHECK-NEXT:   %[[IDX:.*]] = constant 3
-//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-//  CHECK-NEXT:   %[[CAST:.*]] = index_cast %[[DIM]]
-//   CHECK-NOT:   memref.dim
-//       CHECK:   return %[[CAST]] : index
-func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
-    -> index {
-  %c3 = constant 3 : index
-  %0 = memref.reshape %arg0(%arg1)
-      : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
-  %1 = memref.dim %0, %c3 : memref<*xf32>
-  return %1 : index
-}
-
-// -----
-
-// Test case: Folding memref.dim(tensor.cast %0, %idx) -> memref.dim %0, %idx
-// CHECK-LABEL: func @fold_dim_of_tensor.cast
-//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
-//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
-//   CHECK-DAG:   %[[C4:.+]] = constant 4 : index
-//       CHECK:   %[[T0:.+]] = memref.dim %[[ARG0]], %[[C1]]
-//  CHECK-NEXT:   return %[[C4]], %[[T0]]
-func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
-  %1 = memref.dim %0, %c0 : tensor<?x?xf32>
-  %2 = memref.dim %0, %c1 : tensor<?x?xf32>
-  return %1, %2: index, index
-}
-
-// -----
-
-// CHECK-LABEL: func @tensor_cast_to_memref
-//  CHECK-SAME:   %[[ARG0:.+]]: tensor<4x6x16x32xi8>
-//       CHECK:   %[[M:.+]] = memref.buffer_cast %[[ARG0]] : memref<4x6x16x32xi8>
-//       CHECK:   %[[M1:.+]] = memref.cast %[[M]] : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
-//       CHECK:   return %[[M1]] : memref<?x?x16x32xi8>
-func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
-  memref<?x?x16x32xi8> {
-  %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
-  %1 = memref.buffer_cast %0 : memref<?x?x16x32xi8>
-  return %1 : memref<?x?x16x32xi8>
-}
-
-// -----
-
 func @subtensor_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
     %arg2 : index) -> tensor<?x?x?xf32>
 {
@@ -345,29 +195,6 @@ func @rank_reducing_subtensor_to_subtensor_insert_canonicalize(%arg0 : tensor<?x
 
 // -----
 
-func @subtensor_insert_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
-    %arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %c2 = constant 2 : index
-  %c8 = constant 8 : index
-  %0 = memref.dim %arg0, %c1 : tensor<2x?xi32>
-  %1 = tensor.extract %arg1[] : tensor<i32>
-  %2 = tensor.generate %arg2, %c8 {
-  ^bb0(%arg4: index, %arg5: index):
-    tensor.yield %1 : i32
-  } : tensor<?x?xi32>
-  %3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32>
-  return %3 : tensor<?x?xi32>
-}
-// CHECK-LABEL: func @subtensor_insert_propagate_dest_cast
-//       CHECK:   %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
-//  CHECK-SAME:     tensor<2x?xi32> into tensor<?x8xi32>
-//       CHECK:   %[[CAST:.+]] = tensor.cast %[[UPDATED]]
-//       CHECK:   return %[[CAST]]
-
-// -----
-
 func @subtensor_insert_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor<i32>) -> tensor<3x9xi32> {
   %c0 = constant 0 : index
   %c1 = constant 1 : index


        


More information about the Mlir-commits mailing list