Module Ocannl.Parallel

module Tn = Ir.Tnode
module Nd = Ir.Ndarray
module Asgns = Ir.Assignments
module Idx = Ir.Indexing
module Task = Ir.Task
module Backends = Context.Backends_deprecated
val _get_local_debug_runtime : unit -> (module Minidebug_runtime.Debug_runtime)
type reduction =
  1. | Sum
  2. | Mean

Gradient reduction mode used by data_parallel when all-reducing parameter gradients across data-parallel shards. Sum adds the per-shard gradients; Mean additionally divides by the shard count.

val host_get : Tn.t -> Base__Int.t Base.Array.t -> Base.Float.t
val host_values : Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Base.float Base.array

Reads all (unpadded) values of an ndarray-backed (literal) tensor from its host initialization buffer. Use for inspecting host-level shard/gather results.

val shard_along : axis:int -> n_shards:Base__Int.t -> Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Ocannl_tensor.Operation.DSL_modules.Tensor.t Base.array
type handle = {
  1. n_shards : Base.int;
  2. step : Base.unit -> Base.unit;
    (*

    Run one synchronized optimizer step: every shard's forward+backward, an all-reduce of the parameter gradients across shards via merge-buffer transfer routines, one optimizer update on the owner shard, then a broadcast of the updated parameters back to the other shards.

    *)
  3. grad_sync : Base.unit -> Base.unit;
    (*

    All-reduce the parameter gradients across shards onto the owner via merge-buffer transfer routines (with the configured reduction). Run after the shards' backward passes and before the optimizer step; step already calls it, but it is exposed for custom training loops.

    *)
  4. set_batch : inputs:Ocannl_tensor.Operation.DSL_modules.Tensor.t -> targets:Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Base.unit;
    (*

    Scatter a fresh logical batch across the shards (re-shards along the batch axis and copies into the per-shard input buffers) for multi-step training.

    *)
  5. owner_loss_value : Base.unit -> Base.float;
    (*

    The owner shard's scalar loss after the latest step.

    *)
  6. sync_params_to_host : Base.unit -> Base.unit;
    (*

    Awaits the owner stream. Retained for API compatibility; after gh-ocannl-333 host values are read on demand via read_values rather than copied to a host array.

    *)
  7. read_values : Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Base.float Base.array;
    (*

    Reads a tensor's current values from the owner shard's context via an on-demand device-to-host transfer. Use for the owner's parameters / loss.

    *)
  8. owner_params : Ocannl_tensor.Operation.DSL_modules.Tensor.t Base.array;
    (*

    The owner shard's parameter tensors (in stable order). Read their values via read_values.

    *)
  9. shard_seeds : Base.int Base.array;
    (*

    The RNG seed assigned to each shard (shard_seeds.(i) = base_seed + i). These are the exact values passed to set_random_seed when building the shards, so distinct entries witness that the per-shard RNG seeding diverges.

    *)
}

A handle to a live data-parallel training session. All raw-backend state (the shared backend module, the per-shard streams/contexts, the parameter replicas, and the compiled per-shard / gradient-sync / optimizer / broadcast routines) is captured by these closures; there is no hidden global tensor-to-context lookup. Obtain one via data_parallel.

val data_parallel : ?backend_name:Base.string -> ?reduction:reduction -> ?weight_decay:Base.float -> ?momentum:Base.Float.t -> ?base_seed:Base__Int.t -> n_shards:Base__Int.t -> bindings:Idx.unit_bindings -> learning_rate:Ocannl_tensor.Operation.DSL_modules.Tensor.t -> inputs:Ocannl_tensor.Operation.DSL_modules.Tensor.t -> targets:Ocannl_tensor.Operation.DSL_modules.Tensor.t -> loss_of: (Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Ocannl_tensor.Operation.DSL_modules.Tensor.t -> Ocannl_tensor.Operation.DSL_modules.Tensor.t) -> f:(handle -> 'a) -> unit -> 'a