Module Ocannl.Shape

Tensor shape types, shape inference, projection inference.

Labels specifications and einsum notation.

Definition and properties of the syntax of labels specifications and einsum notation:

If labels_spec does not contain "|" nor "->", each label is of the kind Output. If the spec doesn't contain "|", labels to the left of "->" are Input and to the right Output. Labels to the left of "|" are Batch, and between "|" and "->" are Input.

The labels ".."ident"..", "..." (where ident does not contain any of the special characters) are only allowed once for a kind. They are used to enable (in-the-middle) broadcasting for the axis kind in the einsum-related shape inference (like the ellipsis "..." in numpy.einsum), and are translated to row variables. The ellipsis "..." is context dependent: in the batch row it is the same as "..batch..", in the input row the same as "..input..", in the output row the same as "..output..". When the same row variable is used in multiple rows, the corresponding broadcasted axes are matched pointwise in the resulting operation.

The label "_" is a place-holder: it is not output to the resulting map but aligns the axes of other labels.

Conv expressions have the form stride*output+dilation*kernel where stride and dilation are optional integer coefficients (defaulting to 1), and output/kernel are axis labels. This syntax enables convolution-style indexing where input_dimension = stride * output_iterator + dilation * kernel_iterator. Conv expressions automatically trigger multichar mode and are only supported in multichar mode.

Note: currently, OCANNL shapes always allow broadcasting. Row variables track the broadcasted axes -- if there is no row variable, broadcasted axes are not tracked. In the notation case `row_spec` = `axes_spec`, the axes are the rightmost axes (broadcasting to the left). In the past, we supported preventing broadcasting, but removed that to reduce complexity.

User-ish API.

type padding = Row.axis_padding Base.array Base.option
val sexp_of_padding : padding -> Sexplib0.Sexp.t
val padding_of_sexp : Sexplib0.Sexp.t -> padding
val equal_padding : padding -> padding -> Base.bool
type axis_spec =
  1. | Label of Base.string
    (*

    A variable axis label.

    *)
  2. | Fixed_index of Base.int
    (*

    A fixed index, used for projection.

    *)
  3. | Conv_spec of {
    1. stride : Base.int;
    2. output_label : Base.string;
    3. dilation : Base.int;
    4. kernel_label : Base.string;
    }
    (*

    Convolution-style axis specification: stride*output + dilation*kernel.

    *)

Specification for individual axes in the einsum notation.

val compare_axis_spec : axis_spec -> axis_spec -> Base.int
val sexp_of_axis_spec : axis_spec -> Sexplib0.Sexp.t
val axis_spec_of_sexp : Sexplib0.Sexp.t -> axis_spec
type t = {
  1. mutable batch : Row.t;
  2. mutable input : Row.t;
  3. mutable output : Row.t;
  4. mutable batch_padding : padding;
  5. mutable input_padding : padding;
  6. mutable output_padding : padding;
  7. id : Base.int;
    (*

    A node that has the same shape as this shape, or -1.

    *)
  8. debug_name : Base.string;
}
include Ppx_compare_lib.Equal.S with type t := t
val equal : t Base__Ppx_compare_lib.equal
include Sexplib0.Sexpable.S with type t := t
val t_of_sexp : Sexplib0__.Sexp.t -> t
val sexp_of_t : t -> Sexplib0__.Sexp.t
type deduce_within_shape =
  1. | Not_constrained
  2. | Input_equals_output
val compare_deduce_within_shape : deduce_within_shape -> deduce_within_shape -> Base.int
val sexp_of_deduce_within_shape : deduce_within_shape -> Sexplib0.Sexp.t
val deduce_within_shape_of_sexp : Sexplib0.Sexp.t -> deduce_within_shape
type delayed_var_ref = {
  1. var_ref : Ir.Indexing.variable_ref;
  2. mutable var : [ `Row of Row.row_var | `Dim of Row.dim_var | `Not_set_yet ];
}
val equal_delayed_var_ref : delayed_var_ref -> delayed_var_ref -> Base.bool
val sexp_of_delayed_var_ref : delayed_var_ref -> Sexplib0.Sexp.t
val get_variable_ref : Base.string -> delayed_var_ref

Returns a fully unset variable reference with the given label.

val set_dim : delayed_var_ref -> Base.int -> Base.unit

Sets the dimension resp. total elements of the dim resp. row variable reference to the given value. This will propagate through shape inference.

For row variables, this means the product of the dimensions, via the Total_elems constraint.

val set_equal : delayed_var_ref -> delayed_var_ref -> Base.unit

Sets the two variable references to be equal (in some sense). This will propagate through shape inference.

When both references are dimension variables or both are row variables, this means they are precisely equal. When one is a dimension variable and the other is a row variable, this means they have the same number of total elements.

type compose_type =
  1. | Pointwise_bin
    (*

    NumPy-style broadcast matching batch, input and output axes, e.g. as in s1 + s2.

    *)
  2. | Compose
    (*

    Compose the outputs of the second shape with the inputs of the first shape, i.e. the shape of fun x -> s1(s2(x)), or s1 * s2 where * is the inner product (e.g. matrix multiply).

    *)
  3. | Einsum of Base.string * delayed_var_ref Base.list
    (*

    The binary "einsum" syntax: RHS1;RHS2=>LHS, where RHSi, LHS are labels specifications. Since OCANNL's extended einsum notation supports both axis variables and row variables, it makes other compose types redundant. The axis_labels use pseudo-labels local to the notation, to line up the axes and row variables. The symmetric difference / disjunctive union of RHS1 and RHS2's pseudo-labels should be equal to LHS pseudo-labels.

    The optional Ir.Indexing.variable_refs will capture the solutions of the dimensions corresponding to the specification labels equal to ref_label of a reference.

    Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs", "rhs1;rhs2=>lhs".

    *)
val sexp_of_compose_type : compose_type -> Sexplib0.Sexp.t
val equal_compose_type : compose_type -> compose_type -> Base.bool
type transpose_type =
  1. | Transpose
    (*

    Swaps inputs and outputs of a shape, preserves batch axes.

    *)
  2. | Pointwise_un
    (*

    Preserves the shape.

    *)
  3. | Permute of Base.string * delayed_var_ref Base.list
    (*

    The unary "einsum" syntax: RHS1=>LHS.

    The optional Ir.Indexing.variable_refs will capture the solutions of the dimensions corresponding to the specification labels equal to ref_label of a reference.

    *)
  4. | Batch_slice of Ir.Indexing.static_symbol
    (*

    Removes the leftmost batch axis.

    *)
  5. | Uint4x32_to_prec of Ir.Ops.prec Base.Lazy.t
    (*

    Converts precision in a bit-effient way, with a corresponding conversion in total number of elements. Currently, assumes the incoming tensor (RHS) has just a single axis to not force unnecessary minimum sizes on output axes.

    *)
val equal_transpose_type : transpose_type -> transpose_type -> Base.bool
val sexp_of_transpose_type : transpose_type -> Sexplib0.Sexp.t
type ternary_type =
  1. | Pointwise_tern
    (*

    As in the operation Where.

    *)
  2. | Compose_accumulate
    (*

    As in the operation FMA.

    *)

If you miss expressivity here, leave a note on issue 305.

val equal_ternary_type : ternary_type -> ternary_type -> Base.bool
val sexp_of_ternary_type : ternary_type -> Sexplib0.Sexp.t
type terminal_type =
  1. | Data of Ir.Assignments.init_data
  2. | Fetch of Ir.Assignments.fetch_op

Extracts any available shape information from the initialization or fetch.

val equal_terminal_type : terminal_type -> terminal_type -> Base.bool
val sexp_of_terminal_type : terminal_type -> Sexplib0.Sexp.t
val make : ?batch_dims:Base.int Base.list -> ?input_dims:Base.int Base.list -> ?output_dims:Base.int Base.list -> ?batch_axes:(Base.string * Base.int) Base.list -> ?input_axes:(Base.string * Base.int) Base.list -> ?output_axes:(Base.string * Base.int) Base.list -> ?deduced:deduce_within_shape -> debug_name:Base.string -> id:Base.int -> Base.unit -> t

Creates a shape. id should be the id the associated tensor (if any). At most one of the pairs batch_dims, batch_axes etc. should be given: if none, the corresponding row will be inferred. batch_axes etc. provide labels for the dimensions of the corresponding axes. Note that these are dimensions labels and not axis labels: they need not be unique for a row, are inferred when provided, and must match whenever the axis sizes must match.

val to_string_hum : ?style:Row.print_style -> t -> Base.string
val unsafe_reinitialize : Base.unit -> Base.unit

Bring global state to its initialization values. This invalidates any unfinished inference.

Internal-ish API.

type logic =
  1. | Broadcast of compose_type * t * t
    (*

    Matches the shapes for a binary operation.

    For Broadcast (Einsum (ls1, ls2, ls3), s1, s2), the labels of s1 and s2 must match according to the ls1, ls2 lineup, and the resulting shape inherits the labels according to the ls3 lineup.

    *)
  2. | Transpose of transpose_type * t
    (*

    Permutes the axes of a shape. One case of Transpose is to swap inputs with outputs of s1, hence the name.

    *)
  3. | Broadcast_tern of ternary_type * t * t * t
    (*

    Matches the shapes for a ternary operation.

    *)
  4. | Terminal of terminal_type
    (*

    Extracts any available shape information from the initialization.

    *)

How to propagate shape updates and do the last update of Tensor.t.shape when finalizing the tensor. Axes are broadcast-expanded on a bottom-up update to fit the incoming shape.

val equal_logic : logic -> logic -> Base.bool
val sexp_of_logic : logic -> Sexplib0.Sexp.t
type update_id
val equal_update_id : update_id -> update_id -> Base.bool
val compare_update_id : update_id -> update_id -> Base.int
val hash_fold_update_id : Ppx_hash_lib.Std.Hash.state -> update_id -> Ppx_hash_lib.Std.Hash.state
val hash_update_id : update_id -> Ppx_hash_lib.Std.Hash.hash_value
val sexp_of_update_id : update_id -> Sexplib0.Sexp.t
val update_id_of_sexp : Sexplib0.Sexp.t -> update_id
val get_update_id : Base.unit -> update_id
val logic_to_spec : logic -> Base.string

Converts a shape logic to its string specification for debugging/display purposes.

type update_step = {
  1. shape : t;
  2. logic : logic;
  3. id : update_id;
}

Data required for a shape inference update step. Ideally, an update should be performed at least twice, the second time after all the other relevant updates have been performed for the first time. In OCANNL, this is achieved by performing updates both as the tensors are constructed, and via lazy callbacks as the corresponding Ir.Indexing dimensions and projections are first accessed.

val sexp_of_update_step : update_step -> Sexplib0.Sexp.t
val to_dims : t -> Base.int Base.array

Uses the matrix convention of putting the input axes last.

val to_padding : t -> (Ir.Ops.axis_padding Base.array * Base.float) Base.option

Returns the padding of the shape, if any. Includes the padded value. Uses the matrix convention of putting the input axes last.

val propagate_shapes : update_step -> Base.unit
val derive_projections : update_step -> Ir.Indexing.projections

Computes the indexing into subtensors given the shape information of a tensor. derive_projections should only be invoked when the shapes are fully inferred already!

val of_spec : ?deduced:deduce_within_shape -> debug_name:Base.string -> id:Base.int -> Base.string -> t
val default_display_indices : t -> Base.int Base.array
val to_labels : t -> Base.string Base.array

Uses the matrix convention of putting the input axes last.

type 'a axis_map
type parsed_axis_labels
val sexp_of_parsed_axis_labels : parsed_axis_labels -> Sexplib0.Sexp.t
val parsed_axis_labels_of_sexp : Sexplib0.Sexp.t -> parsed_axis_labels
val axis_labels_of_spec : Base.string -> parsed_axis_labels
val axis_map_to_dims_index : ?default:'a -> 'a axis_map -> 'a Base.array