[Mlir-commits] [mlir] 6391da9 - [mlir] [VectorOps] Use 'vector.flat_transpose' for 2-D 'vector.tranpose'

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 3 14:56:08 PDT 2020


Author: aartbik
Date: 2020-06-03T14:55:50-07:00
New Revision: 6391da98f43a995fe3dfb96a5376b2d9c652ed87

URL: https://github.com/llvm/llvm-project/commit/6391da98f43a995fe3dfb96a5376b2d9c652ed87
DIFF: https://github.com/llvm/llvm-project/commit/6391da98f43a995fe3dfb96a5376b2d9c652ed87.diff

LOG: [mlir] [VectorOps] Use 'vector.flat_transpose' for 2-D 'vector.tranpose'

Summary:
Progressive lowering of vector.transpose into an operation that
is closer to an intrinsic, and thus the hardware ISA. Currently
under the common vector transform testing flag, as we prepare
deploying this transformation in the LLVM lowering pipeline.

Reviewers: nicolasvasilache, reidtatge, andydavis1, ftynse

Reviewed By: nicolasvasilache, ftynse

Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, llvm-commits

Tags: #llvm, #mlir

Differential Revision: https://reviews.llvm.org/D80772

Added: 
    mlir/test/Dialect/Vector/vector-flat-transforms.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 8c8424e8ef8f..def0d24adcf5 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -53,9 +53,19 @@ enum class VectorContractLowering {
   /// Lower to `vector.outerproduct`.
   OuterProduct = 2,
 };
+/// Enum to control the lowering of `vector.transpose` operations.
+enum class VectorTransposeLowering {
+  // Lower transpose into element-wise extract and inserts.
+  EltWise = 0,
+  /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
+  /// intrinsics.
+  Flat = 1,
+};
 /// Structure to control the behavior of vector transform patterns.
 struct VectorTransformsOptions {
   VectorContractLowering vectorContractLowering = VectorContractLowering::FMA;
+  VectorTransposeLowering vectorTransposeLowering =
+      VectorTransposeLowering::EltWise;
   VectorTransformsOptions &
   setVectorTransformsOptions(VectorContractLowering opt) {
     vectorContractLowering = opt;

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 4065d19b6c8a..365795fb9cab 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1206,6 +1206,7 @@ def Vector_ShapeCastOp :
     }
   }];
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
+  let hasFolder = 1;
 }
 
 def Vector_TypeCastOp :

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 63891d1004d4..21b62ceaa689 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1667,6 +1667,19 @@ static LogicalResult verify(ShapeCastOp op) {
   return success();
 }
 
+OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
+  // Nop shape cast.
+  if (source().getType() == result().getType())
+    return source();
+
+  // Canceling shape casts.
+  if (auto otherOp = source().getDefiningOp<ShapeCastOp>())
+    if (result().getType() == otherOp.source().getType())
+      return otherOp.source();
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // TypeCastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 491ad62affcb..82c27387bd6e 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1186,6 +1186,11 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 public:
   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
 
+  TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
+                      MLIRContext *context)
+      : OpRewritePattern<vector::TransposeOp>(context),
+        vectorTransformsOptions(vectorTransformsOptions) {}
+
   LogicalResult matchAndRewrite(vector::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
@@ -1197,6 +1202,22 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     for (auto attr : op.transp())
       transp.push_back(attr.cast<IntegerAttr>().getInt());
 
+    // Handle a true 2-D matrix transpose 
diff erently when requested.
+    if (vectorTransformsOptions.vectorTransposeLowering ==
+            vector::VectorTransposeLowering::Flat &&
+        resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
+      Type flattenedType =
+          VectorType::get(resType.getNumElements(), resType.getElementType());
+      auto matrix =
+          rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
+      auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
+      auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
+      Value trans = rewriter.create<vector::FlatTransposeOp>(
+          loc, flattenedType, matrix, rows, columns);
+      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
+      return success();
+    }
+
     // Generate fully unrolled extract/insert ops.
     Value result = rewriter.create<ConstantOp>(loc, resType,
                                                rewriter.getZeroAttr(resType));
@@ -1230,6 +1251,9 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     }
     return result;
   }
+
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformsOptions;
 };
 
 /// Progressive lowering of OuterProductOp.
@@ -1829,9 +1853,9 @@ void mlir::vector::populateVectorContractLoweringPatterns(
                   ConstantMaskOpLowering,
                   OuterProductOpLowering,
                   ShapeCastOp2DDownCastRewritePattern,
-                  ShapeCastOp2DUpCastRewritePattern,
-                  TransposeOpLowering>(context);
-  patterns.insert<ContractionOpLowering,
+                  ShapeCastOp2DUpCastRewritePattern>(context);
+  patterns.insert<TransposeOpLowering,
+                  ContractionOpLowering,
                   ContractionOpToMatmulOpLowering,
                   ContractionOpToOuterProductOpLowering>(parameters, context);
   // clang-format on

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 1dd2f377a29c..491f18fdf5c9 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -319,6 +319,26 @@ func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
   return %0 : vector<3x2xf32>
 }
 
+
+// CHECK-LABEL: func @nop_shape_cast
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK:      return %[[A]] : vector<16xf32>
+
+func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
+  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32>
+  return %0 : vector<16xf32>
+}
+
+// CHECK-LABEL: func @cancel_shape_cast
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK:      return %[[A]] : vector<16xf32>
+
+func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
+  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
+  %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
+  return %1 : vector<16xf32>
+}
+
 // Shape up and downcasts for 2-D vectors, for supporting conversion to
 // llvm.matrix operations
 // CHECK-LABEL: func @shape_casts

diff  --git a/mlir/test/Dialect/Vector/vector-flat-transforms.mlir b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir
new file mode 100644
index 000000000000..e715755738de
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-flat-transpose=1 | FileCheck %s --dump-input-on-failure
+
+// Tests for lowering 2-D vector.transpose into vector.flat_transpose.
+//
+// TODO(ajcbik,ntv): having ShapeCastOp2DDownCastRewritePattern and
+//                   ShapeCastOp2DUpCastRewritePattern too early in
+//                   the greedy rewriting patterns misses opportunities
+//                   to fold shape casts!
+
+// No shape cast folding expected.
+//
+// CHECK-LABEL: func @transpose44_44(
+// CHECK-SAME:  %[[A:.*]]: vector<4x4xf32>
+// CHECK:       %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
+// CHECK:       %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+// CHECK:       %[[T9:.*]] = vector.extract_strided_slice %[[T8]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
+//
+func @transpose44_44(%arg0: vector<4x4xf32>) -> vector<4x4xf32> {
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// Folds preceding shape cast as expected,
+// no following shape cast folding expected.
+//
+// CHECK-LABEL: func @transpose16_44(
+// CHECK-SAME:  %[[A:.*]]: vector<16xf32>
+// CHECK:       %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+// CHECK:       %[[T1:.*]] = vector.extract_strided_slice %[[T0]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
+//
+func @transpose16_44(%arg0: vector<16xf32>) -> vector<4x4xf32> {
+  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
+  %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+  return %1 : vector<4x4xf32>
+}
+
+// No preceding shape cast folding expected,
+// but FAILS to fold following cast.
+//
+// CHECK-LABEL: func @transpose44_16(
+// CHECK-SAME:  %[[A:.*]]: vector<4x4xf32>
+// CHECK:       %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
+// CHECK:       %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+func @transpose44_16(%arg0: vector<4x4xf32>) -> vector<16xf32> {
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+  %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
+  return %1 : vector<16xf32>
+}
+
+// Folds preceding shape cast as expected,
+// but FAILS to fold following cast.
+//
+// CHECK-LABEL: func @transpose16_16(
+// CHECK-SAME:  %[[A:.*]]: vector<16xf32>
+// CHECK:       %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+//
+func @transpose16_16(%arg0: vector<16xf32>) -> vector<16xf32> {
+  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
+  %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+  %2 = vector.shape_cast %1 : vector<4x4xf32> to vector<16xf32>
+  return %2 : vector<16xf32>
+}

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 65024dbe3acd..22585fde4ff7 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -47,10 +47,14 @@ struct TestVectorContractionConversion
   TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
   }
 
-  Option<bool> lowerToLLVMMatrixIntrinsics{
+  Option<bool> lowerToFlatMatrix{
       *this, "vector-lower-matrix-intrinsics",
       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
       llvm::cl::init(false)};
+  Option<bool> lowerToFlatTranspose{
+      *this, "vector-flat-transpose",
+      llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
+      llvm::cl::init(false)};
   Option<bool> lowerToOuterProduct{
       *this, "vector-outerproduct",
       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
@@ -67,10 +71,14 @@ struct TestVectorContractionConversion
       return;
     }
 
-    VectorContractLowering lowering = VectorContractLowering::FMA;
-    if (lowerToLLVMMatrixIntrinsics)
-      lowering = VectorContractLowering::Matmul;
-    VectorTransformsOptions options{lowering};
+    VectorContractLowering contractLowering = VectorContractLowering::FMA;
+    if (lowerToFlatMatrix)
+      contractLowering = VectorContractLowering::Matmul;
+    VectorTransposeLowering transposeLowering =
+        VectorTransposeLowering::EltWise;
+    if (lowerToFlatTranspose)
+      transposeLowering = VectorTransposeLowering::Flat;
+    VectorTransformsOptions options{contractLowering, transposeLowering};
     populateVectorContractLoweringPatterns(patterns, &getContext(), options);
     applyPatternsAndFoldGreedily(getFunction(), patterns);
   }


        


More information about the Mlir-commits mailing list