summaryrefslogtreecommitdiff
path: root/mlir/docs/ShapeInference.md
blob: d3f4155337f627c7bd4c60d87ae6f48e09dde0f4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
# Shape Inference

Shape inference as discussed here is considered a specific instance of type
inference for [ShapedType][ShapedType]. Type constraints are along (at least)
three axis: 1) elemental type, 2) rank (including static or dynamic), 3)
dimensions. While some operations have no compile time fixed shape (e.g., output
shape is dictated by data) we could still have some knowledge of
constraints/bounds in the system for that operation (e.g., the output of a
`tf.where` is at most the size of the input data). That is, there are additional
valuable constraints that could be captured even without full knowledge of the
shape.

Type inference is currently modelled executionally for operation creation using the
[`InferTypeOpInterface`][InferTypeOpInterface], while
`InferShapedTypeOpInterface` is used to implement the shape and element type
inference. The return type can often be deduced from the deduced return shape
and elemental type (queryable from `InferShapedTypeOpInterface`) and so type
inference for tensor types can be implemented with `InferShapedTypeOpInterface`.

[TOC]

## Shape functions

The C++ interfaces are the base mechanism whereby shape inference is queried and
executed, but not the intended way to specify shape constraints in general.

Initially the shape inference will be declaratively specified using:

*   Constraints on the operands of an operation directly. For example
    constraining the input type to be tensor/vector elements or that the
    elemental type be of a specific type (e.g., output of computing the size
    of a value is of elemental type `i1`) or class (e.g., float-like).
*   Constraints across operands and results of an operation.

    - For example, specifying equality constraints on type/constituents of a
      type (shape and elemental type) between operands and results (e.g., the
      output type of an add is the same as those of the input operands).

NOTE: The C++ shape functions are an intermediate step until the shape dialect
is more full-fledged, at which point the C++ functions should become the
exceptional case.

## Testing

Shape inference is currently tested alongside type inference by
`TestReturnTypeDriver` in the test dialect. This driver performs two checks:

1.  Verification that the return types specified matches the inferred types. This
    explicit check will be removed and made part of Op verification instead.
2.  Test the creation of Ops without specifying the return type explicitly in
    function `testCreateFunctions` by creating new binary Ops (Op classes
    specified in `TestReturnTypeDriver`) using 1) all operands to
    `testCreateFunctions` as both operands, and 2) using combinations of input
    operands of the function.

## Shape dialect

This section details the shape type inference dialect (`shape`). The initial
focus will be on shape functions that describe shape functions could be used in
runtime and compiler (for constructions of ops/refinement of shapes, reification
of dynamic allocations for dialect including TF, TFLite, XLA & tensor compute
dialect under discussion).

This will focus on the shape functions (e.g., determine the rank and dimensions
of the output shape). As shown in the shaped container type, shape will be one
of 3 components, the others being elemental type and attribute (which is
currently left open with the intention of supporting extensions such as layouts
or bounded shapes at a later point). This allows for decoupling of these:

*   Not all the information is needed for all analysis;
*   Not all shape functions need to provide all the information (e.g., one could
    define a base class function that only populates element type but composes
    with the others);
*   It allows reusing the constraints between, say, Tensor and Memref
    representation of an operation;
    
An argument could be made that these are metadata function instead of shape
functions, with some considering shape and elemental types different and some considering them both as
part of shape. But `shape function` is IMHO descriptive and metadata can span
too large a range of potential uses/values.

### Requirements

The requirements for the shape inference functions are determined by the
requirements of shape inference, but we believe the requirements below still
allow freedom to consider different shape inference approaches and so we do not
impose a particular shape inference approach here.

#### Shape inference functions

*   **Expressiveness** shape functions need to support programs where tensors
    have shapes that are not known statically (for example, `tensor<16x?xf32>`
    or `tensor<*xf32>*`);
*   **Shape error detection** Many operations will have constraints on their
    operands. If the constraints are not satisfied or cannot be determined if
    satisfied statically, then a runtime check/assertion could be generated.

    *   This also aligns with the requirement that the shape function description
        should be usable by both the compiler and runtime.
    *   Shape error functions should be easy to understand, at least what
        constraint of the operation is violated. This also requires that shape
        function error messages should be configurable by the author of the
        shape function (e.g., the author would be able to give the semantic
        constraint invalidated rather the low-level check that failed).
    *   The static analysis may be used to eliminate run-time checks that are
        guaranteed to pass.
        *   Ideally all would eventually (see section
            [Inlining shape checking](#inline)) be elided.
    *   Only reporting errors which are guaranteed to occur at runtime. If an error is only
        possible (rather than guaranteed) then we use a runtime assertion to fail and produce an error
        message with the invariant violated.

*   Shape functions usable by compiler and runtime.

    *   This does not mean the exact same C++ function, but rather the
        description should be consumable by either.
    *   Shape function description should not be constrained by either runtime
        or compiler's type system to handle types only used for analysis. That
        is, these two type systems differ and both should be supported, but the
        intersection of the two should not be required. As a particular example,
        if a compiler only wants to differentiate exact shapes vs dynamic
        shapes, then it need not consider a more generic shape lattice even
        though the shape description supports it.

*   Declarative (e.g., analyzable at compile time, possible to generate
    different versions for different use cases)

    *   This may not strictly be a requirement, but a way to handle the former:
        a declarative specification could be reused by both while avoiding a
        need to map to or from a 3rd representation given these two systems
        have/and will have different types.

*   Shape inference functions are expressible at runtime

    *   User can define a shape function for a new operation dynamically at runtime,
        this allows for vendors to describe an operation and shape function
        dynamically.

        This requirement is on the wishlist.

*   Doesn't require graph-wide shape information (e.g., only require local
    information)

    *   Shape functions should be cheap to invoke on each kernel launch.
    *   Shape function can be dictated by arguments (operands, attributes and regions)
        only (e.g., same operands as the corresponding operation could be
        constructed & invoked with).
    *   Shape information that needs higher-level/graph information should use
        richer types (e.g., `TensorList<F32>`);
    *   The function should be invocable before/while constructing an op (e.g.,
        can't rely on the op being constructed).

*   Shape functions should be pure functions.

*   Should support functions whose type is only known dynamically (e.g.,
    `read_from_file` op)

    *   Without needing to invoke the op (e.g., reading a file once for
        determining the shape & then post to be able to actually consume the
        output of the file).

*   The shape function operation dialect should be interoperable with non-shape function dialect operations.

    *   There may be a common set of operations that satisfy most uses (e.g., merge,
        equal_type, arithmetic expressions, slice, concat, pattern matching on
        attributes such as padding etc.) that will be discovered and could cover
        a large percentage of the use cases. Among these there will be some
        which carry extra semantic info that could be used for symbolic
        constraints (e.g., checking equality of two dimensions resulting in
        setting an equality constraint) and higher-order interpretation for
        constraint solving.

        It is therefore beneficial (but not required) to reuse operations, 
        especially as for statically known shapes, arbitrary arithmetic
        computations could still be performed. This means that the computations
        performed statically may or may not be supported by an arbitrary solver,
        but would still be allowed.

*   The shape function should be expandable such that symbolic equality and
    upper bound constraints (say) could be represented and may be propagated by
    shape inference.

    *   E.g., the shape functions may contain more information that is only
        useful when used from shape inference;

*   Shape functions are allowed to fail and report an error. The error reporting
    should report the location of the operation that failed with, where
    possible, a user actionable error message.

    *   These failures could become inlined and become runtime failures with
        runtime values and error messages.
    *   Reporting errors should be optional. E.g., The same function
        may be used as to query validity without reporting an error.

#### Non-goals

1.  The shape dialect is an IR representations and not a programming language;
    *   While the functions should be readable, it doesn't carry the
        conveniences of a programming language. Deciding how people write these
        things, e.g. a mini dsl, a C++ API that generates them, extracting them
        programmatically from `SetShapeFn` calls, etc., is still TBD.
1.  Describe the shape inference approach that will use the shape functions;
    *   The goal is that the shape functions and the constraints one could
        obtain from them are general enough that they would be useful for
        various analysis. But whether we follow very simple (e.g., only fully
        static information is used for shape output, unranked for everything
        else) to very advance (e.g., expression trees of symbolic constants) can
        be evaluated independently of this proposal and with concrete benefit
        analysis.
1.  Describe the approach whereby error messages will be generated;
    *   While the shape functions will be able to emit errors optionally, it
        will be possible to dictate when they emit an error. This enables
        deciding whether or which error to emit: there have been proposals in
        the literature that the iteration order for shape inference affect the
        quality of the error message produced, and the shape functions do not
        mandate that.
1.  Flow sensitive shape functions;
    *   To enable scalable/cheap shape inference, the shape functions do not
        intend to provide flow sensitive information. This facility could
        potentially be built as part of some higher order analysis that reuse
        the shape functions/constraints due to the shape functions.
1.  All static functions are usable for dynamic/unknown shapes;
    *   More involved computations can be performed with statically known shapes
        than what can be sensibly analyzed with unknown/symbolic variables.

### Discussion

#### Inline shape inference checks {#inline}

Shape functions should be lowerable to runtime checks for validity. E.g. verify
as much as possible statically, but enable generating instructions to compute the
shape dynamically and or falling back to runtime checks for attributes not
verifiable at compile time. These checks inserted should ideally only check that
which could not have been verified statically.

These inlined calls could interfere with optimization patterns/passes (e.g.,
shape inference should not insert constructs that interfere with optimization
patterns) and so could be delayed until later (with another round of
optimizations, constant folding, CSE, etc., that should remove redundant runtime
operations).

### Possibly Asked Questions

#### What about ODS specifications of operations?

In ODS we have been recording the constraints for the operands & attributes of
an operation. Where these are sufficient to constrain the output shape (e.g.,
`SameOperandAndResultType` or broadcastable) we should generate the shape
function from those. Where not, an explicit shape function should be specified
(spelling TBD but currently considering using the MLIR textual form as
serialization approach).

#### Why not extract the shape function from reference implementation?

This could be done in future! The extracted shape function would use the shape
inference dialect, so we are starting there. Especially for operations described in a
structured way, one could autogenerate the shape function.

#### How/in what language will the shape functions be authored?

TBD. open to many approaches and suggestions, starting on the IR produced by
whatever language is the priority of this proposal.

#### What shape inference approach is being suggested here?

None. There are multiple different shape inference approaches that we could
layer on top of these. From the most basic (always return unranked), to more
useful (return fixed shape for constant inputs/arguments) to the more advanced
(create logical conjunctions of algebraic statements between symbolic named
values).

### Open points

1.  Should shape functions that produce dynamic outputs given all statically
    shaped inputs be marked specially? E.g., read from file.

TODO: Add examples here.

## WIP/Future considerations

Shape functions are determined by attributes and could be arbitrarily
complicated with a wide-range of specification possibilities. Equality
relationships are common (e.g., the elemental type of the output matches the
primitive type of the inputs, both inputs have exactly the same type [primitive
type and shape]) and so these should be easy to specify. Algebraic relationships
would also be common (e.g., a concat of `[n,m]` and `[n,m]` matrix along axis 0
is `[n+n, m]` matrix), while some ops only have defined shapes under certain
cases (e.g., matrix multiplication of `[a,b]` and `[c,d]` is only defined if `b
== c`).

Instead of specifying an additional mechanism to specify a shape transfer
function, the reference implementation of the operation will be used to derive
the shape function. The reference implementation is general and can support the
arbitrary computations needed to specify output shapes.

[InferTypeOpInterface]: https://github.com/llvm/llvm-project/tree/main/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
[ShapedType]: https://github.com/llvm/llvm-project/tree/main/mlir/include/mlir/IR/BuiltinTypes.h