Module Arrayjit.Backend_intf

The interface types for backends

User-facing backend API.

type 'buffer_ptr buffer = {
  1. ptr : 'buffer_ptr;
  2. size_in_bytes : Base.int;
}
val sexp_of_buffer : 'buffer_ptr. ('buffer_ptr -> Sexplib0.Sexp.t) -> 'buffer_ptr buffer -> Sexplib0.Sexp.t
type 'buffer_ptr ctx_arrays = 'buffer_ptr Base.Map.M(Arrayjit.Tnode).t
val sexp_of_ctx_arrays : 'buffer_ptr. ('buffer_ptr -> Sexplib0.Sexp.t) -> 'buffer_ptr ctx_arrays -> Sexplib0.Sexp.t
module Buffer_types (Buffer_ptr : sig ... end) : sig ... end
module type Buffer = sig ... end
module type Alloc_buffer = sig ... end
type config =
  1. | Only_devices_parallel
  2. | For_parallel_copying
  3. | Most_parallel_streams

For now, we only configure a backend with regard to how many streams it should suggest using (where applicable).

val equal_config : config -> config -> Base.bool
val config_of_sexp : Sexplib0.Sexp.t -> config
val sexp_of_config : config -> Sexplib0.Sexp.t
val only_devices_parallel : config
val for_parallel_copying : config
val most_parallel_streams : config
val is_only_devices_parallel : config -> bool
val is_for_parallel_copying : config -> bool
val is_most_parallel_streams : config -> bool
val only_devices_parallel_val : config -> unit Stdlib.Option.t
val for_parallel_copying_val : config -> unit Stdlib.Option.t
val most_parallel_streams_val : config -> unit Stdlib.Option.t
module Variants_of_config : sig ... end
type merge_buffer_use =
  1. | No
  2. | Streaming_for of Task.t
  3. | Copy
val sexp_of_merge_buffer_use : merge_buffer_use -> Sexplib0.Sexp.t
type param_source =
  1. | Log_file_name
  2. | Merge_buffer
  3. | Param_ptr of Tnode.t
  4. | Static_idx of Indexing.static_symbol
val sexp_of_param_source : param_source -> Sexplib0.Sexp.t
type 'context routine = {
  1. context : 'context;
  2. schedule : Task.t;
  3. bindings : Indexing.lowered_bindings;
  4. name : Base.string;
  5. inputs : Base.Set.M(Arrayjit.Tnode).t;
    (*

    The materialized read-only and read-before-write (within the routine) non-constant nodes. They are inputs in a broad sense, as they could be recurrent nodes or parameters.

    *)
  6. merge_buffer_input : Tnode.t Base.option;
    (*

    Similar to inputs, for the merge buffer.

    *)
  7. outputs : Base.Set.M(Arrayjit.Tnode).t;
    (*

    All the materialized nodes written-to by the routine.

    *)
}
val sexp_of_routine : 'context. ('context -> Sexplib0.Sexp.t) -> 'context routine -> Sexplib0.Sexp.t
module type Device_config = sig ... end
type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
  1. dev : 'dev;
  2. ordinal : Base.int;
  3. device_id : Base.int;
  4. cross_stream_candidates : 'buffer_ptr Base.Hashtbl.M(Arrayjit.Tnode).t;
  5. owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Base.Hashtbl.M(Arrayjit.Tnode).t;
  6. shared_writer_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) Base.list Base.Hashtbl.M(Arrayjit.Tnode).t;
  7. host_reading_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) Base.list Base.Hashtbl.M(Arrayjit.Tnode).t;
  8. host_writing_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) Base.list Base.Hashtbl.M(Arrayjit.Tnode).t;
  9. mutable streams : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Utils.weak_dynarray;
}
and ('buffer_ptr, 'dev, 'runner, 'event) stream_ref = {
  1. device : ('buffer_ptr, 'dev, 'runner, 'event) device_ref;
  2. runner : 'runner;
  3. merge_buffer : 'buffer_ptr buffer Base.option Base.ref;
  4. stream_id : Base.int;
  5. mutable allocated_buffer : 'buffer_ptr buffer Base.option;
  6. updating_for : 'event Base.Hashtbl.M(Arrayjit.Tnode).t;
  7. mutable updating_for_merge_buffer : (Tnode.t * 'event Base.option) Base.option;
  8. reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) Base.list Base.Hashtbl.M(Arrayjit.Tnode).t;
}
val sexp_of_device_ref : 'a -> 'b -> 'c -> 'd -> ('e, 'f, 'g, 'h) device_ref -> Sexplib0.Sexp.t
val sexp_of_stream_ref : 'a -> 'b -> 'c -> 'd -> ('e, 'f, 'g, 'h) stream_ref -> Sexplib0.Sexp.t
val equal_stream_ref : ('a, 'b, 'c, 'd) stream_ref -> ('e, 'f, 'g, 'h) stream_ref -> Base.bool
type ('buffer_ptr, 'dev, 'runner, 'event) device = ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
  1. dev : 'dev;
  2. ordinal : Base.int;
    (*

    The number of the represented backend's device, in the range from 0 to the number of the backend's devices - 1.

    *)
  3. device_id : Base.int;
    (*

    A unique identifier among all device instances of all backends. Note that multiple device_id (distinct device instances) might refer to the same physical device.

    *)
  4. cross_stream_candidates : 'buffer_ptr Base.Hashtbl.M(Arrayjit.Tnode).t;
    (*

    Freshly created arrays that might be shared across streams. The map can both grow and shrink.

    *)
  5. owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Base.Hashtbl.M(Arrayjit.Tnode).t;
    (*

    The stream owning a given node. This map can only grow. Currently, if the memory mode of a node is inferred, only this stream will modify a cross-stream shared array. But memory modes can also be set manually.

    *)
  6. shared_writer_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) Base.list Base.Hashtbl.M(Arrayjit.Tnode).t;
    (*

    The streams that most recently have been scheduled to update (write to) a cross-stream-shared node, and the associated update completion event. The completed events are removed opportunistically.

    *)
  7. host_reading_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) Base.list Base.Hashtbl.M(Arrayjit.Tnode).t;
    (*

    The streams that most recently have been reading from a node's on-host array. The completed events are removed opportunistically.

    *)
  8. host_writing_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) Base.list Base.Hashtbl.M(Arrayjit.Tnode).t;
    (*

    The streams that most recently have been writing to a node's on-host array. The completed events are removed opportunistically.

    *)
  9. mutable streams : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Utils.weak_dynarray;
    (*

    All (live) streams created on the device. Used by With_buffer_retrieval_and_syncing.sync_device. Warning: stream_id fields of garbage collected streams can be reused!

    *)
}
val sexp_of_device : 'buffer_ptr 'dev 'runner 'event. ('buffer_ptr -> Sexplib0.Sexp.t) -> ('dev -> Sexplib0.Sexp.t) -> ('runner -> Sexplib0.Sexp.t) -> ('event -> Sexplib0.Sexp.t) -> ('buffer_ptr, 'dev, 'runner, 'event) device -> Sexplib0.Sexp.t
type ('buffer_ptr, 'dev, 'runner, 'event) stream = ('buffer_ptr, 'dev, 'runner, 'event) stream_ref = {
  1. device : ('buffer_ptr, 'dev, 'runner, 'event) device_ref;
  2. runner : 'runner;
  3. merge_buffer : 'buffer_ptr buffer Base.option Base.ref;
    (*

    Depending on backend implementations, either the currently used merge buffer, or the one most recently scheduled. Note that the pointer can be reused for nodes that fit in an already allocated buffer.

    *)
  4. stream_id : Base.int;
    (*

    An ID unique within the device for the lifetime of the stream.

    *)
  5. mutable allocated_buffer : 'buffer_ptr buffer Base.option;
  6. updating_for : 'event Base.Hashtbl.M(Arrayjit.Tnode).t;
  7. mutable updating_for_merge_buffer : (Tnode.t * 'event Base.option) Base.option;
    (*

    The tensor node that was most recently scheduled to be in the stream's merge buffer. The event finishes after the task from a Streaming_for task. See also updating_for.

    *)
  8. reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) Base.list Base.Hashtbl.M(Arrayjit.Tnode).t;
    (*

    The streams, other than this stream, that most recently have been reading from a node in this stream's context, and the associated use completion events. The completed events are removed opportunistically.

    *)
}
val sexp_of_stream : 'buffer_ptr 'dev 'runner 'event. ('buffer_ptr -> Sexplib0.Sexp.t) -> ('dev -> Sexplib0.Sexp.t) -> ('runner -> Sexplib0.Sexp.t) -> ('event -> Sexplib0.Sexp.t) -> ('buffer_ptr, 'dev, 'runner, 'event) stream -> Sexplib0.Sexp.t
val equal_stream : ('a, 'b, 'c, 'd) stream_ref -> ('e, 'f, 'g, 'h) stream_ref -> Base.bool
type ('buffer_ptr, 'stream) context = {
  1. stream : 'stream;
  2. parent : ('buffer_ptr, 'stream) context Base.option;
  3. ctx_arrays : 'buffer_ptr ctx_arrays;
    (*

    This map contains arrays used in this context or an ancestor context (they might be unique but might also be cross-stream shared.

    *)
  4. finalized : Utils.atomic_bool;
}
val sexp_of_context : 'buffer_ptr 'stream. ('buffer_ptr -> Sexplib0.Sexp.t) -> ('stream -> Sexplib0.Sexp.t) -> ('buffer_ptr, 'stream) context -> Sexplib0.Sexp.t
module type Device_types = sig ... end
module type Device = sig ... end
module type Backend_any_common = sig ... end

Parts shared by both assignments-level and lowered-level backend interfaces.

module type Backend_common = sig ... end

Parts shared by assignments-level backend interfaces.

module type Backend_device_common = sig ... end

Parts shared by both assignments-level and lowered-level backend interfaces providing streams and devices, both user-facing and implementation-facing. Does not include: compilation and linking (differnt for assignments-level and lowered-level); copying and tensor-node-level synchronization (copying is different for user-facing and implementation-facing APIs, synchronization is provided by a component outside of backend implementations).

module type With_buffer_retrieval_and_syncing = sig ... end
module type Backend = sig ... end