[Mlir-commits] [mlir] 5c58172 - [mlir][Vector] Add support for scalable vectors in multi_reduction

Andrzej Warzynski llvmlistbot at llvm.org
Tue Aug 8 10:02:46 PDT 2023


Author: Andrzej Warzynski
Date: 2023-08-08T17:01:59Z
New Revision: 5c581720b9bc28e933de1aed99c79557803f22ac

URL: https://github.com/llvm/llvm-project/commit/5c581720b9bc28e933de1aed99c79557803f22ac
DIFF: https://github.com/llvm/llvm-project/commit/5c581720b9bc28e933de1aed99c79557803f22ac.diff

LOG: [mlir][Vector] Add support for scalable vectors in multi_reduction

Support for scalable vectors in vector.multi_reduction is added by
simply updating MultiDimReductionOp::verify.

Also, the conversion pattern for reducing n-D vector.multi_reduction to
2D vector.multi_reduction is updated.

Differential Revision: https://reviews.llvm.org/D157092

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
    mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 094d555a0e8fc9..5b416e4a69996f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -313,18 +313,22 @@ MultiDimReductionOp::getShapeForUnroll() {
 
 LogicalResult MultiDimReductionOp::verify() {
   SmallVector<int64_t> targetShape;
+  SmallVector<bool> scalableDims;
   Type inferredReturnType;
+  auto sourceScalableDims = getSourceVectorType().getScalableDims();
   for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
     if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
           return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
-        }))
+        })) {
       targetShape.push_back(it.value());
+      scalableDims.push_back(sourceScalableDims[it.index()]);
+    }
   // TODO: update to also allow 0-d vectors when available.
   if (targetShape.empty())
     inferredReturnType = getSourceVectorType().getElementType();
   else
-    inferredReturnType =
-        VectorType::get(targetShape, getSourceVectorType().getElementType());
+    inferredReturnType = VectorType::get(
+        targetShape, getSourceVectorType().getElementType(), scalableDims);
   if (getType() != inferredReturnType)
     return emitOpError() << "destination type " << getType()
                          << " is incompatible with source type "

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 5d2abf7c03680f..bed2c2496719dd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -154,12 +154,19 @@ class ReduceMultiDimReductionRank
 
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
     auto srcShape = multiReductionOp.getSourceVectorType().getShape();
+    auto srcScalableDims =
+        multiReductionOp.getSourceVectorType().getScalableDims();
     auto loc = multiReductionOp.getLoc();
 
     // If rank less than 2, nothing to do.
     if (srcRank < 2)
       return failure();
 
+    // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g.
+    // `vscale * vscale` that's currently not modelled.
+    if (llvm::count(srcScalableDims, true) > 1)
+      return failure();
+
     // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
     SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
     if (srcRank == 2 && reductionMask.front() != reductionMask.back())
@@ -167,16 +174,20 @@ class ReduceMultiDimReductionRank
 
     // 1. Separate reduction and parallel dims.
     SmallVector<int64_t, 4> parallelDims, parallelShapes;
+    SmallVector<bool, 4> parallelScalableDims;
     SmallVector<int64_t, 4> reductionDims, reductionShapes;
+    bool isReductionDimScalable = false;
     for (const auto &it : llvm::enumerate(reductionMask)) {
       int64_t i = it.index();
       bool isReduction = it.value();
       if (isReduction) {
         reductionDims.push_back(i);
         reductionShapes.push_back(srcShape[i]);
+        isReductionDimScalable |= srcScalableDims[i];
       } else {
         parallelDims.push_back(i);
         parallelShapes.push_back(srcShape[i]);
+        parallelScalableDims.push_back(srcScalableDims[i]);
       }
     }
 
@@ -212,18 +223,23 @@ class ReduceMultiDimReductionRank
     // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
     // a single parallel (resp. reduction) dim.
     SmallVector<bool, 2> mask;
+    SmallVector<bool, 2> scalableDims;
     SmallVector<int64_t, 2> vectorShape;
+    bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
     if (flattenedParallelDim) {
       mask.push_back(false);
       vectorShape.push_back(flattenedParallelDim);
+      scalableDims.push_back(isParallelDimScalable);
     }
     if (flattenedReductionDim) {
       mask.push_back(true);
       vectorShape.push_back(flattenedReductionDim);
+      scalableDims.push_back(isReductionDimScalable);
     }
     if (!useInnerDimsForReduction && vectorShape.size() == 2) {
       std::swap(mask.front(), mask.back());
       std::swap(vectorShape.front(), vectorShape.back());
+      std::swap(scalableDims.front(), scalableDims.back());
     }
 
     Value newVectorMask;
@@ -237,7 +253,8 @@ class ReduceMultiDimReductionRank
     }
 
     auto castedType = VectorType::get(
-        vectorShape, multiReductionOp.getSourceVectorType().getElementType());
+        vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
+        scalableDims);
     Value cast = rewriter.create<vector::ShapeCastOp>(
         loc, castedType, multiReductionOp.getSource());
 
@@ -245,7 +262,8 @@ class ReduceMultiDimReductionRank
     if (flattenedParallelDim) {
       auto accType = VectorType::get(
           {flattenedParallelDim},
-          multiReductionOp.getSourceVectorType().getElementType());
+          multiReductionOp.getSourceVectorType().getElementType(),
+          /*scalableDims=*/{isParallelDimScalable});
       acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
     }
     // 6. Creates the flattened form of vector.multi_reduction with inner/outer
@@ -264,8 +282,8 @@ class ReduceMultiDimReductionRank
 
     // 8. Creates shape cast for the output n-D -> 2-D.
     VectorType outputCastedType = VectorType::get(
-        parallelShapes,
-        multiReductionOp.getSourceVectorType().getElementType());
+        parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
+        parallelScalableDims);
     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
         rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
     return success();

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index bbb461c56e5401..cf770670c57528 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -249,6 +249,38 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
 //       CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
 
+func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
+  %0 = vector.multi_reduction <add>, %A, %B [2] : vector<8x[4]x2xf32> to vector<8x[4]xf32>
+  return %0 : vector<8x[4]xf32>
+}
+// CHECK-LABEL:   func.func private @scalable_dims(
+// CHECK-SAME:                                     %[[VAL_0:.*]]: vector<8x[4]x2xf32>,
+// CHECK-SAME:                                     %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<[32]xf32>
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_34:.*]] = arith.constant 31 : index
+
+// CHECK:           %[[VAL_35:.*]] = vector.extract %[[VAL_0]][0, 0] : vector<8x[4]x2xf32>
+// CHECK:           %[[VAL_36:.*]] = vector.extract %[[VAL_1]][0, 0] : vector<8x[4]xf32>
+// CHECK:           %[[VAL_37:.*]] = vector.reduction <add>, %[[VAL_35]], %[[VAL_36]] : vector<2xf32> into f32
+// CHECK:           %[[VAL_38:.*]] = vector.insertelement %[[VAL_37]], %[[VAL_2]]{{\[}}%[[VAL_3]] : index] : vector<[32]xf32>
+
+// CHECK:           %[[VAL_39:.*]] = vector.extract %[[VAL_0]][0, 1] : vector<8x[4]x2xf32>
+// CHECK:           %[[VAL_40:.*]] = vector.extract %[[VAL_1]][0, 1] : vector<8x[4]xf32>
+// CHECK:           %[[VAL_41:.*]] = vector.reduction <add>, %[[VAL_39]], %[[VAL_40]] : vector<2xf32> into f32
+// CHECK:           %[[VAL_42:.*]] = vector.insertelement %[[VAL_41]], %[[VAL_38]]{{\[}}%[[VAL_4]] : index] : vector<[32]xf32>
+
+// (...)
+
+// CHECK:           %[[VAL_159:.*]] = vector.extract %[[VAL_0]][7, 3] : vector<8x[4]x2xf32>
+// CHECK:           %[[VAL_160:.*]] = vector.extract %[[VAL_1]][7, 3] : vector<8x[4]xf32>
+// CHECK:           %[[VAL_161:.*]] = vector.reduction <add>, %[[VAL_159]], %[[VAL_160]] : vector<2xf32> into f32
+// CHECK:           %[[VAL_162:.*]] = vector.insertelement %[[VAL_161]], %{{.*}}{{\[}}%[[VAL_34]] : index] : vector<[32]xf32>
+
+// CHECK:           %[[VAL_163:.*]] = vector.shape_cast %[[VAL_162]] : vector<[32]xf32> to vector<8x[4]xf32>
+// CHECK:           return %[[VAL_163]] : vector<8x[4]xf32>
+
 transform.sequence failures(propagate) {
 ^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {


        


More information about the Mlir-commits mailing list