Module Ocannl.Train

module Ops = Arrayjit.Ops
module Tn = Arrayjit.Tnode
module Nd = Arrayjit.Ndarray
module NTDSL = Operation.NTDSL
module Asgns = Arrayjit.Assignments
module Idx = Arrayjit.Indexing
module Task = Arrayjit.Task
module Utils = Arrayjit.Utils
module Rand = Arrayjit.Rand.Lib
module type Backend = Arrayjit.Backend_intf.Backend
val _get_local_debug_runtime : unit -> (module Minidebug_runtime.Debug_runtime)
module CDSL : sig ... end
module IDX : sig ... end
val run : 'a BT.routine -> Base.unit
val is_param : Tensor.t -> Base.bool
val get_params : Tensor.t -> (Tensor.t, Tensor.comparator_witness) Base.Set.t
val set_on_host : ?from_device:bool -> Tn.t -> unit
val set_materialized : Tn.t -> unit
val set_hosted : Tn.t -> unit
val forward : ?disable_rootness_check:bool -> Tensor.t -> Arrayjit.Assignments.comp

Sets the tensor's value as "fully on host", returns the tensor's forward code with a label-derived comment.

type updaten = {
  1. loss : Tensor.t;
  2. params : (Tensor.t, Tensor.comparator_witness) Base.Set.t;
  3. fwd_bprop : Asgns.comp;
}
val diff_or_error : Tensor.t -> Base.String.t -> Tensor.diff
val grad_update_nochecks : Tensor.t -> updaten
val grad_update : ?disable_rootness_check:bool -> ?setup_for_parallel:bool -> Tensor.t -> updaten

Returns the tensor's forward, zeroing gradients, and backprop code wrapped with label-derived comments. Sets the tensor's value as "fully on host". If setup_for_parallel is true (false by default), sets the parameters and their gradients as "non-local" (on-device).

val sgd_one : learning_rate:Tensor.t -> ?momentum:Base.Float.t -> ?weight_decay:Base.float -> ?nesterov:bool -> Tensor.t -> Arrayjit.Assignments.comp

See: https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/optim.py

val sgd_update : learning_rate:Tensor.t -> ?momentum:Base.Float.t -> ?weight_decay:Base.float -> ?nesterov:bool -> updaten -> Arrayjit.Assignments.comp
val sequential_loop : f:(unit -> Base.unit) -> (Idx.static_symbol * Base.int Base.ref) list -> Base.unit

All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.

val round_robin : (unit -> unit) Base.Array.t -> Idx.lowered_bindings Base.Array.t -> (Idx.static_symbol * Base.int Base.ref) list -> sync:(Base.int -> unit) -> Base.unit

Distributes iterated indices to workers in a round-robin fashion. All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values. sync is called after each round of calling all workers, and at the end if needed, with the number of workers called during the round.

val round_robin_dry_run : num_streams:Base__Int.t -> (Idx.static_symbol * Base.int Base.ref) list -> dry_sync:(Base__Int.t -> unit) -> Base.unit
val set_virtual : Tn.t -> unit
val every_non_literal_on_host : Tensor.t -> Base.unit
module Lazy = Utils.Lazy
val parallel_update : (module Backend with type buffer_ptr = 'buffer_ptr and type dev = 'dev and type event = 'event and type runner = 'runner) -> grad_updates: ('buffer_ptr, ('buffer_ptr, 'dev, 'runner, 'event) Arrayjit__Backend_intf.stream_ref) Arrayjit__Backend_intf.context BT.routine Base.array -> sgd_update: ('buffer_ptr, ('buffer_ptr, 'dev, 'runner, 'event) Arrayjit__Backend_intf.stream_ref) Arrayjit__Backend_intf.context BT.routine -> copy_to_merge:bool -> post_sync:(num_synced_devices:Base.int -> Base.unit) -> updaten -> Base.unit -> Base.unit

Performs one optimization step, potentially in parallel (if grad_updates are linked with different streams or devices). All jitted code must have the same bindings. Iterates over bindings with ranges, calling one of grad_updates in a round-robin fashion, and performs the following synchronization each time all grad_updates have been called:

  • merges all gradients into the device of grad_updates.(0),
  • calls sgd_update,
  • copies all parameters from the grad_updates.(0) device to the other devices, if needed,
  • calls post_sync with the number of devices synced since the previous sync.

All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.

val get_all_suggested_streams : ?max_num_streams:Base.int -> (module Backend with type buffer_ptr = 'buffer_ptr and type dev = 'dev and type event = 'event and type runner = 'runner) -> ('buffer_ptr, 'dev, 'runner, 'event) Arrayjit__Backend_intf.device_ref Base.Array.t * ('buffer_ptr, 'dev, 'runner, 'event) Arrayjit__Backend_intf.stream_ref Base.Array.t
val to_routine : (module Backend with type buffer_ptr = 'buffer_ptr and type dev = 'dev and type event = 'event and type runner = 'runner) -> ('buffer_ptr, ('buffer_ptr, 'dev, 'runner, 'event) Arrayjit__Backend_intf.stream_ref) Arrayjit__Backend_intf.context -> ?name:Base.string -> Arrayjit.Indexing.unit_bindings -> Arrayjit.Assignments.comp -> ('buffer_ptr, ('buffer_ptr, 'dev, 'runner, 'event) Arrayjit__Backend_intf.stream_ref) Arrayjit__Backend_intf.context Arrayjit__Backend_intf.routine
type example_train_result = {
  1. inputs : Tensor.t;
  2. outputs : Tensor.t;
  3. model_result : Tensor.t;
  4. infer_callback : Base.float Base.array -> Base.float Base.array;
    (*

    Note: infer_callback is significantly less efficient than using the model via arrayjit.

    *)
  5. rev_batch_losses : Base.float Base.list;
  6. rev_epoch_losses : Base.float Base.list;
  7. learning_rates : Base.float Base.list;
  8. used_memory : Base.int;
}
val example_train_loop : ?disable_rootness_check:Base.bool -> seed:Base.int -> batch_size:Base__Int.t -> init_lr:Base.float -> ?lr_schedule: (batch_n:Idx.static_symbol -> step_n:Idx.static_symbol -> Tensor.t) -> ?copy_to_merge:bool -> ?max_num_streams:Base.int -> data_len:Base__Int.t -> epochs:Base__Int.t -> inputs:(b:Base__Int.t list -> Tensor.t) -> outputs:(b:Base__Int.t list -> Tensor.t) -> model:(Tensor.t -> Tensor.t) -> loss_fn:(output:Tensor.t -> expectation:Tensor.t -> Tensor.t) -> weight_decay:Base.float -> ?per_batch_callback: (at_batch:Base.int -> at_step:Base.int -> learning_rate:Base.Float.t -> batch_loss:Base.Float.t -> epoch_loss:float -> unit) -> ?per_epoch_callback: (at_step:Base.int -> at_epoch:int -> learning_rate:Base.Float.t -> epoch_loss:float -> unit) -> ?per_epoch_debug_streams:bool -> (module Backend) -> unit -> example_train_result
val forward_and_ctx : ?disable_rootness_check:Base.bool -> (module Backend with type buffer_ptr = 'buffer_ptr and type dev = 'dev and type event = 'event and type runner = 'runner) -> ('buffer_ptr, ('buffer_ptr, 'dev, 'runner, 'event) Arrayjit__Backend_intf.stream_ref) Arrayjit__Backend_intf.context -> ?bindings:(Base.unit -> Base.unit) Idx.bindings -> Tensor.t -> ('buffer_ptr, ('buffer_ptr, 'dev, 'runner, 'event) Arrayjit__Backend_intf.stream_ref) Arrayjit__Backend_intf.context
val forward_and_forget : ?disable_rootness_check:Base.bool -> (module Backend with type buffer_ptr = 'a and type dev = 'b and type event = 'c and type runner = 'd) -> ('a, ('a, 'b, 'd, 'c) Arrayjit__Backend_intf.stream_ref) Arrayjit__Backend_intf.context -> ?bindings:(Base.unit -> Base.unit) Idx.bindings -> Tensor.t -> Base.unit