[Mlir-commits] [mlir] 0d6e419 - [mlir][vector] Order parallel indices before transposing the input in multireductions

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 28 18:47:28 PDT 2021


Author: harsh-nod
Date: 2021-06-28T18:47:16-07:00
New Revision: 0d6e4199e32a3a5942f920bf13c0a0ddf10d2579

URL: https://github.com/llvm/llvm-project/commit/0d6e4199e32a3a5942f920bf13c0a0ddf10d2579
DIFF: https://github.com/llvm/llvm-project/commit/0d6e4199e32a3a5942f920bf13c0a0ddf10d2579.diff

LOG: [mlir][vector] Order parallel indices before transposing the input in multireductions

The current code does not preserve the order of the parallel
dimensions when doing multi-reductions and thus we can end
up in scenarios where the result shape does not match the
desired shape after reduction.

This patch fixes that by ensuring that the parallel indices
are in order and then concatenates them to the reduction dimensions
so that the reduction dimensions are innermost.

Differential Revision: https://reviews.llvm.org/D104884

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index e04d48d6ca840..3a10fb3de6416 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -37,6 +37,7 @@
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -3915,42 +3916,33 @@ struct InnerDimReductionConversion
     auto loc = multiReductionOp.getLoc();
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
 
-    auto reductionDims = llvm::to_vector<4>(
-        llvm::map_range(multiReductionOp.reduction_dims().cast<ArrayAttr>(),
-                        [](Attribute attr) -> int64_t {
-                          return attr.cast<IntegerAttr>().getInt();
-                        }));
-    llvm::sort(reductionDims);
-
-    int64_t reductionSize = multiReductionOp.reduction_dims().size();
-
-    // Fails if already inner most reduction.
-    bool innerMostReduction = true;
-    for (int i = 0; i < reductionSize; ++i) {
-      if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) {
-        innerMostReduction = false;
-      }
+    // Separate reduction and parallel dims
+    auto reductionDimsRange =
+        multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
+    auto reductionDims = llvm::to_vector<4>(llvm::map_range(
+        reductionDimsRange, [](APInt a) { return a.getZExtValue(); }));
+    llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
+                                                  reductionDims.end());
+    int64_t reductionSize = reductionDims.size();
+    SmallVector<int64_t, 4> parallelDims;
+    for (int64_t i = 0; i < srcRank; i++) {
+      if (!reductionDimsSet.contains(i))
+        parallelDims.push_back(i);
     }
-    if (innerMostReduction)
-      return failure();
 
-    // Permutes the indices so reduction dims are inner most dims.
-    SmallVector<int64_t> indices;
-    for (int i = 0; i < srcRank; ++i) {
-      indices.push_back(i);
-    }
-    int ir = reductionSize - 1;
-    int id = srcRank - 1;
-    while (ir >= 0) {
-      std::swap(indices[reductionDims[ir--]], indices[id--]);
-    }
+    // Add transpose only if inner-most dimensions are not reductions
+    if (parallelDims ==
+        llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))
+      return failure();
 
-    // Sets inner most dims as reduction.
+    SmallVector<int64_t, 4> indices;
+    indices.append(parallelDims.begin(), parallelDims.end());
+    indices.append(reductionDims.begin(), reductionDims.end());
+    auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
     SmallVector<bool> reductionMask(srcRank, false);
     for (int i = 0; i < reductionSize; ++i) {
       reductionMask[srcRank - i - 1] = true;
     }
-    auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
     rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
         multiReductionOp, transposeOp.result(), reductionMask,
         multiReductionOp.kind());

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 6cfc4e035719d..4121262722e34 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -61,6 +61,49 @@ func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x
 // CHECK-LABEL: func @vector_multi_reduction_transposed
 //  CHECK-SAME:    %[[INPUT:.+]]: vector<2x3x4x5xf32>
 //       CHECK:     %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32>
-//       CHEKC:     vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32>
+//       CHECK:     vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32>
 //       CHECK:     %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
 //       CHECK:       return %[[RESULT]]     
+
+func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> {
+    %0 = vector.multi_reduction #vector.kind<mul>, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32>
+    return %0 : vector<2x4xf32>
+}
+// CHECK-LABEL: func @vector_multi_reduction_ordering
+//  CHECK-SAME:   %[[INPUT:.+]]: vector<3x2x4xf32>
+//       CHECK:       %[[RESULT_VEC_0:.+]] = constant dense<{{.*}}> : vector<8xf32>
+//       CHECK:       %[[C0:.+]] = constant 0 : i32
+//       CHECK:       %[[C1:.+]] = constant 1 : i32
+//       CHECK:       %[[C2:.+]] = constant 2 : i32
+//       CHECK:       %[[C3:.+]] = constant 3 : i32
+//       CHECK:       %[[C4:.+]] = constant 4 : i32
+//       CHECK:       %[[C5:.+]] = constant 5 : i32
+//       CHECK:       %[[C6:.+]] = constant 6 : i32
+//       CHECK:       %[[C7:.+]] = constant 7 : i32
+//       CHECK:       %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [1, 2, 0] : vector<3x2x4xf32> to vector<2x4x3xf32>
+//       CHECK:       %[[V0:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 0]
+//       CHECK:       %[[RV0:.+]] = vector.reduction "mul", %[[V0]] : vector<3xf32> into f32
+//       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : i32] : vector<8xf32>
+//       CHECK:       %[[V1:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 1]
+//       CHECK:       %[[RV1:.+]] = vector.reduction "mul", %[[V1]] : vector<3xf32> into f32
+//       CHECK:       %[[RESULT_VEC_2:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<8xf32>
+//       CHECK:       %[[V2:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 2]
+//       CHECK:       %[[RV2:.+]] = vector.reduction "mul", %[[V2]] : vector<3xf32> into f32
+//       CHECK:       %[[RESULT_VEC_3:.+]] = vector.insertelement %[[RV2:.+]], %[[RESULT_VEC_2]][%[[C2]] : i32] : vector<8xf32>
+//       CHECK:       %[[V3:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 3]
+//       CHECK:       %[[RV3:.+]] = vector.reduction "mul", %[[V3]] : vector<3xf32> into f32
+//       CHECK:       %[[RESULT_VEC_4:.+]] = vector.insertelement %[[RV3:.+]], %[[RESULT_VEC_3]][%[[C3]] : i32] : vector<8xf32>
+//       CHECK:       %[[V4:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 0]
+//       CHECK:       %[[RV4:.+]] = vector.reduction "mul", %[[V4]] : vector<3xf32> into f32
+//       CHECK:       %[[RESULT_VEC_5:.+]] = vector.insertelement %[[RV4:.+]], %[[RESULT_VEC_4]][%[[C4]] : i32] : vector<8xf32>
+//       CHECK:       %[[V5:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 1]
+//       CHECK:       %[[RV5:.+]] = vector.reduction "mul", %[[V5]] : vector<3xf32> into f32
+//       CHECK:       %[[RESULT_VEC_6:.+]] = vector.insertelement %[[RV5:.+]], %[[RESULT_VEC_5]][%[[C5]] : i32] : vector<8xf32>
+//       CHECK:       %[[V6:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 2]
+//       CHECK:       %[[RV6:.+]] = vector.reduction "mul", %[[V6]] : vector<3xf32> into f32
+//       CHECK:       %[[RESULT_VEC_7:.+]] = vector.insertelement %[[RV6:.+]], %[[RESULT_VEC_6]][%[[C6]] : i32] : vector<8xf32>
+//       CHECK:       %[[V7:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 3]
+//       CHECK:       %[[RV7:.+]] = vector.reduction "mul", %[[V7]] : vector<3xf32> into f32
+//       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : i32] : vector<8xf32>
+//       CHECK:       %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32>
+//       CHECK:       return %[[RESHAPED_VEC]]


        


More information about the Mlir-commits mailing list