[Mlir-commits] [mlir] 1e45b55 - [mlir] [VectorOps] Handle 'vector.shape_cast' lowering for all cases

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 9 16:08:59 PDT 2020


Author: aartbik
Date: 2020-06-09T16:08:45-07:00
New Revision: 1e45b55dcc8bc34a45e984ce2a5533a292775484

URL: https://github.com/llvm/llvm-project/commit/1e45b55dcc8bc34a45e984ce2a5533a292775484
DIFF: https://github.com/llvm/llvm-project/commit/1e45b55dcc8bc34a45e984ce2a5533a292775484.diff

LOG: [mlir] [VectorOps] Handle 'vector.shape_cast' lowering for all cases

Summary:
Even though this operation is intended for 1d/2d conversions currently,
leaving a semantic hole in the lowering prohibits proper testing of this
operation. This CL adds a straightforward reference implementation for the
missing cases.

Reviewers: nicolasvasilache, mehdi_amini, ftynse, reidtatge

Reviewed By: reidtatge

Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, msifontes

Tags: #mlir

Differential Revision: https://reviews.llvm.org/D81503

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 99a3a951e9de..ebb0d1109bc9 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1466,6 +1466,61 @@ class ShapeCastOp2DUpCastRewritePattern
   }
 };
 
+// We typically should not lower general shape cast operations into data
+// movement instructions, since the assumption is that these casts are
+// optimized away during progressive lowering. For completeness, however,
+// we fall back to a reference implementation that moves all elements
+// into the right place if we get here.
+class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
+public:
+  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto sourceVectorType = op.getSourceVectorType();
+    auto resultVectorType = op.getResultVectorType();
+    // Intended 2D/1D lowerings with better implementations.
+    int64_t srcRank = sourceVectorType.getRank();
+    int64_t resRank = resultVectorType.getRank();
+    if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
+      return failure();
+    // Compute number of elements involved in the reshape.
+    int64_t numElts = 1;
+    for (int64_t r = 0; r < srcRank; r++)
+      numElts *= sourceVectorType.getDimSize(r);
+    // Replace with data movement operations:
+    //    x[0,0,0] = y[0,0]
+    //    x[0,0,1] = y[0,1]
+    //    x[0,1,0] = y[0,2]
+    // etc., incrementing the two index vectors "row-major"
+    // within the source and result shape.
+    SmallVector<int64_t, 4> srcIdx(srcRank);
+    SmallVector<int64_t, 4> resIdx(resRank);
+    Value result = rewriter.create<ConstantOp>(
+        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+    for (int64_t i = 0; i < numElts; i++) {
+      if (i != 0) {
+        incIdx(srcIdx, sourceVectorType, srcRank - 1);
+        incIdx(resIdx, resultVectorType, resRank - 1);
+      }
+      Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx);
+      result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+
+private:
+  static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
+    assert(0 <= r && r < tp.getRank());
+    if (++idx[r] == tp.getDimSize(r)) {
+      idx[r] = 0;
+      incIdx(idx, tp, r - 1);
+    }
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -1864,7 +1919,8 @@ void mlir::vector::populateVectorContractLoweringPatterns(
                   ConstantMaskOpLowering,
                   OuterProductOpLowering,
                   ShapeCastOp2DDownCastRewritePattern,
-                  ShapeCastOp2DUpCastRewritePattern>(context);
+                  ShapeCastOp2DUpCastRewritePattern,
+                  ShapeCastOpRewritePattern>(context);
   patterns.insert<TransposeOpLowering,
                   ContractionOpLowering,
                   ContractionOpToMatmulOpLowering,

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index da784205224a..a0f5e66fea4b 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -319,7 +319,6 @@ func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
   return %0 : vector<3x2xf32>
 }
 
-
 // CHECK-LABEL: func @nop_shape_cast
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>
 // CHECK:      return %[[A]] : vector<16xf32>
@@ -378,6 +377,72 @@ func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
   return %r0, %1 : vector<4xf32>, vector<2x2xf32>
 }
 
+// CHECK-LABEL: func @shape_cast_2d2d
+// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
+// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<2x3xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<3x2xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32>
+// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<3x2xf32>
+// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32>
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<3x2xf32>
+// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<3x2xf32>
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32>
+// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : vector<3x2xf32>
+// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32>
+// CHECK: return %[[T11]] : vector<2x3xf32>
+
+func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
+  %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
+  return %s : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_3d1d
+// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
+// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<6xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : vector<1x3x2xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32>
+// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : vector<1x3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : vector<1x3x2xf32>
+// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32>
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : vector<1x3x2xf32>
+// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : vector<1x3x2xf32>
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32>
+// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : vector<1x3x2xf32>
+// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32>
+// CHECK: return %[[T11]] : vector<6xf32>
+
+func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
+  %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
+  return %s : vector<6xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_1d3d
+// CHECK-SAME: %[[A:.*]]: vector<6xf32>
+// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<2x1x3xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<6xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32>
+// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : vector<6xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : vector<6xf32>
+// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32>
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : vector<6xf32>
+// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : vector<6xf32>
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32>
+// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : vector<6xf32>
+// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32>
+// CHECK: return %[[T11]] : vector<2x1x3xf32>
+
+func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
+  %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>
+  return %s : vector<2x1x3xf32>
+}
+
 // MATRIX-LABEL: func @matmul
 // MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
 // MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,


        


More information about the Mlir-commits mailing list