Module Einsum_parser

Entry point for the einsum parser library.

This module provides functions to parse einsum notation specifications using a Menhir-based parser.

type use_padding_spec = [
  1. | `True
  2. | `False
  3. | `Unspecified
]

Use_padding specification for convolutions.

val compare_use_padding_spec : use_padding_spec -> use_padding_spec -> Base.int
val __use_padding_spec_of_sexp__ : Sexplib0.Sexp.t -> use_padding_spec
val use_padding_spec_of_sexp : Sexplib0.Sexp.t -> use_padding_spec
val sexp_of_use_padding_spec : use_padding_spec -> Sexplib0.Sexp.t
type conv_spec = {
  1. dilation : Base.string;
  2. kernel_label : Base.string;
  3. use_padding : use_padding_spec;
}

Convolution component for affine axis specifications. Note: dilation is a string because it can be an identifier at parse time, and is resolved to an int at runtime.

val compare_conv_spec : conv_spec -> conv_spec -> Base.int
val conv_spec_of_sexp : Sexplib0.Sexp.t -> conv_spec
val sexp_of_conv_spec : conv_spec -> Sexplib0.Sexp.t
type axis_spec =
  1. | Label of Base.string
    (*

    A variable axis label.

    *)
  2. | Fixed_index of Base.int
    (*

    A fixed index, used for projection.

    *)
  3. | Affine_spec of {
    1. stride : Base.string;
      (*

      Coefficient for the over dimension (string to allow identifiers).

      *)
    2. over_label : Base.string;
      (*

      The output/iteration dimension label.

      *)
    3. conv : conv_spec Base.option;
      (*

      Optional convolution: dilation*kernel.

      *)
    4. stride_offset : Base.int;
      (*

      Constant offset added after stride*over.

      *)
    }
    (*

    Affine axis specification: stride*over + stride_offset + dilation*kernel. Corresponds to Row.Affine in shape inference.

    *)

Specification for individual axes in the einsum notation. Note: stride is a string because it can be an identifier at parse time, and is resolved to an int at runtime.

val compare_axis_spec : axis_spec -> axis_spec -> Base.int
val axis_spec_of_sexp : Sexplib0.Sexp.t -> axis_spec
val sexp_of_axis_spec : axis_spec -> Sexplib0.Sexp.t
module AxisKey : sig ... end

An index pointing to any of a shape's axes, including the kind of the axis (Batch, Input, Output) and the position (which is counted from the end to facilitate broadcasting).

type axis_key = AxisKey.t
val equal_axis_key : axis_key -> axis_key -> Base.bool
val compare_axis_key : axis_key -> axis_key -> Base.int
val axis_key_of_sexp : Sexplib0.Sexp.t -> axis_key
val sexp_of_axis_key : axis_key -> Sexplib0.Sexp.t
type 'a axis_map = 'a Base.Map.M(AxisKey).t
val compare_axis_map : 'a. ('a -> 'a -> Base.int) -> 'a axis_map -> 'a axis_map -> Base.int
val axis_map_of_sexp : 'a. (Sexplib0.Sexp.t -> 'a) -> Sexplib0.Sexp.t -> 'a axis_map
val sexp_of_axis_map : 'a. ('a -> Sexplib0.Sexp.t) -> 'a axis_map -> Sexplib0.Sexp.t
type parsed_axis_labels = {
  1. bcast_batch : Base.string Base.option;
  2. bcast_input : Base.string Base.option;
  3. bcast_output : Base.string Base.option;
  4. given_batch : axis_spec Base.list;
  5. given_input : axis_spec Base.list;
  6. given_output : axis_spec Base.list;
  7. given_beg_batch : axis_spec Base.list;
  8. given_beg_input : axis_spec Base.list;
  9. given_beg_output : axis_spec Base.list;
  10. labels : axis_spec axis_map;
}

The labels are strings assigned to AxisKey axes. Moreover the bcast_ fields represent whether additional leading/middle axes are allowed (corresponding to the dot-ellipsis syntax for broadcasting). The string can be used to identify a row variable, and defaults to "batch", "input", "output" respectively when parsing "...". The given_ fields are lists of axis specs of the corresponding kind in labels where from_end=true, given_beg_ where from_end=false.

val compare_parsed_axis_labels : parsed_axis_labels -> parsed_axis_labels -> Base.int
val parsed_axis_labels_of_sexp : Sexplib0.Sexp.t -> parsed_axis_labels
val sexp_of_parsed_axis_labels : parsed_axis_labels -> Sexplib0.Sexp.t
val given_beg_output : parsed_axis_labels -> axis_spec Base.list
val given_beg_input : parsed_axis_labels -> axis_spec Base.list
val given_beg_batch : parsed_axis_labels -> axis_spec Base.list
val given_output : parsed_axis_labels -> axis_spec Base.list
val given_input : parsed_axis_labels -> axis_spec Base.list
val given_batch : parsed_axis_labels -> axis_spec Base.list
val bcast_output : parsed_axis_labels -> Base.string Base.option
val bcast_input : parsed_axis_labels -> Base.string Base.option
val bcast_batch : parsed_axis_labels -> Base.string Base.option
module Fields_of_parsed_axis_labels : sig ... end
exception Parse_error of Base.string

Exception raised when parsing fails.

val is_multichar : Base.string -> Base.bool

Determine if a spec uses multichar mode. Multichar mode is triggered by presence of: ',', '*', '+', '^', '&'

val axis_labels_of_spec : Base.string -> parsed_axis_labels

Parse an axis labels specification.

Examples:

  • "abc" (single-char mode)
  • "a, b, c" (multichar mode, triggered by comma)
  • "batch|input->output"
  • "...a..b"
val einsum_of_spec : Base.string -> parsed_axis_labels * parsed_axis_labels Base.option * parsed_axis_labels

Parse an einsum specification.

Examples:

  • "ij;jk=>ik" (matrix multiplication)
  • "ij=>ji" (transpose/permute)
  • "i,j->2*i+j" (convolution)