Module Arrayjit.Ops

Operation types shared by all backends; and precision types.

module Lazy = Utils.Lazy

*** Precision ***

type uint8_elt = Stdlib.Bigarray.int8_unsigned_elt
type float16_elt = Stdlib.Bigarray.float16_elt
type float32_elt = Stdlib.Bigarray.float32_elt
type float64_elt = Stdlib.Bigarray.float64_elt
type ('ocaml, 'impl) precision =
  1. | Byte : (Base.char, uint8_elt) precision
  2. | Half : (Base.float, float16_elt) precision
  3. | Single : (Base.float, float32_elt) precision
  4. | Double : (Base.float, float64_elt) precision
val sexp_of_precision : 'ocaml 'impl. ('ocaml -> Sexplib0.Sexp.t) -> ('impl -> Sexplib0.Sexp.t) -> ('ocaml, 'impl) precision -> Sexplib0.Sexp.t
type prec =
  1. | Void_prec
  2. | Byte_prec of (Base.char, uint8_elt) precision
  3. | Half_prec of (Base.float, float16_elt) precision
  4. | Single_prec of (Base.float, float32_elt) precision
  5. | Double_prec of (Base.float, float64_elt) precision
val byte : prec
val half : prec
val single : prec
val double : prec
val is_up_to_fp16 : prec -> bool
val sexp_of_prec : prec -> Base.Sexp.t
val prec_of_sexp : Base.Sexp.t -> prec
val precision_to_string : ('ocaml, 'elt_t) precision -> string
val prec_string : prec -> string
val equal_prec : prec -> prec -> bool
val prec_in_bytes : prec -> int
val promote_prec : prec -> prec -> prec
val pack_prec : ('ocaml, 'elt_t) precision -> prec
type 'r map_prec = {
  1. f : 'ocaml 'elt_t. ('ocaml, 'elt_t) precision -> 'r;
}
val map_prec : ?default:'a -> 'a map_prec -> prec -> 'a
val c_typ_of_prec : prec -> string
val hum_typ_of_prec : prec -> string

*** Operations ***

See: tinygrad ops, CUDA Math API (intrinsics).

This is a redundant set of operations, aiming to expose hardware-supported "intrinsics", to reduce the need for backends to pattern-match and optimize. Also for convenience.

type init_op =
  1. | Constant_fill of {
    1. values : Base.float Base.array;
    2. strict : Base.bool;
    }
    (*

    Fills in the numbers where the rightmost axis is contiguous. If strict=true, loops over the provided values.

    *)
  2. | Range_over_offsets
    (*

    Fills in the offset number of each cell (i.e. how many cells away it is from the beginning).

    *)
  3. | Standard_uniform
    (*

    Draws the values from U(0,1).

    *)
  4. | File_mapped of Base.string * prec
    (*

    Reads the data using Unix.openfile and Unix.map_file.

    *)

Initializes or resets a array by filling in the corresponding numbers, at the appropriate precision.

val equal_init_op : init_op -> init_op -> Base.bool
val init_op_of_sexp : Sexplib0.Sexp.t -> init_op
val sexp_of_init_op : init_op -> Sexplib0.Sexp.t
type binop =
  1. | Arg1
  2. | Arg2
  3. | Add
  4. | Sub
  5. | Mul
  6. | Div
  7. | ToPowOf
  8. | Relu_gate
  9. | Satur01_gate
  10. | Max
  11. | Min
  12. | Mod
  13. | Cmplt
  14. | Cmpeq
  15. | Cmpne
  16. | Or
  17. | And
val binop_of_sexp : Sexplib0.Sexp.t -> binop
val sexp_of_binop : binop -> Sexplib0.Sexp.t
val compare_binop : binop -> binop -> Base.int
val equal_binop : binop -> binop -> Base.bool
type unop =
  1. | Identity
  2. | Relu
  3. | Satur01
    (*

    Saturate (truncate) to within the interval [0; 1].

    *)
  4. | Exp
  5. | Log
  6. | Exp2
  7. | Log2
  8. | Sin
  9. | Cos
  10. | Sqrt
  11. | Recip
  12. | Recip_sqrt
  13. | Neg
  14. | Tanh_approx
  15. | Not
    (*

    0. -> 1. | _ -> 0.

    *)
val unop_of_sexp : Sexplib0.Sexp.t -> unop
val sexp_of_unop : unop -> Sexplib0.Sexp.t
val compare_unop : unop -> unop -> Base.int
val equal_unop : unop -> unop -> Base.bool
type ternop =
  1. | Where
    (*

    Where(a,b,c): if a then b else c

    *)
  2. | FMA
    (*

    FMA(a,b,c): (a * b) + c, non-accumulating

    *)
val ternop_of_sexp : Sexplib0.Sexp.t -> ternop
val sexp_of_ternop : ternop -> Sexplib0.Sexp.t
val compare_ternop : ternop -> ternop -> Base.int
val equal_ternop : ternop -> ternop -> Base.bool
type op =
  1. | Ternop of ternop
  2. | Binop of binop
  3. | Unop of unop
val op_of_sexp : Sexplib0.Sexp.t -> op
val sexp_of_op : op -> Sexplib0.Sexp.t
val compare_op : op -> op -> Base.int
val equal_op : op -> op -> Base.bool
val neutral_elem : binop -> Base.Float.t

Either the left-neutral or right-neutral element of the operation. Unspecified if the operation does not have a neutral element.

val interpret_binop : binop -> Base.Float.t -> Base.Float.t -> Base.Float.t
val interpret_unop : unop -> Base.Float.t -> Base.Float.t
val interpret_ternop : ternop -> Base.Float.t -> Base.Float.t -> Base.Float.t -> Base.Float.t
val is_binop_infix : 'a -> bool

Note: currently the %cd syntax only supports infix binops as assignment ops.

val is_binop_nice_infix : binop -> bool
val binop_cd_syntax : binop -> string
val binop_cd_fallback_syntax : binop -> string

In the %cd syntax, we support uncurried notation for binary ops in addition to the infix notation.

val binop_c_syntax : prec -> binop -> string * string * string
val is_assign_op : binop -> bool
val assign_op_cd_syntax : initialize_neutral:bool -> binop -> string
val unop_cd_syntax : unop -> string

Note: currently we do not support unary prefix symbols.

val unop_c_syntax : prec -> unop -> Base.String.t * string
val ternop_cd_syntax : ternop -> string

In the %cd syntax, we use uncurried notation for ternary ops.

val ternop_c_syntax : prec -> ternop -> string * string * string * string
val c_convert_precision : from:prec -> to_:prec -> Base.String.t * string

*** Global references ***

type voidptr = Base.unit Ctypes.ptr
val sexp_of_voidptr : unit Ctypes_static.ptr -> Base.Sexp.t
val compare_voidptr : 'a Ctypes.ptr -> 'a Ctypes.ptr -> int
val equal_voidptr : voidptr -> voidptr -> Base.bool
val c_rawptr_to_string : Base.nativeint -> prec -> Base.String.t
val rawptr_to_string_hum : Base.nativeint -> prec -> Base.String.t
val c_ptr_to_string : 'elem Ctypes.ptr -> prec -> Base.String.t
val ptr_to_string_hum : 'elem Ctypes.ptr -> prec -> Base.String.t
type global_identifier =
  1. | C_function of Base.string
    (*

    Calls a no-argument or indices-arguments C function.

    *)
  2. | External_unsafe of {
    1. ptr : voidptr;
    2. prec : prec;
    3. dims : Base.int Base.array Lazy.t;
    }
  3. | Merge_buffer of {
    1. source_node_id : Base.int;
    }
    (*

    Each device has at most one merge buffer, which is re-used, and re-allocated as needed, by merge operations. The merge buffer is associated with the source node of the device's most recent device_to_device ~into_merge_buffer:true operation.

    *)
val sexp_of_global_identifier : global_identifier -> Sexplib0.Sexp.t
val equal_global_identifier : global_identifier -> global_identifier -> Base.bool
val compare_global_identifier : global_identifier -> global_identifier -> Base.int