[Mlir-commits] [mlir] 627733b - [mlir][vector] Extend vector distribution to all elementwise and contract

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 30 16:36:08 PDT 2021


Author: thomasraoux
Date: 2021-06-30T16:22:31-07:00
New Revision: 627733b5f045e870577e5abf70944d3ffac7a6fb

URL: https://github.com/llvm/llvm-project/commit/627733b5f045e870577e5abf70944d3ffac7a6fb
DIFF: https://github.com/llvm/llvm-project/commit/627733b5f045e870577e5abf70944d3ffac7a6fb.diff

LOG: [mlir][vector] Extend vector distribution to all elementwise and contract

Uses elementwise interface to generalize canonicalization pattern and add a new
pattern for vector.contract case.

Differential Revision: https://reviews.llvm.org/D104343

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-distribution.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 4b5f5ce8e035f..1fc5f1b45ceef 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -105,6 +105,10 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
 // a sequence of vector.reduction ops.
 void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns);
 
+/// Collect a set of patterns to propagate insert_map/extract_map in the ssa
+/// chain.
+void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
+
 /// An attribute that specifies the combining function for `vector.contract`,
 /// and `vector.reduction`.
 class CombiningKindAttr

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 173183ee9b7c2..8419f49c5385f 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -251,26 +251,6 @@ Optional<DistributeOps>
 distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
                            ArrayRef<Value> id, ArrayRef<int64_t> multiplicity,
                            const AffineMap &map);
-/// Canonicalize an extra element using the result of a pointwise operation.
-/// Transforms:
-/// %v = addf %a, %b : vector32xf32>
-/// %dv = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32>
-/// to:
-/// %da = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
-/// %db = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
-/// %dv = addf %da, %db : vector<1xf32>
-struct PointwiseExtractPattern : public OpRewritePattern<ExtractMapOp> {
-  using FilterConstraintType = std::function<LogicalResult(ExtractMapOp op)>;
-  PointwiseExtractPattern(
-      MLIRContext *context, FilterConstraintType constraint =
-                                [](ExtractMapOp op) { return success(); })
-      : OpRewritePattern<ExtractMapOp>(context), filter(constraint) {}
-  LogicalResult matchAndRewrite(ExtractMapOp extract,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  FilterConstraintType filter;
-};
 
 /// Implements transfer op write to read forwarding and dead transfer write
 /// optimizations.

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 3a10fb3de6416..33d57623f6be5 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2793,25 +2793,6 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
   return failure();
 }
 
-LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite(
-    ExtractMapOp extract, PatternRewriter &rewriter) const {
-  Operation *definedOp = extract.vector().getDefiningOp();
-  if (!definedOp || definedOp->getNumResults() != 1)
-    return failure();
-  // TODO: Create an interfaceOp for elementwise operations.
-  if (!isa<AddFOp>(definedOp))
-    return failure();
-  Location loc = extract.getLoc();
-  SmallVector<Value, 4> extractOperands;
-  for (OpOperand &operand : definedOp->getOpOperands())
-    extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
-        loc, extract.getResultType(), operand.get(), extract.ids()));
-  Operation *newOp = cloneOpWithOperandsAndTypes(
-      rewriter, loc, definedOp, extractOperands, extract.getResult().getType());
-  rewriter.replaceOp(extract, newOp->getResult(0));
-  return success();
-}
-
 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
     OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
     ArrayRef<int64_t> multiplicity, const AffineMap &map) {
@@ -2843,6 +2824,91 @@ Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
   return ops;
 }
 
+/// Canonicalize an extract_map using the result of a pointwise operation.
+/// Transforms:
+/// %v = addf %a, %b : vector32xf32>
+/// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
+/// to:
+/// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
+/// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
+/// %dv = addf %da, %db : vector<1xf32>
+struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
+  using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
+                                PatternRewriter &rewriter) const override {
+    Operation *definedOp = extract.vector().getDefiningOp();
+    if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
+        definedOp->getNumResults() != 1)
+      return failure();
+    Location loc = extract.getLoc();
+    SmallVector<Value, 4> extractOperands;
+    for (OpOperand &operand : definedOp->getOpOperands()) {
+      auto vecType = operand.get().getType().template dyn_cast<VectorType>();
+      if (!vecType) {
+        extractOperands.push_back(operand.get());
+        continue;
+      }
+      extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
+          loc,
+          VectorType::get(extract.getResultType().getShape(),
+                          vecType.getElementType()),
+          operand.get(), extract.ids()));
+    }
+    Operation *newOp = cloneOpWithOperandsAndTypes(
+        rewriter, loc, definedOp, extractOperands, extract.getResultType());
+    rewriter.replaceOp(extract, newOp->getResult(0));
+    return success();
+  }
+};
+
+/// Canonicalize an extract_map using the result of a contract operation.
+/// This propagate the extract_map to operands.
+struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
+  using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
+                                PatternRewriter &rewriter) const override {
+    Operation *definedOp = extract.vector().getDefiningOp();
+    auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
+    if (!contract)
+      return failure();
+    Location loc = contract.getLoc();
+    unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
+    AffineMap affineMap = contract.getIndexingMaps()[accIndex];
+    // Create a map of the dimensions distributed based on the acc affine map.
+    // Only parallel dimensions are being distributed, reduction dimensions are
+    // untouched.
+    DenseMap<int64_t, int64_t> map;
+    for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
+      map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
+    SmallVector<Value, 4> extractOperands;
+    for (auto it : llvm::enumerate(contract.getIndexingMaps())) {
+      // For each operands calculate the new vector type after distribution.
+      Value operand = contract->getOperand(it.index());
+      auto vecType = operand.getType().cast<VectorType>();
+      SmallVector<int64_t> operandShape(vecType.getShape().begin(),
+                                        vecType.getShape().end());
+      for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
+        unsigned dim = it.value().getDimPosition(i);
+        auto distributedDim = map.find(dim);
+        // If the dimension is not in the map it means it is a reduction and
+        // doesn't get distributed.
+        if (distributedDim == map.end())
+          continue;
+        operandShape[i] = distributedDim->second;
+      }
+      VectorType newVecType =
+          VectorType::get(operandShape, vecType.getElementType());
+      extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
+          loc, newVecType, operand, extract.ids()));
+    }
+    Operation *newOp =
+        cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
+                                    extract.getResult().getType());
+    rewriter.replaceOp(extract, newOp->getResult(0));
+    return success();
+  }
+};
+
 /// Converts TransferRead op used by ExtractMap op into a smaller dimension
 /// TransferRead.
 /// Example:
@@ -4100,8 +4166,7 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
 // TODO: Add this as DRR pattern.
 void mlir::vector::populateVectorToVectorTransformationPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp,
-               TransferReadExtractPattern, TransferWriteInsertPattern>(
+  patterns.add<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp>(
       patterns.getContext());
 }
 
@@ -4112,6 +4177,13 @@ void mlir::vector::populateSplitVectorTransferPatterns(
                                                           ignoreFilter);
 }
 
+void mlir::vector::populatePropagateVectorDistributionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<PointwiseExtractPattern, ContractExtractPattern,
+               TransferReadExtractPattern, TransferWriteInsertPattern>(
+      patterns.getContext());
+}
+
 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
     RewritePatternSet &patterns) {
   patterns.add<CastAwayExtractStridedSliceLeadingOneDim,

diff  --git a/mlir/test/Dialect/Vector/vector-distribution.mlir b/mlir/test/Dialect/Vector/vector-distribution.mlir
index 0ad46d1b204e1..603493e7456b6 100644
--- a/mlir/test/Dialect/Vector/vector-distribution.mlir
+++ b/mlir/test/Dialect/Vector/vector-distribution.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,1,32 -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,4 -split-input-file | FileCheck %s --check-prefix=CHECK2D
 
 // CHECK-LABEL: func @distribute_vector_add
 //  CHECK-SAME: (%[[ID:.*]]: index
@@ -15,6 +16,24 @@ func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>)
 
 // -----
 
+// CHECK-LABEL: func @distribute_vector_add_exp
+//  CHECK-SAME: (%[[ID:.*]]: index
+//  CHECK-NEXT:    %[[EXPV:.*]] = math.exp %{{.*}} : vector<32xf32>
+//  CHECK-NEXT:    %[[ADDV:.*]] = addf %[[EXPV]], %{{.*}} : vector<32xf32>
+//  CHECK-NEXT:    %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
+//  CHECK-NEXT:    %[[EXC:.*]] = math.exp %[[EXA]] : vector<1xf32>
+//  CHECK-NEXT:    %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
+//  CHECK-NEXT:    %[[ADD:.*]] = addf %[[EXC]], %[[EXB]] : vector<1xf32>
+//  CHECK-NEXT:    %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID]]] : vector<1xf32> into vector<32xf32>
+//  CHECK-NEXT:    return %[[INS]] : vector<32xf32>
+func @distribute_vector_add_exp(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> {
+  %C = math.exp %A : vector<32xf32>
+  %0 = addf %C, %B : vector<32xf32>
+  return %0: vector<32xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @vector_add_read_write
 //  CHECK-SAME: (%[[ID:.*]]: index
 //       CHECK:    %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
@@ -154,3 +173,32 @@ func @vector_add_transfer_permutation(%id0 : index, %id1 : index, %A: memref<?x?
   vector.transfer_write %acc, %C[%c0, %c0, %c0, %c0] {permutation_map = #map2}: vector<64x4x32xf32>, memref<?x?x?x?xf32>
   return
 }
+
+// -----
+
+// CHECK2D-LABEL: vector_add_contract
+//       CHECK2D:   %[[A:.+]] = vector.transfer_read %arg2[%0, %c0], %cst : memref<?x?xf32>, vector<2x4xf32>
+//       CHECK2D:   %[[B:.+]] = vector.transfer_read %arg3[%2, %c0], %cst : memref<?x?xf32>, vector<16x4xf32>
+//       CHECK2D:   %[[C:.+]] = vector.transfer_read %arg4[%4, %5], %cst : memref<?x?xf32>, vector<2x16xf32>
+//       CHECK2D:   %[[E:.+]] = vector.transfer_read %arg5[%7, %8], %cst : memref<?x?xf32>, vector<2x16xf32>
+//       CHECK2D:   %[[D:.+]] = vector.contract {{.*}} %[[A]], %[[B]], %[[C]] : vector<2x4xf32>, vector<16x4xf32> into vector<2x16xf32>
+//       CHECK2D:   %[[R:.+]] = addf %[[D]], %[[E]] : vector<2x16xf32>
+//       CHECK2D:   vector.transfer_write %[[R]], {{.*}} : vector<2x16xf32>, memref<?x?xf32>
+func @vector_add_contract(%id0 : index, %id1 : index, %A: memref<?x?xf32>,
+  %B: memref<?x?xf32>, %C: memref<?x?xf32>, %D: memref<?x?xf32>) {
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %a = vector.transfer_read %A[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32>
+  %b = vector.transfer_read %B[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32>
+  %c = vector.transfer_read %C[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32>
+  %d = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                                         affine_map<(d0, d1, d2) -> (d1, d2)>,
+                                         affine_map<(d0, d1, d2) -> (d0, d1)>],
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>}
+    %a, %b, %c : vector<64x4xf32>, vector<64x4xf32> into vector<64x64xf32>
+  %e = vector.transfer_read %D[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32>
+  %r = addf %d, %e : vector<64x64xf32>
+  vector.transfer_write %r, %C[%c0, %c0] : vector<64x64xf32>, memref<?x?xf32>
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 95d292612c661..9b0a6aea21b66 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -275,8 +275,7 @@ struct TestVectorDistributePatterns
         }
       }
     });
-    patterns.add<PointwiseExtractPattern>(ctx);
-    populateVectorToVectorTransformationPatterns(patterns);
+    populatePropagateVectorDistributionPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };
@@ -339,8 +338,7 @@ struct TestVectorToLoopPatterns
       }
       return mlir::WalkResult::interrupt();
     });
-    patterns.add<PointwiseExtractPattern>(ctx);
-    populateVectorToVectorTransformationPatterns(patterns);
+    populatePropagateVectorDistributionPatterns(patterns);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };


        


More information about the Mlir-commits mailing list