Ocannl.TrainUser-facing modules
module Ops = Ir.Opsmodule Tn = Ir.Tnodemodule Nd = Ir.Ndarraymodule Asgns = Ir.Assignmentsmodule Idx = Ir.Indexingmodule Task = Ir.Taskmodule CDSL : sig ... endmodule IDX : sig ... endval run : Context.t -> Context.routine -> Base.unitval set_on_host : ?from_device:bool -> Tn.t -> unitval set_materialized : Tn.t -> unitval set_hosted : Tn.t -> unitval forward :
Ocannl_tensor.Operation.DSL_modules.Tensor.t ->
Ir.Assignments.compSets the tensor's value as "fully on host", returns the tensor's forward code with a label-derived comment.
val grad_update :
?setup_for_parallel:bool ->
Ocannl_tensor.Operation.DSL_modules.Tensor.t ->
Ocannl_tensor.Operation.DSL_modules.Ir.Assignments.compReturns the tensor's forward, zeroing gradients, and backprop code wrapped with label-derived comments. Sets the tensor's value as "fully on host". If setup_for_parallel is true (false by default), sets the parameters and their gradients as "non-local" (on-device).
val sgd_one :
learning_rate:Ocannl_tensor.Operation.DSL_modules.Tensor.t ->
?momentum:Base.Float.t ->
?weight_decay:Base.float ->
?nesterov:bool ->
Ocannl_tensor.Operation.DSL_modules.Tensor.t ->
Ocannl_tensor.Operation.DSL_modules.Ir.Assignments.compSee: https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/optim.py
val sgd_update :
learning_rate:Ocannl_tensor.Operation.DSL_modules.Tensor.t ->
?momentum:Base.Float.t ->
?weight_decay:Base.float ->
?nesterov:bool ->
Ocannl_tensor.Operation.DSL_modules.Tensor.t ->
Asgns.compval sequential_loop :
f:(unit -> Base.unit) ->
(Idx.static_symbol * int Base.ref) list ->
Base.unitAll and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.
val round_robin :
(unit -> unit) Base.Array.t ->
Idx.lowered_bindings Base.Array.t ->
(Idx.static_symbol * Base.int Base.ref) list ->
sync:(Base.int -> Base.unit) ->
Base.unitDistributes iterated indices to workers in a round-robin fashion. All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values. sync is called after each round of calling all workers, and at the end if needed, with the number of workers called during the round.
val round_robin_dry_run :
num_streams:Base__Int.t ->
(Idx.static_symbol * int Base.ref) list ->
dry_sync:(Base__Int.t -> Base.unit) ->
Base.unitval set_virtual : Tn.t -> unitval every_non_literal_on_host :
Ocannl_tensor.Operation.DSL_modules.Tensor.t ->
Base.unitmodule Lazy = Utils.Lazyval to_routine :
Context.t ->
?hosted:bool ->
Ir.Indexing.unit_bindings ->
Asgns.comp ->
Context.routineval init_params :
?reinit_all:bool ->
?hosted:bool ->
Context.t ->
Ir.Indexing.unit_bindings ->
Ocannl_tensor.Operation.DSL_modules.Tensor.t ->
Context.tinit_params initializes the parameters of t, via running their forward code or copying from the host as appropriate. If reinit_all is true, all parameters are reinitialized, otherwise only the parameters that are not in ctx.ctx_arrays are initialized.
type example_train_result = {inputs : Ocannl_tensor.Operation.DSL_modules.Tensor.t;outputs : Ocannl_tensor.Operation.DSL_modules.Tensor.t;model_result : Ocannl_tensor.Operation.DSL_modules.Tensor.t;Do not use model_result for deriving gradients.
infer_callback : Base.float Base.array -> Base.float Base.array;Computes the output for the given input via the model_result tensor. Note: infer_callback is inefficient as it is not batched.
rev_batch_losses : Base.float Base.list;rev_epoch_losses : Base.float Base.list;learning_rates : Base.float Base.list;used_memory : Base.int;}val run_once :
?output_cd_file:bool ->
?hosted:bool ->
?skip_init:Base.bool ->
?reinit_all:bool ->
?bindings:(Base.unit -> Base.unit) Idx.bindings ->
f:(Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Asgns.comp) ->
Context.t ->
Ocannl_tensor.Operation.DSL_modules.Tensor.t ->
Context.trun_once is a wrapper around init_params that additionally runs code of f t and returns the context. If skip_init is true (false by default), no initialization is performmed. If reinit_all is true (false by default), all parameters are reinitialized, otherwise only the parameters that are not in ctx.ctx_arrays are initialized.
If output_cd_file is true, the global setting output_debug_files_in_build_directory must be true, and the update code is output to a file before shape inference potentially crashes at init_params.
Context-based versions of training functions for the new simplified API
val forward_once :
?output_cd_file:bool ->
?hosted:bool ->
?skip_init:Base.bool ->
?reinit_all:bool ->
?bindings:(Base.unit -> Base.unit) Idx.bindings ->
Context.t ->
Ocannl_tensor.Operation.DSL_modules.Tensor.t ->
Context.tforward_once is a wrapper around run_once that runs the forward code of t.
val update_once :
?output_cd_file:bool ->
?hosted:bool ->
?skip_init:Base.bool ->
?reinit_all:bool ->
?bindings:(Base.unit -> Base.unit) Idx.bindings ->
Context.t ->
Ocannl_tensor.Operation.DSL_modules.Tensor.t ->
Context.tupdate_once is a wrapper around run_once that runs the gradient update code of t: both forward and backprop.
val printf :
?here:Ppx_here_lib.position ->
?with_grad:Base.bool ->
?with_code:Base.bool ->
?with_low_level:Base.bool ->
?style:Ocannl_tensor.Operation.DSL_modules.Tensor.array_print_style ->
Ocannl_tensor.Operation.DSL_modules.Tensor.t ->
Base.unitprintf is a wrapper around Tensor.print that assumes ~force:true, and by default sets ~with_code:false, ~with_grad:true, and ~style:`Default.
val printf_tree :
?here:Ppx_here_lib.position ->
?with_value:Base.bool ->
?with_grad:Base.bool ->
?depth:Base.int ->
Ocannl_tensor.Operation.DSL_modules.Tensor.t ->
Base.unitprintf_tree is a wrapper around Tensor.print_tree that assumes ~force:true, and by default sets ~with_value:true, ~with_grad:true, and ~depth:9.