Ocannl.Train
module Ops = Ir.Ops
module Tn = Ir.Tnode
module Nd = Ir.Ndarray
module Asgns = Ocannl.Operation.DSL_modules.Ir.Assignments
module Idx = Ocannl.Operation.DSL_modules.Ir.Indexing
module Task = Ocannl.Operation.DSL_modules.Ir.Task
module BT = Ocannl.Operation.DSL_modules.Ir.Backend_intf
module type Backend = Ocannl.Operation.DSL_modules.Ir.Backend_intf.Backend
module CDSL : sig ... end
module IDX : sig ... end
val run : 'a BT.routine -> Base.unit
val set_on_host : ?from_device:bool -> Tn.t -> unit
val set_materialized : Tn.t -> unit
val set_hosted : Tn.t -> unit
val forward : Ocannl.Operation.DSL_modules.Tensor.t -> Ir.Assignments.comp
Sets the tensor's value as "fully on host", returns the tensor's forward code with a label-derived comment.
val grad_update :
?setup_for_parallel:bool ->
Ocannl.Operation.DSL_modules.Tensor.t ->
Ocannl.Operation.DSL_modules.Ir.Assignments.comp
Returns the tensor's forward, zeroing gradients, and backprop code wrapped with label-derived comments. Sets the tensor's value as "fully on host". If setup_for_parallel
is true (false by default), sets the parameters and their gradients as "non-local" (on-device).
val sgd_one :
learning_rate:Ocannl.Operation.DSL_modules.Tensor.t ->
?momentum:Base.Float.t ->
?weight_decay:Base.float ->
?nesterov:bool ->
Ocannl.Operation.DSL_modules.Tensor.t ->
Ocannl.Operation.DSL_modules.Ir.Assignments.comp
See: https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/optim.py
val sgd_update :
learning_rate:Ocannl.Operation.DSL_modules.Tensor.t ->
?momentum:Base.Float.t ->
?weight_decay:Base.float ->
?nesterov:bool ->
Ocannl.Operation.DSL_modules.Tensor.t ->
Asgns.comp
val sequential_loop :
f:(unit -> Base.unit) ->
(Idx.static_symbol * int Base.ref) list ->
Base.unit
All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.
val round_robin :
(unit -> unit) Base.Array.t ->
Idx.lowered_bindings Base.Array.t ->
(Idx.static_symbol * Base.int Base.ref) list ->
sync:(Base.int -> Base.unit) ->
Base.unit
Distributes iterated indices to workers in a round-robin fashion. All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values. sync
is called after each round of calling all workers, and at the end if needed, with the number of workers called during the round.
val round_robin_dry_run :
num_streams:Base__Int.t ->
(Idx.static_symbol * int Base.ref) list ->
dry_sync:(Base__Int.t -> Base.unit) ->
Base.unit
val set_virtual : Tn.t -> unit
val every_non_literal_on_host :
Ocannl.Operation.DSL_modules.Tensor.t ->
Base.unit
module Lazy = Utils.Lazy
val parallel_update :
(module Backend
with type buffer_ptr = 'buffer_ptr
and type dev = 'dev
and type event = 'event
and type optimize_ctx = 'optimize_ctx
and type runner = 'runner) ->
grad_updates:
('buffer_ptr,
('buffer_ptr, 'dev, 'runner, 'event) Ir__Backend_intf.stream_ref,
'optimize_ctx)
Ir__Backend_intf.context
BT.routine
Base.array ->
sgd_update:
('buffer_ptr,
('buffer_ptr, 'dev, 'runner, 'event) Ir__Backend_intf.stream_ref,
'optimize_ctx)
Ir__Backend_intf.context
BT.routine ->
copy_to_merge:bool ->
post_sync:(num_synced_devices:Base.int -> Base.unit) ->
Ocannl.Operation.DSL_modules.Tensor.t ->
unit ->
Base.unit
Performs one optimization step, potentially in parallel (if grad_updates
are linked with different streams or devices). All jitted code must have the same bindings. Iterates over bindings with ranges, calling one of grad_updates
in a round-robin fashion, and performs the following synchronization each time all grad_updates
have been called:
grad_updates.(0)
,sgd_update
,grad_updates.(0)
device to the other devices, if needed,post_sync
with the number of devices synced since the previous sync.All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.
val get_all_suggested_streams :
?max_num_streams:Base.int ->
(module Backend
with type buffer_ptr = 'buffer_ptr
and type dev = 'dev
and type event = 'event
and type runner = 'runner) ->
('buffer_ptr, 'dev, 'runner, 'event) Ir__Backend_intf.device_ref Base.Array.t
* ('buffer_ptr, 'dev, 'runner, 'event) Ir__Backend_intf.stream_ref
Base.Array.t
val to_routine :
(module Backend
with type buffer_ptr = 'buffer_ptr
and type dev = 'dev
and type event = 'event
and type optimize_ctx = 'optimize_ctx
and type runner = 'runner) ->
('buffer_ptr,
('buffer_ptr, 'dev, 'runner, 'event) Ir__Backend_intf.stream_ref,
'optimize_ctx)
Ir__Backend_intf.context ->
?hosted:bool ->
?name:Base.string ->
Ir.Indexing.unit_bindings ->
Asgns.comp ->
('buffer_ptr,
('buffer_ptr, 'dev, 'runner, 'event) Ir__Backend_intf.stream_ref,
'optimize_ctx)
Ir__Backend_intf.context
Ir__Backend_intf.routine
val init_params :
(module Backend
with type buffer_ptr = 'buffer_ptr
and type dev = 'dev
and type event = 'event
and type optimize_ctx = 'optimize_ctx
and type runner = 'runner) ->
?ctx:
('buffer_ptr,
('buffer_ptr, 'dev, 'runner, 'event) Ir__Backend_intf.stream_ref,
'optimize_ctx)
Ir__Backend_intf.context ->
?reinit_all:bool ->
?hosted:bool ->
?name:Base.string ->
Ir.Indexing.unit_bindings ->
Ocannl.Operation.DSL_modules.Tensor.t ->
('buffer_ptr,
('buffer_ptr, 'dev, 'runner, 'event) Ir__Backend_intf.stream_ref,
'optimize_ctx)
Ir__Backend_intf.context
init_params
initializes the parameters of t
, via running their forward code or copying from the host as appropriate. If reinit_all
is true, all parameters are reinitialized, otherwise only the parameters that are not in ctx.ctx_arrays
are initialized.
type example_train_result = {
inputs : Ocannl.Operation.DSL_modules.Tensor.t;
outputs : Ocannl.Operation.DSL_modules.Tensor.t;
model_result : Ocannl.Operation.DSL_modules.Tensor.t;
Do not use model_result
for deriving gradients.
infer_callback : Base.float Base.array -> Base.float Base.array;
Computes the output for the given input via the model_result
tensor. Note: infer_callback
is inefficient as it is not batched.
rev_batch_losses : Base.float Base.list;
rev_epoch_losses : Base.float Base.list;
learning_rates : Base.float Base.list;
used_memory : Base.int;
}
val example_train_loop :
seed:Base.int ->
batch_size:Base__Int.t ->
init_lr:Base.float ->
?lr_schedule:
(batch_n:Idx.static_symbol -> step_n:Idx.static_symbol -> Tensor.t) ->
?copy_to_merge:bool ->
?max_num_streams:Base.int ->
data_len:Base__Int.t ->
epochs:Base__Int.t ->
inputs:Tensor.t ->
outputs:Tensor.t ->
model:(Tensor.t -> Ocannl.Operation.DSL_modules.Tensor.t) ->
loss_fn:
(output:Ocannl.Operation.DSL_modules.Tensor.t ->
expectation:Tensor.t ->
Tensor.t) ->
weight_decay:Base.float ->
?per_batch_callback:
(at_batch:Base.int ->
at_step:Base.int ->
learning_rate:Base.Float.t ->
batch_loss:Base.Float.t ->
epoch_loss:float ->
unit) ->
?per_epoch_callback:
(at_step:Base.int ->
at_epoch:int ->
learning_rate:Base.Float.t ->
epoch_loss:float ->
unit) ->
?per_epoch_debug_streams:bool ->
(module Backend) ->
unit ->
example_train_result
val run_once :
?hosted:bool ->
?skip_init:Base.bool ->
?reinit_all:bool ->
(module Backend
with type buffer_ptr = 'buffer_ptr
and type dev = 'dev
and type event = 'event
and type optimize_ctx = 'optimize_ctx
and type runner = 'runner) ->
?ctx:
('buffer_ptr,
('buffer_ptr, 'dev, 'runner, 'event) Ir__Backend_intf.stream_ref,
'optimize_ctx)
Ir__Backend_intf.context ->
?bindings:(Base.unit -> Base.unit) Idx.bindings ->
f:(Ocannl.Operation.DSL_modules.Tensor.t -> Ir.Assignments.comp) ->
Ocannl.Operation.DSL_modules.Tensor.t ->
('buffer_ptr,
('buffer_ptr, 'dev, 'runner, 'event) Ir__Backend_intf.stream_ref,
'optimize_ctx)
Ir__Backend_intf.context
run_once
is a wrapper around init_params
that additionally runs code of f t
and returns the context. If skip_init
is true (false by default), no initialization is performmed. If reinit_all
is true (false by default), all parameters are reinitialized, otherwise only the parameters that are not in ctx.ctx_arrays
are initialized.
val forward_once :
?hosted:bool ->
?skip_init:Base.bool ->
?reinit_all:bool ->
(module Backend
with type buffer_ptr = 'buffer_ptr
and type dev = 'dev
and type event = 'event
and type optimize_ctx = 'optimize_ctx
and type runner = 'runner) ->
?ctx:
('buffer_ptr,
('buffer_ptr, 'dev, 'runner, 'event) Ir__Backend_intf.stream_ref,
'optimize_ctx)
Ir__Backend_intf.context ->
?bindings:(Base.unit -> Base.unit) Idx.bindings ->
Ocannl.Operation.DSL_modules.Tensor.t ->
('buffer_ptr,
('buffer_ptr, 'dev, 'runner, 'event) Ir__Backend_intf.stream_ref,
'optimize_ctx)
Ir__Backend_intf.context
forward_once
is a wrapper around run_once
that runs the forward code of t
.
val update_once :
?hosted:bool ->
?skip_init:Base.bool ->
?reinit_all:bool ->
(module Backend
with type buffer_ptr = 'buffer_ptr
and type dev = 'dev
and type event = 'event
and type optimize_ctx = 'optimize_ctx
and type runner = 'runner) ->
?ctx:
('buffer_ptr,
('buffer_ptr, 'dev, 'runner, 'event) Ir__Backend_intf.stream_ref,
'optimize_ctx)
Ir__Backend_intf.context ->
?bindings:(Base.unit -> Base.unit) Idx.bindings ->
Ocannl.Operation.DSL_modules.Tensor.t ->
('buffer_ptr,
('buffer_ptr, 'dev, 'runner, 'event) Ir__Backend_intf.stream_ref,
'optimize_ctx)
Ir__Backend_intf.context
update_once
is a wrapper around run_once
that runs the gradient update code of t
: both forward and backprop.
val printf :
?here:Ppx_here_lib.position ->
?with_grad:Base.bool ->
?with_code:Base.bool ->
?with_low_level:Base.bool ->
?style:Ocannl.Operation.DSL_modules.Tensor.array_print_style ->
Ocannl.Operation.DSL_modules.Tensor.t ->
Base.unit
printf
is a wrapper around Tensor.print
that assumes ~force:true
, and by default sets ~with_code:false
, ~with_grad:true
, and ~style:`Default
.
val printf_tree :
?here:Ppx_here_lib.position ->
?with_value:Base.bool ->
?with_grad:Base.bool ->
?depth:Base.int ->
Ocannl.Operation.DSL_modules.Tensor.t ->
Base.unit
printf_tree
is a wrapper around Tensor.print_tree
that assumes ~force:true
, and by default sets ~with_value:true
, ~with_grad:true
, and ~depth:9
.