Module Ir.Low_level

A for-loop-based array language and backend-agnostic optimization

Global references

module Scope_id : sig ... end
type scope_id = Scope_id.t = {
  1. tn : Tnode.t;
  2. scope_id : Base.int;
}
val sexp_of_scope_id : scope_id -> Sexplib0.Sexp.t
val equal_scope_id : scope_id -> scope_id -> Base.bool
val hash_fold_scope_id : Ppx_hash_lib.Std.Hash.state -> scope_id -> Ppx_hash_lib.Std.Hash.state
val hash_scope_id : scope_id -> Ppx_hash_lib.Std.Hash.hash_value
val compare_scope_id : scope_id -> scope_id -> Base.int

Low-level representation

type t =
  1. | Noop
  2. | Comment of Base.string
  3. | Staged_compilation of Base.unit -> PPrint.document
  4. | Seq of t * t
  5. | For_loop of {
    1. index : Indexing.symbol;
    2. from_ : Base.int;
    3. to_ : Base.int;
    4. body : t;
    5. trace_it : Base.bool;
    }
  6. | Zero_out of Tnode.t
  7. | Set of {
    1. tn : Tnode.t;
    2. idcs : Indexing.axis_index Base.array;
    3. llsc : scalar_t;
    4. mutable debug : Base.string;
    }
  8. | Set_from_vec of {
    1. tn : Tnode.t;
    2. idcs : Indexing.axis_index Base.array;
    3. length : Base.int;
    4. vec_unop : Ops.vec_unop;
    5. arg : scalar_arg;
    6. mutable debug : Base.string;
    }
  9. | Set_local of scope_id * scalar_t

Cases: t -- code, scalar_t -- single number at some precision.

and scalar_t =
  1. | Local_scope of {
    1. id : scope_id;
    2. body : t;
    3. orig_indices : Indexing.axis_index Base.array;
    }
  2. | Get_local of scope_id
  3. | Get of Tnode.t * Indexing.axis_index Base.array
  4. | Get_merge_buffer of Tnode.t * Indexing.axis_index Base.array
  5. | Ternop of Ops.ternop * scalar_arg * scalar_arg * scalar_arg
  6. | Binop of Ops.binop * scalar_arg * scalar_arg
  7. | Unop of Ops.unop * scalar_arg
  8. | Constant of Base.float
  9. | Constant_bits of Base.int64
    (*

    Direct bit representation, primarily for uint4x32

    *)
  10. | Embed_index of Indexing.axis_index
and scalar_arg = scalar_t * Ops.prec

The argument precision is preserved in heterogeneous precision operation arguments, and is ignored (overridden) in homogeneous precision operations.

val sexp_of_t : t -> Sexplib0.Sexp.t
val sexp_of_scalar_t : scalar_t -> Sexplib0.Sexp.t
val sexp_of_scalar_arg : scalar_arg -> Sexplib0.Sexp.t
val equal : t -> t -> Base.bool
val equal_scalar_t : scalar_t -> scalar_t -> Base.bool
val equal_scalar_arg : scalar_arg -> scalar_arg -> Base.bool
val compare : t -> t -> Base.int
val compare_scalar_t : scalar_t -> scalar_t -> Base.int
val compare_scalar_arg : scalar_arg -> scalar_arg -> Base.int
val scalar_precision : scalar_t -> Ops.prec
val apply_op : Ops.op -> scalar_t Base.array -> scalar_t
val flat_lines : t Base.list -> t Base.list
val unflat_lines : t Base.list -> t
val loop_over_dims : Base.int Base.array -> body:(Indexing.axis_index Base.array -> t) -> t
val unroll_dims : Base.int Base.array -> body:(Indexing.axis_index Base.array -> offset:Base.int -> t) -> t

Optimization

type virtualize_settings = {
  1. mutable enable_device_only : Base.bool;
  2. mutable max_visits : Base.int;
  3. mutable max_tracing_dim : Base.int;
  4. mutable inline_scalar_constexprs : Base.bool;
  5. mutable inline_simple_computations : Base.bool;
  6. mutable inline_complex_computations : Base.bool;
}
val virtualize_settings : virtualize_settings
type visits =
  1. | Visits of Base.int
  2. | Recurrent
    (*

    A Recurrent visit is when there is an access prior to any assignment in an update.

    *)
val sexp_of_visits : visits -> Sexplib0.Sexp.t
val visits_of_sexp : Sexplib0.Sexp.t -> visits
val equal_visits : visits -> visits -> Base.bool
val visits : Base.int -> visits
val recurrent : visits
val is_visits : visits -> Base.bool
val is_recurrent : visits -> Base.bool
val visits_val : visits -> Base.int Base.option
val recurrent_val : visits -> Base.unit Base.option
module Variants_of_visits : sig ... end
type traced_array = {
  1. tn : Tnode.t;
  2. assignments : Base.int Base.array Base.Hash_set.t;
  3. accesses : (Base.int Base.array, visits) Base.Hashtbl.t;
  4. mutable zero_initialized_by_code : Base.bool;
  5. mutable zeroed_out : Base.bool;
  6. mutable read_before_write : Base.bool;
    (*

    The node is read before it is written (i.e. it is recurrent).

    *)
  7. mutable read_only : Base.bool;
    (*

    Surprisingly, the notions of read-only and of constant memory mode come apart: small hosted constants are not read-only because they are initialized on devices by being assigned to; and a volatile memory mode is read-only from the devices' perspective.

    *)
  8. mutable is_scalar_constexpr : Base.bool;
    (*

    True only if the tensor node has all axes of dimension 1, is either zeroed-out or assigned before accessed, is assigned at most once, and from an expression involving only constants or tensor nodes that were at the time is_scalar_constexpr.

    *)
  9. mutable is_accessing : Base.bool;
    (*

    False only if the tensor node is built from index embeddings and scalar constant expressions.

    *)
  10. mutable is_complex : Base.bool;
    (*

    True only if the tensor node is built acciessing computations that are not a single getter.

    *)
}
val sexp_of_traced_array : traced_array -> Sexplib0.Sexp.t
val get_node : (Tnode.t, traced_array) Base.Hashtbl.t -> Tnode.t -> traced_array
val optimize_integer_pow : Base.bool Base.ref
type traced_store = (Tnode.t, traced_array) Base.Hashtbl.t
val sexp_of_traced_store : traced_store -> Sexplib0.Sexp.t
type optimize_ctx = {
  1. computations : (Tnode.t, (Indexing.axis_index Base.array Base.option * t) Base.list) Base.Hashtbl.t;
    (*

    The computations (of the tensor node) are retrieved for optimization just as they are populated, so that the inlined code corresponds precisely to the changes to the arrays that would happen up till that point. Within the code blocks paired with an index tuple, all assignments and accesses must happen via the index tuple; if this is not the case for some assignment, the node cannot be virtual. Currently, we only allow for-loop symbols in assignment indices of virtual nodes.

    *)
}
val sexp_of_optimize_ctx : optimize_ctx -> Sexplib0.Sexp.t
type optimized = {
  1. traced_store : traced_store;
  2. optimize_ctx : optimize_ctx;
  3. llc : t;
  4. merge_node : Tnode.t Base.option;
}
val sexp_of_optimized : optimized -> Sexplib0.Sexp.t
val optimize : optimize_ctx -> unoptim_ll_source:(PPrint.document -> Base.unit) Base.option -> ll_source:(PPrint.document -> Base.unit) Base.option -> name:Base.string -> Indexing.static_symbol Base.list -> t -> optimized
val input_and_output_nodes : optimized -> (Base.Set.M(Ir.Tnode).t * Base.Set.M(Ir.Tnode).t) * Tnode.t Base.option

Inputs are the materialized read-only and read-before-write (within the code) non-constant non-merge nodes. They are inputs in a broad sense, as they could be recurrent nodes or parameters. Outputs are all the materialized nodes written-to by the code. The last returned component is the input merge node, if used in the code.

Printing

val code_hum_margin : Base.int Base.ref
val function_header_doc : ?name:Base.string -> ?static_indices:Indexing.static_symbol Base.list -> Base.unit -> PPrint.document
val get_ident_within_code : ?no_dots:Base.bool -> ?blacklist:Base.string Base.list -> t Base.array -> Tnode.t -> Base.string
val to_doc_cstyle : ?name:Base.string -> ?static_indices:Indexing.static_symbol Base.list -> Base.unit -> t -> PPrint.document

Adheres more to the C syntax, outputs implicit type casts.

val to_doc : ?name:Base.string -> ?static_indices:Indexing.static_symbol Base.list -> Base.unit -> t -> PPrint.document

Adheres to the %cd syntax.