Module Ocannl_tensor.Row

The row type, shape inference related types and constraint solving.

type axis_padding = Ir.Ops.axis_padding
val equal_axis_padding : axis_padding -> axis_padding -> Base.bool
val sexp_of_axis_padding : axis_padding -> Sexplib0.Sexp.t
val axis_padding_of_sexp : Sexplib0.Sexp.t -> axis_padding
type kind = [
  1. | `Batch
  2. | `Input
  3. | `Output
]
val equal_kind : kind -> kind -> Base.bool
val compare_kind : kind -> kind -> Base.int
val sexp_of_kind : kind -> Sexplib0.Sexp.t
val kind_of_sexp : Sexplib0.Sexp.t -> kind
val __kind_of_sexp__ : Sexplib0.Sexp.t -> kind
val hash_fold_kind : Ppx_hash_lib.Std.Hash.state -> kind -> Ppx_hash_lib.Std.Hash.state
val hash_kind : kind -> Ppx_hash_lib.Std.Hash.hash_value
val batch : kind
val input : kind
val output : kind
val is_batch : kind -> Base.bool
val is_input : kind -> Base.bool
val is_output : kind -> Base.bool
val batch_val : kind -> Base.unit Base.option
val input_val : kind -> Base.unit Base.option
val output_val : kind -> Base.unit Base.option
module Variants_of_kind : sig ... end
type dim_var
val equal_dim_var : dim_var -> dim_var -> Base.bool
val hash_fold_dim_var : Ppx_hash_lib.Std.Hash.state -> dim_var -> Ppx_hash_lib.Std.Hash.state
val hash_dim_var : dim_var -> Ppx_hash_lib.Std.Hash.hash_value
val compare_dim_var : dim_var -> dim_var -> Base.int
val sexp_of_dim_var : dim_var -> Sexplib0.Sexp.t
val dim_var_of_sexp : Sexplib0.Sexp.t -> dim_var
type proj_id
val equal_proj_id : proj_id -> proj_id -> Base.bool
val hash_fold_proj_id : Ppx_hash_lib.Std.Hash.state -> proj_id -> Ppx_hash_lib.Std.Hash.state
val hash_proj_id : proj_id -> Ppx_hash_lib.Std.Hash.hash_value
val compare_proj_id : proj_id -> proj_id -> Base.int
val sexp_of_proj_id : proj_id -> Sexplib0.Sexp.t
val proj_id_of_sexp : Sexplib0.Sexp.t -> proj_id
type dim_cmp
type dim_var_set = (dim_var, dim_cmp) Base.Set.t
val equal_dim_var_set : dim_var_set -> dim_var_set -> Base.bool
val sexp_of_dim_var_set : dim_var_set -> Sexplib0.Sexp.t
val dim_var_set_of_sexp : Sexplib0.Sexp.t -> dim_var_set
type 'a dim_map = (dim_var, 'a, dim_cmp) Base.Map.t
val equal_dim_map : ('a -> 'a -> Base.bool) -> 'a dim_map -> 'a dim_map -> Base.bool
val sexp_of_dim_map : ('a -> Sexplib0.Sexp.t) -> 'a dim_map -> Sexplib0.Sexp.t
val dim_map_of_sexp : (Sexplib0.Sexp.t -> 'a) -> Sexplib0.Sexp.t -> 'a dim_map
type proj_cmp
type proj_var_set = (proj_id, proj_cmp) Base.Set.t
val equal_proj_var_set : proj_var_set -> proj_var_set -> Base.bool
val sexp_of_proj_var_set : proj_var_set -> Sexplib0.Sexp.t
val proj_var_set_of_sexp : Sexplib0.Sexp.t -> proj_var_set
type 'a proj_map = (proj_id, 'a, proj_cmp) Base.Map.t
val equal_proj_map : ('a -> 'a -> Base.bool) -> 'a proj_map -> 'a proj_map -> Base.bool
val sexp_of_proj_map : ('a -> Sexplib0.Sexp.t) -> 'a proj_map -> Sexplib0.Sexp.t
val proj_map_of_sexp : (Sexplib0.Sexp.t -> 'a) -> Sexplib0.Sexp.t -> 'a proj_map
val get_var : ?label:Base.string -> Base.unit -> dim_var
val dim_var_set_empty : dim_var_set
val dim_map_empty : 'a dim_map
val proj_var_set_empty : proj_var_set
val proj_map_empty : 'a proj_map
val use_padding : Base.bool Base.ref
type solved_dim = {
  1. d : Base.int;
  2. label : Base.string Base.option;
  3. proj_id : proj_id Base.option;
}

A single axis in a shape. proj_id is used for projection inference, and abused for provenance tracking during shape inference.

val equal_solved_dim : solved_dim -> solved_dim -> Base.bool
val hash_fold_solved_dim : Ppx_hash_lib.Std.Hash.state -> solved_dim -> Ppx_hash_lib.Std.Hash.state
val hash_solved_dim : solved_dim -> Ppx_hash_lib.Std.Hash.hash_value
val compare_solved_dim : solved_dim -> solved_dim -> Base.int
val sexp_of_solved_dim : solved_dim -> Sexplib0.Sexp.t
val solved_dim_of_sexp : Sexplib0.Sexp.t -> solved_dim
type dim =
  1. | Var of dim_var
  2. | Dim of solved_dim
  3. | Conv_input of {
    1. stride : Base.int;
    2. output : dim;
    3. dilation : Base.int;
    4. kernel : dim;
    }
    (*

    The offset is implicit, automatically derived. If !use_padding is true, the offset is the left part of the dimensionality-preserving symmetric padding, otherwise it is 0.

    *)
val equal_dim : dim -> dim -> Base.bool
val hash_fold_dim : Ppx_hash_lib.Std.Hash.state -> dim -> Ppx_hash_lib.Std.Hash.state
val hash_dim : dim -> Ppx_hash_lib.Std.Hash.hash_value
val compare_dim : dim -> dim -> Base.int
val sexp_of_dim : dim -> Sexplib0.Sexp.t
val dim_of_sexp : Sexplib0.Sexp.t -> dim
val var : dim_var -> dim
val dim : solved_dim -> dim
val conv_input : stride:Base.int -> output:dim -> dilation:Base.int -> kernel:dim -> dim
val is_var : dim -> Base.bool
val is_dim : dim -> Base.bool
val is_conv_input : dim -> Base.bool
val var_val : dim -> dim_var Base.option
val dim_val : dim -> solved_dim Base.option
val conv_input_val : dim -> ([ `stride of Base.int ] * [ `output of dim ] * [ `dilation of Base.int ] * [ `kernel of dim ]) Base.option
module Variants_of_dim : sig ... end
val get_dim : d:Base.int -> ?label:Base.string -> ?proj_id:Base.int -> Base.unit -> dim
val dim_to_int_exn : dim -> Base.int
type print_style =
  1. | Only_labels
  2. | Axis_size
  3. | Axis_number_and_size
  4. | Projection_and_size
val equal_print_style : print_style -> print_style -> Base.bool
val compare_print_style : print_style -> print_style -> Base.int
val sexp_of_print_style : print_style -> Sexplib0.Sexp.t
val print_style_of_sexp : Sexplib0.Sexp.t -> print_style
val solved_dim_to_string : print_style -> solved_dim -> Base.string
val dim_to_string : print_style -> dim -> Base.string
type provenance
val sexp_of_provenance : provenance -> Sexplib0.Sexp.t
val provenance_of_sexp : Sexplib0.Sexp.t -> provenance
val compare_provenance : provenance -> provenance -> Base.int
val equal_provenance : provenance -> provenance -> Base.bool
val hash_fold_provenance : Ppx_hash_lib.Std.Hash.state -> provenance -> Ppx_hash_lib.Std.Hash.state
val hash_provenance : provenance -> Ppx_hash_lib.Std.Hash.hash_value
val provenance : sh_id:Base.int -> kind:kind -> provenance
val empty_provenance : provenance
val merge_provenance : provenance -> provenance -> provenance
type row_var
val sexp_of_row_var : row_var -> Sexplib0.Sexp.t
val row_var_of_sexp : Sexplib0.Sexp.t -> row_var
val compare_row_var : row_var -> row_var -> Base.int
val equal_row_var : row_var -> row_var -> Base.bool
val hash_fold_row_var : Ppx_hash_lib.Std.Hash.state -> row_var -> Ppx_hash_lib.Std.Hash.state
val hash_row_var : row_var -> Ppx_hash_lib.Std.Hash.hash_value
val get_row_var : Base.unit -> row_var
type bcast =
  1. | Row_var of {
    1. v : row_var;
    2. beg_dims : dim Base.list;
    }
    (*

    The row can be inferred to have more axes.

    *)
  2. | Broadcastable
    (*

    The shape does not have more axes of this kind, but is "polymorphic".

    *)

A bcast specifies how axes of a single kind in a shape (i.e. the row) can adapt to other shapes.

val equal_bcast : bcast -> bcast -> Base.bool
val hash_fold_bcast : Ppx_hash_lib.Std.Hash.state -> bcast -> Ppx_hash_lib.Std.Hash.state
val hash_bcast : bcast -> Ppx_hash_lib.Std.Hash.hash_value
val compare_bcast : bcast -> bcast -> Base.int
val sexp_of_bcast : bcast -> Sexplib0.Sexp.t
val bcast_of_sexp : Sexplib0.Sexp.t -> bcast
val row_var : v:row_var -> beg_dims:dim Base.list -> bcast
val broadcastable : bcast
val is_row_var : bcast -> Base.bool
val is_broadcastable : bcast -> Base.bool
val row_var_val : bcast -> ([ `v of row_var ] * [ `beg_dims of dim Base.list ]) Base.option
val broadcastable_val : bcast -> Base.unit Base.option
module Variants_of_bcast : sig ... end
type t = {
  1. dims : dim Base.list;
  2. bcast : bcast;
  3. prov : provenance;
}
include Ppx_compare_lib.Equal.S with type t := t
val equal : t Base__Ppx_compare_lib.equal
include Ppx_hash_lib.Hashable.S with type t := t
val hash_fold_t : t Base__Ppx_hash_lib.hash_fold
val hash : t -> Base__Ppx_hash_lib.Std.Hash.hash_value
include Ppx_compare_lib.Comparable.S with type t := t
val compare : t Base__Ppx_compare_lib.compare
include Sexplib0.Sexpable.S with type t := t
val t_of_sexp : Sexplib0__.Sexp.t -> t
val sexp_of_t : t -> Sexplib0__.Sexp.t
val dims_label_assoc : t -> (Base.string * dim) Base.list
val get_row_for_var : provenance -> row_var -> t
val row_shapes : t -> Base.int Base.list
type environment
val sexp_of_environment : environment -> Sexplib0.Sexp.t
type dim_constraint =
  1. | Unconstrained_dim
  2. | At_least_dim of Base.int
val equal_dim_constraint : dim_constraint -> dim_constraint -> Base.bool
val hash_fold_dim_constraint : Ppx_hash_lib.Std.Hash.state -> dim_constraint -> Ppx_hash_lib.Std.Hash.state
val hash_dim_constraint : dim_constraint -> Ppx_hash_lib.Std.Hash.hash_value
val compare_dim_constraint : dim_constraint -> dim_constraint -> Base.int
val sexp_of_dim_constraint : dim_constraint -> Sexplib0.Sexp.t
val unconstrained_dim : dim_constraint
val at_least_dim : Base.int -> dim_constraint
val is_unconstrained_dim : dim_constraint -> Base.bool
val is_at_least_dim : dim_constraint -> Base.bool
val unconstrained_dim_val : dim_constraint -> Base.unit Base.option
val at_least_dim_val : dim_constraint -> Base.int Base.option
module Variants_of_dim_constraint : sig ... end
type total_elems =
  1. | Num_elems of Base.int
  2. | Strided_var of {
    1. coeff : Base.int Utils.safe_lazy;
    2. var : dim_var;
    3. denom : Base.int;
    }
    (*

    The total number of elements is (coefficient * variable) / denominator.

    *)
val equal_total_elems : total_elems -> total_elems -> Base.bool
val hash_fold_total_elems : Ppx_hash_lib.Std.Hash.state -> total_elems -> Ppx_hash_lib.Std.Hash.state
val hash_total_elems : total_elems -> Ppx_hash_lib.Std.Hash.hash_value
val compare_total_elems : total_elems -> total_elems -> Base.int
val sexp_of_total_elems : total_elems -> Sexplib0.Sexp.t
type row_constraint =
  1. | Unconstrained
  2. | Total_elems of {
    1. numerator : total_elems;
    2. divided_by : dim_var Base.list;
    }
    (*

    The rows, inclusive of the further row spec, have this many elements. The total is numerator / (product of divided_by variables). divided_by has multiset semantics - the same variable can appear multiple times.

    *)
  3. | Exact of dim Base.list
    (*

    The concatenated rows have these axes.

    *)
val equal_row_constraint : row_constraint -> row_constraint -> Base.bool
val hash_fold_row_constraint : Ppx_hash_lib.Std.Hash.state -> row_constraint -> Ppx_hash_lib.Std.Hash.state
val hash_row_constraint : row_constraint -> Ppx_hash_lib.Std.Hash.hash_value
val compare_row_constraint : row_constraint -> row_constraint -> Base.int
val sexp_of_row_constraint : row_constraint -> Sexplib0.Sexp.t
type constraint_origin = {
  1. lhs_name : Base.string;
  2. lhs_kind : kind;
  3. rhs_name : Base.string;
  4. rhs_kind : kind;
  5. operation : Base.string Base.option;
}
val sexp_of_constraint_origin : constraint_origin -> Sexplib0.Sexp.t
type dim_entry =
  1. | Solved_dim of dim
  2. | Bounds_dim of {
    1. is_in_param : Base.bool;
    2. cur : dim_var Base.list;
    3. subr : dim_var Base.list;
    4. lub : dim Base.option;
    5. constr : dim_constraint;
    6. origin : constraint_origin Base.list;
    }

An entry implements inequalities cur >= v >= subr and/or an equality v = solved. cur and subr must be sorted using the @@deriving compare comparison.

val sexp_of_dim_entry : dim_entry -> Sexplib0.Sexp.t
type row_entry =
  1. | Solved_row of t
  2. | Bounds_row of {
    1. is_in_param : Base.bool;
    2. cur : row_var Base.list;
    3. subr : row_var Base.list;
    4. lub : t Base.option;
    5. constr : row_constraint;
    6. origin : constraint_origin Base.list;
    }
val sexp_of_row_entry : row_entry -> Sexplib0.Sexp.t
type constraint_ =
  1. | Dim_eq of {
    1. d1 : dim;
    2. d2 : dim;
    3. origin : constraint_origin Base.list;
    }
  2. | Row_eq of {
    1. r1 : t;
    2. r2 : t;
    3. origin : constraint_origin Base.list;
    }
  3. | Dim_ineq of {
    1. cur : dim;
    2. subr : dim;
    3. origin : constraint_origin Base.list;
    }
  4. | Row_ineq of {
    1. cur : t;
    2. subr : t;
    3. origin : constraint_origin Base.list;
    }
  5. | Dim_constr of {
    1. d : dim;
    2. constr : dim_constraint;
    3. origin : constraint_origin Base.list;
    }
  6. | Rows_constr of {
    1. r : t Base.list;
    2. constr : row_constraint;
    3. origin : constraint_origin Base.list;
    }
    (*

    The constraint applies to the concatenation of the rows. Note: broadcasting does not affect the constraint (i.e. there is no "subtyping", it resembles Row_eq).

    *)
  7. | Terminal_dim of Base.bool * dim * constraint_origin Base.list
    (*

    A terminal dimension with is_param flag indicating if it's a parameter requiring gradient.

    *)
  8. | Terminal_row of Base.bool * t * constraint_origin Base.list
    (*

    A row of the shape of a terminal tensor (i.e. a tensor that does not have sub-tensors). The bool flag indicates if it's a parameter requiring gradient.

    *)
  9. | Shape_row of t * constraint_origin Base.list
    (*

    A row of a shape of interest. The bool flag indicates if it's a parameter requiring gradient.

    *)
val compare_constraint_ : constraint_ -> constraint_ -> Base.int
val equal_constraint_ : constraint_ -> constraint_ -> Base.bool
val sexp_of_constraint_ : constraint_ -> Sexplib0.Sexp.t
val dim_eq : d1:dim -> d2:dim -> origin:constraint_origin Base.list -> constraint_
val row_eq : r1:t -> r2:t -> origin:constraint_origin Base.list -> constraint_
val dim_ineq : cur:dim -> subr:dim -> origin:constraint_origin Base.list -> constraint_
val row_ineq : cur:t -> subr:t -> origin:constraint_origin Base.list -> constraint_
val dim_constr : d:dim -> constr:dim_constraint -> origin:constraint_origin Base.list -> constraint_
val rows_constr : r:t Base.list -> constr:row_constraint -> origin:constraint_origin Base.list -> constraint_
val terminal_dim : Base.bool -> dim -> constraint_origin Base.list -> constraint_
val terminal_row : Base.bool -> t -> constraint_origin Base.list -> constraint_
val shape_row : t -> constraint_origin Base.list -> constraint_
val is_dim_eq : constraint_ -> Base.bool
val is_row_eq : constraint_ -> Base.bool
val is_dim_ineq : constraint_ -> Base.bool
val is_row_ineq : constraint_ -> Base.bool
val is_dim_constr : constraint_ -> Base.bool
val is_rows_constr : constraint_ -> Base.bool
val is_terminal_dim : constraint_ -> Base.bool
val is_terminal_row : constraint_ -> Base.bool
val is_shape_row : constraint_ -> Base.bool
val dim_eq_val : constraint_ -> ([ `d1 of dim ] * [ `d2 of dim ] * [ `origin of constraint_origin Base.list ]) Base.option
val row_eq_val : constraint_ -> ([ `r1 of t ] * [ `r2 of t ] * [ `origin of constraint_origin Base.list ]) Base.option
val dim_ineq_val : constraint_ -> ([ `cur of dim ] * [ `subr of dim ] * [ `origin of constraint_origin Base.list ]) Base.option
val row_ineq_val : constraint_ -> ([ `cur of t ] * [ `subr of t ] * [ `origin of constraint_origin Base.list ]) Base.option
val dim_constr_val : constraint_ -> ([ `d of dim ] * [ `constr of dim_constraint ] * [ `origin of constraint_origin Base.list ]) Base.option
val rows_constr_val : constraint_ -> ([ `r of t Base.list ] * [ `constr of row_constraint ] * [ `origin of constraint_origin Base.list ]) Base.option
val terminal_dim_val : constraint_ -> (Base.bool * dim * constraint_origin Base.list) Base.option
val terminal_row_val : constraint_ -> (Base.bool * t * constraint_origin Base.list) Base.option
val shape_row_val : constraint_ -> (t * constraint_origin Base.list) Base.option
module Variants_of_constraint_ : sig ... end
type error_trace = ..
type error_trace +=
  1. | Row_mismatch of t Base.list
  2. | Dim_mismatch of dim Base.list
  3. | Index_mismatch of Ir.Indexing.axis_index Base.list
  4. | Constraint_failed of constraint_
val sexp_of_error_trace : error_trace -> Base.Sexp.t
exception Shape_error of Base.string * error_trace Base.list
type stage =
  1. | Stage1
  2. | Stage2
  3. | Stage3
  4. | Stage4
  5. | Stage5
  6. | Stage6
  7. | Stage7

Removed duplicate helper functions

val sexp_of_stage : stage -> Sexplib0.Sexp.t
val stage_of_sexp : Sexplib0.Sexp.t -> stage
val equal_stage : stage -> stage -> Base.bool
val compare_stage : stage -> stage -> Base.int
val add_safe_to_guess : row_var -> Base.unit

Mark a row variable as allowed to be empty even if it's in a parameter.

val add_used_in_spec_or_compose : row_var -> Base.unit

Mark a row variable as used in an einsum spec or compose shape update. Meant specifically for input rows, to indicate that the variable should not be guessed empty when it ends up in a parameter.

val add_used_in_pointwise : row_var -> Base.unit

Mark a row variable as used in a pointwise shape update. Meant specifically for input rows, to indicate that the variable can be guessed empty when it ends up in a parameter.

val subst_row : environment -> t -> t
val unify_row : stage:stage -> constraint_origin Base.list -> (t * t) -> environment -> constraint_ Base.list * environment
val empty_env : environment
val get_dim_from_env : environment -> dim_var -> Base.int Base.option
val get_row_from_env : environment -> row_var -> t Base.option
val unsolved_constraints : environment -> constraint_ Base.list
val solve_inequalities : stage:stage -> constraint_ Base.list -> environment -> constraint_ Base.list * environment
val row_to_labels : environment -> t -> Base.string Base.array
type proj
val compare_proj : proj -> proj -> Base.int
val equal_proj : proj -> proj -> Base.bool
val sexp_of_proj : proj -> Sexplib0.Sexp.t
val proj_of_sexp : Sexplib0.Sexp.t -> proj
type proj_env
val sexp_of_proj_env : proj_env -> Sexplib0.Sexp.t
val fresh_row_proj : t -> t
type proj_equation =
  1. | Proj_eq of proj * proj
    (*

    Two projections are the same, e.g. two axes share the same iterator.

    *)
  2. | Iterated of proj
    (*

    The projection needs to be an iterator even if an axis is not matched with another axis, e.g. for broadcasted-to axes of a tensor assigned a constant.

    *)
  3. | Non_iterated of proj
    (*

    The projection is not part of a product space, e.g. for convolution input.

    *)
val compare_proj_equation : proj_equation -> proj_equation -> Base.int
val equal_proj_equation : proj_equation -> proj_equation -> Base.bool
val sexp_of_proj_equation : proj_equation -> Sexplib0.Sexp.t
val proj_equation_of_sexp : Sexplib0.Sexp.t -> proj_equation
val get_proj_equations : constraint_ Base.list -> Ir.Indexing.axis_index dim_map -> environment -> proj_equation Base.list
val solve_proj_equations : proj_equation Base.list -> resolved_padding:(proj_id, axis_padding) Base.List.Assoc.t -> inferred_padding:(proj_id, axis_padding) Base.List.Assoc.t -> proj_env
val get_proj_index : proj_env -> proj -> Ir.Indexing.axis_index
val get_dim_index : proj_env -> dim -> Ir.Indexing.axis_index
val get_product_proj : proj_env -> dim -> (proj_id * Base.int) Base.option
val proj_to_iterator_exn : proj_env -> proj_id -> Ir.Indexing.symbol

proj_to_iterator_exn proj_env p returns the iterator for p in proj_env. Raises Invalid_argument if p is not an iterator.