[Mlir-commits] [mlir] a57def3 - [mlir][vector] generalized masked l/s and compressed l/s with indices
Aart Bik
llvmlistbot at llvm.org
Fri Jan 8 13:59:51 PST 2021
Author: Aart Bik
Date: 2021-01-08T13:59:34-08:00
New Revision: a57def30f53990aafc3f64b9b7a0f60916cc7f61
URL: https://github.com/llvm/llvm-project/commit/a57def30f53990aafc3f64b9b7a0f60916cc7f61
DIFF: https://github.com/llvm/llvm-project/commit/a57def30f53990aafc3f64b9b7a0f60916cc7f61.diff
LOG: [mlir][vector] generalized masked l/s and compressed l/s with indices
Adding the ability to index the base address brings these operations closer
to the transfer read and write semantics (with lowering advantages), ensures
more consistent use in vector MLIR code (easier to read), and reduces the
amount of code duplication to lower memrefs into base addresses considerably
(making codegen less error-prone).
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D94278
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir
mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir
mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir
mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Dialect/Vector/vector-mem-transforms.mlir
mlir/test/Dialect/Vector/vector-transforms.mlir
mlir/test/lib/Transforms/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 0a98b9ffe996..0aa4950e0a9e 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1317,6 +1317,7 @@ def Vector_TransferWriteOp :
def Vector_MaskedLoadOp :
Vector_Op<"maskedload">,
Arguments<(ins AnyMemRef:$base,
+ Variadic<Index>:$indices,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$pass_thru)>,
Results<(outs VectorOfRank<[1]>:$result)> {
@@ -1325,12 +1326,12 @@ def Vector_MaskedLoadOp :
let description = [{
The masked load reads elements from memory into a 1-D vector as defined
- by a base and a 1-D mask vector. When the mask is set, the element is read
- from memory. Otherwise, the corresponding element is taken from a 1-D
- pass-through vector. Informally the semantics are:
+ by a base with indices and a 1-D mask vector. When the mask is set, the
+ element is read from memory. Otherwise, the corresponding element is taken
+ from a 1-D pass-through vector. Informally the semantics are:
```
- result[0] := mask[0] ? MEM[base+0] : pass_thru[0]
- result[1] := mask[1] ? MEM[base+1] : pass_thru[1]
+ result[0] := mask[0] ? base[i+0] : pass_thru[0]
+ result[1] := mask[1] ? base[i+1] : pass_thru[1]
etc.
```
The masked load can be used directly where applicable, or can be used
@@ -1342,7 +1343,7 @@ def Vector_MaskedLoadOp :
Example:
```mlir
- %0 = vector.maskedload %base, %mask, %pass_thru
+ %0 = vector.maskedload %base[%i], %mask, %pass_thru
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
```
}];
@@ -1360,7 +1361,7 @@ def Vector_MaskedLoadOp :
return result().getType().cast<VectorType>();
}
}];
- let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
+ let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
let hasCanonicalizer = 1;
}
@@ -1368,6 +1369,7 @@ def Vector_MaskedLoadOp :
def Vector_MaskedStoreOp :
Vector_Op<"maskedstore">,
Arguments<(ins AnyMemRef:$base,
+ Variadic<Index>:$indices,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$value)> {
@@ -1375,12 +1377,12 @@ def Vector_MaskedStoreOp :
let description = [{
The masked store operation writes elements from a 1-D vector into memory
- as defined by a base and a 1-D mask vector. When the mask is set, the
- corresponding element from the vector is written to memory. Otherwise,
+ as defined by a base with indices and a 1-D mask vector. When the mask is
+ set, the corresponding element from the vector is written to memory. Otherwise,
no action is taken for the element. Informally the semantics are:
```
- if (mask[0]) MEM[base+0] = value[0]
- if (mask[1]) MEM[base+1] = value[1]
+ if (mask[0]) base[i+0] = value[0]
+ if (mask[1]) base[i+1] = value[1]
etc.
```
The masked store can be used directly where applicable, or can be used
@@ -1392,7 +1394,7 @@ def Vector_MaskedStoreOp :
Example:
```mlir
- vector.maskedstore %base, %mask, %value
+ vector.maskedstore %base[%i], %mask, %value
: memref<?xf32>, vector<8xi1>, vector<8xf32>
```
}];
@@ -1407,8 +1409,8 @@ def Vector_MaskedStoreOp :
return value().getType().cast<VectorType>();
}
}];
- let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
- "type($mask) `,` type($value) `into` type($base)";
+ let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
+ "type($base) `,` type($mask) `,` type($value)";
let hasCanonicalizer = 1;
}
@@ -1430,8 +1432,8 @@ def Vector_GatherOp :
semantics are:
```
if (!defined(pass_thru)) pass_thru = [undef, .., undef]
- result[0] := mask[0] ? MEM[base + index[0]] : pass_thru[0]
- result[1] := mask[1] ? MEM[base + index[1]] : pass_thru[1]
+ result[0] := mask[0] ? base[index[0]] : pass_thru[0]
+ result[1] := mask[1] ? base[index[1]] : pass_thru[1]
etc.
```
The vector dialect leaves out-of-bounds behavior undefined.
@@ -1487,8 +1489,8 @@ def Vector_ScatterOp :
bit in a 1-D mask vector is set. Otherwise, no action is taken for that
element. Informally the semantics are:
```
- if (mask[0]) MEM[base + index[0]] = value[0]
- if (mask[1]) MEM[base + index[1]] = value[1]
+ if (mask[0]) base[index[0]] = value[0]
+ if (mask[1]) base[index[1]] = value[1]
etc.
```
The vector dialect leaves out-of-bounds and repeated index behavior
@@ -1531,6 +1533,7 @@ def Vector_ScatterOp :
def Vector_ExpandLoadOp :
Vector_Op<"expandload">,
Arguments<(ins AnyMemRef:$base,
+ Variadic<Index>:$indices,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$pass_thru)>,
Results<(outs VectorOfRank<[1]>:$result)> {
@@ -1539,13 +1542,13 @@ def Vector_ExpandLoadOp :
let description = [{
The expand load reads elements from memory into a 1-D vector as defined
- by a base and a 1-D mask vector. When the mask is set, the next element
- is read from memory. Otherwise, the corresponding element is taken from
- a 1-D pass-through vector. Informally the semantics are:
+ by a base with indices and a 1-D mask vector. When the mask is set, the
+ next element is read from memory. Otherwise, the corresponding element
+ is taken from a 1-D pass-through vector. Informally the semantics are:
```
- index = base
- result[0] := mask[0] ? MEM[index++] : pass_thru[0]
- result[1] := mask[1] ? MEM[index++] : pass_thru[1]
+ index = i
+ result[0] := mask[0] ? base[index++] : pass_thru[0]
+ result[1] := mask[1] ? base[index++] : pass_thru[1]
etc.
```
Note that the index increment is done conditionally.
@@ -1559,7 +1562,7 @@ def Vector_ExpandLoadOp :
Example:
```mlir
- %0 = vector.expandload %base, %mask, %pass_thru
+ %0 = vector.expandload %base[%i], %mask, %pass_thru
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
```
}];
@@ -1577,7 +1580,7 @@ def Vector_ExpandLoadOp :
return result().getType().cast<VectorType>();
}
}];
- let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
+ let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
let hasCanonicalizer = 1;
}
@@ -1585,6 +1588,7 @@ def Vector_ExpandLoadOp :
def Vector_CompressStoreOp :
Vector_Op<"compressstore">,
Arguments<(ins AnyMemRef:$base,
+ Variadic<Index>:$indices,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$value)> {
@@ -1592,13 +1596,13 @@ def Vector_CompressStoreOp :
let description = [{
The compress store operation writes elements from a 1-D vector into memory
- as defined by a base and a 1-D mask vector. When the mask is set, the
- corresponding element from the vector is written next to memory. Otherwise,
- no action is taken for the element. Informally the semantics are:
+ as defined by a base with indices and a 1-D mask vector. When the mask is
+ set, the corresponding element from the vector is written next to memory.
+ Otherwise, no action is taken for the element. Informally the semantics are:
```
- index = base
- if (mask[0]) MEM[index++] = value[0]
- if (mask[1]) MEM[index++] = value[1]
+ index = i
+ if (mask[0]) base[index++] = value[0]
+ if (mask[1]) base[index++] = value[1]
etc.
```
Note that the index increment is done conditionally.
@@ -1612,7 +1616,7 @@ def Vector_CompressStoreOp :
Example:
```mlir
- vector.compressstore %base, %mask, %value
+ vector.compressstore %base[%i], %mask, %value
: memref<?xf32>, vector<8xi1>, vector<8xf32>
```
}];
@@ -1627,7 +1631,7 @@ def Vector_CompressStoreOp :
return value().getType().cast<VectorType>();
}
}];
- let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
+ let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
"type($base) `,` type($mask) `,` type($value)";
let hasCanonicalizer = 1;
}
diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir
index 6310d6ee8790..76022203e113 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir
@@ -5,7 +5,16 @@
func @compress16(%base: memref<?xf32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
- vector.compressstore %base, %mask, %value
+ %c0 = constant 0: index
+ vector.compressstore %base[%c0], %mask, %value
+ : memref<?xf32>, vector<16xi1>, vector<16xf32>
+ return
+}
+
+func @compress16_at8(%base: memref<?xf32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>) {
+ %c8 = constant 8: index
+ vector.compressstore %base[%c8], %mask, %value
: memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}
@@ -86,5 +95,10 @@ func @entry() {
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
+ call @compress16_at8(%A, %some1, %value)
+ : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+ call @printmem16(%A) : (memref<?xf32>) -> ()
+ // CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 0, 1, 2, 3, 12, 13, 14, 15 )
+
return
}
diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir
index 74118fc1125b..b63294f76e7a 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir
@@ -5,8 +5,18 @@
func @expand16(%base: memref<?xf32>,
%mask: vector<16xi1>,
- %pass_thru: vector<16xf32>) -> vector<16xf32> {
- %e = vector.expandload %base, %mask, %pass_thru
+ %pass_thru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = constant 0: index
+ %e = vector.expandload %base[%c0], %mask, %pass_thru
+ : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ return %e : vector<16xf32>
+}
+
+func @expand16_at8(%base: memref<?xf32>,
+ %mask: vector<16xi1>,
+ %pass_thru: vector<16xf32>) -> vector<16xf32> {
+ %c8 = constant 8: index
+ %e = vector.expandload %base[%c8], %mask, %pass_thru
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %e : vector<16xf32>
}
@@ -78,5 +88,10 @@ func @entry() {
vector.print %e6 : vector<16xf32>
// CHECK-NEXT: ( -7, 0, 7.7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, 7.7, 5 )
+ %e7 = call @expand16_at8(%A, %some1, %pass)
+ : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+ vector.print %e7 : vector<16xf32>
+ // CHECK-NEXT: ( 8, 9, 10, 11, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
+
return
}
diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir
index 6c6f6ead005f..d5353affae45 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir
@@ -5,7 +5,16 @@
func @maskedload16(%base: memref<?xf32>, %mask: vector<16xi1>,
%pass_thru: vector<16xf32>) -> vector<16xf32> {
- %ld = vector.maskedload %base, %mask, %pass_thru
+ %c0 = constant 0: index
+ %ld = vector.maskedload %base[%c0], %mask, %pass_thru
+ : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ return %ld : vector<16xf32>
+}
+
+func @maskedload16_at8(%base: memref<?xf32>, %mask: vector<16xi1>,
+ %pass_thru: vector<16xf32>) -> vector<16xf32> {
+ %c8 = constant 8: index
+ %ld = vector.maskedload %base[%c8], %mask, %pass_thru
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
@@ -61,6 +70,11 @@ func @entry() {
vector.print %l4 : vector<16xf32>
// CHECK: ( -7, 1, 2, 3, 4, 5, 6, 7, -7, -7, -7, -7, -7, 13, 14, -7 )
+ %l5 = call @maskedload16_at8(%A, %some, %pass)
+ : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+ vector.print %l5 : vector<16xf32>
+ // CHECK: ( 8, 9, 10, 11, 12, 13, 14, 15, -7, -7, -7, -7, -7, -7, -7, -7 )
+
return
}
diff --git a/mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir
index d0132f61c3e9..a98789847d48 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir
@@ -5,8 +5,17 @@
func @maskedstore16(%base: memref<?xf32>,
%mask: vector<16xi1>, %value: vector<16xf32>) {
- vector.maskedstore %base, %mask, %value
- : vector<16xi1>, vector<16xf32> into memref<?xf32>
+ %c0 = constant 0: index
+ vector.maskedstore %base[%c0], %mask, %value
+ : memref<?xf32>, vector<16xi1>, vector<16xf32>
+ return
+}
+
+func @maskedstore16_at8(%base: memref<?xf32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>) {
+ %c8 = constant 8: index
+ vector.maskedstore %base[%c8], %mask, %value
+ : memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}
@@ -85,5 +94,10 @@ func @entry() {
call @printmem16(%A) : (memref<?xf32>) -> ()
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
+ call @maskedstore16_at8(%A, %some, %val)
+ : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+ call @printmem16(%A) : (memref<?xf32>) -> ()
+ // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7 )
+
return
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5ad266ce0bd7..5dd0b028767a 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -173,33 +173,7 @@ static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
return success();
}
-// Helper that returns a pointer given a memref base.
-static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
- Location loc, Value memref,
- MemRefType memRefType, Value &ptr) {
- Value base;
- if (failed(getBase(rewriter, loc, memref, memRefType, base)))
- return failure();
- auto pType = MemRefDescriptor(memref).getElementPtrType();
- ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
- return success();
-}
-
-// Helper that returns a bit-casted pointer given a memref base.
-static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
- Location loc, Value memref,
- MemRefType memRefType, Type type, Value &ptr) {
- Value base;
- if (failed(getBase(rewriter, loc, memref, memRefType, base)))
- return failure();
- auto pType = LLVM::LLVMPointerType::get(type);
- base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
- ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
- return success();
-}
-
-// Helper that returns vector of pointers given a memref base and an index
-// vector.
+// Helper that returns vector of pointers given a memref base with index vector.
static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
Location loc, Value memref, Value indices,
MemRefType memRefType, VectorType vType,
@@ -213,6 +187,18 @@ static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
return success();
}
+// Casts a strided element pointer to a vector pointer. The vector pointer
+// would always be on address space 0, therefore addrspacecast shall be
+// used when source/dst memrefs are not on address space 0.
+static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
+ Value ptr, MemRefType memRefType, Type vt) {
+ auto pType =
+ LLVM::LLVMPointerType::get(vt.template cast<LLVM::LLVMFixedVectorType>());
+ if (memRefType.getMemorySpace() == 0)
+ return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
+ return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr);
+}
+
static LogicalResult
replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
@@ -343,18 +329,18 @@ class VectorMaskedLoadOpConversion
ConversionPatternRewriter &rewriter) const override {
auto loc = load->getLoc();
auto adaptor = vector::MaskedLoadOpAdaptor(operands);
+ MemRefType memRefType = load.getMemRefType();
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(),
- align)))
+ if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
+ // Resolve address.
auto vtype = typeConverter->convertType(load.getResultVectorType());
- Value ptr;
- if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
- vtype, ptr)))
- return failure();
+ Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
+ adaptor.indices(), rewriter);
+ Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
@@ -374,18 +360,18 @@ class VectorMaskedStoreOpConversion
ConversionPatternRewriter &rewriter) const override {
auto loc = store->getLoc();
auto adaptor = vector::MaskedStoreOpAdaptor(operands);
+ MemRefType memRefType = store.getMemRefType();
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(),
- align)))
+ if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
+ // Resolve address.
auto vtype = typeConverter->convertType(store.getValueVectorType());
- Value ptr;
- if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
- vtype, ptr)))
- return failure();
+ Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
+ adaptor.indices(), rewriter);
+ Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
store, adaptor.value(), ptr, adaptor.mask(),
@@ -473,16 +459,15 @@ class VectorExpandLoadOpConversion
ConversionPatternRewriter &rewriter) const override {
auto loc = expand->getLoc();
auto adaptor = vector::ExpandLoadOpAdaptor(operands);
+ MemRefType memRefType = expand.getMemRefType();
- Value ptr;
- if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
- ptr)))
- return failure();
+ // Resolve address.
+ auto vtype = typeConverter->convertType(expand.getResultVectorType());
+ Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
+ adaptor.indices(), rewriter);
- auto vType = expand.getResultVectorType();
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
- expand, typeConverter->convertType(vType), ptr, adaptor.mask(),
- adaptor.pass_thru());
+ expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
return success();
}
};
@@ -498,11 +483,11 @@ class VectorCompressStoreOpConversion
ConversionPatternRewriter &rewriter) const override {
auto loc = compress->getLoc();
auto adaptor = vector::CompressStoreOpAdaptor(operands);
+ MemRefType memRefType = compress.getMemRefType();
- Value ptr;
- if (failed(getBasePtr(rewriter, loc, adaptor.base(),
- compress.getMemRefType(), ptr)))
- return failure();
+ // Resolve address.
+ Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
+ adaptor.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
compress, adaptor.value(), ptr, adaptor.mask());
@@ -1223,21 +1208,11 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
}
// 1. Get the source/dst address as an LLVM vector pointer.
- // The vector pointer would always be on address space 0, therefore
- // addrspacecast shall be used when source/dst memrefs are not on
- // address space 0.
- // TODO: support alignment when possible.
+ VectorType vtp = xferOp.getVectorType();
Value dataPtr = this->getStridedElementPtr(
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
- auto vecTy = toLLVMTy(xferOp.getVectorType())
- .template cast<LLVM::LLVMFixedVectorType>();
- Value vectorDataPtr;
- if (memRefType.getMemorySpace() == 0)
- vectorDataPtr = rewriter.create<LLVM::BitcastOp>(
- loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
- else
- vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
- loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
+ Value vectorDataPtr =
+ castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
if (!xferOp.isMaskedDim(0))
return replaceTransferOpWithLoadOrStore(rewriter,
@@ -1251,7 +1226,7 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
//
// TODO: when the leaf transfer rank is k > 1, we need the last `k`
// dimensions here.
- unsigned vecWidth = vecTy.getNumElements();
+ unsigned vecWidth = vtp.getNumElements();
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = xferOp.indices()[lastIndex];
Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 5d9c2e63b094..318ca1e27c88 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -76,20 +76,6 @@ static MaskFormat get1DMaskFormat(Value mask) {
return MaskFormat::Unknown;
}
-/// Helper method to cast a 1-D memref<10xf32> "base" into a
-/// memref<vector<10xf32>> in the output parameter "newBase",
-/// using the 'element' vector type "vt". Returns true on success.
-static bool castedToMemRef(Location loc, Value base, MemRefType mt,
- VectorType vt, PatternRewriter &rewriter,
- Value &newBase) {
- // The vector.type_cast operation does not accept unknown memref<?xf32>.
- // TODO: generalize the cast and accept this case too
- if (!mt.hasStaticShape())
- return false;
- newBase = rewriter.create<TypeCastOp>(loc, MemRefType::get({}, vt), base);
- return true;
-}
-
//===----------------------------------------------------------------------===//
// VectorDialect
//===----------------------------------------------------------------------===//
@@ -2380,13 +2366,10 @@ class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(MaskedLoadOp load,
PatternRewriter &rewriter) const override {
- Value newBase;
switch (get1DMaskFormat(load.mask())) {
case MaskFormat::AllTrue:
- if (!castedToMemRef(load.getLoc(), load.base(), load.getMemRefType(),
- load.getResultVectorType(), rewriter, newBase))
- return failure();
- rewriter.replaceOpWithNewOp<LoadOp>(load, newBase);
+ rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+ load, load.getType(), load.base(), load.indices(), false);
return success();
case MaskFormat::AllFalse:
rewriter.replaceOp(load, load.pass_thru());
@@ -2426,13 +2409,10 @@ class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(MaskedStoreOp store,
PatternRewriter &rewriter) const override {
- Value newBase;
switch (get1DMaskFormat(store.mask())) {
case MaskFormat::AllTrue:
- if (!castedToMemRef(store.getLoc(), store.base(), store.getMemRefType(),
- store.getValueVectorType(), rewriter, newBase))
- return failure();
- rewriter.replaceOpWithNewOp<StoreOp>(store, store.value(), newBase);
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ store, store.value(), store.base(), store.indices(), false);
return success();
case MaskFormat::AllFalse:
rewriter.eraseOp(store);
@@ -2568,14 +2548,10 @@ class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandLoadOp expand,
PatternRewriter &rewriter) const override {
- Value newBase;
switch (get1DMaskFormat(expand.mask())) {
case MaskFormat::AllTrue:
- if (!castedToMemRef(expand.getLoc(), expand.base(),
- expand.getMemRefType(), expand.getResultVectorType(),
- rewriter, newBase))
- return failure();
- rewriter.replaceOpWithNewOp<LoadOp>(expand, newBase);
+ rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+ expand, expand.getType(), expand.base(), expand.indices(), false);
return success();
case MaskFormat::AllFalse:
rewriter.replaceOp(expand, expand.pass_thru());
@@ -2615,14 +2591,11 @@ class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CompressStoreOp compress,
PatternRewriter &rewriter) const override {
- Value newBase;
switch (get1DMaskFormat(compress.mask())) {
case MaskFormat::AllTrue:
- if (!castedToMemRef(compress.getLoc(), compress.base(),
- compress.getMemRefType(),
- compress.getValueVectorType(), rewriter, newBase))
- return failure();
- rewriter.replaceOpWithNewOp<StoreOp>(compress, compress.value(), newBase);
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ compress, compress.value(), compress.base(), compress.indices(),
+ false);
return success();
case MaskFormat::AllFalse:
rewriter.eraseOp(compress);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index c1e4e088412d..5c0c9651133d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1070,23 +1070,29 @@ func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
// CHECK: llvm.return %[[T]] : !llvm.vec<16 x f32>
func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
- %0 = vector.maskedload %arg0, %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %c0 = constant 0: index
+ %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %0 : vector<16xf32>
}
// CHECK-LABEL: func @masked_load_op
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<vec<16 x f32>>) -> !llvm.ptr<vec<16 x f32>>
-// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr<vec<16 x f32>>, !llvm.vec<16 x i1>, !llvm.vec<16 x f32>) -> !llvm.vec<16 x f32>
+// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[B:.*]] = llvm.bitcast %[[P]] : !llvm.ptr<f32> to !llvm.ptr<vec<16 x f32>>
+// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[B]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr<vec<16 x f32>>, !llvm.vec<16 x i1>, !llvm.vec<16 x f32>) -> !llvm.vec<16 x f32>
// CHECK: llvm.return %[[L]] : !llvm.vec<16 x f32>
func @masked_store_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
- vector.maskedstore %arg0, %arg1, %arg2 : vector<16xi1>, vector<16xf32> into memref<?xf32>
+ %c0 = constant 0: index
+ vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}
// CHECK-LABEL: func @masked_store_op
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<vec<16 x f32>>) -> !llvm.ptr<vec<16 x f32>>
-// CHECK: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<16 x f32>, !llvm.vec<16 x i1> into !llvm.ptr<vec<16 x f32>>
+// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[B:.*]] = llvm.bitcast %[[P]] : !llvm.ptr<f32> to !llvm.ptr<vec<16 x f32>>
+// CHECK: llvm.intr.masked.store %{{.*}}, %[[B]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<16 x f32>, !llvm.vec<16 x i1> into !llvm.ptr<vec<16 x f32>>
// CHECK: llvm.return
func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
@@ -1110,21 +1116,25 @@ func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>
// CHECK: llvm.return
func @expand_load_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) -> vector<11xf32> {
- %0 = vector.expandload %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32> into vector<11xf32>
+ %c0 = constant 0: index
+ %0 = vector.expandload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32> into vector<11xf32>
return %0 : vector<11xf32>
}
// CHECK-LABEL: func @expand_load_op
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<f32>) -> !llvm.ptr<f32>
+// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: %[[E:.*]] = "llvm.intr.masked.expandload"(%[[P]], %{{.*}}, %{{.*}}) : (!llvm.ptr<f32>, !llvm.vec<11 x i1>, !llvm.vec<11 x f32>) -> !llvm.vec<11 x f32>
// CHECK: llvm.return %[[E]] : !llvm.vec<11 x f32>
func @compress_store_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) {
- vector.compressstore %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32>
+ %c0 = constant 0: index
+ vector.compressstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32>
return
}
// CHECK-LABEL: func @compress_store_op
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<f32>) -> !llvm.ptr<f32>
+// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %[[P]], %{{.*}}) : (!llvm.vec<11 x f32>, !llvm.ptr<f32>, !llvm.vec<11 x i1>) -> ()
// CHECK: llvm.return
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 62eaa4e3a14e..8cadafae1ec4 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1199,36 +1199,41 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] ->
// -----
func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.maskedload' op base and result element type should match}}
- %0 = vector.maskedload %base, %mask, %pass : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
// -----
func @maskedload_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %pass: vector<16xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.maskedload' op expected result dim to match mask dim}}
- %0 = vector.maskedload %base, %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
}
// -----
func @maskedload_pass_thru_type_mask_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass: vector<16xi32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.maskedload' op expected pass_thru of same type as result type}}
- %0 = vector.maskedload %base, %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xi32> into vector<16xf32>
+ %0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xi32> into vector<16xf32>
}
// -----
func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.maskedstore' op base and value element type should match}}
- vector.maskedstore %base, %mask, %value : vector<16xi1>, vector<16xf32> into memref<?xf64>
+ vector.maskedstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
}
// -----
func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %value: vector<16xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.maskedstore' op expected value dim to match mask dim}}
- vector.maskedstore %base, %mask, %value : vector<15xi1>, vector<16xf32> into memref<?xf32>
+ vector.maskedstore %base[%c0], %mask, %value : memref<?xf32>, vector<15xi1>, vector<16xf32>
}
// -----
@@ -1297,36 +1302,41 @@ func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
// -----
func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.expandload' op base and result element type should match}}
- %0 = vector.expandload %base, %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
// -----
func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.expandload' op expected result dim to match mask dim}}
- %0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
}
// -----
func @expand_pass_thru_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<17xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.expandload' op expected pass_thru of same type as result type}}
- %0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<17xf32> into vector<16xf32>
+ %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<17xf32> into vector<16xf32>
}
// -----
func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.compressstore' op base and value element type should match}}
- vector.compressstore %base, %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
+ vector.compressstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
}
// -----
func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
+ %c0 = constant 0 : index
// expected-error at +1 {{'vector.compressstore' op expected value dim to match mask dim}}
- vector.compressstore %base, %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
+ vector.compressstore %base[%c0], %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
}
// -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 07e9d8de3f49..60890e58aef5 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -452,10 +452,11 @@ func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
// CHECK-LABEL: @masked_load_and_store
func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
- // CHECK: %[[X:.*]] = vector.maskedload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
- %0 = vector.maskedload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
- // CHECK: vector.maskedstore %{{.*}}, %{{.*}}, %[[X]] : vector<16xi1>, vector<16xf32> into memref<?xf32>
- vector.maskedstore %base, %mask, %0 : vector<16xi1>, vector<16xf32> into memref<?xf32>
+ %c0 = constant 0 : index
+ // CHECK: %[[X:.*]] = vector.maskedload %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.maskedload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ // CHECK: vector.maskedstore %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
+ vector.maskedstore %base[%c0], %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}
@@ -472,10 +473,11 @@ func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask:
// CHECK-LABEL: @expand_and_compress
func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
- // CHECK: %[[X:.*]] = vector.expandload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
- %0 = vector.expandload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
- // CHECK: vector.compressstore %{{.*}}, %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
- vector.compressstore %base, %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
+ %c0 = constant 0 : index
+ // CHECK: %[[X:.*]] = vector.expandload %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.expandload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ // CHECK: vector.compressstore %{{.*}}[{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
+ vector.compressstore %base[%c0], %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}
diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index 7d79d8b1ed72..f9d7903a343b 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -1,82 +1,93 @@
// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
-//
-// TODO: optimize this one too!
-//
-// CHECK-LABEL: func @maskedload0(
-// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: %[[M:.*]] = vector.constant_mask
-// CHECK-NEXT: %[[T:.*]] = vector.maskedload %[[A0]], %[[M]], %[[A1]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-NEXT: return %[[T]] : vector<16xf32>
-
+// CHECK-LABEL: func @maskedload0(
+// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-DAG: %[[C:.*]] = constant 0 : index
+// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<?xf32>, vector<16xf32>
+// CHECK-NEXT: return %[[T]] : vector<16xf32>
func @maskedload0(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
- %ld = vector.maskedload %base, %mask, %pass_thru
+ %ld = vector.maskedload %base[%c0], %mask, %pass_thru
: memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
-// CHECK-LABEL: func @maskedload1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
-// CHECK-NEXT: %[[T1:.*]] = load %[[T0]][] : memref<vector<16xf32>>
-// CHECK-NEXT: return %[[T1]] : vector<16xf32>
-
+// CHECK-LABEL: func @maskedload1(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-DAG: %[[C:.*]] = constant 0 : index
+// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<16xf32>, vector<16xf32>
+// CHECK-NEXT: return %[[T]] : vector<16xf32>
func @maskedload1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
- %ld = vector.maskedload %base, %mask, %pass_thru
+ %ld = vector.maskedload %base[%c0], %mask, %pass_thru
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
-// CHECK-LABEL: func @maskedload2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: return %[[A1]] : vector<16xf32>
-
+// CHECK-LABEL: func @maskedload2(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-NEXT: return %[[A1]] : vector<16xf32>
func @maskedload2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = constant 0 : index
%mask = vector.constant_mask [0] : vector<16xi1>
- %ld = vector.maskedload %base, %mask, %pass_thru
+ %ld = vector.maskedload %base[%c0], %mask, %pass_thru
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
-// CHECK-LABEL: func @maskedstore1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
-// CHECK-NEXT: store %[[A1]], %[[T0]][] : memref<vector<16xf32>>
-// CHECK-NEXT: return
+// CHECK-LABEL: func @maskedload3(
+// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-DAG: %[[C:.*]] = constant 8 : index
+// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<?xf32>, vector<16xf32>
+// CHECK-NEXT: return %[[T]] : vector<16xf32>
+func @maskedload3(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+ %c8 = constant 8 : index
+ %mask = vector.constant_mask [16] : vector<16xi1>
+ %ld = vector.maskedload %base[%c8], %mask, %pass_thru
+ : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ return %ld : vector<16xf32>
+}
+// CHECK-LABEL: func @maskedstore1(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
+// CHECK-NEXT: %[[C:.*]] = constant 0 : index
+// CHECK-NEXT: vector.transfer_write %[[A1]], %[[A0]][%[[C]]] {masked = [false]} : vector<16xf32>, memref<16xf32>
+// CHECK-NEXT: return
func @maskedstore1(%base: memref<16xf32>, %value: vector<16xf32>) {
+ %c0 = constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
- vector.maskedstore %base, %mask, %value
- : vector<16xi1>, vector<16xf32> into memref<16xf32>
+ vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
return
}
-// CHECK-LABEL: func @maskedstore2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: return
-
+// CHECK-LABEL: func @maskedstore2(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
+// CHECK-NEXT: return
func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>) {
+ %c0 = constant 0 : index
%mask = vector.constant_mask [0] : vector<16xi1>
- vector.maskedstore %base, %mask, %value
- : vector<16xi1>, vector<16xf32> into memref<16xf32>
+ vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
return
}
-// CHECK-LABEL: func @gather1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
-// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
-// CHECK-NEXT: %[[T0:.*]] = vector.constant_mask [16] : vector<16xi1>
-// CHECK-NEXT: %[[T1:.*]] = vector.gather %[[A0]], %[[A1]], %[[T0]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
-// CHECK-NEXT: return %1 : vector<16xf32>
-
+// CHECK-LABEL: func @gather1(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]], %[[A1]], %[[M]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+// CHECK-NEXT: return %[[G]] : vector<16xf32>
func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%mask = vector.constant_mask [16] : vector<16xi1>
%ld = vector.gather %base, %indices, %mask, %pass_thru
@@ -84,12 +95,11 @@ func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
return %ld : vector<16xf32>
}
-// CHECK-LABEL: func @gather2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
-// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
-// CHECK-NEXT: return %[[A2]] : vector<16xf32>
-
+// CHECK-LABEL: func @gather2(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-NEXT: return %[[A2]] : vector<16xf32>
func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%mask = vector.constant_mask [0] : vector<16xi1>
%ld = vector.gather %base, %indices, %mask, %pass_thru
@@ -97,14 +107,13 @@ func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
return %ld : vector<16xf32>
}
-// CHECK-LABEL: func @scatter1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
-// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
-// CHECK-NEXT: %[[T0:.*]] = vector.constant_mask [16] : vector<16xi1>
-// CHECK-NEXT: vector.scatter %[[A0]], %[[A1]], %[[T0]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
-// CHECK-NEXT: return
-
+// CHECK-LABEL: func @scatter1(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
+// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT: vector.scatter %[[A0]], %[[A1]], %[[M]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
+// CHECK-NEXT: return
func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
%mask = vector.constant_mask [16] : vector<16xi1>
vector.scatter %base, %indices, %mask, %value
@@ -112,12 +121,11 @@ func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<1
return
}
-// CHECK-LABEL: func @scatter2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
-// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
-// CHECK-NEXT: return
-
+// CHECK-LABEL: func @scatter2(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
+// CHECK-NEXT: return
func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
%0 = vector.type_cast %base : memref<16xf32> to memref<vector<16xf32>>
%mask = vector.constant_mask [0] : vector<16xi1>
@@ -126,52 +134,53 @@ func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<1
return
}
-// CHECK-LABEL: func @expand1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
-// CHECK-NEXT: %[[T1:.*]] = load %[[T0]][] : memref<vector<16xf32>>
-// CHECK-NEXT: return %[[T1]] : vector<16xf32>
-
+// CHECK-LABEL: func @expand1(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-DAG: %[[C:.*]] = constant 0 : index
+// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<16xf32>, vector<16xf32>
+// CHECK-NEXT: return %[[T]] : vector<16xf32>
func @expand1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
- %ld = vector.expandload %base, %mask, %pass_thru
+ %ld = vector.expandload %base[%c0], %mask, %pass_thru
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
-// CHECK-LABEL: func @expand2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: return %[[A1]] : vector<16xf32>
-
+// CHECK-LABEL: func @expand2(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-NEXT: return %[[A1]] : vector<16xf32>
func @expand2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = constant 0 : index
%mask = vector.constant_mask [0] : vector<16xi1>
- %ld = vector.expandload %base, %mask, %pass_thru
+ %ld = vector.expandload %base[%c0], %mask, %pass_thru
: memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %ld : vector<16xf32>
}
-// CHECK-LABEL: func @compress1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
-// CHECK-NEXT: store %[[A1]], %[[T0]][] : memref<vector<16xf32>>
-// CHECK-NEXT: return
-
+// CHECK-LABEL: func @compress1(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
+// CHECK-NEXT: %[[C:.*]] = constant 0 : index
+// CHECK-NEXT: vector.transfer_write %[[A1]], %[[A0]][%[[C]]] {masked = [false]} : vector<16xf32>, memref<16xf32>
+// CHECK-NEXT: return
func @compress1(%base: memref<16xf32>, %value: vector<16xf32>) {
+ %c0 = constant 0 : index
%mask = vector.constant_mask [16] : vector<16xi1>
- vector.compressstore %base, %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
+ vector.compressstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
return
}
-// CHECK-LABEL: func @compress2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: return
-
+// CHECK-LABEL: func @compress2(
+// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
+// CHECK-NEXT: return
func @compress2(%base: memref<16xf32>, %value: vector<16xf32>) {
+ %c0 = constant 0 : index
%mask = vector.constant_mask [0] : vector<16xi1>
- vector.compressstore %base, %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
+ vector.compressstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
return
}
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 4a58261c9672..754c7cc04088 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
+// RUN: mlir-opt %s -test-vector-to-vector-conversion="unroll" | FileCheck %s
// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 572cd1cd68f1..11e4ee24cb40 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -24,13 +24,25 @@ namespace {
struct TestVectorToVectorConversion
: public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
+ TestVectorToVectorConversion() = default;
+ TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<AffineDialect>();
+ }
+
+ Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
+ llvm::cl::init(false)};
+
void runOnFunction() override {
OwningRewritePatternList patterns;
auto *ctx = &getContext();
- patterns.insert<UnrollVectorPattern>(
- ctx,
- UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
- filter));
+ if (unroll) {
+ patterns.insert<UnrollVectorPattern>(
+ ctx,
+ UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
+ filter));
+ }
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
More information about the Mlir-commits
mailing list