[Mlir-commits] [mlir] e5c8fc7 - [mlir][vector] canonicalize unmasked gather/scatter/compress/expand directly into l/s
Aart Bik
llvmlistbot at llvm.org
Fri Mar 5 14:24:07 PST 2021
Author: Aart Bik
Date: 2021-03-05T14:23:50-08:00
New Revision: e5c8fc776fbd2c93e25f5749049ee31cf73a0a41
URL: https://github.com/llvm/llvm-project/commit/e5c8fc776fbd2c93e25f5749049ee31cf73a0a41
DIFF: https://github.com/llvm/llvm-project/commit/e5c8fc776fbd2c93e25f5749049ee31cf73a0a41.diff
LOG: [mlir][vector] canonicalize unmasked gather/scatter/compress/expand directly into l/s
With the new vector.load/store operations, there is no need to go through
unmasked transfer operations (which will canonicalized to l/s anyway).
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D98056
Added:
Modified:
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Linalg/sparse_vector.mlir
mlir/test/Dialect/Vector/vector-mem-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index ebe3250a09af..8945253644d8 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2764,8 +2764,8 @@ class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
PatternRewriter &rewriter) const override {
switch (get1DMaskFormat(load.mask())) {
case MaskFormat::AllTrue:
- rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
- load, load.getType(), load.base(), load.indices(), false);
+ rewriter.replaceOpWithNewOp<vector::LoadOp>(load, load.getType(),
+ load.base(), load.indices());
return success();
case MaskFormat::AllFalse:
rewriter.replaceOp(load, load.pass_thru());
@@ -2809,8 +2809,8 @@ class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
PatternRewriter &rewriter) const override {
switch (get1DMaskFormat(store.mask())) {
case MaskFormat::AllTrue:
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- store, store.valueToStore(), store.base(), store.indices(), false);
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ store, store.valueToStore(), store.base(), store.indices());
return success();
case MaskFormat::AllFalse:
rewriter.eraseOp(store);
@@ -2951,8 +2951,8 @@ class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
PatternRewriter &rewriter) const override {
switch (get1DMaskFormat(expand.mask())) {
case MaskFormat::AllTrue:
- rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
- expand, expand.getType(), expand.base(), expand.indices(), false);
+ rewriter.replaceOpWithNewOp<vector::LoadOp>(
+ expand, expand.getType(), expand.base(), expand.indices());
return success();
case MaskFormat::AllFalse:
rewriter.replaceOp(expand, expand.pass_thru());
@@ -2996,9 +2996,9 @@ class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
PatternRewriter &rewriter) const override {
switch (get1DMaskFormat(compress.mask())) {
case MaskFormat::AllTrue:
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
compress, compress.valueToStore(), compress.base(),
- compress.indices(), false);
+ compress.indices());
return success();
case MaskFormat::AllFalse:
rewriter.eraseOp(compress);
diff --git a/mlir/test/Dialect/Linalg/sparse_vector.mlir b/mlir/test/Dialect/Linalg/sparse_vector.mlir
index beed6e8bc9aa..5409e4b80681 100644
--- a/mlir/test/Dialect/Linalg/sparse_vector.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_vector.mlir
@@ -35,10 +35,10 @@
// CHECK-VEC1-DAG: %[[c16:.*]] = constant 16 : index
// CHECK-VEC1-DAG: %[[c1024:.*]] = constant 1024 : index
// CHECK-VEC1: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
-// CHECK-VEC1: %[[r:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %{{.*}} {masked = [false]} : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC1: %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
// CHECK-VEC1: %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
// CHECK-VEC1: %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32>
-// CHECK-VEC1: vector.transfer_write %[[m]], %{{.*}}[%[[i]]] {masked = [false]} : vector<16xf32>, memref<1024xf32>
+// CHECK-VEC1: vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
// CHECK-VEC1: }
// CHECK-VEC1: return
//
@@ -47,10 +47,10 @@
// CHECK-VEC2-DAG: %[[c16:.*]] = constant 16 : index
// CHECK-VEC2-DAG: %[[c1024:.*]] = constant 1024 : index
// CHECK-VEC2: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
-// CHECK-VEC2: %[[r:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %{{.*}} {masked = [false]} : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC2: %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
// CHECK-VEC2: %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
// CHECK-VEC2: %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32>
-// CHECK-VEC2: vector.transfer_write %[[m]], %{{.*}}[%[[i]]] {masked = [false]} : vector<16xf32>, memref<1024xf32>
+// CHECK-VEC2: vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
// CHECK-VEC2: }
// CHECK-VEC2: return
//
@@ -214,8 +214,8 @@ func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf
// CHECK-VEC1-DAG: %[[c1024:.*]] = constant 1024 : index
// CHECK-VEC1-DAG: %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
// CHECK-VEC1: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) {
-// CHECK-VEC1: %[[la:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32>
-// CHECK-VEC1: %[[lb:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC1: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC1: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
// CHECK-VEC1: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
// CHECK-VEC1: %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
// CHECK-VEC1: scf.yield %[[a]] : vector<16xf32>
@@ -229,8 +229,8 @@ func @mul_s_alt(%argA: !SparseTensor, %argB: !SparseTensor, %argx: tensor<1024xf
// CHECK-VEC2-DAG: %[[c1024:.*]] = constant 1024 : index
// CHECK-VEC2-DAG: %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
// CHECK-VEC2: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) {
-// CHECK-VEC2: %[[la:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32>
-// CHECK-VEC2: %[[lb:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC2: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
+// CHECK-VEC2: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
// CHECK-VEC2: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
// CHECK-VEC2: %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
// CHECK-VEC2: scf.yield %[[a]] : vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
index ab46a7863a55..bf2c0770bc3e 100644
--- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir
@@ -4,8 +4,7 @@
// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-DAG: %[[C:.*]] = constant 0 : index
-// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
-// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<?xf32>, vector<16xf32>
+// CHECK-NEXT: %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<?xf32>, vector<16xf32>
// CHECK-NEXT: return %[[T]] : vector<16xf32>
func @maskedload0(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = constant 0 : index
@@ -19,8 +18,7 @@ func @maskedload0(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-DAG: %[[C:.*]] = constant 0 : index
-// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
-// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<16xf32>, vector<16xf32>
+// CHECK-NEXT: %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
// CHECK-NEXT: return %[[T]] : vector<16xf32>
func @maskedload1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = constant 0 : index
@@ -46,8 +44,7 @@ func @maskedload2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<1
// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-DAG: %[[C:.*]] = constant 8 : index
-// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
-// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<?xf32>, vector<16xf32>
+// CHECK-NEXT: %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<?xf32>, vector<16xf32>
// CHECK-NEXT: return %[[T]] : vector<16xf32>
func @maskedload3(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c8 = constant 8 : index
@@ -61,7 +58,7 @@ func @maskedload3(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
// CHECK-NEXT: %[[C:.*]] = constant 0 : index
-// CHECK-NEXT: vector.transfer_write %[[A1]], %[[A0]][%[[C]]] {masked = [false]} : vector<16xf32>, memref<16xf32>
+// CHECK-NEXT: vector.store %[[A1]], %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
// CHECK-NEXT: return
func @maskedstore1(%base: memref<16xf32>, %value: vector<16xf32>) {
%c0 = constant 0 : index
@@ -144,8 +141,7 @@ func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<1
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
// CHECK-DAG: %[[C:.*]] = constant 0 : index
-// CHECK-DAG: %[[D:.*]] = constant 0.000000e+00 : f32
-// CHECK-NEXT: %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<16xf32>, vector<16xf32>
+// CHECK-NEXT: %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
// CHECK-NEXT: return %[[T]] : vector<16xf32>
func @expand1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = constant 0 : index
@@ -171,7 +167,7 @@ func @expand2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf3
// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
// CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
// CHECK-NEXT: %[[C:.*]] = constant 0 : index
-// CHECK-NEXT: vector.transfer_write %[[A1]], %[[A0]][%[[C]]] {masked = [false]} : vector<16xf32>, memref<16xf32>
+// CHECK-NEXT: vector.store %[[A1]], %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
// CHECK-NEXT: return
func @compress1(%base: memref<16xf32>, %value: vector<16xf32>) {
%c0 = constant 0 : index
More information about the Mlir-commits
mailing list