[Mlir-commits] [mlir] c9aa55d - [mlir][Linalg] Add speculation for LinalgStructuredOps (#108032)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 11 01:30:09 PDT 2024


Author: Kunwar Grover
Date: 2024-09-11T09:30:05+01:00
New Revision: c9aa55da62b2a9e482c1877897152fb3c47719d2

URL: https://github.com/llvm/llvm-project/commit/c9aa55da62b2a9e482c1877897152fb3c47719d2
DIFF: https://github.com/llvm/llvm-project/commit/c9aa55da62b2a9e482c1877897152fb3c47719d2.diff

LOG: [mlir][Linalg] Add speculation for LinalgStructuredOps (#108032)

This patch adds speculation behavior for linalg structured ops, allowing
them to be hoisted out of loops using LICM.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Transforms/loop-invariant-code-motion.mlir
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ac61117c3d6e36..31f29139247267 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -29,6 +29,7 @@ class LinalgStructuredBase_Op<string mnemonic, list<Trait> props>
   : Op<Linalg_Dialect, mnemonic, !listconcat([
        SingleBlockImplicitTerminator<"YieldOp">,
        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+       DeclareOpInterfaceMethods<ConditionallySpeculatable>,
        DestinationStyleOpInterface,
        LinalgStructuredInterface,
        ReifyRankedShapedTypeOpInterface], props)> {

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 76df3ecf2d2bd4..630985d76a0ebf 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -34,6 +34,7 @@
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallSet.h"
@@ -1202,6 +1203,20 @@ void GenericOp::getEffects(
   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
+static Speculation::Speculatability
+getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
+  // Operands with value semantics are speculatable, while operands with memory
+  // semantics are not.
+  if (!linalgOp.hasPureTensorSemantics())
+    return Speculation::NotSpeculatable;
+  // The body of the op can still have speculation in its region.
+  return Speculation::RecursivelySpeculatable;
+}
+
+Speculation::Speculatability GenericOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 LogicalResult GenericOp::verify() { return success(); }
 
 namespace {
@@ -1553,6 +1568,10 @@ void MapOp::getEffects(
   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
+Speculation::Speculatability MapOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 //===----------------------------------------------------------------------===//
 // ReduceOp
 //===----------------------------------------------------------------------===//
@@ -1621,6 +1640,10 @@ void ReduceOp::getEffects(
   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
+Speculation::Speculatability ReduceOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
                                           NamedAttrList &attributes,
                                           StringRef attributeName) {
@@ -1906,6 +1929,10 @@ void TransposeOp::getEffects(
   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
+Speculation::Speculatability TransposeOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
                                 SmallVectorImpl<OpFoldResult> &result) {
   // Only the tensor type is supported.
@@ -2134,6 +2161,10 @@ void BroadcastOp::getEffects(
   getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
+Speculation::Speculatability BroadcastOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);

diff  --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index 47a49465e8a7cd..57f4ece9c9f2a4 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -1118,3 +1118,94 @@ func.func @hoist_from_scf_while(%arg0: i32, %arg1: i32) -> i32 {
   }
   return %0 : i32
 }
+
+// -----
+
+#trait = {
+  indexing_maps = [
+    affine_map<(m, n, k) -> (m, k)>,
+    affine_map<(m, n, k) -> (k, n)>,
+    affine_map<(m, n, k) -> (m, n)>
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"] 
+}
+
+// CHECK-LABEL: func @hoist_linalg_ops
+// CHECK: linalg.generic
+// CHECK: scf.for
+// CHECK-NOT: linalg.generic
+// CHECK: tensor.insert_slice
+// CHECK: scf.yield
+func.func @hoist_linalg_ops(%a : tensor<128x128xf32>, 
+                            %b : tensor<128x128xf32>, 
+                            %c: tensor<128x128xf32>,
+                            %lb : index,
+                            %ub : index,
+                            %step : index,
+                            %output : tensor<?x128xf32>) -> tensor<?x128xf32> {
+  %final = 
+  scf.for %i = %lb to %ub step %step iter_args(%acc = %output) 
+                                            -> tensor<?x128xf32> {
+    %compute = linalg.generic #trait
+               ins(%a, %b : tensor<128x128xf32>, tensor<128x128xf32>) 
+               outs(%c : tensor<128x128xf32>) {
+    ^bb0(%in : f32, %in2 : f32, %in3 : f32):
+      %mul = arith.mulf %in, %in2 : f32
+      %add = arith.addf %mul, %in3 : f32
+      linalg.yield %in3 : f32
+    } -> tensor<128x128xf32>
+
+    %newacc = tensor.insert_slice %compute into 
+                                  %output[%i, 0][128, 128][1, 1] 
+                                  : tensor<128x128xf32> into tensor<?x128xf32>
+    scf.yield %newacc : tensor<?x128xf32>
+  }
+
+  func.return %final : tensor<?x128xf32>
+}
+
+// -----
+
+#trait = {
+  indexing_maps = [
+    affine_map<(m, n, k) -> (m, k)>,
+    affine_map<(m, n, k) -> (k, n)>,
+    affine_map<(m, n, k) -> (m, n)>
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"] 
+}
+
+// CHECK-LABEL: func @hoist_linalg_ops_div_by_zero
+// CHECK-NOT: linalg.generic
+// CHECK: scf.for
+// CHECK: linalg.generic
+// CHECK: tensor.insert_slice
+// CHECK: scf.yield
+func.func @hoist_linalg_ops_div_by_zero(%a : tensor<128x128xi32>, 
+                            %b : tensor<128x128xi32>, 
+                            %c: tensor<128x128xi32>,
+                            %lb : index,
+                            %ub : index,
+                            %step : index,
+                            %output : tensor<?x128xi32>) -> tensor<?x128xi32> {
+  %cst0 = arith.constant 0 : i32
+  %final = 
+  scf.for %i = %lb to %ub step %step iter_args(%acc = %output) 
+                                            -> tensor<?x128xi32> {
+    %compute = linalg.generic #trait
+               ins(%a, %b : tensor<128x128xi32>, tensor<128x128xi32>) 
+               outs(%c : tensor<128x128xi32>) {
+    ^bb0(%in : i32, %in2 : i32, %in3 : i32):
+      %div = arith.divui %in, %in2 : i32
+      %add = arith.addi %div, %in3 : i32
+      linalg.yield %in3 : i32
+    } -> tensor<128x128xi32>
+
+    %newacc = tensor.insert_slice %compute into 
+                                  %output[%i, 0][128, 128][1, 1] 
+                                  : tensor<128x128xi32> into tensor<?x128xi32>
+    scf.yield %newacc : tensor<?x128xi32>
+  }
+
+  func.return %final : tensor<?x128xi32>
+}

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index a00f12661f7120..7d42c03469dc98 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -656,7 +656,7 @@ ArrayAttr {0}::getIndexingMaps() {{
 }
 )FMT";
 
-// Implementations of fold and getEffects.
+// Implementations of fold, getEffects and getSpeculatability.
 // Parameters:
 // {0}: Class name
 const char structuredOpFoldersFormat[] = R"FMT(
@@ -669,6 +669,9 @@ void {0}::getEffects(SmallVectorImpl<
       if (hasPureTensorSemantics()) return;
       getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
+Speculation::Speculatability {0}::getSpeculatability() {{
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
 )FMT";
 
 // Implementation of parse/print.


        


More information about the Mlir-commits mailing list