Module Ir.Ops

Operation types shared by all backends; and precision types.

module Lazy = Utils.Lazy

*** Precision ***

type uint8_elt = Stdlib.Bigarray.int8_unsigned_elt
type uint16_elt = Stdlib.Bigarray.int16_unsigned_elt
type int32_elt = Stdlib.Bigarray.int32_elt
type float16_elt = Stdlib.Bigarray.float16_elt
type float32_elt = Stdlib.Bigarray.float32_elt
type float64_elt = Stdlib.Bigarray.float64_elt
type int64_elt = Stdlib.Bigarray.int64_elt
type ('ocaml, 'impl) precision =
  1. | Byte : (Base.char, uint8_elt) precision
  2. | Uint16 : (Base.int, uint16_elt) precision
  3. | Int32 : (Base.int32, int32_elt) precision
  4. | Uint32 : (Base.int32, int32_elt) precision
    (*

    Using int32_elt representation but treating as unsigned

    *)
  5. | Int64 : (Base.int64, int64_elt) precision
  6. | Uint64 : (Base.int64, int64_elt) precision
    (*

    Using int64_elt representation but treating as unsigned

    *)
  7. | Uint4x32 : (Stdlib.Complex.t, Stdlib.Bigarray.complex64_elt) precision
    (*

    A 128-bit value that corresponds to e.g. CUDA's uint4 type. Luckily, the OCaml Bigarray library supports complex64_elt which is a 128-bit value, so we avoid dims conversions.

    *)
  8. | Half : (Base.float, float16_elt) precision
  9. | Bfloat16 : (Base.int, uint16_elt) precision
    (*

    Using uint16 representation for now

    *)
  10. | Fp8 : (Base.char, uint8_elt) precision
    (*

    Using uint8 representation for now

    *)
  11. | Single : (Base.float, float32_elt) precision
  12. | Double : (Base.float, float64_elt) precision
val sexp_of_precision : 'ocaml 'impl. ('ocaml -> Sexplib0.Sexp.t) -> ('impl -> Sexplib0.Sexp.t) -> ('ocaml, 'impl) precision -> Sexplib0.Sexp.t
type prec =
  1. | Void_prec
  2. | Byte_prec of (Base.char, uint8_elt) precision
  3. | Uint16_prec of (Base.int, uint16_elt) precision
  4. | Int32_prec of (Base.int32, int32_elt) precision
  5. | Uint32_prec of (Base.int32, int32_elt) precision
  6. | Int64_prec of (Base.int64, int64_elt) precision
  7. | Uint64_prec of (Base.int64, int64_elt) precision
  8. | Uint4x32_prec of (Stdlib.Complex.t, Stdlib.Bigarray.complex64_elt) precision
  9. | Half_prec of (Base.float, float16_elt) precision
  10. | Bfloat16_prec of (Base.int, uint16_elt) precision
  11. | Fp8_prec of (Base.char, uint8_elt) precision
  12. | Single_prec of (Base.float, float32_elt) precision
  13. | Double_prec of (Base.float, float64_elt) precision
val byte : prec
val uint16 : prec
val int32 : prec
val uint32 : prec
val int64 : prec
val uint64 : prec
val uint4x32 : prec
val half : prec
val bfloat16 : prec
val fp8 : prec
val single : prec
val double : prec
val index_prec : unit -> prec

Returns the precision to use for indexing arithmetic based on the big_models setting.

val is_up_to_fp16 : prec -> bool
val exceeds_fp16_cutoff : Base.Float.t -> bool
val sexp_of_prec : prec -> Base.Sexp.t
val prec_of_sexp : Base.Sexp.t -> prec
val precision_to_string : ('ocaml, 'elt_t) precision -> string
val prec_string : prec -> string
val prec_of_string : Base.String.t -> prec
val equal_prec : prec -> prec -> bool
val compare_prec : prec -> prec -> int
val prec_in_bytes : prec -> int
val is_float : prec -> bool
val promote_prec : prec -> prec -> prec

Prefer precision which is more likely to remain functional in the resulting computations. uint4x32 always dominates, because operations that work on uint4x32 do not support other precisions. Otherwise, fractional number precisions dominate; within them, larger dynamic range precisions dominate.

val pack_prec : ('ocaml, 'elt_t) precision -> prec
type 'r apply_prec = {
  1. f : 'ocaml 'elt_t. ('ocaml, 'elt_t) precision -> 'r;
}
val apply_prec : ?default:'a -> 'a apply_prec -> prec -> 'a
val c_typ_of_prec : prec -> string
val c_vec_typ_of_prec : length:int -> prec -> string
val hum_typ_of_prec : prec -> string

*** Operations ***

See: tinygrad ops, CUDA Math API (intrinsics).

This is a redundant set of operations, aiming to expose hardware-supported "intrinsics", to reduce the need for backends to pattern-match and optimize. Also for convenience.

type binop =
  1. | Arg1
  2. | Arg2
  3. | Add
  4. | Sub
  5. | Mul
  6. | Div
  7. | ToPowOf
  8. | Relu_gate
  9. | Satur01_gate
  10. | Max
  11. | Min
  12. | Mod
  13. | Cmplt
  14. | Cmpeq
  15. | Cmpne
  16. | Or
  17. | And
  18. | Threefry4x32_crypto
    (*

    4x32-bit Threefry PRNG, 20-round cryptographic version. Requires a 128-bit key and a 128-bit counter and outputs a 128-bit value (precision Uint4x32).

    *)
  19. | Threefry4x32_light
    (*

    4x32-bit Threefry PRNG, 2-round light version (as in JAX/XLA). Requires a 128-bit key and a 128-bit counter and outputs a 128-bit value (precision Uint4x32).

    *)
val binop_of_sexp : Sexplib0.Sexp.t -> binop
val sexp_of_binop : binop -> Sexplib0.Sexp.t
val compare_binop : binop -> binop -> Base.int
val equal_binop : binop -> binop -> Base.bool
type unop =
  1. | Identity
  2. | Relu
  3. | Satur01
    (*

    Saturate (truncate) to within the interval [0; 1].

    *)
  4. | Exp
  5. | Log
  6. | Exp2
  7. | Log2
  8. | Sin
  9. | Cos
  10. | Sqrt
  11. | Recip
  12. | Recip_sqrt
  13. | Neg
  14. | Tanh_approx
  15. | Not
    (*

    0. -> 1. | _ -> 0.

    *)
  16. | Uint4x32_to_prec_uniform1
    (*

    Non-vectorized variant of Uint4x32_to_prec_uniform that converts the given Uint4x32 to a single value of the output precision. Less bit-efficient but operates poitwise. For random bits, the result is uniform over the range of the precision for integer precisions, and over the range [0.0, 1.0) for floating point precisions.

    *)
val unop_of_sexp : Sexplib0.Sexp.t -> unop
val sexp_of_unop : unop -> Sexplib0.Sexp.t
val compare_unop : unop -> unop -> Base.int
val equal_unop : unop -> unop -> Base.bool
type vec_unop =
  1. | Uint4x32_to_prec_uniform
    (*

    Converts the given Uint4x32 to the precision of the output in a bit-efficient manner. For random bits, the result is uniform over the range of the precision for integer precisions, and over the range [0.0, 1.0) for floating point precisions. When used in an access pattern, the indices are converted to a byte offset depending on the given precision. NOTE: this operation, unlike any others, impacts projections and shape inference (one input cell corresponds to a few output cells).

    *)
val vec_unop_of_sexp : Sexplib0.Sexp.t -> vec_unop
val sexp_of_vec_unop : vec_unop -> Sexplib0.Sexp.t
val compare_vec_unop : vec_unop -> vec_unop -> Base.int
val equal_vec_unop : vec_unop -> vec_unop -> Base.bool
type ternop =
  1. | Where
    (*

    Where(a,b,c): if a then b else c

    *)
  2. | FMA
    (*

    FMA(a,b,c): (a * b) + c, non-accumulating

    *)
val ternop_of_sexp : Sexplib0.Sexp.t -> ternop
val sexp_of_ternop : ternop -> Sexplib0.Sexp.t
val compare_ternop : ternop -> ternop -> Base.int
val equal_ternop : ternop -> ternop -> Base.bool
type op =
  1. | Ternop of ternop
  2. | Binop of binop
  3. | Unop of unop
val op_of_sexp : Sexplib0.Sexp.t -> op
val sexp_of_op : op -> Sexplib0.Sexp.t
val compare_op : op -> op -> Base.int
val equal_op : op -> op -> Base.bool
val neutral_elem : binop -> Base.Float.t

Either the left-neutral or right-neutral element of the operation. Unspecified if the operation does not have a neutral element.

val interpret_binop : binop -> Base.Float.t -> Base.Float.t -> Base.Float.t
val interpret_unop : unop -> Base.Float.t -> Base.Float.t
val interpret_ternop : ternop -> Base.Float.t -> Base.Float.t -> Base.Float.t -> Base.Float.t
val is_binop_infix : 'a -> bool

Note: currently the %cd syntax only supports infix binops as assignment ops.

val is_binop_nice_infix : binop -> bool
val binop_cd_syntax : binop -> string
val binop_cd_fallback_syntax : binop -> string

In the %cd syntax, we support uncurried notation for binary ops in addition to the infix notation.

val binop_c_syntax : prec -> binop -> string * string * string
val is_assign_op : binop -> bool
val assign_op_cd_syntax : initialize_neutral:bool -> binop -> string
val unop_cd_syntax : unop -> string

Note: currently we do not support unary prefix symbols.

val vec_unop_cd_syntax : vec_unop -> string
val unop_c_syntax : prec -> unop -> Base.String.t * string
val vec_unop_c_syntax : prec -> vec_unop -> Base.String.t * string
val ternop_cd_syntax : ternop -> string

In the %cd syntax, we use uncurried notation for ternary ops.

val ternop_c_syntax : prec -> ternop -> string * string * string * string
val c_convert_precision : from:prec -> to_:prec -> Base.String.t * string

*** Pointer representation ***

type voidptr = Base.unit Ctypes.ptr
val sexp_of_voidptr : unit Ctypes_static.ptr -> Base.Sexp.t
val compare_voidptr : 'a Ctypes.ptr -> 'a Ctypes.ptr -> int
val equal_voidptr : voidptr -> voidptr -> Base.bool
val c_rawptr_to_string : Base.nativeint -> prec -> Base.String.t
val rawptr_to_string_hum : Base.nativeint -> prec -> Base.String.t
val c_ptr_to_string : 'elem Ctypes.ptr -> prec -> Base.String.t
val ptr_to_string_hum : 'elem Ctypes.ptr -> prec -> Base.String.t

*** External FFI declarations ***

type axis_padding = {
  1. left : Base.int;
  2. right : Base.int;
}
val axis_padding_of_sexp : Sexplib0.Sexp.t -> axis_padding
val sexp_of_axis_padding : axis_padding -> Sexplib0.Sexp.t
val equal_axis_padding : axis_padding -> axis_padding -> Base.bool
val bfloat16_to_single : Base.int -> Base.float

Original conversion functions

val single_to_bfloat16 : Base.float -> Base.int
val half_to_single : Base.int -> Base.float
val single_to_half : Base.float -> Base.int
val fp8_to_single : Base.int -> Base.float
val single_to_fp8 : Base.float -> Base.int
val copy_with_padding_c : ('a, 'b, Stdlib.Bigarray.c_layout) Stdlib.Bigarray.Genarray.t -> ('a, 'b, Stdlib.Bigarray.c_layout) Stdlib.Bigarray.Genarray.t -> axis_padding Base.array -> Base.unit
val threefry4x32_crypto : Base.int Base.array -> Base.int Base.array -> Base.int Base.array

Threefry4x32 PRNG - 20 round cryptographic version

val threefry4x32_light : Base.int Base.array -> Base.int Base.array -> Base.int Base.array

Threefry4x32 PRNG - 2 round light version

val threefry4x32 : Base.int Base.array -> Base.int Base.array -> Base.int Base.array

Threefry4x32 PRNG - default version

val uint4x32_to_single_uniform : Base.int Base.array -> Base.float

Conversion from uint4x32 to various uniform distributions

val uint4x32_to_double_uniform : Base.int Base.array -> Base.float
val uint4x32_to_int32_uniform : Base.int Base.array -> Base.int
val uint4x32_to_int64_uniform : Base.int Base.array -> Base.int64
val uint4x32_to_uint32_uniform : Base.int Base.array -> Base.int
val uint4x32_to_uint64_uniform : Base.int Base.array -> Base.int64
val uint4x32_to_byte_uniform : Base.int Base.array -> Base.int
val uint4x32_to_uint16_uniform : Base.int Base.array -> Base.int
val uint4x32_to_bfloat16_uniform : Base.int Base.array -> Base.int
val uint4x32_to_half_uniform : Base.int Base.array -> Base.int
val uint4x32_to_fp8_uniform : Base.int Base.array -> Base.int
val single_to_uint4x32 : Base.float -> Base.int Base.array

Conversion to uint4x32 from various types

val double_to_uint4x32 : Base.float -> Base.int Base.array
val int32_to_uint4x32 : Base.int -> Base.int Base.array
val int64_to_uint4x32 : Base.int64 -> Base.int Base.array
val uint32_to_uint4x32 : Base.int -> Base.int Base.array
val uint64_to_uint4x32 : Base.int64 -> Base.int Base.array
val byte_to_uint4x32 : Base.int -> Base.int Base.array
val uint16_to_uint4x32 : Base.int -> Base.int Base.array
val bfloat16_to_uint4x32 : Base.int -> Base.int Base.array
val half_to_uint4x32 : Base.int -> Base.int Base.array
val fp8_to_uint4x32 : Base.int -> Base.int Base.array

*** Precision homogeneity classification ***

val is_homogeneous_prec_unop : unop -> bool

Returns true if the unary operation is homogeneous in precision, meaning its argument should be converted to the result precision.

val is_homogeneous_prec_vec_unop : vec_unop -> bool

Returns true if the vec_unop operation is homogeneous in precision, meaning its argument should be converted to the result precision.

val is_homogeneous_prec_binop : 'a -> bool

Returns true if the binary operation is homogeneous in precision, meaning its arguments should be converted to the result precision.

val is_homogeneous_prec_ternop : ternop -> bool

Returns true if the ternary operation is homogeneous in precision, meaning its arguments should be converted to the result precision.