summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
blob: 5bf26365caa6ec95b25793ddde337117f3f9d24c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Module Bufferization is an extension of Comprehensive Bufferize that
// bufferizes function boundaries. It provides `BufferizableOpInterface`
// implementations for FuncOp, CallOp and ReturnOp.
//
// Module Bufferization is run via `runComprehensiveBufferize(ModuleOp, ...)`.
// This function analyzed the given module and determines the order of
// analysis and bufferization: Functions that are called are processed before
// their respective callers.
//
// After analyzing a FuncOp, additional information about its bbArgs is
// gathered through PostAnalysisSteps and stored in `ModuleBufferizationState`.
//
// * `EquivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each
//   tensor return value (if any).
// * `FuncOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
//   read/written.
//
// Only tensors that are equivalent to some FuncOp bbArg may be returned.
// Bufferization currently fails if other tensors (in particular tensors that
// bufferize out-of-place and result in a new buffer allocation) are returned.
// In the future, such allocations could be hoisted to the caller.
//
// Example: `foo` fails bufferization because %0 is not equivalent to any bbArg.
// ```
// func @foo() -> tensor<?xf32> {
//   %0 = linalg.init_tensor [...] : tensor<?xf32>
//   return %0 : tensor<?xf32>
// }
// ```
//
// Module Bufferization implements the following calling convention.
//
// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
//   be written to in-place.
// * If a tensor operand of a CallOp is read after the CallOp, the operand of
//   the CallOp must bufferize out-of-place.
//
// Example: The tensor.insert op bufferizes in-place because it is allowed to
// modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
// out-of-place because `%t0` is modified by the callee but read by the
// tensor.extract op. The analysis of CallOps decides whether an OpOperand must
// bufferize out-of-place based on results of `FuncOpBbArgReadWriteAnalysis`.
// ```
// func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
//   %f = ... : f32
//   %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
//   return %0 : tensor<?xf32>
// }
//
// func @caller() -> () {
//   %t0 = ... : tensor<?xf32>
//   %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
//   %2 = tensor.extract %1[...]  : tensor<?xf32>
// }
// ```
//
// Note: If a function is external, `FuncOpBbArgReadWriteAnalysis` cannot
// analyze the function body. In such a case, the CallOp analysis conservatively
// assumes that each tensor OpOperand is both read and written.
//
// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
// as "not reading" and/or "not writing".

#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Operation.h"

using namespace mlir;
using namespace linalg;
using namespace tensor;
using namespace comprehensive_bufferize;

namespace {
/// The state of analysis of a FuncOp.
enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed };

/// Extra bufferization state that is required for bufferization of function
/// boundaries.
struct ModuleBufferizationState : public DialectBufferizationState {
  /// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
  /// indices.
  DenseMap<FuncOp, DenseMap<int64_t, int64_t>> equivalentFuncArgs;

  /// A set of all read BlockArguments of FuncOps.
  // Note: BlockArgument knows about its owner, so we do not need to store
  // FuncOps here.
  DenseSet<BlockArgument> readBbArgs;

  /// A set of all written-to BlockArguments of FuncOps.
  DenseSet<BlockArgument> writtenBbArgs;

  /// Keep track of which FuncOps are fully analyzed or currently being
  /// analyzed.
  DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;

  // A list of functions in the order in which they are analyzed + bufferized.
  SmallVector<FuncOp> orderedFuncOps;

  // A mapping of FuncOps to their callers.
  DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
};
} // namespace

/// Get ModuleBufferizationState.
static const ModuleBufferizationState &
getModuleBufferizationState(const BufferizationState &state) {
  Optional<const ModuleBufferizationState *> maybeState =
      state.getDialectState<ModuleBufferizationState>(
          StandardOpsDialect::getDialectNamespace());
  assert(maybeState.hasValue() && "ModuleBufferizationState does not exist");
  return **maybeState;
}

/// Get or create ModuleBufferizationState.
static ModuleBufferizationState &
getModuleBufferizationState(BufferizationState &state) {
  return state.getOrCreateDialectState<ModuleBufferizationState>(
      StandardOpsDialect::getDialectNamespace());
}

/// Return the state (phase) of analysis of the FuncOp.
static FuncOpAnalysisState
getFuncOpAnalysisState(const BufferizationState &state, FuncOp funcOp) {
  const ModuleBufferizationState &moduleState =
      getModuleBufferizationState(state);
  auto it = moduleState.analyzedFuncOps.find(funcOp);
  if (it == moduleState.analyzedFuncOps.end())
    return FuncOpAnalysisState::NotAnalyzed;
  return it->second;
}

/// Return the unique ReturnOp that terminates `funcOp`.
/// Return nullptr if there is no such unique ReturnOp.
static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
  ReturnOp returnOp;
  for (Block &b : funcOp.body()) {
    if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
      if (returnOp)
        return nullptr;
      returnOp = candidateOp;
    }
  }
  return returnOp;
}

namespace {
/// Store function BlockArguments that are equivalent to a returned value in
/// ModuleBufferizationState.
struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
  /// Annotate IR with the results of the analysis. For testing purposes only.
  static void annotateReturnOp(OpOperand &returnVal, BlockArgument bbArg) {
    const char *kEquivalentArgsAttr = "__equivalent_func_args__";
    Operation *op = returnVal.getOwner();

    SmallVector<int64_t> equivBbArgs;
    if (op->hasAttr(kEquivalentArgsAttr)) {
      auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>();
      equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
        return a.cast<IntegerAttr>().getValue().getSExtValue();
      }));
    } else {
      equivBbArgs.append(op->getNumOperands(), -1);
    }
    equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();

    OpBuilder b(op->getContext());
    op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
  }

  LogicalResult run(Operation *op, BufferizationState &state,
                    BufferizationAliasInfo &aliasInfo,
                    SmallVector<Operation *> &newOps) override {
    ModuleBufferizationState &moduleState = getModuleBufferizationState(state);

    // Support only single return-terminated block in the function.
    auto funcOp = cast<FuncOp>(op);
    ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
    assert(returnOp && "expected func with single return op");

    for (OpOperand &returnVal : returnOp->getOpOperands())
      if (returnVal.get().getType().isa<RankedTensorType>())
        for (BlockArgument bbArg : funcOp.getArguments())
          if (bbArg.getType().isa<RankedTensorType>())
            if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(),
                                                        bbArg)) {
              moduleState
                  .equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] =
                  bbArg.getArgNumber();
              if (state.getOptions().testAnalysisOnly)
                annotateReturnOp(returnVal, bbArg);
            }

    return success();
  }
};

/// Return true if the buffer of the given tensor value is written to. Must not
/// be called for values inside not yet analyzed functions. (Post-analysis
/// steps do not have to be run yet, i.e., "in progress" is also OK.)
static bool isValueWritten(Value value, const BufferizationState &state,
                           const BufferizationAliasInfo &aliasInfo) {
#ifndef NDEBUG
  assert(value.getType().isa<TensorType>() && "expected TensorType");
  FuncOp funcOp;
  if (auto bbArg = value.dyn_cast<BlockArgument>()) {
    Operation *owner = bbArg.getOwner()->getParentOp();
    funcOp = isa<FuncOp>(owner) ? cast<FuncOp>(owner)
                                : owner->getParentOfType<FuncOp>();
  } else {
    funcOp = value.getDefiningOp()->getParentOfType<FuncOp>();
  }
  assert(getFuncOpAnalysisState(state, funcOp) !=
             FuncOpAnalysisState::NotAnalyzed &&
         "FuncOp must be fully analyzed or analysis in progress");
#endif // NDEBUG

  bool isWritten = false;
  aliasInfo.applyOnAliases(value, [&](Value val) {
    for (OpOperand &use : val.getUses())
      if (state.isInPlace(use) && state.bufferizesToMemoryWrite(use))
        isWritten = true;
  });
  return isWritten;
}

/// Determine which FuncOp bbArgs are read and which are written. If this
/// PostAnalysisStep is run on a function with unknown ops, it will
/// conservatively assume that such ops bufferize to a read + write.
struct FuncOpBbArgReadWriteAnalysis : public PostAnalysisStep {
  LogicalResult run(Operation *op, BufferizationState &state,
                    BufferizationAliasInfo &aliasInfo,
                    SmallVector<Operation *> &newOps) override {
    ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
    auto funcOp = cast<FuncOp>(op);

    // If the function has no body, conservatively assume that all args are
    // read + written.
    if (funcOp.getBody().empty()) {
      for (BlockArgument bbArg : funcOp.getArguments()) {
        moduleState.readBbArgs.insert(bbArg);
        moduleState.writtenBbArgs.insert(bbArg);
      }

      return success();
    }

    for (BlockArgument bbArg : funcOp.getArguments()) {
      if (!bbArg.getType().isa<TensorType>())
        continue;
      if (state.isValueRead(bbArg))
        moduleState.readBbArgs.insert(bbArg);
      if (isValueWritten(bbArg, state, aliasInfo))
        moduleState.writtenBbArgs.insert(bbArg);
    }

    return success();
  }
};
} // namespace

static bool isaTensor(Type t) { return t.isa<TensorType>(); }

/// If `value` is a memref::CastOp, return its source. Otherwise, return
/// `value` directly.
static Value getNonCastedValue(Value value) {
  while (auto castOp = value.getDefiningOp<memref::CastOp>())
    value = castOp.source();
  return value;
}

/// Remove the attribute that triggers inplace bufferization on a FuncOp
/// argument `bbArg`.
static void removeBufferizationFuncArguments(BlockArgument bbArg) {
  auto funcOp = cast<FuncOp>(bbArg.getOwner()->getParentOp());
  funcOp.removeArgAttr(bbArg.getArgNumber(),
                       BufferizableOpInterface::kBufferLayoutAttrName);
  funcOp.removeArgAttr(bbArg.getArgNumber(),
                       BufferizableOpInterface::kInplaceableAttrName);
}

/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp) {
  SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
  if (!sym)
    return nullptr;
  return dyn_cast_or_null<FuncOp>(
      SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}

/// Return the FunctionType with `argumentTypes` and `resultTypes` where each
/// tensor is replaced by the corresponding buffer type.
/// In order for all the callers to agree, this *must* bufferize to the most
/// dynamic buffer type supported.
/// A later pass across all CallOps in the module can decide whether to simplify
/// the types of to version according to some cost model.
static FunctionType getBufferizedFunctionType(MLIRContext *ctx,
                                              TypeRange argumentTypes,
                                              TypeRange resultTypes) {
  auto rewrite = [](Type t) -> Type {
    // TODO: non-zero address space.
    // TODO: layout information if relevant.
    if (auto rankedTensorType = t.dyn_cast<RankedTensorType>())
      return getDynamicMemRefType(rankedTensorType);
    if (auto tensorType = t.dyn_cast<TensorType>())
      return getUnrankedMemRefType(tensorType.getElementType());
    return t;
  };
  auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite));
  auto retTypes = llvm::to_vector<4>(llvm::map_range(resultTypes, rewrite));
  return FunctionType::get(ctx, argTypes, retTypes);
}

/// Gather equivalence info of CallOps.
/// Note: This only adds new equivalence info if `funcOp` was already analyzed.
// TODO: This does not handle cyclic function call graphs etc.
static void equivalenceAnalysis(FuncOp funcOp,
                                BufferizationAliasInfo &aliasInfo,
                                ModuleBufferizationState &moduleState) {
  funcOp->walk([&](CallOp callOp) {
    FuncOp calledFunction = getCalledFunction(callOp);
    assert(calledFunction && "could not retrieved called FuncOp");

    // No equivalence info available for the called function.
    if (!moduleState.equivalentFuncArgs.count(calledFunction))
      return WalkResult::skip();

    for (auto it : moduleState.equivalentFuncArgs[calledFunction]) {
      int64_t returnIdx = it.first;
      int64_t bbargIdx = it.second;
      Value returnVal = callOp.getResult(returnIdx);
      Value argVal = callOp->getOperand(bbargIdx);
      aliasInfo.unionEquivalenceClasses(returnVal, argVal);
    }

    return WalkResult::advance();
  });
}

/// Rewrite the `funcOp` arguments analysis return values and terminator into
/// buffer form (using the canonical memref layout for now), according to the
/// inPlace-bufferizable information of the function arguments.
///
/// This relies on a buffer equivalence analysis of each return operand. When a
/// result buffer is equivalent to a BlockArgument of `funcOp`, it can be
/// dropped from the return values and becomes inplaceable at all callers. This
/// assumes all CallOp perform the necessary work to clone operands so as to
/// make them inplaceable. Reliance on this logic will need to be relaxed in the
/// future.
///
/// Note: Returning a memref currently fails bufferization. If such memrefs
/// originate from an op with an Alloc effect, they could be hoisted in the
/// future.
static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
                                             RewriterBase &rewriter,
                                             BufferizationState &state) {
  ModuleBufferizationState &moduleState = getModuleBufferizationState(state);

  // If nothing to do then we are done.
  if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
      !llvm::any_of(funcOp.getType().getResults(), isaTensor))
    return success();

  // Get the bufferized FunctionType for funcOp or construct it if not yet
  // available.
  // TODO: Atm we have 3 cases:
  // 1. if a function is called from within the Module, it must have bufferized
  //    to inplaceable tensor results.
  // 2. if it is bodiless, it must have bufferized and is not allowed to have
  //    result tensors.
  // 3. if it is not called internally, it still must bufferize to inplaceable
  //    tensor results and we construct it now (e.g. top-level function called
  //    externally).
  // -> Figure out a better layering.
  TypeRange resultTypes;

  // Corner case: Bodiless FuncOp
  // ============================
  // The body of such functions is assumed opaque and we can't know the
  // bufferization contract they want to enforce atm.
  // As a consequence, only support functions that don't return any tensor atm.
  if (funcOp.getBody().empty()) {
    if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
      return funcOp->emitError() << "cannot bufferize bodiless function that "
                                 << "returns a tensor";
    FunctionType bufferizedFuncType = getBufferizedFunctionType(
        funcOp.getContext(), funcOp.getType().getInputs(), TypeRange{});
    funcOp.setType(bufferizedFuncType);
    return success();
  }

  // Support only single return-terminated block in the function.
  ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
  assert(returnOp && "expected func with single return op");

  // 1. For each FuncOp result, keep track of which inplace argument it reuses.
  SmallVector<Value> returnValues;
  for (OpOperand &returnOperand : returnOp->getOpOperands()) {
    Value returnVal = returnOperand.get();

    // If not a renturn tensor type just forward it.
    if (!returnVal.getType().isa<RankedTensorType>()) {
      returnValues.push_back(returnVal);
      continue;
    }

    // If return operand is equivalent to some bbArg, no need to return it.
    if (moduleState.equivalentFuncArgs[funcOp].count(
            returnOperand.getOperandNumber()))
      continue;

    // Cast values at the call site if necessary.
    returnValues.push_back(
        getNonCastedValue(*state.getBuffer(rewriter, returnOperand)));
  }

  // 2. Rewrite the terminator without the inPlace bufferizable values.
  ValueRange retValues{returnValues};
  FunctionType bufferizedFuncType = getBufferizedFunctionType(
      funcOp.getContext(), funcOp.getType().getInputs(), retValues.getTypes());
  OpBuilder b(returnOp);
  b.create<ReturnOp>(returnOp.getLoc(), returnValues);
  returnOp->erase();

  // 3. Rewrite the bbArgs.
  // Iterate on the original `numArgs` and replace them in order.
  // This guarantees the argument order still matches after the rewrite.
  Block &frontBlock = funcOp.body().front();
  unsigned numArgs = frontBlock.getNumArguments();
  for (unsigned idx = 0; idx < numArgs; ++idx) {
    auto bbArg = frontBlock.getArgument(0);
    auto tensorType = bbArg.getType().dyn_cast<TensorType>();
    // Non-tensor types are just forwarded.
    if (!tensorType) {
      frontBlock.addArgument(bbArg.getType());
      bbArg.replaceAllUsesWith(frontBlock.getArguments().back());
      frontBlock.eraseArgument(0);
      continue;
    }

    // Get the buffer type from the bufferized function type.
    Type memrefType = bufferizedFuncType.getInput(idx);
    Value memref = frontBlock.addArgument(memrefType);
    OpBuilder b(funcOp->getContext());
    b.setInsertionPointToStart(&frontBlock);
    // Replace all uses of bbArg through a ToMemRefOp by a memref::CastOp.
    for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
      if (auto toMemrefOp =
              dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
        assert(memref::CastOp::areCastCompatible(
                   memref.getType(), toMemrefOp.memref().getType()) &&
               "bufferizeFuncOpBoundary: cast incompatible");
        auto castOp = b.create<memref::CastOp>(
            funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
        toMemrefOp.memref().replaceAllUsesWith(castOp);
      }
    }
    // Replace all remaining uses by a to_tensor.
    if (!bbArg.use_empty()) {
      auto toTensorOp =
          b.create<bufferization::ToTensorOp>(funcOp.getLoc(), memref);
      bbArg.replaceAllUsesWith(toTensorOp);
    }
    frontBlock.eraseArgument(0);
    // TODO: add support to erase aliasInfo entries if deemed necessary.
  }

  // 4. Rewrite the FuncOp type to buffer form.
  funcOp.setType(bufferizedFuncType);

  return success();
}

/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
/// callee-caller order (i.e. callees without callers first).
/// Store the map of FuncOp to all its callers in `callerMap`.
/// Return `failure()` if a cycle of calls is detected or if we are unable to
/// retrieve the called FuncOp from any CallOpInterface.
static LogicalResult
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
                         SmallVectorImpl<FuncOp> &orderedFuncOps,
                         DenseMap<FuncOp, DenseSet<Operation *>> &callerMap) {
  // For each FuncOp, the set of functions called by it (i.e. the union of
  // symbols of all nested CallOpInterfaceOp).
  DenseMap<FuncOp, DenseSet<FuncOp>> calledBy;
  // For each FuncOp, the number of CallOpInterface it contains.
  DenseMap<FuncOp, unsigned> numberCallOpsContainedInFuncOp;
  WalkResult res = moduleOp.walk([&](FuncOp funcOp) -> WalkResult {
    if (!funcOp.body().empty()) {
      ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
      if (!returnOp)
        return funcOp->emitError()
               << "cannot bufferize a FuncOp with tensors and "
                  "without a unique ReturnOp";
    }

    numberCallOpsContainedInFuncOp[funcOp] = 0;
    return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
      // Only support CallOp for now.
      if (!isa<CallOp>(callOp.getOperation()))
        return callOp->emitError() << "expected a CallOp";
      FuncOp calledFunction = getCalledFunction(callOp);
      assert(calledFunction && "could not retrieved called FuncOp");
      auto it = callerMap.try_emplace(calledFunction, DenseSet<Operation *>{});
      it.first->getSecond().insert(callOp);
      if (calledBy[calledFunction].count(funcOp) == 0) {
        calledBy[calledFunction].insert(funcOp);
        numberCallOpsContainedInFuncOp[funcOp]++;
      }
      return WalkResult::advance();
    });
  });
  if (res.wasInterrupted())
    return failure();
  // Iteratively remove function operation that do not call any of the
  // functions remaining in the callCounter map and add them to the worklist.
  while (!numberCallOpsContainedInFuncOp.empty()) {
    auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
                            [](auto entry) { return entry.getSecond() == 0; });
    if (it == numberCallOpsContainedInFuncOp.end())
      return moduleOp.emitOpError(
          "expected callgraph to be free of circular dependencies.");
    orderedFuncOps.push_back(it->getFirst());
    for (auto callee : calledBy[it->getFirst()])
      numberCallOpsContainedInFuncOp[callee]--;
    numberCallOpsContainedInFuncOp.erase(it);
  }
  return success();
}

static void
foreachCaller(const DenseMap<FuncOp, DenseSet<Operation *>> &callerMap,
              FuncOp callee, llvm::function_ref<void(Operation *)> doit) {
  auto itCallers = callerMap.find(callee);
  if (itCallers == callerMap.end())
    return;
  for (Operation *caller : itCallers->second)
    doit(caller);
}

/// Postprocess the linalg.buffer_layout annotation across function boundaries.
/// This is a purely mechanical process that may later become part of a
/// separate pass with its own layout assignment heuristic.
static void layoutPostProcessing(ModuleOp moduleOp) {
  SmallVector<FuncOp> orderedFuncOps;
  DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
  auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap);
  (void)res;
  assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure");

  for (FuncOp funcOp : orderedFuncOps) {
    DenseMap<Operation *, SmallVector<Value>> operandsPerCaller;
    foreachCaller(callerMap, funcOp, [&](Operation *caller) {
      operandsPerCaller.try_emplace(caller, SmallVector<Value>());
    });

    SmallVector<Type> argumentTypes;
    // Iterate on each function argument and check it it was marked with a
    // desired layout.
    for (const auto &it : llvm::enumerate(funcOp.getType().getInputs())) {
      int argNumber = it.index();
      Type inputType = it.value();
      auto memrefType = inputType.dyn_cast<MemRefType>();
      auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
          argNumber, BufferizableOpInterface::kBufferLayoutAttrName);
      AffineMap desiredLayoutMap =
          layoutAttr ? layoutAttr.getValue() : AffineMap();
      AffineMap currentLayoutMap =
          memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap();
      if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) {
        argumentTypes.push_back(inputType);
        foreachCaller(callerMap, funcOp, [&](Operation *caller) {
          operandsPerCaller.find(caller)->getSecond().push_back(
              caller->getOperand(argNumber));
        });
        continue;
      }

      // Compute the buffer type with desired layout and add to input argument
      // types.
      MemRefType desiredMemrefType = MemRefType::get(
          memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap);
      argumentTypes.push_back(desiredMemrefType);

      // If funcOp's body is not empty, change the bbArg type and propagate.
      if (!funcOp.body().empty()) {
        BlockArgument bbArg = funcOp.getArgument(argNumber);
        bbArg.setType(desiredMemrefType);
        OpBuilder b(bbArg.getContext());
        b.setInsertionPointToStart(bbArg.getOwner());
        assert(memref::CastOp::areCastCompatible(bbArg.getType(), memrefType) &&
               "layoutPostProcessing: cast incompatible");
        // Cast back to the original memrefType and let it canonicalize.
        Value cast =
            b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg);
        bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp());
      }

      // Cast to desired buffer type on all callers to `funcOp`.
      // TODO: on the callee side, this may even have to trigger a copy to
      // change the layout. For now let the memref::CastOp fail to verify in
      // such cases.
      auto castArg = [&](Operation *caller) {
        OpBuilder b(caller);
        assert(
            memref::CastOp::areCastCompatible(
                caller->getOperand(argNumber).getType(), desiredMemrefType) &&
            "layoutPostProcessing.2: cast incompatible");
        Value newOperand = b.create<memref::CastOp>(
            funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber));
        operandsPerCaller.find(caller)->getSecond().push_back(newOperand);
      };
      foreachCaller(callerMap, funcOp, castArg);
    }

    // Set operands with cast buffer on all callers to `funcOp`.
    foreachCaller(callerMap, funcOp, [&](Operation *caller) {
      caller->setOperands(operandsPerCaller.lookup(caller));
    });

    // Finally set the funcOp type to update the arguments.
    auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes,
                                         funcOp.getType().getResults());
    funcOp.setType(newFuncType);
  }
}

namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
namespace std_ext {

/// Return the index of the bbArg in the given FuncOp that is equivalent to the
/// specified return value (if any).
static Optional<int64_t>
getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleBufferizationState &state,
                        int64_t returnValIdx) {
  auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
  if (funcOpIt == state.equivalentFuncArgs.end())
    // No equivalence info stores for funcOp.
    return None;

  auto retValIt = funcOpIt->getSecond().find(returnValIdx);
  if (retValIt == funcOpIt->getSecond().end())
    // Return value has no equivalent bbArg.
    return None;

  return retValIt->getSecond();
}

struct CallOpInterface
    : public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                              const BufferizationState &state) const {
    CallOp callOp = cast<CallOp>(op);
    FuncOp funcOp = getCalledFunction(callOp);
    assert(funcOp && "expected CallOp to a FuncOp");

    const ModuleBufferizationState &moduleState =
        getModuleBufferizationState(state);
    if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
      // FuncOp not analyzed yet. Assume that OpOperand is read.
      return true;

    return moduleState.readBbArgs.contains(
        funcOp.getArgument(opOperand.getOperandNumber()));
  }

  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                               const BufferizationState &state) const {
    CallOp callOp = cast<CallOp>(op);
    FuncOp funcOp = getCalledFunction(callOp);
    assert(funcOp && "expected CallOp to a FuncOp");

    const ModuleBufferizationState &moduleState =
        getModuleBufferizationState(state);
    if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
      // FuncOp not analyzed yet. Assume that OpOperand is written.
      return true;

    return moduleState.writtenBbArgs.contains(
        funcOp.getArgument(opOperand.getOperandNumber()));
  }

  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
                               const BufferizationState &state) const {
    CallOp callOp = cast<CallOp>(op);
    FuncOp funcOp = getCalledFunction(callOp);
    assert(funcOp && "expected CallOp to a FuncOp");
    const ModuleBufferizationState &moduleState =
        getModuleBufferizationState(state);

    for (int64_t resultIdx = 0; resultIdx < callOp->getNumResults();
         ++resultIdx)
      if (Optional<int64_t> maybeArgNumber =
              getEquivalentFuncArgIdx(funcOp, moduleState, resultIdx))
        if (*maybeArgNumber == opOperand.getOperandNumber())
          return callOp->getOpResult(resultIdx);

    // Note: Returning a non-equivalent tensor from a FuncOp is currently not
    // supported an will fail bufferization. (Even if allow-return-memref, it
    // will fail when the function is called.)
    return OpResult();
  }

  SmallVector<OpOperand *>
  getAliasingOpOperand(Operation *op, OpResult opResult,
                       const BufferizationState &state) const {
    CallOp callOp = cast<CallOp>(op);
    FuncOp funcOp = getCalledFunction(callOp);
    assert(funcOp && "expected CallOp to a FuncOp");
    const ModuleBufferizationState &moduleState =
        getModuleBufferizationState(state);

    // TODO: We should be looking for aliasing block arguments here. The current
    // condition is actually stronger than neccesary. Once we check for aliasing
    // block arguments, we may be multiple.
    if (Optional<int64_t> maybeArgNumber = getEquivalentFuncArgIdx(
            funcOp, moduleState, opResult.getResultNumber()))
      return {&op->getOpOperand(*maybeArgNumber)};

    // Note: Returning a non-equivalent tensor from a FuncOp is currently not
    // supported an will fail bufferization.
    return {};
  }

  BufferRelation bufferRelation(Operation *op, OpResult opResult,
                                const BufferizationAliasInfo &aliasInfo,
                                const BufferizationState &state) const {
    return BufferRelation::Equivalent;
  }

  /// In a first approximation, all the function arguments of a FuncOp are
  /// marked inplaceable. For now, it is the responsibility of the `callOp`
  /// bufferization to allow FuncOp that are inplaceable to write inPlace.
  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationState &state) const {
    CallOp callOp = cast<CallOp>(op);
    unsigned numResults = callOp.getNumResults();
    unsigned numOperands = callOp->getNumOperands();
    FuncOp funcOp = getCalledFunction(callOp);
    assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
           "expected CallOp to a FuncOp");
    const ModuleBufferizationState &moduleState =
        getModuleBufferizationState(state);

    // Result types of the bufferized CallOp.
    SmallVector<Type> resultTypes;
    // Replacement values for the existing CallOp. These are usually the results
    // of the bufferized CallOp, unless a tensor result folds onto an operand.
    SmallVector<Value> replacementValues(numResults, Value());
    // For non-tensor results: A mapping from return val indices of the old
    // CallOp to return val indices of the bufferized CallOp.
    SmallVector<Optional<unsigned>> retValMapping(numResults, None);
    // Operands of the bufferized CallOp.
    SmallVector<Value> newOperands(numOperands, Value());

    // Based on previously gathered equivalence information, we know if a
    // tensor result folds onto an operand. These are the only tensor value
    // results that are supported at the moment.
    //
    // For tensors return values that do not fold onto an operand, additional
    // work is needed (TODO) to either:
    // * hoist a result into an inplaceable operand or
    // * devise a better representation to truly return a buffer.
    //
    // Note: If a function has no body, no equivalence information is
    // available. Consequently, a tensor return value cannot be proven to fold
    // onto a FuncOp bbArg, so calls to such functions are not bufferizable at
    // the moment.

    // 1. Compute the result types of the new CallOp. Tensor results that are
    // equivalent to a FuncOp bbArg are no longer returned.
    for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
      unsigned returnValIdx = it.index();
      Type returnType = it.value();
      if (!isaTensor(returnType)) {
        // Non-tensor values are returned.
        retValMapping[returnValIdx] = resultTypes.size();
        resultTypes.push_back(returnType);
        continue;
      }

      if (Optional<int64_t> bbArgIdx =
              getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) {
        // Return operands that are equivalent to some bbArg, are not
        // returned.
        FailureOr<Value> bufferOrFailure =
            state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx));
        if (failed(bufferOrFailure))
          return failure();
        replacementValues[returnValIdx] = *bufferOrFailure;
        newOperands[*bbArgIdx] = *bufferOrFailure;
        continue;
      }

      return callOp->emitError(
          "call to FuncOp that returns non-equivalent tensors not supported");
    }

    // 2. Compute bufferized FunctionType.
    SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
    // Get the bufferized FunctionType for funcOp or construct it if not yet
    // available.
    FunctionType bufferizedFuncType = getBufferizedFunctionType(
        funcOp.getContext(), argumentTypes, resultTypes);

    // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
    for (OpOperand &opOperand : callOp->getOpOperands()) {
      unsigned idx = opOperand.getOperandNumber();
      Value tensorOperand = opOperand.get();

      // Non-tensor operands are just copied.
      if (!tensorOperand.getType().isa<TensorType>()) {
        newOperands[idx] = tensorOperand;
        continue;
      }

      // Retrieve buffers for tensor operands. Tensor operand buffers, who's
      // corresponding FuncOp bbArgs are equivalent to a returned tensor, were
      // already stored in `newOperands` during Step 1.
      Value buffer = newOperands[idx];
      if (!buffer) {
        FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, opOperand);
        if (failed(bufferOrFailure))
          return failure();
        buffer = *bufferOrFailure;
      }

      // Caller / callee type mismatch is handled with a CastOp.
      auto memRefType = bufferizedFuncType.getInput(idx);
      // Since we don't yet have a clear layout story, to_memref may
      // conservatively turn tensors into more dynamic memref than necessary.
      // If the memref type of the callee fails, introduce an extra memref.cast
      // that will either canonicalize away or fail compilation until we can do
      // something better.
      if (buffer.getType() != memRefType) {
        assert(
            memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
            "CallOp::bufferize: cast incompatible");
        Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
                                                           memRefType, buffer);
        buffer = castBuffer;
      }
      newOperands[idx] = buffer;
    }

    // 4. Create the new CallOp.
    Operation *newCallOp = rewriter.create<CallOp>(
        callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
    newCallOp->setAttrs(callOp->getAttrs());
    // Get replacement values for non-tensor / non-equivalent results.
    for (unsigned i = 0; i < replacementValues.size(); ++i) {
      if (replacementValues[i])
        continue;
      replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
    }

    // 5. Replace the old op with the new op.
    replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);

    return success();
  }
};

struct ReturnOpInterface
    : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
                                                    ReturnOp> {
  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                              const BufferizationState &state) const {
    return true;
  }

  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                               const BufferizationState &state) const {
    return false;
  }

  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
                               const BufferizationState &state) const {
    return OpResult();
  }

  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationState &state) const {
#ifndef NDEBUG
    auto returnOp = cast<ReturnOp>(op);
    assert(isa<FuncOp>(returnOp->getParentOp()) &&
           "only support FuncOp parent for ReturnOp");
#endif // NDEBUG
    return failure();
  }
};

struct FuncOpInterface
    : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                          const BufferizationState &state) const {
    return failure();
  }

  /// Return `true` if the given function argument is writable.
  bool isWritable(Operation *op, Value value,
                  const BufferizationState &state) const {
    auto funcOp = cast<FuncOp>(op);
    BlockArgument bbArg = value.dyn_cast<BlockArgument>();
    assert(bbArg && "expected BlockArgument");

    // "linalg.inplaceable" overrides other writability decisions. This is
    // currently used for testing only.
    if (BoolAttr inplaceAttr = funcOp.getArgAttrOfType<BoolAttr>(
            bbArg.getArgNumber(),
            BufferizableOpInterface::kInplaceableAttrName))
      return inplaceAttr.getValue();

    // All function arguments are writable by default.
    return true;
  }

  bool isAllocationHoistingBarrier(Operation *op) const { return true; }
};

} // namespace std_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir

void mlir::linalg::comprehensive_bufferize::std_ext::
    registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
  registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
  registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
  registry.addOpInterface<FuncOp, std_ext::FuncOpInterface>();
}

/// Set the attribute that triggers inplace bufferization on a FuncOp argument
/// `bbArg`.
static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) {
  auto funcOp = cast<FuncOp>(bbArg.getOwner()->getParentOp());
  funcOp.setArgAttr(bbArg.getArgNumber(),
                    BufferizableOpInterface::kInplaceableAttrName,
                    BoolAttr::get(bbArg.getContext(), inPlace));
}

/// Annotate the IR with the result of the analysis. For testing/debugging only.
static void
annotateOpsWithBufferizationMarkers(FuncOp funcOp,
                                    const BufferizationState &state) {
  auto bufferizableOp = cast<BufferizableOpInterface>(funcOp.getOperation());
  for (BlockArgument bbArg : funcOp.getArguments())
    if (bbArg.getType().isa<TensorType>())
      setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state));
}

LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
    ModuleOp moduleOp, std::unique_ptr<BufferizationOptions> options) {
  IRRewriter rewriter(moduleOp.getContext());
  BufferizationState state(moduleOp, *options);
  ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
  BufferizationAliasInfo &aliasInfo = state.getAliasInfo();

  if (failed(getFuncOpsOrderedByCalls(moduleOp, moduleState.orderedFuncOps,
                                      moduleState.callerMap)))
    return failure();

  // Collect bbArg/return value information after the analysis.
  options->postAnalysisSteps.emplace_back(
      std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
  options->postAnalysisSteps.emplace_back(
      std::make_unique<FuncOpBbArgReadWriteAnalysis>());

  // Analyze ops.
  for (FuncOp funcOp : moduleState.orderedFuncOps) {
    // No body => no analysis.
    if (funcOp.body().empty())
      continue;

    // Now analyzing function.
    moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;

    // Analyze funcOp.
    if (failed(analyzeOp(funcOp, state)))
      return failure();

    // Gather equivalence info for CallOps.
    // TODO: Make this a post-analysis step.
    equivalenceAnalysis(funcOp, aliasInfo, moduleState);

    // Mark op as fully analyzed.
    moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;

    // Add annotations to function arguments.
    if (options->testAnalysisOnly)
      annotateOpsWithBufferizationMarkers(funcOp, state);
  }

  if (options->testAnalysisOnly)
    return success();

  // Bufferize function bodies.
  for (FuncOp funcOp : moduleState.orderedFuncOps) {
    // No body => no analysis.
    if (funcOp.body().empty())
      continue;

    if (failed(bufferizeOp(funcOp, state)))
      return failure();
  }

  // Bufferize function boundaries.
  for (FuncOp funcOp : moduleState.orderedFuncOps) {
    // Note: It would be good to apply cleanups here but we cannot as aliasInfo
    // would be invalidated.
    if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, state)))
      return failure();

    if (!options->allowReturnMemref &&
        llvm::any_of(funcOp.getType().getResults(), [](Type t) {
          return t.isa<MemRefType, UnrankedMemRefType>();
        })) {
      funcOp->emitError("memref return type is unsupported");
      return failure();
    }
  }

  // Perform a post-processing pass of layout modification at function boundary
  // according to the kBufferLayoutAttrName.
  layoutPostProcessing(moduleOp);

  // Post-pass cleanup of inplaceable and buffer_layout attributes.
  moduleOp.walk([&](FuncOp op) {
    for (BlockArgument bbArg : op.getArguments())
      removeBufferizationFuncArguments(bbArg);
  });

  return success();
}