[Mlir-commits] [mlir] 66f878c - [mlir][NFC] Remove Standard dialect dependency on MemRef dialect
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 21 01:55:51 PDT 2021
Author: Matthias Springer
Date: 2021-06-21T17:55:23+09:00
New Revision: 66f878cee91047162e7913cf9533e5313988175a
URL: https://github.com/llvm/llvm-project/commit/66f878cee91047162e7913cf9533e5313988175a
DIFF: https://github.com/llvm/llvm-project/commit/66f878cee91047162e7913cf9533e5313988175a.diff
LOG: [mlir][NFC] Remove Standard dialect dependency on MemRef dialect
* Remove dependency: Standard --> MemRef
* Add dependencies: GPUToNVVMTransforms --> MemRef, Linalg --> MemRef, MemRef --> Tensor
* Note: The `subtensor_insert_propagate_dest_cast` test case in MemRef/canonicalize.mlir will be moved to Tensor/canonicalize.mlir in a subsequent commit, which moves over the remaining Tensor ops from the Standard dialect to the Tensor dialect.
Differential Revision: https://reviews.llvm.org/D104506
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 07a378b39d066..9d1e3baad8ee7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -33,7 +33,8 @@ def Linalg_Dialect : Dialect {
}];
let cppNamespace = "::mlir::linalg";
let dependentDialects = [
- "AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
+ "AffineDialect", "memref::MemRefDialect", "StandardOpsDialect",
+ "tensor::TensorDialect"
];
let hasCanonicalizer = 1;
let hasOperationAttrVerify = 1;
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index e1167b86e54e7..3ada4fbfdfb18 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_LINALG_LINALGOPS_H_
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
index af09671e32518..c5cfdd15c00a8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_LINALG_LINALGTYPES_H_
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8841af104a360..8a085d5573c64 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -10,6 +10,7 @@
#define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Utils.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Identifier.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 1f694fc4d22d8..20ff5df914250 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_MEMREF_IR_MEMREF_H_
#define MLIR_DIALECT_MEMREF_IR_MEMREF_H_
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
index ada1e526ca42e..4e29acf99e31a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td
@@ -19,6 +19,7 @@ def MemRef_Dialect : Dialect {
manipulation ops, which are not strongly associated with any particular
other dialect or domain abstraction.
}];
+ let dependentDialects = ["tensor::TensorDialect"];
let hasConstantMaterializer = 1;
}
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index ee1cc67dee457..a9aa9810d27f5 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -14,7 +14,6 @@
#ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H
#define MLIR_DIALECT_STANDARDOPS_IR_OPS_H
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 7216e3d2ed5c2..e71d497b7ab66 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -27,9 +27,6 @@ def StandardOps_Dialect : Dialect {
let name = "std";
let cppNamespace = "::mlir";
let hasConstantMaterializer = 1;
- // TODO: This dependency is needed to handle memref ops in the
- // canonicalize pass and should be resolved.
- let dependentDialects = ["memref::MemRefDialect"];
}
// Base class for Standard dialect ops.
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index ea336dc68f0e2..3e102a0c9966c 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 227890bc1f661..1c0b0313c0dfc 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -17,6 +17,7 @@
#include "../PassDetail.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index cfd21dcadd0d7..8a86efd7ef1fe 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -17,6 +17,7 @@
#include "../PassDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 01ff9b1974d45..851ec5051a6be 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -23,6 +23,7 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 2b345a42c9a1e..8323dbe23342d 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
index 126041a6e9f7f..9ee8075892426 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
@@ -8,6 +8,7 @@
#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SCF/Transforms.h"
diff --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
index 95bc5457933ee..ea2f97d7d0569 100644
--- a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
@@ -8,6 +8,7 @@
#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 684a97580fda8..c3defdc4c5469 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -43,6 +43,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 140cd43ede147..ec5fd88e69cde 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -217,3 +217,177 @@ func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
return
}
+// -----
+
+// Test case: Folding of memref.load(memref.buffer_cast(%v, %idxs))
+// -> tensor.extract(%v, %idx)
+// CHECK-LABEL: func @load_from_buffer_cast(
+// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK: %[[RES:.*]] = tensor.extract %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
+// CHECK-NOT: memref.load
+// CHECK: return %[[RES]] : f32
+func @load_from_buffer_cast(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {
+ %0 = memref.buffer_cast %arg2 : memref<?x?xf32>
+ %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
+ return %1 : f32
+}
+
+// -----
+
+
+// Test case: Basic folding of memref.dim(memref.tensor_load(m)) -> memref.dim(m).
+// CHECK-LABEL: func @dim_of_tensor_load(
+// CHECK-SAME: %[[MEMREF:[0-9a-z]*]]: memref<?xf32>
+// CHECK: %[[C0:.*]] = constant 0
+// CHECK: %[[D:.*]] = memref.dim %[[MEMREF]], %[[C0]]
+// CHECK: return %[[D]] : index
+func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
+ %c0 = constant 0 : index
+ %0 = memref.tensor_load %arg0 : memref<?xf32>
+ %1 = memref.dim %0, %c0 : tensor<?xf32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(tensor.generate %idx) -> %idx
+// CHECK-LABEL: func @dim_of_tensor.generate(
+// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+// CHECK-NOT: memref.dim
+// CHECK: return %[[IDX1]] : index
+func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index {
+ %c3 = constant 3 : index
+ %0 = tensor.generate %arg0, %arg1 {
+ ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %c3 : index
+ } : tensor<2x?x4x?x5xindex>
+ %1 = memref.dim %0, %c3 : tensor<2x?x4x?x5xindex>
+ return %1 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.alloca(%size), %idx) -> %size
+// CHECK-LABEL: func @dim_of_alloca(
+// CHECK-SAME: %[[SIZE:[0-9a-z]+]]: index
+// CHECK-NEXT: return %[[SIZE]] : index
+func @dim_of_alloca(%size: index) -> index {
+ %0 = memref.alloca(%size) : memref<?xindex>
+ %c0 = constant 0 : index
+ %1 = memref.dim %0, %c0 : memref<?xindex>
+ return %1 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v)
+// CHECK-LABEL: func @dim_of_alloca_with_dynamic_size(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>
+// CHECK-NEXT: %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32>
+// CHECK-NEXT: return %[[RANK]] : index
+func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
+ %0 = rank %arg0 : memref<*xf32>
+ %1 = memref.alloca(%0) : memref<?xindex>
+ %c0 = constant 0 : index
+ %2 = memref.dim %1, %c0 : memref<?xindex>
+ return %2 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
+// CHECK-NEXT: %[[IDX:.*]] = constant 3
+// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
+// CHECK-NEXT: memref.store
+// CHECK-NOT: memref.dim
+// CHECK: return %[[DIM]] : index
+func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
+ -> index {
+ %c3 = constant 3 : index
+ %0 = memref.reshape %arg0(%arg1)
+ : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
+ // Update the shape to test that he load ends up in the right place.
+ memref.store %c3, %arg1[%c3] : memref<?xindex>
+ %1 = memref.dim %0, %c3 : memref<*xf32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_i32(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
+// CHECK-NEXT: %[[IDX:.*]] = constant 3
+// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
+// CHECK-NEXT: %[[CAST:.*]] = index_cast %[[DIM]]
+// CHECK-NOT: memref.dim
+// CHECK: return %[[CAST]] : index
+func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
+ -> index {
+ %c3 = constant 3 : index
+ %0 = memref.reshape %arg0(%arg1)
+ : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
+ %1 = memref.dim %0, %c3 : memref<*xf32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: Folding memref.dim(tensor.cast %0, %idx) -> memref.dim %0, %idx
+// CHECK-LABEL: func @fold_dim_of_tensor.cast
+// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-NEXT: return %[[C4]], %[[T0]]
+func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
+ %1 = memref.dim %0, %c0 : tensor<?x?xf32>
+ %2 = memref.dim %0, %c1 : tensor<?x?xf32>
+ return %1, %2: index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor_cast_to_memref
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
+// CHECK: %[[M:.+]] = memref.buffer_cast %[[ARG0]] : memref<4x6x16x32xi8>
+// CHECK: %[[M1:.+]] = memref.cast %[[M]] : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
+// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
+func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
+ memref<?x?x16x32xi8> {
+ %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
+ %1 = memref.buffer_cast %0 : memref<?x?x16x32xi8>
+ return %1 : memref<?x?x16x32xi8>
+}
+
+// -----
+
+// TODO: Move this test to Tensor/canonicalize.mlir.
+func @subtensor_insert_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
+ %arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c8 = constant 8 : index
+ %0 = memref.dim %arg0, %c1 : tensor<2x?xi32>
+ %1 = tensor.extract %arg1[] : tensor<i32>
+ %2 = tensor.generate %arg2, %c8 {
+ ^bb0(%arg4: index, %arg5: index):
+ tensor.yield %1 : i32
+ } : tensor<?x?xi32>
+ %3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32>
+ return %3 : tensor<?x?xi32>
+}
+// CHECK-LABEL: func @subtensor_insert_propagate_dest_cast
+// CHECK: %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
+// CHECK-SAME: tensor<2x?xi32> into tensor<?x8xi32>
+// CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]]
+// CHECK: return %[[CAST]]
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 24db1d295ffc3..aed84cfd16c31 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -1,53 +1,5 @@
// RUN: mlir-opt %s -canonicalize --split-input-file | FileCheck %s
-// Test case: Basic folding of memref.dim(memref.tensor_load(m)) -> memref.dim(m).
-// CHECK-LABEL: func @dim_of_tensor_load(
-// CHECK-SAME: %[[MEMREF:[0-9a-z]*]]: memref<?xf32>
-// CHECK: %[[C0:.*]] = constant 0
-// CHECK: %[[D:.*]] = memref.dim %[[MEMREF]], %[[C0]]
-// CHECK: return %[[D]] : index
-func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
- %c0 = constant 0 : index
- %0 = memref.tensor_load %arg0 : memref<?xf32>
- %1 = memref.dim %0, %c0 : tensor<?xf32>
- return %1 : index
-}
-
-// -----
-
-// Test case: Folding of memref.load(memref.buffer_cast(%v, %idxs))
-// -> tensor.extract(%v, %idx)
-// CHECK-LABEL: func @load_from_buffer_cast(
-// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
-// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
-// CHECK: %[[RES:.*]] = tensor.extract %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
-// CHECK-NOT: memref.load
-// CHECK: return %[[RES]] : f32
-func @load_from_buffer_cast(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {
- %0 = memref.buffer_cast %arg2 : memref<?x?xf32>
- %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
- return %1 : f32
-}
-
-// -----
-
-// Test case: Folding of memref.dim(tensor.generate %idx) -> %idx
-// CHECK-LABEL: func @dim_of_tensor.generate(
-// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
-// CHECK-NOT: memref.dim
-// CHECK: return %[[IDX1]] : index
-func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index {
- %c3 = constant 3 : index
- %0 = tensor.generate %arg0, %arg1 {
- ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
- tensor.yield %c3 : index
- } : tensor<2x?x4x?x5xindex>
- %1 = memref.dim %0, %c3 : tensor<2x?x4x?x5xindex>
- return %1 : index
-}
-
-// -----
-
// Test case: Folding of comparisons with equal operands.
// CHECK-LABEL: @cmpi_equal_operands
// CHECK-DAG: %[[T:.*]] = constant true
@@ -72,108 +24,6 @@ func @cmpi_equal_operands(%arg0: i64)
// -----
-// Test case: Folding of memref.dim(memref.alloca(%size), %idx) -> %size
-// CHECK-LABEL: func @dim_of_alloca(
-// CHECK-SAME: %[[SIZE:[0-9a-z]+]]: index
-// CHECK-NEXT: return %[[SIZE]] : index
-func @dim_of_alloca(%size: index) -> index {
- %0 = memref.alloca(%size) : memref<?xindex>
- %c0 = constant 0 : index
- %1 = memref.dim %0, %c0 : memref<?xindex>
- return %1 : index
-}
-
-// -----
-
-// Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v)
-// CHECK-LABEL: func @dim_of_alloca_with_dynamic_size(
-// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>
-// CHECK-NEXT: %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32>
-// CHECK-NEXT: return %[[RANK]] : index
-func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
- %0 = rank %arg0 : memref<*xf32>
- %1 = memref.alloca(%0) : memref<?xindex>
- %c0 = constant 0 : index
- %2 = memref.dim %1, %c0 : memref<?xindex>
- return %2 : index
-}
-
-// -----
-
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape(
-// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
-// CHECK-NEXT: %[[IDX:.*]] = constant 3
-// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-// CHECK-NEXT: memref.store
-// CHECK-NOT: memref.dim
-// CHECK: return %[[DIM]] : index
-func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
- -> index {
- %c3 = constant 3 : index
- %0 = memref.reshape %arg0(%arg1)
- : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
- // Update the shape to test that he load ends up in the right place.
- memref.store %c3, %arg1[%c3] : memref<?xindex>
- %1 = memref.dim %0, %c3 : memref<*xf32>
- return %1 : index
-}
-
-// -----
-
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape_i32(
-// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
-// CHECK-NEXT: %[[IDX:.*]] = constant 3
-// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-// CHECK-NEXT: %[[CAST:.*]] = index_cast %[[DIM]]
-// CHECK-NOT: memref.dim
-// CHECK: return %[[CAST]] : index
-func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
- -> index {
- %c3 = constant 3 : index
- %0 = memref.reshape %arg0(%arg1)
- : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
- %1 = memref.dim %0, %c3 : memref<*xf32>
- return %1 : index
-}
-
-// -----
-
-// Test case: Folding memref.dim(tensor.cast %0, %idx) -> memref.dim %0, %idx
-// CHECK-LABEL: func @fold_dim_of_tensor.cast
-// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK-NEXT: return %[[C4]], %[[T0]]
-func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
- %1 = memref.dim %0, %c0 : tensor<?x?xf32>
- %2 = memref.dim %0, %c1 : tensor<?x?xf32>
- return %1, %2: index, index
-}
-
-// -----
-
-// CHECK-LABEL: func @tensor_cast_to_memref
-// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
-// CHECK: %[[M:.+]] = memref.buffer_cast %[[ARG0]] : memref<4x6x16x32xi8>
-// CHECK: %[[M1:.+]] = memref.cast %[[M]] : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
-// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
-func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
- memref<?x?x16x32xi8> {
- %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
- %1 = memref.buffer_cast %0 : memref<?x?x16x32xi8>
- return %1 : memref<?x?x16x32xi8>
-}
-
-// -----
-
func @subtensor_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
%arg2 : index) -> tensor<?x?x?xf32>
{
@@ -345,29 +195,6 @@ func @rank_reducing_subtensor_to_subtensor_insert_canonicalize(%arg0 : tensor<?x
// -----
-func @subtensor_insert_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
- %arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %c2 = constant 2 : index
- %c8 = constant 8 : index
- %0 = memref.dim %arg0, %c1 : tensor<2x?xi32>
- %1 = tensor.extract %arg1[] : tensor<i32>
- %2 = tensor.generate %arg2, %c8 {
- ^bb0(%arg4: index, %arg5: index):
- tensor.yield %1 : i32
- } : tensor<?x?xi32>
- %3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32>
- return %3 : tensor<?x?xi32>
-}
-// CHECK-LABEL: func @subtensor_insert_propagate_dest_cast
-// CHECK: %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
-// CHECK-SAME: tensor<2x?xi32> into tensor<?x8xi32>
-// CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]]
-// CHECK: return %[[CAST]]
-
-// -----
-
func @subtensor_insert_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor<i32>) -> tensor<3x9xi32> {
%c0 = constant 0 : index
%c1 = constant 1 : index
More information about the Mlir-commits
mailing list