[Mlir-commits] [mlir] c295a65 - [mlir] [VectorOps] Add 'vector.flat_transpose' operation

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 27 11:09:58 PDT 2020


Author: aartbik
Date: 2020-05-27T11:09:48-07:00
New Revision: c295a65da496f5e982402e8f83e417659c7dd166

URL: https://github.com/llvm/llvm-project/commit/c295a65da496f5e982402e8f83e417659c7dd166
DIFF: https://github.com/llvm/llvm-project/commit/c295a65da496f5e982402e8f83e417659c7dd166.diff

LOG: [mlir] [VectorOps] Add 'vector.flat_transpose' operation

Summary:
Provides a representation of the linearized LLVM instrinsic.
With tests and lowering implementation to LLVM IR dialect.
Prepares better lowering for 2-D vector.transpose.

Reviewers: nicolasvasilache, ftynse, reidtatge, bkramer, dcaballe

Reviewed By: ftynse, dcaballe

Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D80419

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 1b978e44dd6a..4065d19b6c8a 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1482,6 +1482,9 @@ def Vector_PrintOp :
 
 //===----------------------------------------------------------------------===//
 // Ops used for supporting progressive lowering and conversion type changes.
+// The Ops are typically not used directly by higher level dialects, but are
+// used by intra-dialect rewriting rules to bring vector operations closer
+// to the hardware ISA.
 //===----------------------------------------------------------------------===//
 
 /// Vector dialect matrix multiplication op that operates on flattened 1-D
@@ -1510,12 +1513,20 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect,
   let description = [{
     This is the counterpart of llvm.matrix.multiply in MLIR. It serves the
     purposes of more progressive lowering and localized type conversion.
+    Higher levels typically lower matrix multiplications into 'vector.contract'
+    operations. Subsequent rewriting rule progressively lower these operations
+    into 'vector.matrix_multiply' operations to bring the operations closer
+    to the hardware ISA.
 
     The ‘vector.matrix_multiply’ op treats `lhs` as matrix with <lhs_rows> rows
     and <lhs_columns> columns, `rhs` as matrix with <lhs_columns> rows and
     <rhs_columns> and multiplies them. The result matrix is returned embedded in
     the result vector.
 
+    Also see:
+
+    http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic
+
     Example:
 
     ```mlir
@@ -1541,4 +1552,48 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect,
     "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
 }
 
+/// Vector dialect matrix tranposition op that operates on flattened 1-D
+/// MLIR vectors. This is the counterpart of llvm.matrix.transpose in MLIR.
+/// This may seem redundant with vector.transpose but it serves the purposes of
+/// more progressive lowering and localized type conversion on the path:
+///   `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
+def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [NoSideEffect,
+  PredOpTrait<"source operand and result have same element type",
+                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
+    Arguments<(
+      // TODO(ntv, fhahn, ajcbik): tighten vector element types that make sense.
+      ins VectorOfRankAndType<[1],
+            [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$matrix,
+          I32Attr:$rows, I32Attr:$columns)>,
+    Results<(
+      outs VectorOfRankAndType<[1],
+             [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)> {
+  let summary = "Vector matrix transposition on flattened 1-D MLIR vectors";
+  let description = [{
+    This is the counterpart of llvm.matrix.transpose in MLIR. It serves
+    the purposes of more progressive lowering and localized type conversion.
+    Higher levels typically lower matrix tranpositions into 'vector.transpose'
+    operations. Subsequent rewriting rule progressively lower these operations
+    into 'vector.flat_transpose' operations to bring the operations closer
+    to the hardware ISA.
+
+    The ‘vector.flat_transpose’ op treats the 1-D input `matrix` as
+    a 2-D matrix with <rows> rows and <columns> columns, and returns the
+    transposed matrix in flattened form in 'res'.
+
+    Also see:
+
+    http://llvm.org/docs/LangRef.html#llvm-matrix-transpose-intrinsic
+
+    Example:
+
+    ```mlir
+    %1 = vector.flat_transpose %0 { rows = 4: i32, columns = 4: i32 }
+       : (vector<16xf32>) -> vector<16xf32>
+    ```
+  }];
+  let verifier = ?;
+  let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)";
+}
+
 #endif // VECTOR_OPS

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 975807ca8671..5b3a01c7512f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -148,6 +148,27 @@ class VectorMatmulOpConversion : public ConvertToLLVMPattern {
   }
 };
 
+/// Conversion pattern for a vector.flat_transpose.
+/// This is lowered directly to the proper llvm.intr.matrix.transpose.
+class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
+public:
+  explicit VectorFlatTransposeOpConversion(MLIRContext *context,
+                                           LLVMTypeConverter &typeConverter)
+      : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
+                             context, typeConverter) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto transOp = cast<vector::FlatTransposeOp>(op);
+    auto adaptor = vector::FlatTransposeOpOperandAdaptor(operands);
+    rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
+        transOp, typeConverter.convertType(transOp.res().getType()),
+        adaptor.matrix(), transOp.rows(), transOp.columns());
+    return success();
+  }
+};
+
 class VectorReductionOpConversion : public ConvertToLLVMPattern {
 public:
   explicit VectorReductionOpConversion(MLIRContext *context,
@@ -1157,6 +1178,7 @@ void mlir::populateVectorToLLVMMatrixConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
   MLIRContext *ctx = converter.getDialect()->getContext();
   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
+  patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
 }
 
 namespace {

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 26e3e9dbe2b1..6150ac78fc2a 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -952,3 +952,15 @@ func @genbool_1d() -> vector<8xi1> {
 // CHECK: %[[T8:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i64
 // CHECK: %[[T9:.*]] = llvm.insertelement %[[T0]], %[[T7]][%[[T8]] : !llvm.i64] : !llvm<"<8 x i1>">
 // CHECK: llvm.return %9 : !llvm<"<8 x i1>">
+
+// CHECK-LABEL: func @flat_transpose
+// CHECK-SAME:  %[[A:.*]]: !llvm<"<16 x float>">
+// CHECK:       %[[T:.*]] = llvm.intr.matrix.transpose %[[A]]
+// CHECK-SAME:      {columns = 4 : i32, rows = 4 : i32} :
+// CHECK-SAME:      !llvm<"<16 x float>"> into !llvm<"<16 x float>">
+// CHECK:       llvm.return %[[T]] : !llvm<"<16 x float>">
+func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
+  %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
+     : vector<16xf32> -> vector<16xf32>
+  return %0 : vector<16xf32>
+}

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index cc72511a6e78..1f6da8190bae 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1145,6 +1145,13 @@ func @transpose_dim_size_mismatch(%arg0: vector<11x7x3x2xi32>) {
 
 // -----
 
+func @flat_transpose_type_mismatch(%arg0: vector<16xf32>) {
+  // expected-error at +1 {{'vector.flat_transpose' op failed to verify that source operand and result have same element type}}
+  %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } : vector<16xf32> -> vector<16xf64>
+}
+
+// -----
+
 func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>>) {
   // expected-error at +1 {{expects operand to be a memref with no layout}}
   %0 = vector.type_cast %arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>> to memref<vector<4x3xf32>>

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 57c03c903fe8..dbffe4206f12 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -140,7 +140,7 @@ func @extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
   indexing_maps = #contraction_to_scalar_accesses,
   iterator_types = ["reduction"]
 }
-// CHECK-LABEL: contraction_to_scalar
+// CHECK-LABEL: @contraction_to_scalar
 func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 {
   // CHECK:      %[[C0:.*]] = constant 0.000000e+00 : f32
   %f0 = constant 0.0: f32
@@ -172,7 +172,7 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32
   iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction",
                     "reduction"]
 }
-// CHECK-LABEL: contraction
+// CHECK-LABEL: @contraction
 func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
                   %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>,
                   %arg4 : index) {
@@ -196,7 +196,7 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
   return
 }
 
-// CHECK-LABEL: create_vector_mask
+// CHECK-LABEL: @create_vector_mask
 func @create_vector_mask() {
   // CHECK:      %[[C2:.*]] = constant 2 : index
   %c2 = constant 2 : index
@@ -208,14 +208,14 @@ func @create_vector_mask() {
   return
 }
 
-// CHECK-LABEL: constant_vector_mask
+// CHECK-LABEL: @constant_vector_mask
 func @constant_vector_mask() {
   // CHECK: vector.constant_mask [3, 2] : vector<4x3xi1>
   %0 = vector.constant_mask [3, 2] : vector<4x3xi1>
   return
 }
 
-// CHECK-LABEL: extract_slices
+// CHECK-LABEL: @extract_slices
 func @extract_slices(%arg0 : vector<4x2xf32>)
   -> (tuple<vector<2x2xf32>, vector<2x2xf32>>) {
   // CHECK: vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
@@ -227,7 +227,7 @@ func @extract_slices(%arg0 : vector<4x2xf32>)
   return %3 : tuple<vector<2x2xf32>, vector<2x2xf32>>
 }
 
-// CHECK-LABEL: insert_slices
+// CHECK-LABEL: @insert_slices
 func @insert_slices(%arg0 : tuple<vector<2x2xf32>, vector<2x2xf32>>)
   -> (vector<4x2xf32>) {
   // CHECK: vector.insert_slices %{{.*}}, [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
@@ -243,7 +243,7 @@ func @vector_print(%arg0: vector<8x4xf32>) {
   return
 }
 
-// CHECK-LABEL: reshape
+// CHECK-LABEL: @reshape
 func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) {
   // CHECK:      %[[C2:.*]] = constant 2 : index
   %c2 = constant 2 : index
@@ -260,7 +260,7 @@ func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) {
   return %1 : vector<2x3x4xf32>
 }
 
-// CHECK-LABEL: shape_cast
+// CHECK-LABEL: @shape_cast
 func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
                  %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>)
   -> (vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>) {
@@ -284,7 +284,7 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) {
   return
 }
 
-// CHECK-LABEL: reduce_fp
+// CHECK-LABEL: @reduce_fp
 func @reduce_fp(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
   // CHECK:    vector.reduction "add", %{{.*}} : vector<16xf32> into f32
   vector.reduction "add", %arg0 : vector<16xf32> into f32
@@ -302,7 +302,7 @@ func @reduce_fp(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
   return %0 : f32
 }
 
-// CHECK-LABEL: reduce_int
+// CHECK-LABEL: @reduce_int
 func @reduce_int(%arg0: vector<16xi32>) -> i32 {
   // CHECK:    vector.reduction "add", %{{.*}} : vector<16xi32> into i32
   vector.reduction "add", %arg0 : vector<16xi32> into i32
@@ -322,14 +322,34 @@ func @reduce_int(%arg0: vector<16xi32>) -> i32 {
   return %0 : i32
 }
 
-// CHECK-LABEL: transpose_fp
+// CHECK-LABEL: @transpose_fp
 func @transpose_fp(%arg0: vector<3x7xf32>) -> vector<7x3xf32> {
+  // CHECK: %[[X:.*]] = vector.transpose %{{.*}}, [1, 0] : vector<3x7xf32> to vector<7x3xf32>
   %0 = vector.transpose %arg0, [1, 0] : vector<3x7xf32> to vector<7x3xf32>
+  // CHECK: return %[[X]] : vector<7x3xf32>
   return %0 : vector<7x3xf32>
 }
 
-// CHECK-LABEL: transpose_int
+// CHECK-LABEL: @transpose_int
 func @transpose_int(%arg0: vector<11x7x3x2xi32>) -> vector<2x11x7x3xi32> {
+  // CHECK: %[[X:.*]] = vector.transpose %{{.*}}, [3, 0, 1, 2] : vector<11x7x3x2xi32> to vector<2x11x7x3xi32>
   %0 = vector.transpose %arg0, [3, 0, 1, 2] : vector<11x7x3x2xi32> to vector<2x11x7x3xi32>
+  // CHECK: return %[[X]] : vector<2x11x7x3xi32>
   return %0 : vector<2x11x7x3xi32>
 }
+
+// CHECK-LABEL: @flat_transpose_fp
+func @flat_transpose_fp(%arg0: vector<16xf32>) -> vector<16xf32> {
+  // CHECK: %[[X:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+  %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } : vector<16xf32> -> vector<16xf32>
+  // CHECK: return %[[X]] : vector<16xf32>
+  return %0 : vector<16xf32>
+}
+
+// CHECK-LABEL: @flat_transpose_int
+func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
+  // CHECK: %[[X:.*]] = vector.flat_transpose %{{.*}} {columns = 8 : i32, rows = 2 : i32} : vector<16xi32> -> vector<16xi32>
+  %0 = vector.flat_transpose %arg0 { rows = 2: i32, columns = 8: i32 } : vector<16xi32> -> vector<16xi32>
+  // CHECK: return %[[X]] : vector<16xi32>
+  return %0 : vector<16xi32>
+}


        


More information about the Mlir-commits mailing list