[Mlir-commits] [mlir] 2d32ee0 - [mlir][Vector] Update lowering of vector ops to llvm intrinsics to use row-major.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Apr 9 13:40:56 PDT 2020


Author: Nicolas Vasilache
Date: 2020-04-09T16:37:28-04:00
New Revision: 2d32ee0d7a4c01209419408e73a4f075ec06c8a7

URL: https://github.com/llvm/llvm-project/commit/2d32ee0d7a4c01209419408e73a4f075ec06c8a7
DIFF: https://github.com/llvm/llvm-project/commit/2d32ee0d7a4c01209419408e73a4f075ec06c8a7.diff

LOG: [mlir][Vector] Update lowering of vector ops to llvm intrinsics to use row-major.

Summary:
LLVM matrix intrinsics recently introduced an option to support row-major mode.
This matches the MLIR vector model, this revision switches to row-major.

A corner case related to degenerate sizes was also fixed upstream.
This revision removes the guard against this corner case.

A bug was uncovered on the output vector construction which this revision also fixes.

Lastly, this has been tested on a small size and benchmarked independently: no visible performance regression is observed.

In the future, when matrix intrinsics support per op attribute, we can more aggressively translate to that and avoid inserting MLIR-level transposes.

This has been tested independently to work on small matrices.

Differential Revision: https://reviews.llvm.org/D77761

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index aac766f38a9c..0abfbdc9c0da 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1446,7 +1446,7 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect,
      result.addAttribute("lhs_rows", builder->getI32IntegerAttr(lhsRows));
      result.addAttribute("lhs_columns", builder->getI32IntegerAttr(lhsColumns));
      result.addAttribute("rhs_columns", builder->getI32IntegerAttr(rhsColumns));
-     result.addTypes(VectorType::get(lhsRows * lhsColumns,
+     result.addTypes(VectorType::get(lhsRows * rhsColumns,
        lhs.getType().cast<VectorType>().getElementType()));
    }]>,
   ];

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 7a197ef14334..3cd0b7b4b733 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1125,43 +1125,34 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
 
     // TODO(ntv, ajcbik): implement benefits, cost models, separate this out in
     // a new pattern.
-    // TODO(ntv, fhahn): once row-major mode is available in LLVM's matrix
-    // intrinsics, use that.
     if (vectorTransformsOptions.lowerToLLVMMatrixIntrinsics &&
-        isColumnMajorMatmul(op.indexing_maps())) {
+        isRowMajorMatmul(op.indexing_maps())) {
       VectorType lhsType = op.getLhsType();
       VectorType rhsType = op.getRhsType();
       unsigned lhsRows = op.getLhsType().getShape()[0];
       unsigned lhsColumns = op.getLhsType().getShape()[1];
       unsigned rhsColumns = op.getRhsType().getShape()[1];
 
-      // In cases where matrices are degenerate, scalarization issues occur in
-      // the backend. Avoid all LLVM scalarization issues for now.
-      // For more details, see: https://bugs.llvm.org/show_bug.cgi?id=45227 and
-      // https://bugs.llvm.org/show_bug.cgi?id=45229
-      // TODO(ntv, fhahn): Relax once above bugs are fixed.
-      if (lhsRows != 1 && lhsColumns != 1 && rhsColumns != 1) {
-        Type flattenedLHSType =
-            VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
-        Type flattenedRHSType =
-            VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
-        auto lhs = rewriter.create<vector::ShapeCastOp>(
-            op.getLoc(), flattenedLHSType, op.lhs());
-        auto rhs = rewriter.create<vector::ShapeCastOp>(
-            op.getLoc(), flattenedRHSType, op.rhs());
-
-        Value mul = rewriter.create<vector::MatmulOp>(
-            op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns);
-        mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(),
-                                                   op.acc().getType(), mul);
-        Type elementType = op.getLhsType().getElementType();
-        assert(elementType.isIntOrFloat());
-        if (elementType.isa<IntegerType>())
-          rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
-        else
-          rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
-        return success();
-      }
+      Type flattenedLHSType =
+          VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
+      Type flattenedRHSType =
+          VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
+      auto lhs = rewriter.create<vector::ShapeCastOp>(
+          op.getLoc(), flattenedLHSType, op.lhs());
+      auto rhs = rewriter.create<vector::ShapeCastOp>(
+          op.getLoc(), flattenedRHSType, op.rhs());
+
+      Value mul = rewriter.create<vector::MatmulOp>(
+          op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns);
+      mul = rewriter.create<vector::ShapeCastOp>(op.getLoc(),
+                                                 op.acc().getType(), mul);
+      Type elementType = op.getLhsType().getElementType();
+      assert(elementType.isIntOrFloat());
+      if (elementType.isa<IntegerType>())
+        rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
+      else
+        rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
+      return success();
     }
 
     // Find first batch dimension in LHS/RHS, and lower when found.

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 051c42d32ed5..08140b4ae065 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -357,46 +357,35 @@ func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
   return %r0, %1 : vector<4xf32>, vector<2x2xf32>
 }
 
-// MATRIX-LABEL: func @column_major_matmul
-// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x3xf32>,
-// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<2x4xf32>,
-// MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
-//      MATRIX:  %[[vcst:.*]] = constant dense<0.000000e+00> : vector<12xf32>
-//      MATRIX:  %[[vcst_0:.*]] = constant dense<0.000000e+00> : vector<8xf32>
-//      MATRIX:  %[[vcst_1:.*]] = constant dense<0.000000e+00> : vector<3x2xf32>
-//      MATRIX:  %[[a0:.*]] = vector.extract %[[A]][0] : vector<4x3xf32>
-//      MATRIX:  %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32>
-//      MATRIX:  %[[a2:.*]] = vector.extract %[[A]][1] : vector<4x3xf32>
-//      MATRIX:  %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32>
-//      MATRIX:  %[[a4:.*]] = vector.extract %[[A]][2] : vector<4x3xf32>
-//      MATRIX:  %[[a5:.*]] = vector.insert_strided_slice %[[a4]], %[[a3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32>
-//      MATRIX:  %[[a6:.*]] = vector.extract %[[A]][3] : vector<4x3xf32>
-//      MATRIX:  %[[a7:.*]] = vector.insert_strided_slice %[[a6]], %[[a5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32>
-//      MATRIX:  %[[b8:.*]] = vector.extract %[[B]][0] : vector<2x4xf32>
-//      MATRIX:  %[[b9:.*]] = vector.insert_strided_slice %[[b8]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32>
-//      MATRIX:  %[[b10:.*]] = vector.extract %[[B]][1] : vector<2x4xf32>
-//      MATRIX:  %[[b11:.*]] = vector.insert_strided_slice %[[b10]], %[[b9]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32>
-//      MATRIX:  %[[mm12:.*]] = vector.matrix_multiply %[[a7]], %[[b11]] {lhs_columns = 3 : i32, lhs_rows = 4 : i32, rhs_columns = 4 : i32} : (vector<12xf32>, vector<8xf32>) -> vector<12xf32>
-//      MATRIX:  %[[mm13:.*]] = vector.strided_slice %[[mm12]] {offsets = [0], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32>
-//      MATRIX:  %[[mm14:.*]] = vector.insert %[[mm13]], %[[vcst_1]] [0] : vector<2xf32> into vector<3x2xf32>
-//      MATRIX:  %[[mm15:.*]] = vector.strided_slice %[[mm12]] {offsets = [2], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32>
-//      MATRIX:  %[[mm16:.*]] = vector.insert %[[mm15]], %[[mm14]] [1] : vector<2xf32> into vector<3x2xf32>
-//      MATRIX:  %[[mm17:.*]] = vector.strided_slice %[[mm12]] {offsets = [4], sizes = [2], strides = [1]} : vector<12xf32> to vector<2xf32>
-//      MATRIX:  %[[mm18:.*]] = vector.insert %[[mm17]], %[[mm16]] [2] : vector<2xf32> into vector<3x2xf32>
-//      MATRIX:  %[[mm19:.*]] = addf %[[C]], %[[mm18]] : vector<3x2xf32>
-#column_major_matmat_accesses = [
-  affine_map<(i, j, k) -> (k, j)>,
-  affine_map<(i, j, k) -> (i, k)>,
-  affine_map<(i, j, k) -> (j, i)>
-]
-#column_major_matmat_trait = {
-  indexing_maps = #column_major_matmat_accesses,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-func @column_major_matmul(%arg0: vector<4x3xf32>,
-                          %arg1: vector<2x4xf32>,
-                          %arg2: vector<3x2xf32>) -> vector<3x2xf32> {
-  %0 = vector.contract #column_major_matmat_trait %arg0, %arg1, %arg2
-    : vector<4x3xf32>, vector<2x4xf32> into vector<3x2xf32>
-  return %0 : vector<3x2xf32>
+// MATRIX-LABEL: func @matmul
+// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
+// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
+// MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//      MATRIX:  %[[vcst:.*]] = constant dense<0.000000e+00> : vector<8xf32>
+//      MATRIX:  %[[vcst_0:.*]] = constant dense<0.000000e+00> : vector<12xf32>
+//      MATRIX:  %[[vcst_1:.*]] = constant dense<0.000000e+00> : vector<2x3xf32>
+//      MATRIX:  %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32>
+//      MATRIX:  %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32>
+//      MATRIX:  %[[a2:.*]] = vector.extract %[[A]][1] : vector<2x4xf32>
+//      MATRIX:  %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32>
+//      MATRIX:  %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32>
+//      MATRIX:  %[[b1:.*]] = vector.insert_strided_slice %[[b0]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32>
+//      MATRIX:  %[[b2:.*]] = vector.extract %[[B]][1] : vector<4x3xf32>
+//      MATRIX:  %[[b3:.*]] = vector.insert_strided_slice %[[b2]], %[[b1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32>
+//      MATRIX:  %[[b4:.*]] = vector.extract %[[B]][2] : vector<4x3xf32>
+//      MATRIX:  %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32>
+//      MATRIX:  %[[b6:.*]] = vector.extract %[[B]][3] : vector<4x3xf32>
+//      MATRIX:  %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32>
+//      MATRIX:  %[[mm1:.*]] = vector.matrix_multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32>
+//      MATRIX:  %[[mm2:.*]] = vector.strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
+//      MATRIX:  %[[mm3:.*]] = vector.insert %[[mm2]], %[[vcst_1]] [0] : vector<3xf32> into vector<2x3xf32>
+//      MATRIX:  %[[mm4:.*]] = vector.strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
+//      MATRIX:  %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32>
+//      MATRIX:  %[[mm6:.*]] = addf %[[C]], %[[mm5]] : vector<2x3xf32>
+func @matmul(%arg0: vector<2x4xf32>,
+                          %arg1: vector<4x3xf32>,
+                          %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+    : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
 }


        


More information about the Mlir-commits mailing list