Ocannl.Row
The row type, shape inference related types and constraint solving.
type axis_padding = Ir.Ops.axis_padding
val equal_axis_padding : axis_padding -> axis_padding -> Base.bool
val sexp_of_axis_padding : axis_padding -> Sexplib0.Sexp.t
val axis_padding_of_sexp : Sexplib0.Sexp.t -> axis_padding
val sexp_of_kind : kind -> Sexplib0.Sexp.t
val kind_of_sexp : Sexplib0.Sexp.t -> kind
val __kind_of_sexp__ : Sexplib0.Sexp.t -> kind
val hash_fold_kind :
Ppx_hash_lib.Std.Hash.state ->
kind ->
Ppx_hash_lib.Std.Hash.state
val hash_kind : kind -> Ppx_hash_lib.Std.Hash.hash_value
val batch : kind
val input : kind
val output : kind
val is_batch : kind -> Base.bool
val is_input : kind -> Base.bool
val is_output : kind -> Base.bool
val batch_val : kind -> Base.unit Base.option
val input_val : kind -> Base.unit Base.option
val output_val : kind -> Base.unit Base.option
module Variants_of_kind : sig ... end
val hash_fold_dim_var :
Ppx_hash_lib.Std.Hash.state ->
dim_var ->
Ppx_hash_lib.Std.Hash.state
val hash_dim_var : dim_var -> Ppx_hash_lib.Std.Hash.hash_value
val sexp_of_dim_var : dim_var -> Sexplib0.Sexp.t
val dim_var_of_sexp : Sexplib0.Sexp.t -> dim_var
val hash_fold_proj_id :
Ppx_hash_lib.Std.Hash.state ->
proj_id ->
Ppx_hash_lib.Std.Hash.state
val hash_proj_id : proj_id -> Ppx_hash_lib.Std.Hash.hash_value
val sexp_of_proj_id : proj_id -> Sexplib0.Sexp.t
val proj_id_of_sexp : Sexplib0.Sexp.t -> proj_id
val equal_dim_var_set : dim_var_set -> dim_var_set -> Base.bool
val sexp_of_dim_var_set : dim_var_set -> Sexplib0.Sexp.t
val dim_var_set_of_sexp : Sexplib0.Sexp.t -> dim_var_set
val sexp_of_dim_map : ('a -> Sexplib0.Sexp.t) -> 'a dim_map -> Sexplib0.Sexp.t
val dim_map_of_sexp : (Sexplib0.Sexp.t -> 'a) -> Sexplib0.Sexp.t -> 'a dim_map
val equal_proj_var_set : proj_var_set -> proj_var_set -> Base.bool
val sexp_of_proj_var_set : proj_var_set -> Sexplib0.Sexp.t
val proj_var_set_of_sexp : Sexplib0.Sexp.t -> proj_var_set
val sexp_of_proj_map :
('a -> Sexplib0.Sexp.t) ->
'a proj_map ->
Sexplib0.Sexp.t
val proj_map_of_sexp :
(Sexplib0.Sexp.t -> 'a) ->
Sexplib0.Sexp.t ->
'a proj_map
val get_var : ?label:Base.string -> Base.unit -> dim_var
val dim_var_set_empty : dim_var_set
val dim_map_empty : 'a dim_map
val proj_var_set_empty : proj_var_set
val proj_map_empty : 'a proj_map
A single axis in a shape.
val equal_solved_dim : solved_dim -> solved_dim -> Base.bool
val hash_fold_solved_dim :
Ppx_hash_lib.Std.Hash.state ->
solved_dim ->
Ppx_hash_lib.Std.Hash.state
val hash_solved_dim : solved_dim -> Ppx_hash_lib.Std.Hash.hash_value
val compare_solved_dim : solved_dim -> solved_dim -> Base.int
val sexp_of_solved_dim : solved_dim -> Sexplib0.Sexp.t
val solved_dim_of_sexp : Sexplib0.Sexp.t -> solved_dim
type dim =
| Var of dim_var
| Dim of solved_dim
| Conv_input of {
}
The offset is implicit, automatically derived. If !use_padding
is true
, the offset is the left part of the dimensionality-preserving symmetric padding, otherwise it is 0.
val hash_fold_dim :
Ppx_hash_lib.Std.Hash.state ->
dim ->
Ppx_hash_lib.Std.Hash.state
val hash_dim : dim -> Ppx_hash_lib.Std.Hash.hash_value
val sexp_of_dim : dim -> Sexplib0.Sexp.t
val dim_of_sexp : Sexplib0.Sexp.t -> dim
val dim : solved_dim -> dim
val is_var : dim -> Base.bool
val is_dim : dim -> Base.bool
val is_conv_input : dim -> Base.bool
val dim_val : dim -> solved_dim Base.option
module Variants_of_dim : sig ... end
val get_dim : d:Base.int -> ?label:Base.string -> Base.unit -> dim
val dim_to_int_exn : dim -> Base.int
val equal_print_style : print_style -> print_style -> Base.bool
val compare_print_style : print_style -> print_style -> Base.int
val sexp_of_print_style : print_style -> Sexplib0.Sexp.t
val print_style_of_sexp : Sexplib0.Sexp.t -> print_style
val solved_dim_to_string : print_style -> solved_dim -> Base.string
val dim_to_string : print_style -> dim -> Base.string
val sexp_of_row_id : row_id -> Sexplib0.Sexp.t
val row_id_of_sexp : Sexplib0.Sexp.t -> row_id
val hash_fold_row_id :
Ppx_hash_lib.Std.Hash.state ->
row_id ->
Ppx_hash_lib.Std.Hash.state
val hash_row_id : row_id -> Ppx_hash_lib.Std.Hash.hash_value
val sexp_of_row_var : row_var -> Sexplib0.Sexp.t
val row_var_of_sexp : Sexplib0.Sexp.t -> row_var
val hash_fold_row_var :
Ppx_hash_lib.Std.Hash.state ->
row_var ->
Ppx_hash_lib.Std.Hash.state
val hash_row_var : row_var -> Ppx_hash_lib.Std.Hash.hash_value
val get_row_var : Base.unit -> row_var
A bcast specifies how axes of a single kind in a shape (i.e. the row) can adapt to other shapes.
val hash_fold_bcast :
Ppx_hash_lib.Std.Hash.state ->
bcast ->
Ppx_hash_lib.Std.Hash.state
val hash_bcast : bcast -> Ppx_hash_lib.Std.Hash.hash_value
val sexp_of_bcast : bcast -> Sexplib0.Sexp.t
val bcast_of_sexp : Sexplib0.Sexp.t -> bcast
val broadcastable : bcast
val is_row_var : bcast -> Base.bool
val is_broadcastable : bcast -> Base.bool
val broadcastable_val : bcast -> Base.unit Base.option
module Variants_of_bcast : sig ... end
include Ppx_compare_lib.Equal.S with type t := t
val equal : t Base__Ppx_compare_lib.equal
include Ppx_compare_lib.Comparable.S with type t := t
val compare : t Base__Ppx_compare_lib.compare
val sexp_of_environment : environment -> Sexplib0.Sexp.t
type error_trace +=
| Row_mismatch of t Base.list
| Dim_mismatch of dim Base.list
| Index_mismatch of Ir.Indexing.axis_index Base.list
val sexp_of_error_trace : error_trace -> Base.Sexp.t
exception Shape_error of Base.string * error_trace Base.list
val equal_dim_constraint : dim_constraint -> dim_constraint -> Base.bool
val hash_fold_dim_constraint :
Ppx_hash_lib.Std.Hash.state ->
dim_constraint ->
Ppx_hash_lib.Std.Hash.state
val hash_dim_constraint : dim_constraint -> Ppx_hash_lib.Std.Hash.hash_value
val compare_dim_constraint : dim_constraint -> dim_constraint -> Base.int
val sexp_of_dim_constraint : dim_constraint -> Sexplib0.Sexp.t
val unconstrained_dim : dim_constraint
val at_least_dim : Base.int -> dim_constraint
val is_unconstrained_dim : dim_constraint -> Base.bool
val is_at_least_dim : dim_constraint -> Base.bool
val unconstrained_dim_val : dim_constraint -> Base.unit Base.option
val at_least_dim_val : dim_constraint -> Base.int Base.option
module Variants_of_dim_constraint : sig ... end
type total_elems =
| Num_elems of Base.int
| Strided_var of {
coeff : Base.int Utils.safe_lazy;
var : dim_var;
denom : Base.int;
}
The total number of elements is (coefficient * variable) / denominator.
*)val equal_total_elems : total_elems -> total_elems -> Base.bool
val hash_fold_total_elems :
Ppx_hash_lib.Std.Hash.state ->
total_elems ->
Ppx_hash_lib.Std.Hash.state
val hash_total_elems : total_elems -> Ppx_hash_lib.Std.Hash.hash_value
val compare_total_elems : total_elems -> total_elems -> Base.int
val sexp_of_total_elems : total_elems -> Sexplib0.Sexp.t
type row_constraint =
| Unconstrained
| Total_elems of {
numerator : total_elems;
divided_by : dim_var Base.list;
}
The rows, inclusive of the further row spec, have this many elements. The total is numerator / (product of divided_by variables). divided_by has multiset semantics - the same variable can appear multiple times.
*)| Exact of dim Base.list
The concatenated rows have these axes.
*)val equal_row_constraint : row_constraint -> row_constraint -> Base.bool
val hash_fold_row_constraint :
Ppx_hash_lib.Std.Hash.state ->
row_constraint ->
Ppx_hash_lib.Std.Hash.state
val hash_row_constraint : row_constraint -> Ppx_hash_lib.Std.Hash.hash_value
val compare_row_constraint : row_constraint -> row_constraint -> Base.int
val sexp_of_row_constraint : row_constraint -> Sexplib0.Sexp.t
type dim_entry =
| Solved_dim of dim
| Bounds_dim of {
cur : dim_var Base.list;
subr : dim_var Base.list;
lub : dim Base.option;
constr : dim_constraint;
}
An entry implements inequalities cur >= v >= subr
and/or an equality v = solved
. cur
and subr
must be sorted using the @@deriving compare
comparison.
val sexp_of_dim_entry : dim_entry -> Sexplib0.Sexp.t
type row_entry =
| Solved_row of t
| Bounds_row of {
cur : row_var Base.list;
subr : row_var Base.list;
lub : t Base.option;
constr : row_constraint;
}
val sexp_of_row_entry : row_entry -> Sexplib0.Sexp.t
type constraint_ =
| Dim_eq of {
}
| Row_eq of {
}
| Dim_ineq of {
}
| Row_ineq of {
}
| Dim_constr of {
d : dim;
constr : dim_constraint;
}
| Rows_constr of {
r : t Base.list;
constr : row_constraint;
}
The constraint applies to the concatenation of the rows. Note: broadcasting does not affect the constraint (i.e. there is no "subtyping", it resembles Row_eq).
*)| Terminal_dim of dim
| Terminal_row of t
A row of the shape of a terminal tensor (i.e. a tensor that does not have sub-tensors).
*)| Shape_row of t
A row of a shape of interest.
*)val compare_constraint_ : constraint_ -> constraint_ -> Base.int
val equal_constraint_ : constraint_ -> constraint_ -> Base.bool
val sexp_of_constraint_ : constraint_ -> Sexplib0.Sexp.t
val dim_eq : d1:dim -> d2:dim -> constraint_
val row_eq : r1:t -> r2:t -> constraint_
val dim_ineq : cur:dim -> subr:dim -> constraint_
val row_ineq : cur:t -> subr:t -> constraint_
val dim_constr : d:dim -> constr:dim_constraint -> constraint_
val rows_constr : r:t Base.list -> constr:row_constraint -> constraint_
val terminal_dim : dim -> constraint_
val terminal_row : t -> constraint_
val shape_row : t -> constraint_
val is_dim_eq : constraint_ -> Base.bool
val is_row_eq : constraint_ -> Base.bool
val is_dim_ineq : constraint_ -> Base.bool
val is_row_ineq : constraint_ -> Base.bool
val is_dim_constr : constraint_ -> Base.bool
val is_rows_constr : constraint_ -> Base.bool
val is_terminal_dim : constraint_ -> Base.bool
val is_terminal_row : constraint_ -> Base.bool
val is_shape_row : constraint_ -> Base.bool
val dim_eq_val : constraint_ -> ([ `d1 of dim ] * [ `d2 of dim ]) Base.option
val row_eq_val : constraint_ -> ([ `r1 of t ] * [ `r2 of t ]) Base.option
val dim_ineq_val :
constraint_ ->
([ `cur of dim ] * [ `subr of dim ]) Base.option
val row_ineq_val : constraint_ -> ([ `cur of t ] * [ `subr of t ]) Base.option
val dim_constr_val :
constraint_ ->
([ `d of dim ] * [ `constr of dim_constraint ]) Base.option
val rows_constr_val :
constraint_ ->
([ `r of t Base.list ] * [ `constr of row_constraint ]) Base.option
val terminal_dim_val : constraint_ -> dim Base.option
val terminal_row_val : constraint_ -> t Base.option
val shape_row_val : constraint_ -> t Base.option
module Variants_of_constraint_ : sig ... end
val sexp_of_stage : stage -> Sexplib0.Sexp.t
val stage_of_sexp : Sexplib0.Sexp.t -> stage
val subst_row : environment -> t -> t
val unify_row :
stage:stage ->
(t * t) ->
environment ->
constraint_ Base.list * environment
val empty_env : environment
val get_dim_from_env : environment -> dim_var -> Base.int Base.option
val get_row_from_env : environment -> row_var -> t Base.option
val solve_inequalities :
stage:stage ->
constraint_ Base.list ->
environment ->
constraint_ Base.list * environment
val row_to_labels : environment -> t -> Base.string Base.array
val sexp_of_proj : proj -> Sexplib0.Sexp.t
val proj_of_sexp : Sexplib0.Sexp.t -> proj
val sexp_of_proj_env : proj_env -> Sexplib0.Sexp.t
type proj_equation =
| Proj_eq of proj * proj
Two projections are the same, e.g. two axes share the same iterator.
*)| Iterated of proj
The projection needs to be an iterator even if an axis is not matched with another axis, e.g. for broadcasted-to axes of a tensor assigned a constant.
*)| Non_iterated of proj
The projection is not part of a product space, e.g. for convolution input.
*)val compare_proj_equation : proj_equation -> proj_equation -> Base.int
val equal_proj_equation : proj_equation -> proj_equation -> Base.bool
val sexp_of_proj_equation : proj_equation -> Sexplib0.Sexp.t
val proj_equation_of_sexp : Sexplib0.Sexp.t -> proj_equation
val get_proj_equations :
constraint_ Base.list ->
Ir.Indexing.axis_index dim_map ->
environment ->
proj_equation Base.list
val solve_proj_equations :
proj_equation Base.list ->
resolved_padding:(proj_id, axis_padding) Base.List.Assoc.t ->
inferred_padding:(proj_id, axis_padding) Base.List.Assoc.t ->
proj_env
val get_proj_index : proj_env -> proj -> Ir.Indexing.axis_index
val get_dim_index : proj_env -> dim -> Ir.Indexing.axis_index
val proj_to_iterator_exn : proj_env -> proj_id -> Ir.Indexing.symbol
proj_to_iterator_exn proj_env p
returns the iterator for p
in proj_env
. Raises Invalid_argument
if p
is not an iterator.