Module Ir.Assignments

The code for operating on n-dimensional arrays.

module Lazy = Utils.Lazy
module Tn = Tnode
module Nd = Ndarray
val _get_local_debug_runtime : unit -> (module Minidebug_runtime.Debug_runtime)
type init_data =
  1. | Reshape of Ndarray.t
  2. | Keep_shape_no_padding of Ndarray.t
  3. | Padded of {
    1. data : Nd.t;
    2. padding : Ops.axis_padding Base.array;
    3. padded_value : Base.float;
    }
val sexp_of_init_data : init_data -> Sexplib0.Sexp.t
val equal_init_data : init_data -> init_data -> Base.bool
type buffer =
  1. | Node of Tn.t
  2. | Merge_buffer of Tn.t
val sexp_of_buffer : buffer -> Sexplib0.Sexp.t
val equal_buffer : buffer -> buffer -> Base.bool
type fetch_op =
  1. | Constant of Base.float
  2. | Constant_bits of Base.int64
    (*

    Direct bit representation, primarily for uint4x32

    *)
  3. | Constant_fill of Base.float Base.array
    (*

    Fills in the numbers where the rightmost axis is contiguous. Primes shape inference to require the assigned tensor to have the same number of elements as the array, but in case of "leaky" shape inference, will loop over the values. This unrolls all assignments and should be used only for small arrays. Consider using Tnode.set_values instead for larger arrays.

    *)
  4. | Range_over_offsets
    (*

    Fills in the offset number of each cell, i.e. how many cells away it is from the beginning, in the logical representation of the tensor node. (The actual in-memory positions in a buffer instantiating the node can differ.)

    *)
  5. | Slice of {
    1. batch_idx : Indexing.static_symbol;
    2. sliced : Tn.t;
    }
  6. | Embed_symbol of Indexing.static_symbol
  7. | Embed_self_id
    (*

    Embeds the id of the array field of the Fetch constructor.

    *)
  8. | Embed_dim of Indexing.variable_ref

Resets a array by performing the specified computation or data fetching.

val sexp_of_fetch_op : fetch_op -> Sexplib0.Sexp.t
val equal_fetch_op : fetch_op -> fetch_op -> Base.bool
type accum_rhs =
  1. | Ternop of {
    1. op : Ops.ternop;
    2. rhs1 : buffer;
    3. rhs2 : buffer;
    4. rhs3 : buffer;
    }
  2. | Binop of {
    1. op : Ops.binop;
    2. rhs1 : buffer;
    3. rhs2 : buffer;
    }
  3. | Unop of {
    1. op : Ops.unop;
    2. rhs : buffer;
    }
val sexp_of_accum_rhs : accum_rhs -> Sexplib0.Sexp.t
val equal_accum_rhs : accum_rhs -> accum_rhs -> Base.bool
type t =
  1. | Noop
  2. | Seq of t * t
  3. | Block_comment of Base.string * t
    (*

    Same as the given code, with a comment.

    *)
  4. | Accum_op of {
    1. initialize_neutral : Base.bool;
    2. accum : Ops.binop;
    3. lhs : Tn.t;
    4. rhs : accum_rhs;
    5. projections : Indexing.projections Lazy.t;
    6. projections_debug : Base.string;
    }
  5. | Set_vec_unop of {
    1. op : Ops.vec_unop;
    2. lhs : Tn.t;
    3. rhs : buffer;
    4. projections : Indexing.projections Lazy.t;
    5. projections_debug : Base.string;
    }
  6. | Fetch of {
    1. array : Tn.t;
    2. fetch_op : fetch_op;
    3. dims : Base.int Base.array Lazy.t;
    }
val sexp_of_t : t -> Sexplib0.Sexp.t
type comp = {
  1. asgns : t;
  2. embedded_nodes : Base.Set.M(Tn).t;
    (*

    The nodes in asgns that are not in embedded_nodes need to already be in contexts linked with the comp.

    *)
}

Computations based on assignments. Note: the arrayjit library makes use of, but does not produce nor verify the embedded_nodes associated to some given asgns.

val sexp_of_comp : comp -> Sexplib0.Sexp.t
val to_comp : t -> comp
val empty_comp : comp
val is_total : initialize_neutral:Base.bool -> projections:Indexing.projections -> Base.bool
val can_skip_accumulation : projections:Indexing.projections -> bool
val context_nodes : use_host_memory:'a Base.option -> t -> Tn.t_set

Returns materialized nodes in the sense of Tnode.is_in_context_force. NOTE: it must be called after compilation; otherwise, it will disrupt memory mode inference.

val collect_nodes_guess_output : t -> Tn.t_set * Tn.t_set

In the second set, returns the nodes that are not read from after being written to. In the first set, returns the nodes that are ever read from.

val sequential : t Base.List.t -> t
val sequence : comp Base.List.t -> comp
val to_low_level : t -> Low_level.t
val flatten : t -> t Base__List.t
val is_noop : t -> bool
val get_ident_within_code : ?no_dots:bool -> t -> Tn.t -> Base.String.t
val to_doc : ?name:Base.String.t -> ?static_indices:Ir__Indexing.static_symbol Base.List.t -> unit -> t -> PPrint.document
val to_string : t -> string
val get_name_exn : t -> Base.String.t
val lower : Low_level.optimize_ctx -> unoptim_ll_source:(PPrint.document -> Base.unit) Base.option -> ll_source:(PPrint.document -> Base.unit) Base.option -> cd_source:(PPrint.document -> unit) option -> name:Base.String.t -> Indexing.static_symbol Base.List.t -> t -> Low_level.optimized