[Mlir-commits] [mlir] 68330ee - [mlir][vector] Relax transfer_read/transfer_write restriction on memref operand
Thomas Raoux
llvmlistbot at llvm.org
Mon Aug 10 08:58:08 PDT 2020
Author: Thomas Raoux
Date: 2020-08-10T08:57:48-07:00
New Revision: 68330ee0a977926d2f2857c62420b7729f4e45d3
URL: https://github.com/llvm/llvm-project/commit/68330ee0a977926d2f2857c62420b7729f4e45d3
DIFF: https://github.com/llvm/llvm-project/commit/68330ee0a977926d2f2857c62420b7729f4e45d3.diff
LOG: [mlir][vector] Relax transfer_read/transfer_write restriction on memref operand
Relax the verifier for transfer_read/transfer_write operation so that it can
take a memref with a different element type than the vector being read/written.
This is based on the discourse discussion:
https://llvm.discourse.group/t/memref-cast/1514
Differential Revision: https://reviews.llvm.org/D85244
Added:
Modified:
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 1e92b80d830f..522a45965722 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1147,6 +1147,26 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
Type i64Type = rewriter.getIntegerType(64);
MemRefType memRefType = xferOp.getMemRefType();
+ if (auto memrefVectorElementType =
+ memRefType.getElementType().dyn_cast<VectorType>()) {
+ // Memref has vector element type.
+ if (memrefVectorElementType.getElementType() !=
+ xferOp.getVectorType().getElementType())
+ return failure();
+ // Check that memref vector type is a suffix of 'vectorType.
+ unsigned memrefVecEltRank = memrefVectorElementType.getRank();
+ unsigned resultVecRank = xferOp.getVectorType().getRank();
+ assert(memrefVecEltRank <= resultVecRank);
+ // TODO: Move this to isSuffix in Vector/Utils.h.
+ unsigned rankOffset = resultVecRank - memrefVecEltRank;
+ auto memrefVecEltShape = memrefVectorElementType.getShape();
+ auto resultVecShape = xferOp.getVectorType().getShape();
+ for (unsigned i = 0; i < memrefVecEltRank; ++i)
+ assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
+ "memref vector element shape should match suffix of vector "
+ "result shape.");
+ }
+
// 1. Get the source/dst address as an LLVM vector pointer.
// The vector pointer would always be on address space 0, therefore
// addrspacecast shall be used when source/dst memrefs are not on
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 7c715bfdb6d0..1c0a5ceb8d86 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1500,37 +1500,33 @@ static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
// Memref has vector element type.
- // Check that 'memrefVectorElementType' and vector element types match.
- if (memrefVectorElementType.getElementType() != vectorType.getElementType())
+ unsigned memrefVecSize = memrefVectorElementType.getElementTypeBitWidth() *
+ memrefVectorElementType.getShape().back();
+ unsigned resultVecSize =
+ vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
+ if (resultVecSize % memrefVecSize != 0)
return op->emitOpError(
- "requires memref and vector types of the same elemental type");
+ "requires the bitwidth of the minor 1-D vector to be an integral "
+ "multiple of the bitwidth of the minor 1-D vector of the memref");
- // Check that memref vector type is a suffix of 'vectorType.
unsigned memrefVecEltRank = memrefVectorElementType.getRank();
unsigned resultVecRank = vectorType.getRank();
if (memrefVecEltRank > resultVecRank)
return op->emitOpError(
"requires memref vector element and vector result ranks to match.");
- // TODO: Move this to isSuffix in Vector/Utils.h.
unsigned rankOffset = resultVecRank - memrefVecEltRank;
- auto memrefVecEltShape = memrefVectorElementType.getShape();
- auto resultVecShape = vectorType.getShape();
- for (unsigned i = 0; i < memrefVecEltRank; ++i)
- if (memrefVecEltShape[i] != resultVecShape[rankOffset + i])
- return op->emitOpError(
- "requires memref vector element shape to match suffix of "
- "vector result shape.");
// Check that permutation map results match 'rankOffset' of vector type.
if (permutationMap.getNumResults() != rankOffset)
return op->emitOpError("requires a permutation_map with result dims of "
"the same rank as the vector type");
} else {
// Memref has scalar element type.
-
- // Check that memref and vector element types match.
- if (memrefType.getElementType() != vectorType.getElementType())
+ unsigned resultVecSize =
+ vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
+ if (resultVecSize % memrefElementType.getIntOrFloatBitWidth() != 0)
return op->emitOpError(
- "requires memref and vector types of the same elemental type");
+ "requires the bitwidth of the minor 1-D vector to be an integral "
+ "multiple of the bitwidth of the memref element type");
// Check that permutation map results match rank of vector type.
if (permutationMap.getNumResults() != vectorType.getRank())
@@ -1560,7 +1556,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vector, Value memref, ValueRange indices,
AffineMap permutationMap,
ArrayRef<bool> maybeMasked) {
- Type elemType = vector.cast<VectorType>().getElementType();
+ Type elemType = memref.getType().cast<MemRefType>().getElementType();
Value padding = builder.create<ConstantOp>(result.location, elemType,
builder.getZeroAttr(elemType));
if (maybeMasked.empty())
@@ -1673,9 +1669,9 @@ static LogicalResult verify(TransferReadOp op) {
return op.emitOpError("requires valid padding vector elemental type");
// Check that padding type and vector element types match.
- if (paddingType != vectorType.getElementType())
+ if (paddingType != memrefElementType)
return op.emitOpError(
- "requires formal padding and vector of the same elemental type");
+ "requires formal padding and memref of the same elemental type");
}
return verifyPermutationMap(permutationMap,
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index d91d4db06106..d35c7fa645b7 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -906,6 +906,24 @@ func @transfer_read_1d_not_masked(%A : memref<?xf32>, %base: index) -> vector<17
// 2. Rewrite as a load.
// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm.ptr<vec<17 x float>>
+func @transfer_read_1d_cast(%A : memref<?xi32>, %base: index) -> vector<12xi8> {
+ %c0 = constant 0: i32
+ %v = vector.transfer_read %A[%base], %c0 {masked = [false]} :
+ memref<?xi32>, vector<12xi8>
+ return %v: vector<12xi8>
+}
+// CHECK-LABEL: func @transfer_read_1d_cast
+// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm.vec<12 x i8>
+//
+// 1. Bitcast to vector form.
+// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
+// CHECK-SAME: (!llvm.ptr<i32>, !llvm.i64) -> !llvm.ptr<i32>
+// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
+// CHECK-SAME: !llvm.ptr<i32> to !llvm.ptr<vec<12 x i8>>
+//
+// 2. Rewrite as a load.
+// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm.ptr<vec<12 x i8>>
+
func @genbool_1d() -> vector<8xi1> {
%0 = vector.constant_mask [4] : vector<8xi1>
return %0 : vector<8xi1>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index e1d03ca480f4..71d2989661ae 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -339,16 +339,6 @@ func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
// -----
-func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
- %c3 = constant 3 : index
- %f0 = constant 0.0 : f32
- %vf0 = splat %f0 : vector<4x3xf32>
- // expected-error at +1 {{requires memref and vector types of the same elemental type}}
- %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xi32>
-}
-
-// -----
-
func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
@@ -359,12 +349,12 @@ func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
// -----
-func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
+func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
- %vf0 = splat %f0 : vector<4x3xf32>
- // expected-error at +1 {{ requires memref vector element shape to match suffix of vector result shape}}
- %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<1x1x2x3xf32>
+ %vf0 = splat %f0 : vector<6xf32>
+ // expected-error at +1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the memref}}
+ %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<6xf32>>, vector<3xf32>
}
// -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 0381c88cc247..6f9990d5d97c 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -4,12 +4,15 @@
// CHECK-LABEL: func @vector_transfer_ops(
func @vector_transfer_ops(%arg0: memref<?x?xf32>,
- %arg1 : memref<?x?xvector<4x3xf32>>) {
+ %arg1 : memref<?x?xvector<4x3xf32>>,
+ %arg2 : memref<?x?xvector<4x3xi32>>) {
// CHECK: %[[C3:.*]] = constant 3 : index
%c3 = constant 3 : index
%cst = constant 3.0 : f32
%f0 = constant 0.0 : f32
+ %c0 = constant 0 : i32
%vf0 = splat %f0 : vector<4x3xf32>
+ %v0 = splat %c0 : vector<4x3xi32>
//
// CHECK: vector.transfer_read
@@ -24,6 +27,9 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
%5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+ // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
+ %6 = vector.transfer_read %arg2[%c3, %c3], %v0 : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
+
// CHECK: vector.transfer_write
vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref<?x?xf32>
@@ -33,6 +39,8 @@ func @vector_transfer_ops(%arg0: memref<?x?xf32>,
vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
+ // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
+ vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
return
}
More information about the Mlir-commits
mailing list