Arrayjit.Indexing
val symbol_of_sexp : Sexplib0.Sexp.t -> symbol
val sexp_of_symbol : symbol -> Sexplib0.Sexp.t
val hash_fold_symbol :
Ppx_hash_lib.Std.Hash.state ->
symbol ->
Ppx_hash_lib.Std.Hash.state
val hash_symbol : symbol -> Ppx_hash_lib.Std.Hash.hash_value
val symbol : Base.int -> symbol
val is_symbol : symbol -> bool
val symbol_val : symbol -> Base.int Stdlib.Option.t
module Variants_of_symbol : sig ... end
val get_symbol : unit -> symbol
module CompareSymbol : sig ... end
module Symbol : sig ... end
val symbol_ident : symbol -> Base.String.t
val environment_of_sexp :
'a. (Sexplib0.Sexp.t -> 'a) ->
Sexplib0.Sexp.t ->
'a environment
val sexp_of_environment :
'a. ('a -> Sexplib0.Sexp.t) ->
'a environment ->
Sexplib0.Sexp.t
val empty_env : 'a environment
val compare_static_symbol : static_symbol -> static_symbol -> Base.int
val equal_static_symbol : static_symbol -> static_symbol -> Base.bool
val static_symbol_of_sexp : Sexplib0.Sexp.t -> static_symbol
val sexp_of_static_symbol : static_symbol -> Sexplib0.Sexp.t
val hash_fold_static_symbol :
Ppx_hash_lib.Std.Hash.state ->
static_symbol ->
Ppx_hash_lib.Std.Hash.state
val hash_static_symbol : static_symbol -> Ppx_hash_lib.Std.Hash.hash_value
val sexp_of_bindings :
'a. ('a -> Sexplib0.Sexp.t) ->
'a bindings ->
Sexplib0.Sexp.t
val bound_symbols : 'a bindings -> static_symbol Base.List.t
type ('r, 'idcs, 'p1, 'p2) variadic =
| Result of 'r
| Param_idx of Base.int Base.ref
* (Base.int -> 'r, Base.int -> 'idcs, 'p1, 'p2) variadic
| Param_1 of 'p1 Base.option Base.ref * ('p1 -> 'r, 'idcs, 'p1, 'p2) variadic
| Param_2 of 'p2 Base.option Base.ref * ('p2 -> 'r, 'idcs, 'p1, 'p2) variadic
| Param_2f : ('p2f ->
'p2)
* 'p2f Base.option Base.ref
* ('p2 -> 'r, 'idcs, 'p1, 'p2) variadic -> ('r, 'idcs, 'p1, 'p2) variadic
Helps lowering the bindings.
type unit_bindings = (Base.unit -> Base.unit) bindings
val sexp_of_unit_bindings : unit_bindings -> Sexplib0.Sexp.t
type lowered_bindings = (static_symbol, Base.int Base.ref) Base.List.Assoc.t
val sexp_of_lowered_bindings : lowered_bindings -> Sexplib0.Sexp.t
val apply : 'r 'idcs 'p1 'p2. ('r, 'idcs, 'p1, 'p2) variadic -> 'r
apply run_variadic ()
applies the parameters in reverse order to how they appear in the run_variadic
list.
val lowered_bindings :
'a bindings ->
('b, 'a, 'p1, 'p2) variadic ->
(static_symbol * Base.int Base.ref) Base.List.t
val find_exn : lowered_bindings -> static_symbol -> Base.int Base.ref
val get_static_symbol :
?static_range:Base.int ->
(Base.int -> 'a) bindings ->
static_symbol * 'a bindings
Dimensions to string, "x"
-separated, e.g. 1x2x3 for batch dims 1, input dims 3, output dims 2. Outputs "-"
for empty dimensions.
type axis_index =
| Fixed_idx of Base.int
The specific position along an axis.
*)| Iterator of symbol
The given member of the product_space
corresponding to some product_iterators
.
val compare_axis_index : axis_index -> axis_index -> Base.int
val equal_axis_index : axis_index -> axis_index -> Base.bool
val axis_index_of_sexp : Sexplib0.Sexp.t -> axis_index
val sexp_of_axis_index : axis_index -> Sexplib0.Sexp.t
val fixed_idx : Base.int -> axis_index
val iterator : symbol -> axis_index
val is_fixed_idx : axis_index -> bool
val is_iterator : axis_index -> bool
val fixed_idx_val : axis_index -> Base.int Stdlib.Option.t
val iterator_val : axis_index -> symbol Stdlib.Option.t
module Variants_of_axis_index : sig ... end
type str_osym_map =
(Base.string, symbol Base.option, Base.String.comparator_witness) Base.Map.t
val sexp_of_str_osym_map : str_osym_map -> Base.Sexp.t
val projections_debug_of_sexp : Sexplib0.Sexp.t -> projections_debug
val sexp_of_projections_debug : projections_debug -> Sexplib0.Sexp.t
type projections = {
product_space : Base.int Base.array;
The product space dimensions that an operation should parallelize (map-reduce) over.
*)lhs_dims : Base.int Base.array;
The dimensions of the LHS array.
*)rhs_dims : Base.int Base.array Base.array;
The dimensions of the RHS arrays, needed for deriving projections from other projections.
*)product_iterators : symbol Base.array;
The product space iterators (concatentation of the relevant batch, output, input axes) for iterating over the product_space
axes, where same axes are at same array indices.
project_lhs : axis_index Base.array;
A projection that takes an product_space
-bound index and produces an index into the result of an operation.
project_rhs : axis_index Base.array Base.array;
project_rhs.(i)
Produces an index into the i+1
th argument of an operation.
debug_info : projections_debug;
}
All the information relevant for code generation.
val compare_projections : projections -> projections -> Base.int
val equal_projections : projections -> projections -> Base.bool
val projections_of_sexp : Sexplib0.Sexp.t -> projections
val sexp_of_projections : projections -> Sexplib0.Sexp.t
val opt_symbol : int -> symbol option
val opt_iterator : symbol option -> axis_index
val is_bijective : projections -> bool
val identity_projections :
?debug_info:projections_debug ->
?derived_for:string ->
lhs_dims:Base.int Base.Array.t ->
unit ->
projections
Projections for a pointwise unary operator. Provide only one of debug_info
or derived_for
.
val derive_index :
product_syms:Symbol.t Base.Array.t ->
projection:axis_index Base.array ->
product:axis_index Base.Array.t ->
axis_index Base.Array.t
module Pp_helpers : sig ... end