[Mlir-commits] [mlir] 6391da9 - [mlir] [VectorOps] Use 'vector.flat_transpose' for 2-D 'vector.tranpose'
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 3 14:56:08 PDT 2020
Author: aartbik
Date: 2020-06-03T14:55:50-07:00
New Revision: 6391da98f43a995fe3dfb96a5376b2d9c652ed87
URL: https://github.com/llvm/llvm-project/commit/6391da98f43a995fe3dfb96a5376b2d9c652ed87
DIFF: https://github.com/llvm/llvm-project/commit/6391da98f43a995fe3dfb96a5376b2d9c652ed87.diff
LOG: [mlir] [VectorOps] Use 'vector.flat_transpose' for 2-D 'vector.tranpose'
Summary:
Progressive lowering of vector.transpose into an operation that
is closer to an intrinsic, and thus the hardware ISA. Currently
under the common vector transform testing flag, as we prepare
deploying this transformation in the LLVM lowering pipeline.
Reviewers: nicolasvasilache, reidtatge, andydavis1, ftynse
Reviewed By: nicolasvasilache, ftynse
Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, llvm-commits
Tags: #llvm, #mlir
Differential Revision: https://reviews.llvm.org/D80772
Added:
mlir/test/Dialect/Vector/vector-flat-transforms.mlir
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 8c8424e8ef8f..def0d24adcf5 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -53,9 +53,19 @@ enum class VectorContractLowering {
/// Lower to `vector.outerproduct`.
OuterProduct = 2,
};
+/// Enum to control the lowering of `vector.transpose` operations.
+enum class VectorTransposeLowering {
+ // Lower transpose into element-wise extract and inserts.
+ EltWise = 0,
+ /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
+ /// intrinsics.
+ Flat = 1,
+};
/// Structure to control the behavior of vector transform patterns.
struct VectorTransformsOptions {
VectorContractLowering vectorContractLowering = VectorContractLowering::FMA;
+ VectorTransposeLowering vectorTransposeLowering =
+ VectorTransposeLowering::EltWise;
VectorTransformsOptions &
setVectorTransformsOptions(VectorContractLowering opt) {
vectorContractLowering = opt;
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 4065d19b6c8a..365795fb9cab 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1206,6 +1206,7 @@ def Vector_ShapeCastOp :
}
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
+ let hasFolder = 1;
}
def Vector_TypeCastOp :
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 63891d1004d4..21b62ceaa689 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1667,6 +1667,19 @@ static LogicalResult verify(ShapeCastOp op) {
return success();
}
+OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
+ // Nop shape cast.
+ if (source().getType() == result().getType())
+ return source();
+
+ // Canceling shape casts.
+ if (auto otherOp = source().getDefiningOp<ShapeCastOp>())
+ if (result().getType() == otherOp.source().getType())
+ return otherOp.source();
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// TypeCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 491ad62affcb..82c27387bd6e 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1186,6 +1186,11 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+ TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
+ MLIRContext *context)
+ : OpRewritePattern<vector::TransposeOp>(context),
+ vectorTransformsOptions(vectorTransformsOptions) {}
+
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
@@ -1197,6 +1202,22 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
for (auto attr : op.transp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
+ // Handle a true 2-D matrix transpose
diff erently when requested.
+ if (vectorTransformsOptions.vectorTransposeLowering ==
+ vector::VectorTransposeLowering::Flat &&
+ resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
+ Type flattenedType =
+ VectorType::get(resType.getNumElements(), resType.getElementType());
+ auto matrix =
+ rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
+ auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
+ auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
+ Value trans = rewriter.create<vector::FlatTransposeOp>(
+ loc, flattenedType, matrix, rows, columns);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
+ return success();
+ }
+
// Generate fully unrolled extract/insert ops.
Value result = rewriter.create<ConstantOp>(loc, resType,
rewriter.getZeroAttr(resType));
@@ -1230,6 +1251,9 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
}
return result;
}
+
+ /// Options to control the vector patterns.
+ vector::VectorTransformsOptions vectorTransformsOptions;
};
/// Progressive lowering of OuterProductOp.
@@ -1829,9 +1853,9 @@ void mlir::vector::populateVectorContractLoweringPatterns(
ConstantMaskOpLowering,
OuterProductOpLowering,
ShapeCastOp2DDownCastRewritePattern,
- ShapeCastOp2DUpCastRewritePattern,
- TransposeOpLowering>(context);
- patterns.insert<ContractionOpLowering,
+ ShapeCastOp2DUpCastRewritePattern>(context);
+ patterns.insert<TransposeOpLowering,
+ ContractionOpLowering,
ContractionOpToMatmulOpLowering,
ContractionOpToOuterProductOpLowering>(parameters, context);
// clang-format on
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 1dd2f377a29c..491f18fdf5c9 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -319,6 +319,26 @@ func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
return %0 : vector<3x2xf32>
}
+
+// CHECK-LABEL: func @nop_shape_cast
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK: return %[[A]] : vector<16xf32>
+
+func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
+ %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32>
+ return %0 : vector<16xf32>
+}
+
+// CHECK-LABEL: func @cancel_shape_cast
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK: return %[[A]] : vector<16xf32>
+
+func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
+ %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
+ %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
+ return %1 : vector<16xf32>
+}
+
// Shape up and downcasts for 2-D vectors, for supporting conversion to
// llvm.matrix operations
// CHECK-LABEL: func @shape_casts
diff --git a/mlir/test/Dialect/Vector/vector-flat-transforms.mlir b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir
new file mode 100644
index 000000000000..e715755738de
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-flat-transpose=1 | FileCheck %s --dump-input-on-failure
+
+// Tests for lowering 2-D vector.transpose into vector.flat_transpose.
+//
+// TODO(ajcbik,ntv): having ShapeCastOp2DDownCastRewritePattern and
+// ShapeCastOp2DUpCastRewritePattern too early in
+// the greedy rewriting patterns misses opportunities
+// to fold shape casts!
+
+// No shape cast folding expected.
+//
+// CHECK-LABEL: func @transpose44_44(
+// CHECK-SAME: %[[A:.*]]: vector<4x4xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
+// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+// CHECK: %[[T9:.*]] = vector.extract_strided_slice %[[T8]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
+//
+func @transpose44_44(%arg0: vector<4x4xf32>) -> vector<4x4xf32> {
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+
+// Folds preceding shape cast as expected,
+// no following shape cast folding expected.
+//
+// CHECK-LABEL: func @transpose16_44(
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+// CHECK: %[[T1:.*]] = vector.extract_strided_slice %[[T0]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
+//
+func @transpose16_44(%arg0: vector<16xf32>) -> vector<4x4xf32> {
+ %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
+ %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+ return %1 : vector<4x4xf32>
+}
+
+// No preceding shape cast folding expected,
+// but FAILS to fold following cast.
+//
+// CHECK-LABEL: func @transpose44_16(
+// CHECK-SAME: %[[A:.*]]: vector<4x4xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
+// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+func @transpose44_16(%arg0: vector<4x4xf32>) -> vector<16xf32> {
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+ %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
+ return %1 : vector<16xf32>
+}
+
+// Folds preceding shape cast as expected,
+// but FAILS to fold following cast.
+//
+// CHECK-LABEL: func @transpose16_16(
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+//
+func @transpose16_16(%arg0: vector<16xf32>) -> vector<16xf32> {
+ %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
+ %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+ %2 = vector.shape_cast %1 : vector<4x4xf32> to vector<16xf32>
+ return %2 : vector<16xf32>
+}
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 65024dbe3acd..22585fde4ff7 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -47,10 +47,14 @@ struct TestVectorContractionConversion
TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
}
- Option<bool> lowerToLLVMMatrixIntrinsics{
+ Option<bool> lowerToFlatMatrix{
*this, "vector-lower-matrix-intrinsics",
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
llvm::cl::init(false)};
+ Option<bool> lowerToFlatTranspose{
+ *this, "vector-flat-transpose",
+ llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
+ llvm::cl::init(false)};
Option<bool> lowerToOuterProduct{
*this, "vector-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
@@ -67,10 +71,14 @@ struct TestVectorContractionConversion
return;
}
- VectorContractLowering lowering = VectorContractLowering::FMA;
- if (lowerToLLVMMatrixIntrinsics)
- lowering = VectorContractLowering::Matmul;
- VectorTransformsOptions options{lowering};
+ VectorContractLowering contractLowering = VectorContractLowering::FMA;
+ if (lowerToFlatMatrix)
+ contractLowering = VectorContractLowering::Matmul;
+ VectorTransposeLowering transposeLowering =
+ VectorTransposeLowering::EltWise;
+ if (lowerToFlatTranspose)
+ transposeLowering = VectorTransposeLowering::Flat;
+ VectorTransformsOptions options{contractLowering, transposeLowering};
populateVectorContractLoweringPatterns(patterns, &getContext(), options);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
More information about the Mlir-commits
mailing list