[Mlir-commits] [mlir] a40a08e - [mlir][Linalg] Teach constant -> generic op fusion to handle scalar constants.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 22 13:41:56 PDT 2021


Author: MaheshRavishankar
Date: 2021-09-22T13:41:47-07:00
New Revision: a40a08ed988f4da0183622ff62bc151712bd9de0

URL: https://github.com/llvm/llvm-project/commit/a40a08ed988f4da0183622ff62bc151712bd9de0
DIFF: https://github.com/llvm/llvm-project/commit/a40a08ed988f4da0183622ff62bc151712bd9de0.diff

LOG: [mlir][Linalg] Teach constant -> generic op fusion to handle scalar constants.

The current folder of constant -> generic op only handles splat
constants. The same logic holds for scalar constants. Teach the
pattern to handle such cases.

Differential Revision: https://reviews.llvm.org/D109982

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 2b29f58428dce..befff7b4795e1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1162,11 +1162,12 @@ struct FoldReshapeWithGenericOpByExpansion
   ControlElementwiseOpsFusionFn controlFoldingReshapes;
 };
 
-/// Pattern to fold a generic op with a splat constant.
-class FoldSplatConstants : public OpRewritePattern<GenericOp> {
+/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
+/// handle cases where the constant is not single-valued.
+class FoldConstants : public OpRewritePattern<GenericOp> {
 public:
-  FoldSplatConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
-                     PatternBenefit benefit = 1)
+  FoldConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
+                PatternBenefit benefit = 1)
       : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
@@ -1175,10 +1176,37 @@ class FoldSplatConstants : public OpRewritePattern<GenericOp> {
       return failure();
     for (OpOperand *opOperand : genericOp.getInputOperands()) {
       Operation *def = opOperand->get().getDefiningOp();
-      DenseElementsAttr constantAttr;
-      if (!def ||
-          !matchPattern(def, m_Constant<DenseElementsAttr>(&constantAttr)) ||
-          !constantAttr.isSplat() || !controlFn(def->getResult(0), *opOperand))
+      Attribute constantAttr;
+      auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
+        {
+          DenseElementsAttr splatAttr;
+          if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
+              splatAttr.isSplat() &&
+              splatAttr.getType().getElementType().isIntOrFloat()) {
+            constantAttr = splatAttr.getSplatValue();
+            return true;
+          }
+        }
+        {
+          IntegerAttr intAttr;
+          if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
+            constantAttr = intAttr;
+            return true;
+          }
+        }
+        {
+          FloatAttr floatAttr;
+          if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
+            constantAttr = floatAttr;
+            return true;
+          }
+        }
+        return false;
+      };
+
+      auto resultValue = opOperand->get().dyn_cast<OpResult>();
+      if (!def || !resultValue || !isScalarOrSplatConstantOp(def) ||
+          !controlFn(resultValue, *opOperand))
         continue;
 
       // The operands and the indexing_maps of the fused operation the same as
@@ -1205,8 +1233,7 @@ class FoldSplatConstants : public OpRewritePattern<GenericOp> {
 
       // Create a constant scalar value from the splat constant.
       Value scalarConstant = rewriter.create<ConstantOp>(
-          def->getLoc(), constantAttr.getSplatValue(),
-          constantAttr.getType().getElementType());
+          def->getLoc(), constantAttr, constantAttr.getType());
 
       SmallVector<Value> outputOperands = genericOp.getOutputOperands();
       auto fusedOp = rewriter.create<GenericOp>(
@@ -1411,7 +1438,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
 void mlir::linalg::populateElementwiseOpsFusionPatterns(
     RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
   auto *context = patterns.getContext();
-  patterns.add<FuseElementwiseOps, FoldSplatConstants>(
+  patterns.add<FuseElementwiseOps, FoldConstants>(
       context, options.controlElementwiseOpsFusionFn);
   patterns.add<RemoveOutsDependency>(context);
   populateFoldReshapeOpsByExpansionPatterns(patterns,

diff  --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 976d6eede80fe..985335a5f9523 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -740,3 +740,37 @@ func @break_outs_dependency(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
 //  CHECK-DAG:   %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
 //      CHECK:   %[[RESULT:.+]] = linalg.generic
 // CHECK-SAME:     outs(%[[INIT]] : tensor<?x?xf32>)
+
+// -----
+
+func @fuse_scalar_constant(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi32>) {
+  %cst = constant 4.0 : f32
+  %c42 = constant 42 : i32
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %0 = linalg.init_tensor[%d0, %d1] : tensor<?x?xf32>
+  %1 = linalg.init_tensor[%d0, %d1] : tensor<?x?xi32>
+  %2:2 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                       affine_map<(d0, d1) -> ()>,
+		       affine_map<(d0, d1) -> ()>,
+		       affine_map<(d0, d1) -> (d0, d1)>,
+		       affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%arg0, %cst, %c42 : tensor<?x?xf32>, f32, i32)
+      outs(%0, %1 : tensor<?x?xf32>, tensor<?x?xi32>) {
+      ^bb0(%arg1 : f32, %arg2 : f32, %arg3 : i32, %arg4 : f32, %arg5 : i32) :
+        %3 = addf %arg1, %arg2 : f32
+	linalg.yield %3, %arg3 : f32, i32
+      } -> (tensor<?x?xf32>, tensor<?x?xi32>)
+  return %2#0, %2#1 : tensor<?x?xf32>, tensor<?x?xi32>
+}
+// CHECK-LABEL: func @fuse_scalar_constant
+//   CHECK-DAG:   %[[CST:.+]] = constant 4.000000e+00 : f32
+//   CHECK-DAG:   %[[C42:.+]] = constant 42 : i32
+//       CHECK:   linalg.generic
+//  CHECK-SAME:       ins(%{{.+}} : tensor<?x?xf32>)
+//       CHECK:     %[[YIELD:.+]] = addf %{{.+}}, %[[CST]] : f32
+//       CHECK:     linalg.yield %[[YIELD]], %[[C42]] : f32, i32


        


More information about the Mlir-commits mailing list