Arrayjit.Low_level
module Scope_id : sig ... end
val sexp_of_scope_id : scope_id -> Sexplib0.Sexp.t
val hash_fold_scope_id :
Ppx_hash_lib.Std.Hash.state ->
scope_id ->
Ppx_hash_lib.Std.Hash.state
val hash_scope_id : scope_id -> Ppx_hash_lib.Std.Hash.hash_value
type t =
| Noop
| Comment of Base.string
| Staged_compilation of Base.unit -> Base.unit
| Seq of t * t
| For_loop of {
index : Indexing.symbol;
from_ : Base.int;
to_ : Base.int;
body : t;
trace_it : Base.bool;
}
| Zero_out of Tnode.t
| Set of {
tn : Tnode.t;
idcs : Indexing.axis_index Base.array;
llv : float_t;
mutable debug : Base.string;
}
| Set_local of scope_id * float_t
Cases: t
-- code, float_t
-- single number at some precision.
and float_t =
| Local_scope of {
id : scope_id;
body : t;
orig_indices : Indexing.axis_index Base.array;
}
| Get_local of scope_id
| Get_global of Ops.global_identifier
* Indexing.axis_index Base.array Base.option
| Get of Tnode.t * Indexing.axis_index Base.array
| Ternop of Ops.ternop * float_t * float_t * float_t
| Binop of Ops.binop * float_t * float_t
| Unop of Ops.unop * float_t
| Constant of Base.float
| Embed_index of Indexing.axis_index
val sexp_of_t : t -> Sexplib0.Sexp.t
val sexp_of_float_t : float_t -> Sexplib0.Sexp.t
val loop_over_dims :
Base.int Base.array ->
body:(Indexing.axis_index Base.array -> t) ->
t
val virtualize_settings : virtualize_settings
val sexp_of_visits : visits -> Sexplib0.Sexp.t
val visits_of_sexp : Sexplib0.Sexp.t -> visits
val visits : Base.int -> visits
val recurrent : visits
val is_visits : visits -> Base.bool
val is_recurrent : visits -> Base.bool
val visits_val : visits -> Base.int Base.option
val recurrent_val : visits -> Base.unit Base.option
module Variants_of_visits : sig ... end
type traced_array = {
tn : Tnode.t;
mutable computations : (Indexing.axis_index Base.array Base.option * t)
Base.list;
The computations (of the tensor node) are retrieved for optimization just as they are populated, so that the inlined code corresponds precisely to the changes to the arrays that would happen up till that point. Within the code blocks paired with an index tuple, all assignments and accesses must happen via the index tuple; if this is not the case for some assignment, the node cannot be virtual. Currently, we only allow for-loop symbols in assignment indices of virtual nodes.
*)assignments : Base.int Base.array Base.Hash_set.t;
accesses : (Base.int Base.array, visits) Base.Hashtbl.t;
mutable zero_initialized : Base.bool;
mutable zeroed_out : Base.bool;
mutable read_before_write : Base.bool;
The node is read before it is written (i.e. it is recurrent).
*)mutable read_only : Base.bool;
mutable is_scalar_constexpr : Base.bool;
True only if the tensor node has all axes of dimension 1, is either zeroed-out or assigned before accessed, is assigned at most once, and from an expression involving only constants or tensor nodes that were at the time is_scalar_constexpr.
*)}
val sexp_of_traced_array : traced_array -> Sexplib0.Sexp.t
val get_node :
(Tnode.t, traced_array) Base.Hashtbl.t ->
Tnode.t ->
traced_array
type traced_store = (Tnode.t, traced_array) Base.Hashtbl.t
val sexp_of_traced_store : traced_store -> Sexplib0.Sexp.t
val sexp_of_optimized : optimized -> Sexplib0.Sexp.t
val optimize :
unoptim_ll_source:Stdlib.Format.formatter Base.option ->
ll_source:Stdlib.Format.formatter Base.option ->
name:Base.string ->
Indexing.static_symbol Base.list ->
t ->
optimized
val input_and_output_nodes :
optimized ->
(Base.Set.M(Arrayjit.Tnode).t * Base.Set.M(Arrayjit.Tnode).t)
* Tnode.t Base.option
Inputs are the materialized read-only and read-before-write (within the code) non-constant non-merge nodes. They are inputs in a broad sense, as they could be recurrent nodes or parameters. Outputs are all the materialized nodes written-to by the code. The last returned component is the input merge node, if used in the code.
val fprint_function_header :
?name:Base.string ->
?static_indices:Indexing.static_symbol Base.list ->
Base.unit ->
Stdlib.Format.formatter ->
Base.unit
val fprint_cstyle :
?name:Base.string ->
?static_indices:Indexing.static_symbol Base.list ->
Base.unit ->
Stdlib.Format.formatter ->
t ->
Base.unit
Adheres more to the C syntax, outputs implicit type casts.
val fprint_hum :
?name:Base.string ->
?static_indices:Indexing.static_symbol Base.list ->
Base.unit ->
Stdlib.Format.formatter ->
t ->
Base.unit
Adheres more to the %cd syntax, does not output implicit type casts.