[Mlir-commits] [mlir] 9032240 - [mlir][Vector] Allow lowering of vector.shape_cast 2D <-> 1D

Nicolas Vasilache llvmlistbot at llvm.org
Mon Mar 9 10:18:27 PDT 2020


Author: Nicolas Vasilache
Date: 2020-03-09T13:14:39-04:00
New Revision: 90322403c203ac180b31930f148555d13e03b121

URL: https://github.com/llvm/llvm-project/commit/90322403c203ac180b31930f148555d13e03b121
DIFF: https://github.com/llvm/llvm-project/commit/90322403c203ac180b31930f148555d13e03b121.diff

LOG: [mlir][Vector] Allow lowering of vector.shape_cast 2D <-> 1D

Summary:
This will support the progressive lowering of:
```
vector.contract ->
  downcast + vector.matrix_multiply + upcast ->
    llvm.intr.matrix
```

Differential Revision: https://reviews.llvm.org/D75776

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/VectorOps/VectorOps.td
    mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
    mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index aee269555bd3..67a880d2e5d6 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -1091,7 +1091,12 @@ def Vector_ShapeCastOp :
     described above is applied to each source/result tuple element pair.
 
     It is currently assumed that this operation does not require moving data,
-    and that it will be canonicalized away before lowering vector operations.
+    and that it will be folded away before lowering vector operations.
+
+    There is an exception to the folding expectation when targeting
+    llvm.intr.matrix operations. We need a type conversion back and forth from a
+    2-D MLIR vector to a 1-D flattened LLVM vector.shape_cast lowering to LLVM
+    is supported in that particular case, for now.
 
     Examples:
 
@@ -1108,6 +1113,14 @@ def Vector_ShapeCastOp :
                                 tuple<vector<12x2xf32>, vector<9x2xf32>>
     ```
   }];
+  let extraClassDeclaration = [{
+    VectorType getSourceVectorType() {
+      return source().getType().cast<VectorType>();
+    }
+    VectorType getResultVectorType() {
+      return getResult().getType().cast<VectorType>();
+    }
+  }];
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
 }
 

diff  --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 8764d487dfb9..00089ebefd12 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -1171,6 +1171,75 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
   }
 };
 
+/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
+/// vectors progressively on the way to target llvm.matrix intrinsics.
+/// This iterates over the most major dimension of the 2-D vector and performs
+/// rewrites into:
+///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
+class ShapeCastOp2DDownCastRewritePattern
+    : public OpRewritePattern<vector::ShapeCastOp> {
+public:
+  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(vector::ShapeCastOp op,
+                                     PatternRewriter &rewriter) const override {
+    auto sourceVectorType = op.getSourceVectorType();
+    auto resultVectorType = op.getResultVectorType();
+    if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
+      return matchFailure();
+
+    auto loc = op.getLoc();
+    auto elemType = sourceVectorType.getElementType();
+    Value zero = rewriter.create<ConstantOp>(loc, elemType,
+                                             rewriter.getZeroAttr(elemType));
+    Value desc = rewriter.create<SplatOp>(loc, resultVectorType, zero);
+    unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
+    for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
+      Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i);
+      desc = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, vec, desc,
+          /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
+    }
+    rewriter.replaceOp(op, desc);
+    return matchSuccess();
+  }
+};
+
+/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
+/// vectors progressively on the way from targeting llvm.matrix intrinsics.
+/// This iterates over the most major dimension of the 2-D vector and performs
+/// rewrites into:
+///   vector.strided_slice from 1-D + vector.insert into 2-D
+class ShapeCastOp2DUpCastRewritePattern
+    : public OpRewritePattern<vector::ShapeCastOp> {
+public:
+  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(vector::ShapeCastOp op,
+                                     PatternRewriter &rewriter) const override {
+    auto sourceVectorType = op.getSourceVectorType();
+    auto resultVectorType = op.getResultVectorType();
+    if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
+      return matchFailure();
+
+    auto loc = op.getLoc();
+    auto elemType = sourceVectorType.getElementType();
+    Value zero = rewriter.create<ConstantOp>(loc, elemType,
+                                             rewriter.getZeroAttr(elemType));
+    Value desc = rewriter.create<SplatOp>(loc, resultVectorType, zero);
+    unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
+    for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
+      Value vec = rewriter.create<vector::StridedSliceOp>(
+          loc, op.source(), /*offsets=*/i * mostMinorVectorSize,
+          /*sizes=*/mostMinorVectorSize,
+          /*strides=*/1);
+      desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
+    }
+    rewriter.replaceOp(op, desc);
+    return matchSuccess();
+  }
+};
+
 } // namespace
 
 // TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp).
@@ -1188,5 +1257,9 @@ void mlir::vector::populateVectorSlicesLoweringPatterns(
 
 void mlir::vector::populateVectorContractLoweringPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context) {
-  patterns.insert<ContractionOpLowering>(context);
+  patterns.insert<ContractionOpLowering,
+                  // Shape 2d up/down casts are used as part of contraction
+                  // lowering.
+                  ShapeCastOp2DDownCastRewritePattern,
+                  ShapeCastOp2DUpCastRewritePattern>(context);
 }

diff  --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
index 275fd0841a60..c5e40a7c18ca 100644
--- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
@@ -250,3 +250,42 @@ func @full_contract2(%arg0: vector<2x3xf32>,
     : vector<2x3xf32>, vector<3x2xf32> into f32
   return %0 : f32
 }
+
+// Shape up and downcasts for 2-D vectors, for supporting conversion to
+// llvm.matrix operations
+// CHECK-LABEL: func @shape_casts
+func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
+  // CHECK: %[[cst:.*]] = constant dense<0.000000e+00> : vector<4xf32>
+  // CHECK: %[[cst22:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
+  // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2x2xf32>
+  //
+  // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]]
+  // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+  //
+  // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2x2xf32>
+  //
+  // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]]
+  // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+  //
+  %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32>
+  // CHECK: %[[add:.*]] = addf %[[in2]], %[[in2]] : vector<4xf32>
+  %r0 = addf %0, %0: vector<4xf32>
+  //
+  // CHECK: %[[ss0:.*]] = vector.strided_slice %[[add]]
+  // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
+  // CHECK-SAME: vector<4xf32> to vector<2xf32>
+  //
+  // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] :
+  // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
+  //
+  // CHECK: %[[s2:.*]] = vector.strided_slice %[[add]]
+  // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
+  // CHECK-SAME: vector<4xf32> to vector<2xf32>
+  //
+  // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] :
+  // CHECK-SAME: vector<2xf32> into vector<2x2xf32>
+  //
+  %1 = vector.shape_cast %r0  : vector<4xf32> to vector<2x2xf32>
+  // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32>
+  return %r0, %1 : vector<4xf32>, vector<2x2xf32>
+}


        


More information about the Mlir-commits mailing list