Ir.Ops
Operation types shared by all backends; and precision types.
module Lazy = Utils.Lazy
type ('ocaml, 'impl) precision =
| Byte : (Base.char, uint8_elt) precision
| Uint16 : (Base.int, uint16_elt) precision
| Int32 : (Base.int32, int32_elt) precision
| Uint32 : (Base.int32, int32_elt) precision
Using int32_elt representation but treating as unsigned
*)| Int64 : (Base.int64, int64_elt) precision
| Uint64 : (Base.int64, int64_elt) precision
Using int64_elt representation but treating as unsigned
*)| Uint4x32 : (Stdlib.Complex.t, Stdlib.Bigarray.complex64_elt) precision
A 128-bit value that corresponds to e.g. CUDA's uint4 type. Luckily, the OCaml Bigarray library supports complex64_elt which is a 128-bit value, so we avoid dims conversions.
*)| Half : (Base.float, float16_elt) precision
| Bfloat16 : (Base.int, uint16_elt) precision
Using uint16 representation for now
*)| Fp8 : (Base.char, uint8_elt) precision
Using uint8 representation for now
*)| Single : (Base.float, float32_elt) precision
| Double : (Base.float, float64_elt) precision
val sexp_of_precision :
'ocaml 'impl. ('ocaml -> Sexplib0.Sexp.t) ->
('impl -> Sexplib0.Sexp.t) ->
('ocaml, 'impl) precision ->
Sexplib0.Sexp.t
type prec =
| Void_prec
| Byte_prec of (Base.char, uint8_elt) precision
| Uint16_prec of (Base.int, uint16_elt) precision
| Int32_prec of (Base.int32, int32_elt) precision
| Uint32_prec of (Base.int32, int32_elt) precision
| Int64_prec of (Base.int64, int64_elt) precision
| Uint64_prec of (Base.int64, int64_elt) precision
| Uint4x32_prec of (Stdlib.Complex.t, Stdlib.Bigarray.complex64_elt) precision
| Half_prec of (Base.float, float16_elt) precision
| Bfloat16_prec of (Base.int, uint16_elt) precision
| Fp8_prec of (Base.char, uint8_elt) precision
| Single_prec of (Base.float, float32_elt) precision
| Double_prec of (Base.float, float64_elt) precision
val byte : prec
val uint16 : prec
val int32 : prec
val uint32 : prec
val int64 : prec
val uint64 : prec
val uint4x32 : prec
val half : prec
val bfloat16 : prec
val fp8 : prec
val single : prec
val double : prec
val index_prec : unit -> prec
Returns the precision to use for indexing arithmetic based on the big_models setting.
val is_up_to_fp16 : prec -> bool
val sexp_of_prec : prec -> Base.Sexp.t
val prec_of_sexp : Base.Sexp.t -> prec
val precision_to_string : ('ocaml, 'elt_t) precision -> string
val prec_string : prec -> string
val prec_of_string : Base.String.t -> prec
val prec_in_bytes : prec -> int
val is_float : prec -> bool
Prefer precision which is more likely to remain functional in the resulting computations. uint4x32 always dominates, because operations that work on uint4x32 do not support other precisions. Otherwise, fractional number precisions dominate; within them, larger dynamic range precisions dominate.
val apply_prec : ?default:'a -> 'a apply_prec -> prec -> 'a
val c_typ_of_prec : prec -> string
val c_vec_typ_of_prec : length:int -> prec -> string
val hum_typ_of_prec : prec -> string
See: tinygrad ops, CUDA Math API (intrinsics).
This is a redundant set of operations, aiming to expose hardware-supported "intrinsics", to reduce the need for backends to pattern-match and optimize. Also for convenience.
type binop =
| Arg1
| Arg2
| Add
| Sub
| Mul
| Div
| ToPowOf
| Relu_gate
| Satur01_gate
| Max
| Min
| Mod
| Cmplt
| Cmpeq
| Cmpne
| Or
| And
| Threefry4x32_crypto
4x32-bit Threefry PRNG, 20-round cryptographic version. Requires a 128-bit key and a 128-bit counter and outputs a 128-bit value (precision Uint4x32
).
| Threefry4x32_light
4x32-bit Threefry PRNG, 2-round light version (as in JAX/XLA). Requires a 128-bit key and a 128-bit counter and outputs a 128-bit value (precision Uint4x32
).
val binop_of_sexp : Sexplib0.Sexp.t -> binop
val sexp_of_binop : binop -> Sexplib0.Sexp.t
type unop =
| Identity
| Relu
| Satur01
Saturate (truncate) to within the interval [0; 1]
.
| Exp
| Log
| Exp2
| Log2
| Sin
| Cos
| Sqrt
| Recip
| Recip_sqrt
| Neg
| Tanh_approx
| Not
0. -> 1. | _ -> 0.
*)| Uint4x32_to_prec_uniform1
Non-vectorized variant of Uint4x32_to_prec_uniform
that converts the given Uint4x32 to a single value of the output precision. Less bit-efficient but operates poitwise. For random bits, the result is uniform over the range of the precision for integer precisions, and over the range [0.0, 1.0) for floating point precisions.
val unop_of_sexp : Sexplib0.Sexp.t -> unop
val sexp_of_unop : unop -> Sexplib0.Sexp.t
type vec_unop =
| Uint4x32_to_prec_uniform
Converts the given Uint4x32 to the precision of the output in a bit-efficient manner. For random bits, the result is uniform over the range of the precision for integer precisions, and over the range [0.0, 1.0) for floating point precisions. When used in an access pattern, the indices are converted to a byte offset depending on the given precision. NOTE: this operation, unlike any others, impacts projections and shape inference (one input cell corresponds to a few output cells).
*)val vec_unop_of_sexp : Sexplib0.Sexp.t -> vec_unop
val sexp_of_vec_unop : vec_unop -> Sexplib0.Sexp.t
val ternop_of_sexp : Sexplib0.Sexp.t -> ternop
val sexp_of_ternop : ternop -> Sexplib0.Sexp.t
val op_of_sexp : Sexplib0.Sexp.t -> op
val sexp_of_op : op -> Sexplib0.Sexp.t
val neutral_elem : binop -> Base.Float.t
Either the left-neutral or right-neutral element of the operation. Unspecified if the operation does not have a neutral element.
val interpret_binop : binop -> Base.Float.t -> Base.Float.t -> Base.Float.t
val interpret_unop : unop -> Base.Float.t -> Base.Float.t
val interpret_ternop :
ternop ->
Base.Float.t ->
Base.Float.t ->
Base.Float.t ->
Base.Float.t
Note: currently the %cd syntax only supports infix binops as assignment ops.
val is_binop_nice_infix : binop -> bool
val binop_cd_syntax : binop -> string
val binop_cd_fallback_syntax : binop -> string
In the %cd syntax, we support uncurried notation for binary ops in addition to the infix notation.
val is_assign_op : binop -> bool
val assign_op_cd_syntax : initialize_neutral:bool -> binop -> string
val unop_cd_syntax : unop -> string
Note: currently we do not support unary prefix symbols.
val vec_unop_cd_syntax : vec_unop -> string
val ternop_cd_syntax : ternop -> string
In the %cd syntax, we use uncurried notation for ternary ops.
val c_rawptr_to_string : Base.nativeint -> prec -> Base.String.t
val rawptr_to_string_hum : Base.nativeint -> prec -> Base.String.t
val c_ptr_to_string : 'elem Ctypes.ptr -> prec -> Base.String.t
val ptr_to_string_hum : 'elem Ctypes.ptr -> prec -> Base.String.t
val axis_padding_of_sexp : Sexplib0.Sexp.t -> axis_padding
val sexp_of_axis_padding : axis_padding -> Sexplib0.Sexp.t
val equal_axis_padding : axis_padding -> axis_padding -> Base.bool
val copy_with_padding_c :
('a, 'b, Stdlib.Bigarray.c_layout) Stdlib.Bigarray.Genarray.t ->
('a, 'b, Stdlib.Bigarray.c_layout) Stdlib.Bigarray.Genarray.t ->
axis_padding Base.array ->
Base.unit
Threefry4x32 PRNG - 20 round cryptographic version
Threefry4x32 PRNG - 2 round light version
Threefry4x32 PRNG - default version
Conversion from uint4x32 to various uniform distributions
val is_homogeneous_prec_unop : unop -> bool
Returns true if the unary operation is homogeneous in precision, meaning its argument should be converted to the result precision.
val is_homogeneous_prec_vec_unop : vec_unop -> bool
Returns true if the vec_unop operation is homogeneous in precision, meaning its argument should be converted to the result precision.
Returns true if the binary operation is homogeneous in precision, meaning its arguments should be converted to the result precision.
val is_homogeneous_prec_ternop : ternop -> bool
Returns true if the ternary operation is homogeneous in precision, meaning its arguments should be converted to the result precision.