[Mlir-commits] [mlir] bc408af - [mlir][vector] Fold splat constant transpose

Lei Zhang llvmlistbot at llvm.org
Thu Apr 14 05:56:19 PDT 2022


Author: Lei Zhang
Date: 2022-04-14T08:51:25-04:00
New Revision: bc408afbfebef71460a7f8b4313021956633ef21

URL: https://github.com/llvm/llvm-project/commit/bc408afbfebef71460a7f8b4313021956633ef21
DIFF: https://github.com/llvm/llvm-project/commit/bc408afbfebef71460a7f8b4313021956633ef21.diff

LOG: [mlir][vector] Fold splat constant transpose

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D123595

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f326217af299c..940f9262f1472 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4278,10 +4278,14 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
   result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp));
 }
 
-// Eliminates transpose operations, which produce values identical to their
-// input values. This happens when the dimensions of the input vector remain in
-// their original order after the transpose operation.
 OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
+  // Eliminate splat constant transpose ops.
+  if (auto attr = operands.front().dyn_cast_or_null<DenseElementsAttr>())
+    if (attr.isSplat())
+      return attr.reshape(getResultType());
+
+  // Eliminate identity transpose ops. This happens when the dimensions of the
+  // input vector remain in their original order after the transpose operation.
   SmallVector<int64_t, 4> transp;
   getTransp(transp);
 

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 336d22c5808cf..608cb026c43ea 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1463,6 +1463,17 @@ func @transpose_scalar_broadcast2(%value: f32) -> vector<1x8xf32> {
 
 // -----
 
+// CHECK-LABEL: func @transpose_splat_constant
+//       CHECK:   %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
+//       CHECK:   return %[[CST]]
+func @transpose_splat_constant() -> vector<8x4xf32> {
+  %cst = arith.constant dense<5.0> : vector<4x8xf32>
+  %0 = vector.transpose %cst, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
+  return %0 : vector<8x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @insert_element_fold
 //       CHECK:   %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32>
 //       CHECK:   return %[[V]]

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 155d23b529ddf..aacd8ed5cd5b0 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -273,14 +273,13 @@ func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : index)
 // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
 
 // CHECK-LABEL: func @transfer_read_permutations
-func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?xf32>)
+func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?xf32>, %m: i1)
     -> (vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
        vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<8xf32>) {
 // CHECK-DAG: %[[CF0:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
   %cst = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index
-  %m = arith.constant 1 : i1
 
   %mask0 = vector.splat %m : vector<7x14xi1>
   %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
@@ -331,10 +330,9 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
 // CHECK-SAME:      %[[ARG1:.*]]: tensor<?x?x?x?xf32>
 func @transfer_write_permutations(
     %arg0 : memref<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
-    %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>) -> tensor<?x?x?x?xf32> {
+    %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>, %m: i1) -> tensor<?x?x?x?xf32> {
   // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
   %c0 = arith.constant 0 : index
-  %m = arith.constant 1 : i1
 
   %mask0 = vector.splat %m : vector<7x14x8x16xi1>
   %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32>


        


More information about the Mlir-commits mailing list