Module Ocannl.Train

User-facing modules

module Ops = Ir.Ops
module Tn = Ir.Tnode
module Nd = Ir.Ndarray
module Asgns = Ir.Assignments
module Idx = Ir.Indexing
module Task = Ir.Task
val _get_local_debug_runtime : unit -> (module Minidebug_runtime.Debug_runtime)
module CDSL : sig ... end
module IDX : sig ... end
val run : Context.t -> Context.routine -> Base.unit
val set_on_host : ?from_device:bool -> Tn.t -> unit
val set_materialized : Tn.t -> unit
val set_hosted : Tn.t -> unit

Sets the tensor's value as "fully on host", returns the tensor's forward code with a label-derived comment.

Returns the tensor's forward, zeroing gradients, and backprop code wrapped with label-derived comments. Sets the tensor's value as "fully on host". If setup_for_parallel is true (false by default), sets the parameters and their gradients as "non-local" (on-device).

val sgd_one : learning_rate:Ocannl_tensor.Operation.DSL_modules.Tensor.t -> ?momentum:Base.Float.t -> ?weight_decay:Base.float -> ?nesterov:bool -> Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Ocannl_tensor.Operation.DSL_modules.Ir.Assignments.comp

See: https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/optim.py

val sgd_update : learning_rate:Ocannl_tensor.Operation.DSL_modules.Tensor.t -> ?momentum:Base.Float.t -> ?weight_decay:Base.float -> ?nesterov:bool -> Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Asgns.comp
val sequential_loop : f:(unit -> Base.unit) -> (Idx.static_symbol * int Base.ref) list -> Base.unit

All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.

val round_robin : (unit -> unit) Base.Array.t -> Idx.lowered_bindings Base.Array.t -> (Idx.static_symbol * Base.int Base.ref) list -> sync:(Base.int -> Base.unit) -> Base.unit

Distributes iterated indices to workers in a round-robin fashion. All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values. sync is called after each round of calling all workers, and at the end if needed, with the number of workers called during the round.

val round_robin_dry_run : num_streams:Base__Int.t -> (Idx.static_symbol * int Base.ref) list -> dry_sync:(Base__Int.t -> Base.unit) -> Base.unit
val set_virtual : Tn.t -> unit
val every_non_literal_on_host : Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Base.unit
module Lazy = Utils.Lazy
val to_routine : Context.t -> ?hosted:bool -> Ir.Indexing.unit_bindings -> Asgns.comp -> Context.routine
val init_params : ?reinit_all:bool -> ?hosted:bool -> Context.t -> Ir.Indexing.unit_bindings -> Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Context.t

init_params initializes the parameters of t, via running their forward code or copying from the host as appropriate. If reinit_all is true, all parameters are reinitialized, otherwise only the parameters that are not in ctx.ctx_arrays are initialized.

type example_train_result = {
  1. inputs : Ocannl_tensor.Operation.DSL_modules.Tensor.t;
  2. outputs : Ocannl_tensor.Operation.DSL_modules.Tensor.t;
  3. model_result : Ocannl_tensor.Operation.DSL_modules.Tensor.t;
    (*

    Do not use model_result for deriving gradients.

    *)
  4. infer_callback : Base.float Base.array -> Base.float Base.array;
    (*

    Computes the output for the given input via the model_result tensor. Note: infer_callback is inefficient as it is not batched.

    *)
  5. rev_batch_losses : Base.float Base.list;
  6. rev_epoch_losses : Base.float Base.list;
  7. learning_rates : Base.float Base.list;
  8. used_memory : Base.int;
}
val run_once : ?output_cd_file:bool -> ?hosted:bool -> ?skip_init:Base.bool -> ?reinit_all:bool -> ?bindings:(Base.unit -> Base.unit) Idx.bindings -> f:(Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Asgns.comp) -> Context.t -> Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Context.t

run_once is a wrapper around init_params that additionally runs code of f t and returns the context. If skip_init is true (false by default), no initialization is performmed. If reinit_all is true (false by default), all parameters are reinitialized, otherwise only the parameters that are not in ctx.ctx_arrays are initialized.

If output_cd_file is true, the global setting output_debug_files_in_build_directory must be true, and the update code is output to a file before shape inference potentially crashes at init_params.

Context-based versions of training functions for the new simplified API

val forward_once : ?output_cd_file:bool -> ?hosted:bool -> ?skip_init:Base.bool -> ?reinit_all:bool -> ?bindings:(Base.unit -> Base.unit) Idx.bindings -> Context.t -> Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Context.t

forward_once is a wrapper around run_once that runs the forward code of t.

val update_once : ?output_cd_file:bool -> ?hosted:bool -> ?skip_init:Base.bool -> ?reinit_all:bool -> ?bindings:(Base.unit -> Base.unit) Idx.bindings -> Context.t -> Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Context.t

update_once is a wrapper around run_once that runs the gradient update code of t: both forward and backprop.

val printf : ?here:Ppx_here_lib.position -> ?with_grad:Base.bool -> ?with_code:Base.bool -> ?with_low_level:Base.bool -> ?style:Ocannl_tensor.Operation.DSL_modules.Tensor.array_print_style -> Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Base.unit

printf is a wrapper around Tensor.print that assumes ~force:true, and by default sets ~with_code:false, ~with_grad:true, and ~style:`Default.

val printf_tree : ?here:Ppx_here_lib.position -> ?with_value:Base.bool -> ?with_grad:Base.bool -> ?depth:Base.int -> Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Base.unit

printf_tree is a wrapper around Tensor.print_tree that assumes ~force:true, and by default sets ~with_value:true, ~with_grad:true, and ~depth:9.