[Mlir-commits] [mlir] bc408af - [mlir][vector] Fold splat constant transpose
Lei Zhang
llvmlistbot at llvm.org
Thu Apr 14 05:56:19 PDT 2022
Author: Lei Zhang
Date: 2022-04-14T08:51:25-04:00
New Revision: bc408afbfebef71460a7f8b4313021956633ef21
URL: https://github.com/llvm/llvm-project/commit/bc408afbfebef71460a7f8b4313021956633ef21
DIFF: https://github.com/llvm/llvm-project/commit/bc408afbfebef71460a7f8b4313021956633ef21.diff
LOG: [mlir][vector] Fold splat constant transpose
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D123595
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f326217af299c..940f9262f1472 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4278,10 +4278,14 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp));
}
-// Eliminates transpose operations, which produce values identical to their
-// input values. This happens when the dimensions of the input vector remain in
-// their original order after the transpose operation.
OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
+ // Eliminate splat constant transpose ops.
+ if (auto attr = operands.front().dyn_cast_or_null<DenseElementsAttr>())
+ if (attr.isSplat())
+ return attr.reshape(getResultType());
+
+ // Eliminate identity transpose ops. This happens when the dimensions of the
+ // input vector remain in their original order after the transpose operation.
SmallVector<int64_t, 4> transp;
getTransp(transp);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 336d22c5808cf..608cb026c43ea 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1463,6 +1463,17 @@ func @transpose_scalar_broadcast2(%value: f32) -> vector<1x8xf32> {
// -----
+// CHECK-LABEL: func @transpose_splat_constant
+// CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
+// CHECK: return %[[CST]]
+func @transpose_splat_constant() -> vector<8x4xf32> {
+ %cst = arith.constant dense<5.0> : vector<4x8xf32>
+ %0 = vector.transpose %cst, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
+ return %0 : vector<8x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @insert_element_fold
// CHECK: %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32>
// CHECK: return %[[V]]
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 155d23b529ddf..aacd8ed5cd5b0 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -273,14 +273,13 @@ func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : index)
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
// CHECK-LABEL: func @transfer_read_permutations
-func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?xf32>)
+func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?xf32>, %m: i1)
-> (vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<8xf32>) {
// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
- %m = arith.constant 1 : i1
%mask0 = vector.splat %m : vector<7x14xi1>
%0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
@@ -331,10 +330,9 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
// CHECK-SAME: %[[ARG1:.*]]: tensor<?x?x?x?xf32>
func @transfer_write_permutations(
%arg0 : memref<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
- %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>) -> tensor<?x?x?x?xf32> {
+ %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>, %m: i1) -> tensor<?x?x?x?xf32> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
- %m = arith.constant 1 : i1
%mask0 = vector.splat %m : vector<7x14x8x16xi1>
%0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32>
More information about the Mlir-commits
mailing list