Module Arrayjit.Tnode

module Lazy = Utils.Lazy
module Nd = Ndarray
val _get_local_debug_runtime : unit -> (module Minidebug_runtime.Debug_runtime)
type sharing =
  1. | Unset
    (*

    One of: Per_stream, Shared_cross_streams.

    *)
  2. | Per_stream
    (*

    The tensor node has separate arrays for each stream.

    *)
  3. | Shared_cross_streams
    (*

    The tensor node has a single array per device that can appear in multiple contexts, except for backends with Option.is_some use_host_memory and nodes with memory mode already Hosted (Changed_on_devices Shared_cross_streams) before first linking on a device, where it only has the on-host array. In that case the on-host array is registered in the context, to avoid misleading behavior from `device_to_device`.

    *)

A possible algorithm for deciding sharing within a single device:

  • If a tensor node is read-only for a context, and not otherwise recorded, it is stored as a cross-stream sharing candidate.
  • If a cross-stream sharing candidate is read-only for another context, whose parent does not have the corresponding array (i.e. it is a different stream), it is recorded as cross-stream shared, and the same array is reused.
  • If a tensor node is writable by a context, and it is not cross-stream shared, it is marked as non-cross-stream, the array is removed from cross-stream sharing candidates if present. If it is cross-stream shared, it is recorded as owned by the corresponding stream. It is an error if the node was already owned by a different stream.

If a tensor node is shared cross-stream, within-device copying is a NOOP as source and destination pointers are in that case identical.

val sharing_of_sexp : Sexplib0.Sexp.t -> sharing
val sexp_of_sharing : sharing -> Sexplib0.Sexp.t
val compare_sharing : sharing -> sharing -> Base.int
val equal_sharing : sharing -> sharing -> Base.bool
type memory_type =
  1. | Constant
    (*

    The tensor node does not change after initialization.

    *)
  2. | Nonconstant
    (*

    One of: Changed_on_devices, Volatile.

    *)
  3. | Changed_on_devices of sharing
    (*

    The tensor node will only change on host via a to_host call.

    *)
  4. | Volatile
    (*

    The tensor node will only change on any device via a from_host call possibly followed by device_to_device.

    *)
val memory_type_of_sexp : Sexplib0.Sexp.t -> memory_type
val sexp_of_memory_type : memory_type -> Sexplib0.Sexp.t
val compare_memory_type : memory_type -> memory_type -> Base.int
val equal_memory_type : memory_type -> memory_type -> Base.bool
type memory_mode =
  1. | Effectively_constant
    (*

    Either Hosted Constant, or a subset of Virtual.

    *)
  2. | Virtual
    (*

    The tensor node's computations are inlined on a per-scalar basis.

    *)
  3. | Never_virtual
    (*

    One of: Local, On_device, Hosted.

    *)
  4. | Local
    (*

    The full tensor node is cached for the duration of a computation but not persisted across calls to compiled functions. It is not available for merging across devices.

    *)
  5. | Device_only
    (*

    One of: Local, On_device.

    *)
  6. | On_device of sharing
    (*

    The tensor node is stored on the devices that compute with it and persisted across function calls. It is available for merging across devices (for devices that support merging / P2P), but not (directly) for visualization or storing to disk.

    *)
  7. | Materialized
    (*

    One of: On_device, Hosted.

    *)
  8. | Hosted of memory_type
    (*

    The tensor node is stored in a globally addressable memory, in addition to on devices where it is computed with (or only on the host and not on the device, for some backends). It is available for all operations, and visible to OCaml programs as an Ndarray (the optional array of t).

    *)
val memory_mode_of_sexp : Sexplib0.Sexp.t -> memory_mode
val sexp_of_memory_mode : memory_mode -> Sexplib0.Sexp.t
val compare_memory_mode : memory_mode -> memory_mode -> Base.int
val equal_memory_mode : memory_mode -> memory_mode -> Base.bool
type delayed_prec =
  1. | Not_specified
  2. | Default_spec of Ops.prec Lazy.t
  3. | Specified of Ops.prec
val delayed_prec_of_sexp : Sexplib0.Sexp.t -> delayed_prec
val sexp_of_delayed_prec : delayed_prec -> Sexplib0.Sexp.t
val equal_delayed_prec : delayed_prec -> delayed_prec -> Base.bool
type prepare = {
  1. is_done : Base.unit -> Base.bool;
  2. sync : Base.unit -> Base.unit;
  3. transfer : Base.unit -> Base.unit;
}
val sexp_of_prepare : prepare -> Sexplib0.Sexp.t
type t = {
  1. array : Nd.t Base.option Lazy.t;
  2. prec : Ops.prec Lazy.t;
  3. dims : Base.int Base.array Lazy.t;
  4. size_in_bytes : Base.int Lazy.t;
  5. id : Base.int;
  6. label : Base.string Base.list;
    (*

    Display information. It is better if the last element of the list is the most narrow or alphanumeric, e.g. an identifier.

    *)
  7. mutable delayed_prec_unsafe : delayed_prec;
    (*

    Participates in the computation of prec.

    *)
  8. mutable memory_mode : (memory_mode * Base.int) Base.option;
  9. mutable backend_info : Base.Sexp.t;
  10. mutable code_name : Base.string Base.option;
  11. mutable prepare_read : prepare Base.option;
  12. mutable prepare_write : prepare Base.option;
  13. mutable host_read_by_devices : Base.Hash_set.M(Base.Int).t;
    (*

    The unique ids of devices that read the most recent modification of the host array.

    *)
}
val sexp_of_t : t -> Sexplib0.Sexp.t
val compare : t -> t -> Base.int
val num_elems : t -> Base.int
val id : t -> Base.String.t
val label : t -> Base.String.t
val is_alphanum_ : Base.String.t -> bool
val get_debug_name : ?code_name:Base.String.t -> id:Base__Int.t -> label:Base.String.t Base.List.t -> unit -> Base.String.t
val prepare : is_done:(Base.unit -> Base.bool) -> sync:(Base.unit -> Base.unit) -> transfer:(Base.unit -> Base.unit) -> prepare option -> prepare
val prepare_read : is_done:(Base.unit -> Base.bool) -> sync:(Base.unit -> Base.unit) -> transfer:(Base.unit -> Base.unit) -> t -> unit
val prepare_write : is_done:(Base.unit -> Base.bool) -> sync:(Base.unit -> Base.unit) -> t -> unit
val debug_name : t -> Base.String.t
val debug_memory_mode : (memory_mode * Base.Int.t) option -> Base.String.t
val log_debug_info : from_log_level:int -> t -> unit
val default_to_most_local : t -> Base.int -> unit

The one exception to "most local" is that the sharing property is kept at Unset.

val is_virtual_force : t -> Base.int -> bool
val is_hosted_force : t -> Base.int -> bool
val is_materialized_force : t -> Base.int -> bool
val is_in_context_force : use_host_memory:'a Base.option -> t -> Base.int -> Base.bool
val known_not_materialized : t -> bool
val known_constant : t -> bool
val known_volatile : t -> bool
val known_non_virtual : t -> bool
val known_not_param : t -> bool
val known_shared_cross_streams : t -> bool
val known_non_cross_stream : t -> bool
val potentially_cross_stream : t -> Base.bool
val mode_is_unspecified : t -> bool
val update_memory_mode : t -> memory_mode -> Base.int -> unit
val update_memory_sharing : t -> sharing -> Base.int -> unit

update_memory_sharing tn sharing provenance preserves the memory mode of tn while updating the cross-stream sharing property, except that Hosted Nonconstant is further specialized to Hosted (Changed_on_devices sharing).

val update_prec : ?only_if:(Ops.prec -> bool) -> t -> Ops.prec -> unit
val exceeds_fp16_cutoff : t -> Base.Float.t -> Base.bool
include sig ... end
type comparator_witness
val comparator : (t, comparator_witness) Base__Comparator.comparator
val equal : t -> t -> Base.bool
val hash : t -> Base__Ppx_hash_lib.Std.Hash.hash_value
val hash_fold_t : Base__.Ppx_hash_lib.Std.Hash.state -> t -> Base__.Ppx_hash_lib.Std.Hash.state
val hash_t : t -> Base__Ppx_hash_lib.Std.Hash.hash_value
module Comp : sig ... end
type t_set = Base.Set.M(Comp).t
val sexp_of_t_set : (t, 'a) Base.Set.t -> Sexplib0.Sexp.t
val get_exn : t -> Nd.t
val has : t -> bool
val dims_to_string : ?with_axis_numbers:bool -> t -> Base.String.t
val no_grad_ident_label : t -> bool * Base.String.t option
val styled_ident : repeating_nograd_idents:(Base.String.t, 'a) Base.Hashtbl.t -> repeating_grad_idents:(Base.String.t, 'a) Base.Hashtbl.t -> [< `Heuristic_ocannl of [< `Dot_grad | `Under_grad ] | `Name_and_label | `Name_only ] -> t -> Base.String.t
val update_code_name : t -> Base.string -> unit
val get_style : ?arg_name:Base.string -> ?no_dots:bool -> unit -> [> `Heuristic_ocannl of [> `Dot_grad | `Under_grad ] | `Name_and_label | `Name_only ]
val header : t -> string
module Registry : sig ... end
val registry : Registry.t
val create : ?default_prec:Ops.prec Lazy.t -> id:Base__Int.t -> label:Base.String.t Base.List.t -> dims:Base.int Base.array Lazy.t -> Ops.init_op -> t
val find : id:Base.int -> Registry.data option

Accessors

val do_read : t -> unit
val do_write : t -> unit
val points_1d : ?from_axis:Base__Int.t -> xdim:int -> t -> Base.Float.t Base.Array.t
val points_2d : ?from_axis:Base__Int.t -> xdim:int -> ydim:int -> t -> (Base.Float.t * Base.Float.t) Base.Array.t
val set_value : t -> int array -> Base.float -> unit
val get_value : t -> int array -> Base.Float.t
val set_values : t -> Base.float Base.array -> unit
val get_values : t -> Base.Float.t Base.Array.t
val print_accessible_headers : unit -> Base.unit
val log_accessible_headers : unit -> unit