Ocannl_tensor.ShapeDefinition and properties of the syntax of labels specifications and einsum notation:
',' anywhere in the initial text, the multicharacter version is used,'>', '|', '-', ',', '=', ';', '+', '*', '_'.If labels_spec does not contain "|" nor "->", each label is of the kind Output. If the spec doesn't contain "|", labels to the left of "->" are Input and to the right Output. Labels to the left of "|" are Batch, and between "|" and "->" are Input.
The labels ".."ident"..", "..." (where ident does not contain any of the special characters) are only allowed once for a kind. They are used to enable (in-the-middle) broadcasting for the axis kind in the einsum-related shape inference (like the ellipsis "..." in numpy.einsum), and are translated to row variables. The ellipsis "..." is context dependent: in the batch row it is the same as "..batch..", in the input row the same as "..input..", in the output row the same as "..output..". When the same row variable is used in multiple rows, the corresponding broadcasted axes are matched pointwise in the resulting operation.
The label "_" is a place-holder: it is not output to the resulting map but aligns the axes of other labels.
Conv expressions have the form stride*output+dilation*kernel where stride and dilation are optional integer coefficients (defaulting to 1), and output/kernel are axis labels. This syntax enables convolution-style indexing where input_dimension = stride * output_iterator + dilation * kernel_iterator. Conv expressions automatically trigger multichar mode and are only supported in multichar mode.
Note: currently, OCANNL shapes always allow broadcasting. Row variables track the broadcasted axes -- if there is no row variable, broadcasted axes are not tracked. In the notation case `row_spec` = `axes_spec`, the axes are the rightmost axes (broadcasting to the left). In the past, we supported preventing broadcasting, but removed that to reduce complexity.
type padding = Row.axis_padding Base.array Base.optionval sexp_of_padding : padding -> Sexplib0.Sexp.tval padding_of_sexp : Sexplib0.Sexp.t -> paddingval sexp_of_axis_spec : axis_spec -> Sexplib0.Sexp.tval axis_spec_of_sexp : Sexplib0.Sexp.t -> axis_specinclude Ppx_compare_lib.Equal.S with type t := tval equal : t Base__Ppx_compare_lib.equalval compare_deduce_within_shape :
deduce_within_shape ->
deduce_within_shape ->
Base.intval sexp_of_deduce_within_shape : deduce_within_shape -> Sexplib0.Sexp.tval deduce_within_shape_of_sexp : Sexplib0.Sexp.t -> deduce_within_shapetype delayed_var_ref = {var_ref : Ir.Indexing.variable_ref;mutable var : [ `Row of Row.row_var | `Dim of Row.dim_var | `Not_set_yet ];}val equal_delayed_var_ref : delayed_var_ref -> delayed_var_ref -> Base.boolval sexp_of_delayed_var_ref : delayed_var_ref -> Sexplib0.Sexp.tval get_variable_ref : Base.string -> delayed_var_refReturns a fully unset variable reference with the given label.
val set_dim : delayed_var_ref -> Base.int -> Base.unitSets the dimension resp. total elements of the dim resp. row variable reference to the given value. This will propagate through shape inference.
For row variables, this means the product of the dimensions, via the Total_elems constraint.
val set_equal : delayed_var_ref -> delayed_var_ref -> Base.unitSets the two variable references to be equal (in some sense). This will propagate through shape inference.
When both references are dimension variables or both are row variables, this means they are precisely equal. When one is a dimension variable and the other is a row variable, this means they have the same number of total elements.
type compose_type = | Pointwise_binNumPy-style broadcast matching batch, input and output axes, e.g. as in s1 + s2.
| ComposeCompose the outputs of the second shape with the inputs of the first shape, i.e. the shape of fun x -> s1(s2(x)), or s1 * s2 where * is the inner product (e.g. matrix multiply).
| Einsum of Base.string * delayed_var_ref Base.listThe binary "einsum" syntax: RHS1;RHS2=>LHS, where RHSi, LHS are labels specifications. Since OCANNL's extended einsum notation supports both axis variables and row variables, it makes other compose types redundant. The axis_labels use pseudo-labels local to the notation, to line up the axes and row variables. The symmetric difference / disjunctive union of RHS1 and RHS2's pseudo-labels should be equal to LHS pseudo-labels.
The optional Ir.Indexing.variable_refs will capture the solutions of the dimensions corresponding to the specification labels equal to ref_label of a reference.
Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs", "rhs1;rhs2=>lhs".
*)val sexp_of_compose_type : compose_type -> Sexplib0.Sexp.tval equal_compose_type : compose_type -> compose_type -> Base.booltype transpose_type = | TransposeSwaps inputs and outputs of a shape, preserves batch axes.
*)| Pointwise_unPreserves the shape.
*)| Permute of Base.string * delayed_var_ref Base.listThe unary "einsum" syntax: RHS1=>LHS.
The optional Ir.Indexing.variable_refs will capture the solutions of the dimensions corresponding to the specification labels equal to ref_label of a reference.
| Batch_slice of Ir.Indexing.static_symbolRemoves the leftmost batch axis.
*)| Uint4x32_to_prec of Ir.Ops.prec Base.Lazy.tConverts precision in a bit-effient way, with a corresponding conversion in total number of elements. Currently, assumes the incoming tensor (RHS) has just a single axis to not force unnecessary minimum sizes on output axes.
*)val equal_transpose_type : transpose_type -> transpose_type -> Base.boolval sexp_of_transpose_type : transpose_type -> Sexplib0.Sexp.tIf you miss expressivity here, leave a note on issue 305.
val equal_ternary_type : ternary_type -> ternary_type -> Base.boolval sexp_of_ternary_type : ternary_type -> Sexplib0.Sexp.tExtracts any available shape information from the initialization or fetch.
val equal_terminal_type : terminal_type -> terminal_type -> Base.boolval sexp_of_terminal_type : terminal_type -> Sexplib0.Sexp.tval make :
?batch_dims:Base.int Base.list ->
?input_dims:Base.int Base.list ->
?output_dims:Base.int Base.list ->
?batch_axes:(Base.string * Base.int) Base.list ->
?input_axes:(Base.string * Base.int) Base.list ->
?output_axes:(Base.string * Base.int) Base.list ->
?deduced:deduce_within_shape ->
debug_name:Base.string ->
id:Base.int ->
Base.unit ->
tCreates a shape. id should be the id the associated tensor (if any). At most one of the pairs batch_dims, batch_axes etc. should be given: if none, the corresponding row will be inferred. batch_axes etc. provide labels for the dimensions of the corresponding axes. Note that these are dimensions labels and not axis labels: they need not be unique for a row, are inferred when provided, and must match whenever the axis sizes must match.
val to_string_hum : ?style:Row.print_style -> t -> Base.stringBring global state to its initialization values. This invalidates any unfinished inference.
val set_terminal : is_param:Base.bool -> t -> Base.unitMark the shape as terminal, so that its rows can be closed to Least Upper Bounds (LUBs). This function is only intended for parameters shapes, which would otherwise not be terminal because of the initialization expressions of the parameters.
type logic = | Broadcast of compose_type * t * tMatches the shapes for a binary operation.
For Broadcast (Einsum (ls1, ls2, ls3), s1, s2), the labels of s1 and s2 must match according to the ls1, ls2 lineup, and the resulting shape inherits the labels according to the ls3 lineup.
| Transpose of transpose_type * tPermutes the axes of a shape. One case of Transpose is to swap inputs with outputs of s1, hence the name.
| Broadcast_tern of ternary_type * t * t * tMatches the shapes for a ternary operation.
*)| Terminal of {is_param : Base.bool;logic : terminal_type;}Extracts any available shape information from the initialization. The is_param field indicates if this is a parameter tensor that requires gradients.
How to propagate shape updates and do the last update of Tensor.t.shape when finalizing the tensor. Axes are broadcast-expanded on a bottom-up update to fit the incoming shape.
val sexp_of_logic : logic -> Sexplib0.Sexp.tval hash_fold_update_id :
Ppx_hash_lib.Std.Hash.state ->
update_id ->
Ppx_hash_lib.Std.Hash.stateval hash_update_id : update_id -> Ppx_hash_lib.Std.Hash.hash_valueval sexp_of_update_id : update_id -> Sexplib0.Sexp.tval update_id_of_sexp : Sexplib0.Sexp.t -> update_idval get_update_id : Base.unit -> update_idval logic_to_spec : logic -> Base.stringConverts a shape logic to its string specification for debugging/display purposes.
Data required for a shape inference update step. Ideally, an update should be performed at least twice, the second time after all the other relevant updates have been performed for the first time. In OCANNL, this is achieved by performing updates both as the tensors are constructed, and via lazy callbacks as the corresponding Ir.Indexing dimensions and projections are first accessed.
val sexp_of_update_step : update_step -> Sexplib0.Sexp.tval to_dims : t -> Base.int Base.arrayUses the matrix convention of putting the input axes last.
val to_padding : t -> (Ir.Ops.axis_padding Base.array * Base.float) Base.optionReturns the padding of the shape, if any. Includes the padded value. Uses the matrix convention of putting the input axes last.
val propagate_shapes : update_step -> Base.unitval derive_projections : update_step -> Ir.Indexing.projectionsComputes the indexing into subtensors given the shape information of a tensor. derive_projections should only be invoked when the shapes are fully inferred already!
val of_spec :
?deduced:deduce_within_shape ->
debug_name:Base.string ->
id:Base.int ->
Base.string ->
tval default_display_indices : t -> Base.int Base.arrayval to_labels : t -> Base.string Base.arrayUses the matrix convention of putting the input axes last.
val sexp_of_parsed_axis_labels : parsed_axis_labels -> Sexplib0.Sexp.tval parsed_axis_labels_of_sexp : Sexplib0.Sexp.t -> parsed_axis_labelsval axis_labels : parsed_axis_labels -> axis_spec axis_mapval axis_labels_of_spec : Base.string -> parsed_axis_labelsval axis_map_to_dims_index : ?default:'a -> 'a axis_map -> 'a Base.array