Ocannl_tensor.RowThe row type, shape inference related types and constraint solving.
type axis_padding = Ir.Ops.axis_paddingval equal_axis_padding : axis_padding -> axis_padding -> Base.boolval sexp_of_axis_padding : axis_padding -> Sexplib0.Sexp.tval axis_padding_of_sexp : Sexplib0.Sexp.t -> axis_paddingval sexp_of_kind : kind -> Sexplib0.Sexp.tval kind_of_sexp : Sexplib0.Sexp.t -> kindval __kind_of_sexp__ : Sexplib0.Sexp.t -> kindval hash_fold_kind :
Ppx_hash_lib.Std.Hash.state ->
kind ->
Ppx_hash_lib.Std.Hash.stateval hash_kind : kind -> Ppx_hash_lib.Std.Hash.hash_valueval batch : kindval input : kindval output : kindval is_batch : kind -> Base.boolval is_input : kind -> Base.boolval is_output : kind -> Base.boolval batch_val : kind -> Base.unit Base.optionval input_val : kind -> Base.unit Base.optionval output_val : kind -> Base.unit Base.optionmodule Variants_of_kind : sig ... endval hash_fold_dim_var :
Ppx_hash_lib.Std.Hash.state ->
dim_var ->
Ppx_hash_lib.Std.Hash.stateval hash_dim_var : dim_var -> Ppx_hash_lib.Std.Hash.hash_valueval sexp_of_dim_var : dim_var -> Sexplib0.Sexp.tval dim_var_of_sexp : Sexplib0.Sexp.t -> dim_varval hash_fold_proj_id :
Ppx_hash_lib.Std.Hash.state ->
proj_id ->
Ppx_hash_lib.Std.Hash.stateval hash_proj_id : proj_id -> Ppx_hash_lib.Std.Hash.hash_valueval sexp_of_proj_id : proj_id -> Sexplib0.Sexp.tval proj_id_of_sexp : Sexplib0.Sexp.t -> proj_idval equal_dim_var_set : dim_var_set -> dim_var_set -> Base.boolval sexp_of_dim_var_set : dim_var_set -> Sexplib0.Sexp.tval dim_var_set_of_sexp : Sexplib0.Sexp.t -> dim_var_setval sexp_of_dim_map : ('a -> Sexplib0.Sexp.t) -> 'a dim_map -> Sexplib0.Sexp.tval dim_map_of_sexp : (Sexplib0.Sexp.t -> 'a) -> Sexplib0.Sexp.t -> 'a dim_mapval equal_proj_var_set : proj_var_set -> proj_var_set -> Base.boolval sexp_of_proj_var_set : proj_var_set -> Sexplib0.Sexp.tval proj_var_set_of_sexp : Sexplib0.Sexp.t -> proj_var_setval sexp_of_proj_map :
('a -> Sexplib0.Sexp.t) ->
'a proj_map ->
Sexplib0.Sexp.tval proj_map_of_sexp :
(Sexplib0.Sexp.t -> 'a) ->
Sexplib0.Sexp.t ->
'a proj_mapval get_var : ?label:Base.string -> Base.unit -> dim_varval dim_var_set_empty : dim_var_setval dim_map_empty : 'a dim_mapval proj_var_set_empty : proj_var_setval proj_map_empty : 'a proj_mapA single axis in a shape. proj_id is used for projection inference, and abused for provenance tracking during shape inference.
val equal_solved_dim : solved_dim -> solved_dim -> Base.boolval hash_fold_solved_dim :
Ppx_hash_lib.Std.Hash.state ->
solved_dim ->
Ppx_hash_lib.Std.Hash.stateval hash_solved_dim : solved_dim -> Ppx_hash_lib.Std.Hash.hash_valueval compare_solved_dim : solved_dim -> solved_dim -> Base.intval sexp_of_solved_dim : solved_dim -> Sexplib0.Sexp.tval solved_dim_of_sexp : Sexplib0.Sexp.t -> solved_dimtype dim = | Var of dim_var| Dim of solved_dim| Conv_input of {}The offset is implicit, automatically derived. If !use_padding is true, the offset is the left part of the dimensionality-preserving symmetric padding, otherwise it is 0.
val hash_fold_dim :
Ppx_hash_lib.Std.Hash.state ->
dim ->
Ppx_hash_lib.Std.Hash.stateval hash_dim : dim -> Ppx_hash_lib.Std.Hash.hash_valueval sexp_of_dim : dim -> Sexplib0.Sexp.tval dim_of_sexp : Sexplib0.Sexp.t -> dimval dim : solved_dim -> dimval is_var : dim -> Base.boolval is_dim : dim -> Base.boolval is_conv_input : dim -> Base.boolval dim_val : dim -> solved_dim Base.optionmodule Variants_of_dim : sig ... endval get_dim :
d:Base.int ->
?label:Base.string ->
?proj_id:Base.int ->
Base.unit ->
dimval dim_to_int_exn : dim -> Base.intval equal_print_style : print_style -> print_style -> Base.boolval compare_print_style : print_style -> print_style -> Base.intval sexp_of_print_style : print_style -> Sexplib0.Sexp.tval print_style_of_sexp : Sexplib0.Sexp.t -> print_styleval solved_dim_to_string : print_style -> solved_dim -> Base.stringval dim_to_string : print_style -> dim -> Base.stringval sexp_of_provenance : provenance -> Sexplib0.Sexp.tval provenance_of_sexp : Sexplib0.Sexp.t -> provenanceval compare_provenance : provenance -> provenance -> Base.intval equal_provenance : provenance -> provenance -> Base.boolval hash_fold_provenance :
Ppx_hash_lib.Std.Hash.state ->
provenance ->
Ppx_hash_lib.Std.Hash.stateval hash_provenance : provenance -> Ppx_hash_lib.Std.Hash.hash_valueval provenance : sh_id:Base.int -> kind:kind -> provenanceval empty_provenance : provenanceval merge_provenance : provenance -> provenance -> provenanceval sexp_of_row_var : row_var -> Sexplib0.Sexp.tval row_var_of_sexp : Sexplib0.Sexp.t -> row_varval hash_fold_row_var :
Ppx_hash_lib.Std.Hash.state ->
row_var ->
Ppx_hash_lib.Std.Hash.stateval hash_row_var : row_var -> Ppx_hash_lib.Std.Hash.hash_valueval get_row_var : Base.unit -> row_varA bcast specifies how axes of a single kind in a shape (i.e. the row) can adapt to other shapes.
val hash_fold_bcast :
Ppx_hash_lib.Std.Hash.state ->
bcast ->
Ppx_hash_lib.Std.Hash.stateval hash_bcast : bcast -> Ppx_hash_lib.Std.Hash.hash_valueval sexp_of_bcast : bcast -> Sexplib0.Sexp.tval bcast_of_sexp : Sexplib0.Sexp.t -> bcastval broadcastable : bcastval is_row_var : bcast -> Base.boolval is_broadcastable : bcast -> Base.boolval broadcastable_val : bcast -> Base.unit Base.optionmodule Variants_of_bcast : sig ... endinclude Ppx_compare_lib.Equal.S with type t := tval equal : t Base__Ppx_compare_lib.equalinclude Ppx_compare_lib.Comparable.S with type t := tval compare : t Base__Ppx_compare_lib.compareval get_row_for_var : provenance -> row_var -> tval row_shapes : t -> Base.int Base.listval sexp_of_environment : environment -> Sexplib0.Sexp.tval equal_dim_constraint : dim_constraint -> dim_constraint -> Base.boolval hash_fold_dim_constraint :
Ppx_hash_lib.Std.Hash.state ->
dim_constraint ->
Ppx_hash_lib.Std.Hash.stateval hash_dim_constraint : dim_constraint -> Ppx_hash_lib.Std.Hash.hash_valueval compare_dim_constraint : dim_constraint -> dim_constraint -> Base.intval sexp_of_dim_constraint : dim_constraint -> Sexplib0.Sexp.tval unconstrained_dim : dim_constraintval at_least_dim : Base.int -> dim_constraintval is_unconstrained_dim : dim_constraint -> Base.boolval is_at_least_dim : dim_constraint -> Base.boolval unconstrained_dim_val : dim_constraint -> Base.unit Base.optionval at_least_dim_val : dim_constraint -> Base.int Base.optionmodule Variants_of_dim_constraint : sig ... endtype total_elems = | Num_elems of Base.int| Strided_var of {coeff : Base.int Utils.safe_lazy;var : dim_var;denom : Base.int;}The total number of elements is (coefficient * variable) / denominator.
*)val equal_total_elems : total_elems -> total_elems -> Base.boolval hash_fold_total_elems :
Ppx_hash_lib.Std.Hash.state ->
total_elems ->
Ppx_hash_lib.Std.Hash.stateval hash_total_elems : total_elems -> Ppx_hash_lib.Std.Hash.hash_valueval compare_total_elems : total_elems -> total_elems -> Base.intval sexp_of_total_elems : total_elems -> Sexplib0.Sexp.ttype row_constraint = | Unconstrained| Total_elems of {numerator : total_elems;divided_by : dim_var Base.list;}The rows, inclusive of the further row spec, have this many elements. The total is numerator / (product of divided_by variables). divided_by has multiset semantics - the same variable can appear multiple times.
*)| Exact of dim Base.listThe concatenated rows have these axes.
*)val equal_row_constraint : row_constraint -> row_constraint -> Base.boolval hash_fold_row_constraint :
Ppx_hash_lib.Std.Hash.state ->
row_constraint ->
Ppx_hash_lib.Std.Hash.stateval hash_row_constraint : row_constraint -> Ppx_hash_lib.Std.Hash.hash_valueval compare_row_constraint : row_constraint -> row_constraint -> Base.intval sexp_of_row_constraint : row_constraint -> Sexplib0.Sexp.tval sexp_of_constraint_origin : constraint_origin -> Sexplib0.Sexp.ttype dim_entry = | Solved_dim of dim| Bounds_dim of {is_in_param : Base.bool;has_uniq_constr_unless : dim_var_set Base.option;If set, the variable should not be guessed 1 unless a variable from the set is also prevented from being guessed 1.
*)cur : dim_var Base.list;subr : dim_var Base.list;lub : dim Base.option;constr : dim_constraint;origin : constraint_origin Base.list;}An entry implements inequalities cur >= v >= subr and/or an equality v = solved. cur and subr must be sorted using the @@deriving compare comparison.
val sexp_of_dim_entry : dim_entry -> Sexplib0.Sexp.ttype row_entry = | Solved_row of t| Bounds_row of {is_in_param : Base.bool;cur : row_var Base.list;subr : row_var Base.list;lub : t Base.option;constr : row_constraint;origin : constraint_origin Base.list;}val sexp_of_row_entry : row_entry -> Sexplib0.Sexp.ttype constraint_ = | Dim_eq of {d1 : dim;d2 : dim;origin : constraint_origin Base.list;}| Row_eq of {r1 : t;r2 : t;origin : constraint_origin Base.list;}| Dim_ineq of {cur : dim;subr : dim;origin : constraint_origin Base.list;}| Row_ineq of {cur : t;subr : t;origin : constraint_origin Base.list;}| Dim_constr of {d : dim;constr : dim_constraint;origin : constraint_origin Base.list;}| Rows_constr of {r : t Base.list;constr : row_constraint;origin : constraint_origin Base.list;}The constraint applies to the concatenation of the rows. Note: broadcasting does not affect the constraint (i.e. there is no "subtyping", it resembles Row_eq).
*)| Terminal_dim of Base.bool * dim * constraint_origin Base.listA terminal dimension with is_param flag indicating if it's a parameter requiring gradient.
*)| Terminal_row of Base.bool * t * constraint_origin Base.listA row of the shape of a terminal tensor (i.e. a tensor that does not have sub-tensors). The bool flag indicates if it's a parameter requiring gradient.
*)| Shape_row of t * constraint_origin Base.listA row of a shape of interest. The bool flag indicates if it's a parameter requiring gradient.
*)val compare_constraint_ : constraint_ -> constraint_ -> Base.intval equal_constraint_ : constraint_ -> constraint_ -> Base.boolval sexp_of_constraint_ : constraint_ -> Sexplib0.Sexp.tval dim_eq :
d1:dim ->
d2:dim ->
origin:constraint_origin Base.list ->
constraint_val row_eq : r1:t -> r2:t -> origin:constraint_origin Base.list -> constraint_val dim_ineq :
cur:dim ->
subr:dim ->
origin:constraint_origin Base.list ->
constraint_val row_ineq :
cur:t ->
subr:t ->
origin:constraint_origin Base.list ->
constraint_val dim_constr :
d:dim ->
constr:dim_constraint ->
origin:constraint_origin Base.list ->
constraint_val rows_constr :
r:t Base.list ->
constr:row_constraint ->
origin:constraint_origin Base.list ->
constraint_val terminal_dim :
Base.bool ->
dim ->
constraint_origin Base.list ->
constraint_val terminal_row : Base.bool -> t -> constraint_origin Base.list -> constraint_val shape_row : t -> constraint_origin Base.list -> constraint_val is_dim_eq : constraint_ -> Base.boolval is_row_eq : constraint_ -> Base.boolval is_dim_ineq : constraint_ -> Base.boolval is_row_ineq : constraint_ -> Base.boolval is_dim_constr : constraint_ -> Base.boolval is_rows_constr : constraint_ -> Base.boolval is_terminal_dim : constraint_ -> Base.boolval is_terminal_row : constraint_ -> Base.boolval is_shape_row : constraint_ -> Base.boolval dim_eq_val :
constraint_ ->
([ `d1 of dim ] * [ `d2 of dim ] * [ `origin of constraint_origin Base.list ])
Base.optionval row_eq_val :
constraint_ ->
([ `r1 of t ] * [ `r2 of t ] * [ `origin of constraint_origin Base.list ])
Base.optionval dim_ineq_val :
constraint_ ->
([ `cur of dim ]
* [ `subr of dim ]
* [ `origin of constraint_origin Base.list ])
Base.optionval row_ineq_val :
constraint_ ->
([ `cur of t ] * [ `subr of t ] * [ `origin of constraint_origin Base.list ])
Base.optionval dim_constr_val :
constraint_ ->
([ `d of dim ]
* [ `constr of dim_constraint ]
* [ `origin of constraint_origin Base.list ])
Base.optionval rows_constr_val :
constraint_ ->
([ `r of t Base.list ]
* [ `constr of row_constraint ]
* [ `origin of constraint_origin Base.list ])
Base.optionval terminal_dim_val :
constraint_ ->
(Base.bool * dim * constraint_origin Base.list) Base.optionval terminal_row_val :
constraint_ ->
(Base.bool * t * constraint_origin Base.list) Base.optionval shape_row_val :
constraint_ ->
(t * constraint_origin Base.list) Base.optionmodule Variants_of_constraint_ : sig ... endtype error_trace += | Row_mismatch of t Base.list| Dim_mismatch of dim Base.list| Index_mismatch of Ir.Indexing.axis_index Base.list| Constraint_failed of constraint_val sexp_of_error_trace : error_trace -> Base.Sexp.texception Shape_error of Base.string * error_trace Base.listRemoved duplicate helper functions
val sexp_of_stage : stage -> Sexplib0.Sexp.tval stage_of_sexp : Sexplib0.Sexp.t -> stageval add_safe_to_guess : row_var -> Base.unitMark a row variable as allowed to be empty even if it's in a parameter.
val add_used_in_spec_or_compose : row_var -> Base.unitMark a row variable as used in an einsum spec or compose shape update. Meant specifically for input rows, to indicate that the variable should not be guessed empty when it ends up in a parameter.
val add_used_in_pointwise : row_var -> Base.unitMark a row variable as used in a pointwise shape update. Meant specifically for input rows, to indicate that the variable can be guessed empty when it ends up in a parameter.
val subst_row : environment -> t -> tval unify_row :
stage:stage ->
constraint_origin Base.list ->
(t * t) ->
environment ->
constraint_ Base.list * environmentval empty_env : environmentval get_dim_val : environment -> dim_var -> Base.int Base.optionval get_row_from_env : environment -> row_var -> t Base.optionval unsolved_constraints : environment -> constraint_ Base.listval solve_inequalities :
stage:stage ->
constraint_ Base.list ->
environment ->
constraint_ Base.list * environmentval row_to_labels : environment -> t -> Base.string Base.arrayval sexp_of_proj : proj -> Sexplib0.Sexp.tval proj_of_sexp : Sexplib0.Sexp.t -> projval sexp_of_proj_env : proj_env -> Sexplib0.Sexp.ttype proj_equation = | Proj_eq of proj * projTwo projections are the same, e.g. two axes share the same iterator.
*)| Iterated of projThe projection needs to be an iterator even if an axis is not matched with another axis, e.g. for broadcasted-to axes of a tensor assigned a constant.
*)| Non_iterated of projThe projection is not part of a product space, e.g. for convolution input.
*)val compare_proj_equation : proj_equation -> proj_equation -> Base.intval equal_proj_equation : proj_equation -> proj_equation -> Base.boolval sexp_of_proj_equation : proj_equation -> Sexplib0.Sexp.tval proj_equation_of_sexp : Sexplib0.Sexp.t -> proj_equationval get_proj_equations :
constraint_ Base.list ->
Ir.Indexing.axis_index dim_map ->
environment ->
proj_equation Base.listval solve_proj_equations :
proj_equation Base.list ->
resolved_padding:(proj_id, axis_padding) Base.List.Assoc.t ->
inferred_padding:(proj_id, axis_padding) Base.List.Assoc.t ->
proj_envval get_proj_index : proj_env -> proj -> Ir.Indexing.axis_indexval get_dim_index : proj_env -> dim -> Ir.Indexing.axis_indexval proj_to_iterator_exn : proj_env -> proj_id -> Ir.Indexing.symbolproj_to_iterator_exn proj_env p returns the iterator for p in proj_env. Raises Invalid_argument if p is not an iterator.