[Mlir-commits] [mlir] [mlir][vector] Implement speculation for vector.transferx ops (PR #111533)

Kunwar Grover llvmlistbot at llvm.org
Tue Oct 8 08:27:06 PDT 2024


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/111533

>From 164e5a7575b7a0fa610168e5aba6b6a6cb9576b6 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 8 Oct 2024 15:01:35 +0100
Subject: [PATCH 1/2] [mlir][vector] Implement speculation for vector.transferx
 ops

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  2 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 12 +++
 .../loop-invariant-code-motion.mlir           | 82 +++++++++++++++++++
 3 files changed, 96 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 45fd1c6e3f9384..b0de7c11b9d436 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1240,6 +1240,7 @@ def Vector_TransferReadOp :
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+      DeclareOpInterfaceMethods<ConditionallySpeculatable>,
       AttrSizedOperandSegments,
       DestinationStyleOpInterface
     ]>,
@@ -1487,6 +1488,7 @@ def Vector_TransferWriteOp :
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+      DeclareOpInterfaceMethods<ConditionallySpeculatable>,
       AttrSizedOperandSegments,
       DestinationStyleOpInterface
   ]>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index dc92bea09dc160..2dfb59edcf3d2f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4245,6 +4245,12 @@ void TransferReadOp::getEffects(
                          SideEffects::DefaultResource::get());
 }
 
+Speculation::Speculatability TransferReadOp::getSpeculatability() {
+  if (!hasPureTensorSemantics())
+    return Speculation::NotSpeculatable;
+  return Speculation::Speculatable;
+}
+
 namespace {
 /// Store to load forwarding for transfer operations with permuation maps.
 /// Even if the permutation maps are different we can still propagate the store
@@ -4627,6 +4633,12 @@ void TransferWriteOp::getEffects(
                          SideEffects::DefaultResource::get());
 }
 
+Speculation::Speculatability TransferWriteOp::getSpeculatability() {
+  if (!hasPureTensorSemantics())
+    return Speculation::NotSpeculatable;
+  return Speculation::Speculatable;
+}
+
 namespace {
 /// Remove dead transfer write from the SSA chain so that it an be eliminated by
 /// DCE
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index 57f4ece9c9f2a4..be76cb8a6156f4 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -1209,3 +1209,85 @@ func.func @hoist_linalg_ops_div_by_zero(%a : tensor<128x128xi32>,
 
   func.return %final : tensor<?x128xi32>
 }
+
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_ops
+// CHECK: vector.transfer_read
+// CHECK: scf.for
+// CHECK-NOT: vector.transfer_read
+// CHECK: arith.addf
+// CHECK: scf.yield
+func.func @hoist_vector_transfer_ops(
+                            %a : tensor<128x128xf32>, 
+                            %lb : index,
+                            %ub : index,
+                            %step : index,
+                            %ida : index,
+                            %idb : index) -> vector<4x4xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+  %cst = arith.constant dense<0.0> : vector<4x4xf32>
+  %final = 
+  scf.for %i = %lb to %ub step %step iter_args(%acc = %cst) -> vector<4x4xf32> {
+    %read = vector.transfer_read %a[%ida, %idb], %cst_0 : tensor<128x128xf32>, vector<4x4xf32>
+    %out = arith.addf %read, %acc : vector<4x4xf32>
+    scf.yield %out : vector<4x4xf32>
+  }
+  func.return %final : vector<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_ops
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+// CHECK: scf.for
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: arith.addf
+// CHECK: scf.yield
+func.func @hoist_vector_transfer_ops(
+                            %lb : index,
+                            %ub : index,
+                            %step : index,
+                            %ida : index,
+                            %idb : index) -> vector<4x4xf32> {
+  %c0 = arith.constant 0 : index
+  %cst_0 = arith.constant 0.0 : f32
+  %cst = arith.constant dense<0.0> : vector<4x4xf32>
+  %empty = tensor.empty() : tensor<4x4xf32>
+  %final = 
+  scf.for %i = %lb to %ub step %step iter_args(%acc = %cst) -> vector<4x4xf32> {
+    %a = vector.transfer_write %cst, %empty[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32>
+    %read = vector.transfer_read %a[%c0, %c0], %cst_0 : tensor<4x4xf32>, vector<4x4xf32>
+    %out = arith.addf %read, %acc : vector<4x4xf32>
+    scf.yield %out : vector<4x4xf32>
+  }
+  func.return %final : vector<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_ops
+// CHECK-NOT: vector.transfer_read
+// CHECK: scf.for
+// CHECK: vector.transfer_read
+// CHECK: arith.addf
+// CHECK: scf.yield
+func.func @hoist_vector_transfer_ops_memref(
+                            %a : memref<128x128xf32>, 
+                            %lb : index,
+                            %ub : index,
+                            %step : index,
+                            %ida : index,
+                            %idb : index) -> vector<4x4xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+  %cst = arith.constant dense<0.0> : vector<4x4xf32>
+  %final = 
+  scf.for %i = %lb to %ub step %step iter_args(%acc = %cst) -> vector<4x4xf32> {
+    %read = vector.transfer_read %a[%ida, %idb], %cst_0 : memref<128x128xf32>, vector<4x4xf32>
+    %out = arith.addf %read, %acc : vector<4x4xf32>
+    scf.yield %out : vector<4x4xf32>
+  }
+  func.return %final : vector<4x4xf32>
+}

>From ddd1a9d1d5de939737709cddc2aaa5d40d88c53f Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Tue, 8 Oct 2024 16:26:42 +0100
Subject: [PATCH 2/2] remove negation

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2dfb59edcf3d2f..1718530b4aa167 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4246,9 +4246,9 @@ void TransferReadOp::getEffects(
 }
 
 Speculation::Speculatability TransferReadOp::getSpeculatability() {
-  if (!hasPureTensorSemantics())
-    return Speculation::NotSpeculatable;
-  return Speculation::Speculatable;
+  if (hasPureTensorSemantics())
+    return Speculation::Speculatable;
+  return Speculation::NotSpeculatable;
 }
 
 namespace {
@@ -4634,9 +4634,9 @@ void TransferWriteOp::getEffects(
 }
 
 Speculation::Speculatability TransferWriteOp::getSpeculatability() {
-  if (!hasPureTensorSemantics())
-    return Speculation::NotSpeculatable;
-  return Speculation::Speculatable;
+  if (hasPureTensorSemantics())
+    return Speculation::Speculatable;
+  return Speculation::NotSpeculatable;
 }
 
 namespace {



More information about the Mlir-commits mailing list