Syntax extensions %cd and %op

Preliminaries

OCANNL, and arrayjit specifically, is built around a fixed number of numeric operations, declared in arrayjit/ops.ml. We assign lexical operators to the binary operations, inventing novel operators if needed. For example, Rectified Linear Unit Relu operation, which computes f(x) = max(0,x), is called relu, while the ReLU-Gate Relu_gate operation, which computes f(x,y) = if x > 0.0 then y else 0.0, gets the operator -?/ in addition to name relu_gate. These built-in numeric operations are used to construct assignments (Assignments.t packaged as Assignments.comp). The syntax %cd is needed to build assignments concisely, and the assignment operators always start with = (unlike in C where they end with =). On the other hand, while the syntax %op helps build tensors (Tensor.t), they can be expressed concisely in pure OCaml. Unlike for assignments, the building blocks for tensor expressions are easy to extend. The meaningful basic ones are provided in tensor/operation.ml.

In OCANNL, we call a tensor that is prohibited from propagating gradients, does not have a gradient node nor backprop code, a non-differentiable tensor. Accordingly we can call the “plain” tensors with a gradient node differentiable tensors. Expressions in the %cd syntax will sometimes build new non-differentiable tensors as components of assignments (they will never build new differentiable tensors). The syntax extensions make the following assumption:

Functions inside Operation.NTDSL use ~grad_spec:Prohibit_grad when calling into Tensor, making the resulting tensors non-differentiable. Functions inside Operation.TDSL use ~grad_spec:If_needed, which will make the tensors non-differentiable when the gradient is not needed – except for TDSL.param, which internally sets ~grad_spec:Require_grad. Functions inside Operation.PDSL use ~grad_spec:Require_grad.

The extension points open NTDSL.O, resp. TDSL.O, for the scope of the extension point, to expose the corresponding operators.

The %oc anti-quotation and the unit-parameter heuristic

Within %op and %cd contexts, expressions typically undergo transformation to build tensors or assignments. However, OCANNL uses two mechanisms to preserve pure OCaml expressions:

Unit-parameter heuristic (automatic in %op)

In the %op syntax, when a function application contains a unit () argument, all arguments appearing before the unit are automatically preserved as pure OCaml expressions. This aligns with OCANNL’s design pattern where configuration happens before the unit parameter:

(* Arguments before () are automatically preserved as OCaml *)
let%op my_fn ~label x = 
  other_fn ~label:(("prefix_" ^ name) :: label) ~config:value () x
  (* label and config are preserved; x after () is transformed *)

Explicit %oc anti-quotation

For cases where you need explicit control or the heuristic doesn’t apply, the %oc (mnemonic: “OCaml”) anti-quotation escapes from the transformation context:

(* Force preservation even after () or in edge cases *)
let%op special = process_data data [%oc complex_ocaml_expr]

The %oc extension expects a single expression and returns it unchanged. Use cases: - Overriding the unit-parameter heuristic when needed - Preserving expressions in contexts without a unit parameter - Escaping from the DSL in %cd contexts (which don’t use the unit heuristic)

Primitive operations

To accomodate stylistic preferences, OCANNL supports both curried and uncurried syntaxes for primitive operation application. Binary operators are associated with infix operators, in addition to having alphabetic identifiers. This stems from the following restriction: in the %cd syntax, the assignment is always an infix operator, and it needs to pick the accumulation operation.

The unary primitive operations:

Identifier Default projection Constructor in Ir.Ops
id pointwise Identity
relu pointwise Relu
sat01 pointwise Satur01
exp pointwise Exp
log pointwise Log
exp2 pointwise Exp2
log2 pointwise Log2
sin pointwise Sin
cos pointwise Cos
sqrt pointwise Sqrt
recip pointwise Recip
recip_sqrt pointwise Recip_sqrt
neg pointwise Neg
tanh pointwise Tanh_approx
not pointwise Not
uint4x32_to_prec_uniform dedicated Uint4x32_to_prec_uniform

The binary primitive operations:

Identifier Infix operator Default projection Constructor in Ir.Ops Assignments
fst -@> pointwise Arg1 none
snd -/> pointwise Arg2 =:
add + pointwise Add =+, =:+
sub - pointwise Sub =-, =:-
mul * none Mul =*, =:*
div / none Div =/, =:/
pow ** pointwise ToPowOf =**, =:**
relu_gate -?/ pointwise Relu_gate =?/, =:?/
sat01_gate -?^ pointwise Satur01_gate =?^, =:?^
lt < pointwise Cmplt none
eq = pointwise Cmpeq none
ne <> pointwise Cmpne none
or_ \|\| pointwise Or =\|\|, =:\|\|
and_ && pointwise And =&&, =:&&
mod_ % pointwise Mod none
max @^ pointwise Max =@^, =:@^
min @- pointwise Min =@-, =:@-
threefry4x32 ^^^^ pointwise Threefry4x32 =^^^^, =:^^^^

The ternary primitive operations:

Identifier Default projection Constructor in Ir.Ops
where pointwise Where
fma compose-accumulate FMA

The interpretation functions also state the semantics:

let interpret_unop op v =
  let open Float in
  match op with
  | Identity -> v
  | Relu when v >= 0. -> v
  | Relu -> 0.
  | Satur01 when v <= 0. -> 0.
  | Satur01 when v >= 1. -> 1.
  | Satur01 -> v
  | Exp -> exp v
  | Log -> log v
  | Exp2 -> 2. ** v
  | Log2 -> log v / log 2.
  | Sin -> sin v
  | Cos -> cos v
  | Sqrt -> sqrt v
  | Recip -> 1. / v
  | Recip_sqrt -> 1. / sqrt v
  | Neg -> ~-.v
  | Tanh_approx -> tanh v
  | Not -> if v = 0. then 1. else 0.
  | Uint4x32_to_prec_uniform -> failwith "NOT IMPLEMENTED"

let interpret_binop op v1 v2 =
  let open Float in
  match op with
  | Arg1 -> v1
  | Arg2 -> v2
  | Add -> v1 + v2
  | Sub -> v1 - v2
  | Mul -> v1 * v2
  | Div -> v1 / v2
  | ToPowOf when is_integer v2 -> int_pow v1 @@ to_int v2
  | ToPowOf -> v1 ** v2
  | Relu_gate -> if v1 > 0.0 then v2 else 0.0
  | Satur01_gate -> if v1 > 0.0 && v1 < 1.0 then v2 else 0.0
  | Max -> max v1 v2
  | Min -> min v1 v2
  | Mod -> v1 % v2
  | Cmplt -> if v1 < v2 then 1. else 0.
  | Cmpeq -> if v1 = v2 then 1. else 0.
  | Cmpne -> if v1 <> v2 then 1. else 0.
  | Or -> if v1 <> 0. || v2 <> 0. then 1. else 0.
  | And -> if v1 <> 0. && v2 <> 0. then 1. else 0.
  | Threefry4x32 -> ...

let interpret_ternop op v1 v2 v3 =
  let open Float in
  match op with Where -> if v1 <> 0. then v2 else v3 | FMA -> (v1 * v2) + v3

The syntax for %op

The %op syntax is simpler than the %cd syntax since it relies more on regular OCaml expressions. For example, we can write without syntax extensions:

  let hid_dim = 8 in
  let w = TDSL.param "w" in
  let b = TDSL.param ~output_dims:[ hid_dim ] "b" in
  let layer x = TDSL.O.( relu(w * x + b) ) in
  ...

Since TDSL.O is opened for the scope of an extension point %op:

  let hid_dim = 8 in
  let w = TDSL.param "w" in
  let b = TDSL.param ~output_dims:[ hid_dim ] "b" in
  let%op layer x = relu(w * x + b) in
  ...

Using inline declarations, this becomes more concise:

  let hid_dim = 8 in
  let%op mlp_layer x = relu({ w } * x + { b; o = [ hid_dim ] }) in
  ...

When there is a function directly under the %op extension point, like in the example above, or directly under a function taking a unit parameter (), the function parameter (to the right of ()) should be a tensor. That’s because %op uses this tensor’s (value’s) label to enrich the label of the resulting tensor.

When the declaration is followed by a literal float, the float provides the initial value to initialize the tensor. Otherwise, the tensor value cells are initialized randomly with uniform distribution.

The syntax for %cd

The basic building blocks of the %cd syntax are individual assignments, separated by semicolons. The assignments, represented via Assignments.Accum_binop and Assignments.Accum_unop, are in full generality accumulating:

type Assignments.t =
   ...
  | Accum_binop of {
      initialize_neutral : bool;
      accum : Ops.binop;
      op : Ops.binop;
      lhs : Tnode.t;
      rhs1 : buffer;
      rhs2 : buffer;
      projections : Indexing.projections Lazy.t;
    }
  | Accum_unop of {
      initialize_neutral : bool;
      accum : Ops.binop;
      op : Ops.unop;
      lhs : Tnode.t;
      rhs : buffer;
      projections : Indexing.projections Lazy.t;
    }

For example the binary case in pseudocode: if initialize_neutral then lhs = 0; lhs = lhs accum (rhs1 op rhs2) (assuming the neutral element of accum is 0). The representation also has a field projections which determines which loops should be run and how the tensor nodes should be indexed to perform the computation.

The basic %cd syntax for assignments has the form: <lhs> <asgn-op> <primitive-op-application[rhs1, rhs2?, rhs3?]>. See Primitive operations for the syntax of primitive operation application, where <rhs1>, <rhs2> (for binary and ternary ops), <rhs3> (for ternary ops) are subexpressions. <asgn-op> starts with =, followed by : only if initialize_neutral is true, then followed by the operator syntax variant of a binary primitive operation. The fields <lhs>, <rhs1>, <rhs2>, <rhs3> will often be either special-purpose identifiers (specifically v, t, t1, t2, t3, g, g1, g2, g3) or identifiers bound to tensors. <rhs1>, <rsh2>, <rsh3> will also often be (non-differentiable) tensor expressions. The notation <tensor>.grad stands for the gradient node of the given tensor. For more about “slot fillers”, and to learn about the operators +* and ++, see the section further features of the syntax extension %cd.

How is the projections field determined? projections can be given explicitly as a labeled argument ~projections. If they aren’t but %cd realizes there is a ~projections parameter in scope, it uses it – see tensor/operation.ml where this option is used to define tensor operations. If instead of ~projections a ~logic labeled argument is given, the string passed is used to determine projections. ~logic:"." means a pointwise operation. ~logic:"@" means an “output axes of rhs2 match input axes of rhs1” operation (matrix multiplication is a special case). ~logic:"T" means transpose of input and output axes. The string passed to ~logic can also use OCANNL’s generalization of the einsum notation, allowing arbitrary permutations and reductions of axes. If no information is given, the default depends on the primitive operation, but it is almost always a pointwise operation.

Here we see an example of tensor multiplication – extending matrix multiplication to arbitrary number of axes – multiplying a by b to get c. In =:+, = is required to separate the assigned-to part from the computation, : clears-out c before the computation, + selects addition to accumulate the results.

c =:+ a * b ~logic:"@"

Compare the following two ways of updating a parameter p:

p =+ learning_rate * p.grad ~logic:"."

and:

p =+ learning_rate *. p.grad

In the first case, we have a binary assignment calculated pointwise. The resulting representation is Accum_binop where accum is Add and op is Mul (multiplication). In the second case, *. is not recognized as one of the built-in operators. This leaves the expression learning_rate *. p.grad un-transformed. Since (*.) is bound in NTDSL.O to pointwise tensor multiplication, this creates an intermediate tensor, that is then added onto p. The resulting representation is Accum_unop where accum is Add and op is Identity. Both variants end up with the same result, and even with the same computation, because the second variant’s computation will get optimized (unless configured not to).

Advanced note: when a ~projections parameter is in scope but no assignment-specific ~projections argument is given – the typical case in tensor/operation.ml – the actual projections field for an assignment is computed by transforming the projections parameter according to hints regarding how tensor nodes relate to the given projections. Specifically, the identifiers rhs1, t1, v1, g1 are “slot RHS1” of the projections, rhs2, t2, v2, g2 are “slot RHS2”, lhs,, t, v, g are “slot LHS”. Scalar constants are provided the projection directly, to make the automated derivation more expressive; this is supported both for literals, and (heuristically) for !. and !.. embedding operators.

Numeric and N-dimensional array literals

Both %cd and %op extensions use a shared syntax for N-dimensional array literals. %cd uses NTDSL.number and NTDSL.ndarray functions, while %op uses TDSL.number and TDSL.ndarray functions. (This is just for consistency: TDSL.ndarray invokes Tensor.ndarray ~grad_spec:If_needed, which will figure out the gradient is not needed and will make the tensor non-differentiable.)

Numbers are a special case: an array of (output) dimension 1.

N-dimensional array literals combine the list, tuple and array syntaxes to strictly distinguish between output, input and batch axes:

For example, [ (1, 2, 3); (4, 5, 6) ] is a mathematical matrix converting 3D vectors into 2D vectors.

OCANNL supports dimension labels. The syntax for number allows prefixing a number by a character that stands for the dimension label of the resulting output dimension 1. These labels can then propagate to specify labels of other dimensions in other tensors, via shape inference. Example: let%op y = ({ hey } * 'q' 2.0) + 'p' 1.0 in ...

Wildcard bindings

When an extension is over a wildcard (ignore result) binding: let%cd _ = ... and let%op _ = ..., the generated code is wrapped in Tensor.with_unchanged_roots, to prevent it from upsetting rootness checks. The use-case for writing %op and %cd notations with ignored result is to generate additional shape inference constraints.

Inline declarations

Both %cd and %op syntaxes support inline declarations of tensors. For %op these are differentiable, for %cd non-differentiable tensors.

A declaration site uses the record syntax. The key difference between the two extensions:

Both syntaxes support additional record fields that map directly to labeled arguments of the tensor creation functions (see Tensor module signatures):

Note: for the %op declarations, if the root operation comes from TDSL.O and is not qualified with a module name, it becomes qualified with PDSL which ensures that the created tensor will be differentiable (will have gradients), and will be able to take the additional argumetns. There are also special cases for literal constants to ensure the resulting tensor is initialized with these constants but is differentiable.

Examples:

The tensor name is bound to the newly created tensor, and the record expression itself evaluates to the tensor. The scope of the binding is the full scope of the extension point, even if the declaring record appeared in the body of a function that’s inside the extension point scope (except for %op there is a special case of functions taking a unit parameter () discussed below – inline definitions are introduced once () is applied). The first element of the label of the created tensor is the name that introduced it.

For %cd, inline declarations are allowed both in the assigned-to position (left-hand side) of assignments and in standalone tensor expressions. When used in assignments, one of the tensors on the right-hand-side is picked to provide additional label information if possible. In particular, tensors that are function parameters inside the scope of the extension point, cannot be picked to provide label information, as they would escape their scope at the point the tensor is created. Inline declarations are still prohibited within the right-hand side of assignments to discourage over-use in locations with less label information. Example showing two tensor nodes declared inline, both of them include the label of the param p in their labels:

let sgd_one ~learning_rate ?(momentum = 0.0) ?(weight_decay = 0.0) ?(nesterov = false) p =
  [%cd
    { sgd_delta } =: p.grad + (!.weight_decay *. p);
    if Float.(momentum > 0.0) then (
      { sgd_momentum } =: (!.momentum *. sgd_momentum) + sgd_delta;
      if nesterov then sgd_delta =+ !.momentum *. sgd_momentum else sgd_delta =: sgd_momentum);
    p =- learning_rate *. sgd_delta]

Inline declarations can also be used outside of assignments for creating non-differentiable tensors, to mimic the behavior of %op but without the burden of initialization that a parameter would introduce:

  let%cd mlp_result = mlp { point } in
  let result_routine =
    Train.to_routine (Context.context sgd_routine) IDX.empty
      [%cd ~~("mlp infer"; mlp_result.forward)]
  in
  let callback (x, y) =
    Tn.set_values point [| x; y |];
    Train.run ctx result_routine;
    Float.(mlp_result.@[0] >= 0.)
  in

For %op, the declaration is allowed anywhere. If there is a unit () parameter in the function, the scope of inline-declared tensors is delimited at that parameter. The tensors are defined right after the unit parameter. If there is a labeled parameter with label label before the unit parameter (e.g., ~label), the inline-declared tensors will use that parameter (which should be of type string list) to enrich their labels. Example showing two param tensors declared inline, with scope delimited by () and labels enriched by the label parameter:

let%op mlp_layer ~label ~hid_dim () x = relu ({ w } * x + { b; o = [ hid_dim ] })

Implementation strategy for the initialization syntax

To maintain the familiar concise syntax, yet allow for configurability during initialization, the %op syntax substitutes the operator function applied at the root of the initialization expression by prefixing the function identifier with PDSL (or by NTDSL when invoked from the %%extend_dsl syntax). Only unqualified identifiers get prefixed, and %oc is an escape hatch to prevent perfixing even for unqualified identifiers.

Using OCANNL’s generalized einsum notation

As we mentioned above, in the %cd syntax you can set up an arbitrary assignment with projections derived from a generalized einsum specification, by passing the specification as a string with the ~logic label. However, both the %cd and %op syntaxes support built-in operators that take an einsum specification: +* binding to NTDSL.einsum resp. TDSL.einsum, and ++ binding to NTDSL.einsum1 resp. TDSL.einsum1. +* is a “ternary” operator, binary wrt. tensor arguments, and ++ is a binary operator, unary postfix wrt. tensor arguments. There are even more einsum operators: binary @^+ and +++; unary @^^. When the einsum specification is a literal string, we support two syntax patterns: the string can either directly follow the operator (infix-style notation), or the string can follow the second argument (mixfix-style notation). When the spec string is an identifier, it must directly follow the operator.

+*, +++ and ++ use addition for the accumulation operation; @^+ and @^^ use maximum. You can verify that looking at the definitions of Operation.einsum, Operation.einsum1, etc. You can find examples of +* and ++ behavior in the test suite einsum_trivia.ml and in nn_blocks.ml. A frequent use-case for ++ is to sum out all axes of a tensor:

  let%op scalar_loss = (margin_loss ++ "...|... => 0") /. !..batch_size in
  ...

where (!..) converts an integer into a constant tensor.

Syntax of the generalized einsum notation

The specification syntax has two modes:

The syntax of a generalized einsum spec has two variants:

Recall that a tensor shape is composed of three rows, i.e. sequences of axes: batch, input and output axes. Correspondingly, a shape spec in the notation can be:

The notation for a row is composed of sequences of row specs, and an optional row variable spec. A row variable tracks broadcasting. The syntax of a row:

The syntax of a row variable:

The syntax of an axis spec:

Examples:

Affine indexing for convolutions and pooling

The affine axis syntax enables convolution and pooling operations directly in einsum notation. The semantics:

Important constraint for valid convolution: The formula must hold exactly. For example, with stride=2, kernel_size=2, dilation=1: - effective_kernel_span = 1 + (2-1) * 1 = 2 - A 4x4 input gives output_size: 4 = 2 * (output - 1) + 2output = 2 - A 5x5 input would fail: 5 = 2 * (output - 1) + 2output = 2.5 (not integer)

Examples:

Capturing the dimensions of selected axes for further computation or to add shape constraints

The syntaxes +* and ++ accept an optional list of strings argument after the specification string. When passed, the strings should be some of the identifiers used in the specification. Both dimension variable and row variable labels are supported. This will introduce bindings for Indexing.variable_ref objects at the same point as the inline parameter definition bindings, and will pass these objects with the ~capture_dims argument to einsum resp. einsum1. The bound objects can later be used with Operation.embed_dim or its alias Operation.TDSL.O.dim to embed the solved dimension of the corresponding variable (as a number) into a tensor expression. For a row variable, the number will be the product of the dimensions it resolved into.

Further features of the syntax extension %cd

Referencing arrays: tensor value, tensor gradient, merge buffer of a tensor node

The %cd syntax uses record-style notation to point to:

The accessor .value can (almost?) always be dropped: by default, tensors in the %cd syntax refer to their value nodes. The forward and backprop code accesses manage roots (via the Tensor.consume_forward_code and Tensor.consume_backprop_code functions).

For example, in a data-parallel computation, gradients of the same param p can be merged across devices using the code p.grad =+ p.grad.merge, combined with an explicit device-to-device transfer.

Block comments

The %cd syntax uses the prefix operator (~~) in a semicolon sequence to introduce block comments:

type Assignments.t =
  ...
  | Block_comment of string * t
  ...

Schematic example: ~~("space" "separated" "comment" "tensor p debug_name:" p; <scope of the comment>). The content of the comment uses application syntax, must be composed of strings, <tensor>, <tensor>.value (equivalent to <tensor>), <tensor>.grad components, where <tensor> is any tensor expression or tensor identifier.

This syntax used to be very important, because comments in assignments are used to derive file names for generated code. Now, the %cd syntax automatically introduces block comments for code at let-binding points, using the identifier. Currently the comment does not yet incorporate any tensor node labels – and for that reason we are not yet adding comments around function bodies if a function is annotated with %cd. Moreover, we only automatically add comments for code, not for tensors – so the ~~ syntax is still helpful when the comment needs to be more precise for debugging or naming purposes, or when %cd is not used with a let binding, or when we want to pass a forward code directly instead of let-binding it. If an explicit comment is provided at the let-binding level, the automatic one is omitted.

Further features of the syntax extension %op

Name from binding

When an extension point is applied to a let-binding, e.g. let%op mlp_layer ~label ~hid_dim () x = relu ({ w } * x + { b; o = [ hid_dim ] }), it uses the name of the binding (mlp_layer in the example) for the label of the primary tensor created by the extension, if any. This is why the resulting layer tensor in the example has its label starting with "mlp_layer". If the extension is over a semicolon-separated sequence of expressions, the primary tensor can only be in the last component of the sequence, other syntax constructs are handled analogously.

The example let%op mlp_layer ~label ~hid_dim () x = relu ({ w } * x + { b; o = [ hid_dim ] }) also illustrates providing additional string list to populate the label of the tensor: label must be of type string list.

Label from function argument

The resulting (primary) tensor’s label will also have incorporated the label of the input argument, if any. In our example, the resulting mlp_layer tensor will also include the label of the actually applied x. If the function has a unit parameter (), like mlp_layer above, only parameters to the right of () are considered for label extraction.

When there is the unit parameter, and a ~label parameter (specifically a parameter with label label), this label is also incorporated.

Configuring inline declarations: inline output dimensions, initial values

In the %op syntax, inline declarations use record syntax with additional fields to configure the tensor:

A very simple example from micrograd_demo: Micrograd README basic example:

  let%op c = { a = [ -4 ] } + { b = [ 2 ] } in
  ...

How does it relate to let%op c = { a = -4 } + { b = 2 } in ...? Without brackets, the number is used to initialize all cells of the tensor value, and shape inference decides the shape of the tensor. With brackets, the bracketing specifies both all the cells and the exact shape of the tensor.

Need to lift the applications of configuration arguments (up to the unit parameter)

If you recall, inline declared param tensors get lifted out of functions to be defined at the point of a unit () parameter. Our example let%op mlp_layer ~label ~hid_dim () x = relu ({ w } * x + { b; o = [ hid_dim ] }) translates as:

let mlp_layer ~label ~hid_dim () =
  let w = TDSL.param ~more_label:label "w" () 
  and b = TDSL.param ~more_label:label ~output_dims:[ hid_dim ] "b" () in
  fun x -> TDSL.O.(relu (w * x + b))

For this to work properly, when employing such network blocks, their params also need to be introduced at the right moment. At one point, we tried to do this automatically by the %op syntax, but that was confusing to use. So you need to ensure scoping manually. Consider:

(* FIXME: this is wrong! Doesn't bind the parameters at the right place. *)
let%op three_layer_perceptron ~label ~dim1 ~dim2 ~dim3 () x =
  mlp_layer ~label:[ "L3" ] ~hid_dim:dim3 ()
    (mlp_layer ~label:[ "L2" ] ~hid_dim:dim2 ()
       (mlp_layer ~label:[ "L1" ] ~hid_dim:dim1 () x))

This example would work if we used direct inline definitions, but it does not work when the definitions are indirectly in the functions called. We need to write instead:

let three_layer_perceptron ~label ~dim1 ~dim2 ~dim3 () =
  let layer3 = mlp_layer ~label:[ "L3" ] ~hid_dim:dim3 ()
  and layer2 = mlp_layer ~label:[ "L2" ] ~hid_dim:dim2 ()
  and layer1 = mlp_layer ~label:[ "L1" ] ~hid_dim:dim1 () in
  fun x -> layer3 (layer2 (layer1 x))

The manual approach naturally extends to programmatic network architectures:

let mlp ~label ~hid_dims () =
  let layers =
    List.mapi hid_dims ~f:(fun i hid_dim ->
        mlp_layer ~label:[ "L" ^ Int.to_string i ] ~hid_dim ())
  in
  fun x -> List.fold layers ~init:x ~f:(fun x layer -> layer x)

The syntax extension %%extend_dsls

This syntax extension creates a module DSL_modules with the same submodules as Operation.DSL_modules. It removes the boilerplate associated with introducing new operators into the modules TDSL, NTDSL, PDSL and their O submodules. The payload (i.e. content) of %%extend_dsls must be non-recursive let-bindings. They are parsed using a slight variant of the %op syntax, and are inserted into the DSL modules. The identifiers of the root operator functions of the definitions, if unqualified, are prefixed with the appropriate module, similarly to the behavior of inline definitions. Another unique feature of %%extend_dsls parsing is that inline tensor definitions, like in %cd, do not introduce gradients for the tensors, but, like %op, they do introduce initialization for the inline-defined tensors.

The DSL modules expose the value grad_spec that can be useful for defining operators via a “scheme” function. See the example using the box_muller helper at the beginning of lib/nn_blocks.ml. The definitions there use the %oc escape extension to avoid the prefixing mentioned above.

Implementation details

The hard-coded to-the-power-of operator

OCANNL has a built-in numerical binary operation to-power-of: Ops.ToPowOf. As part of assignments, the corresponding operator is **. Here is the full definition of the to-power-of tensor operation from Operation:

let rec pointpow ?(label : string list = []) ~grad_spec p t1 : Tensor.t =
  let module NTDSL = struct
    include Initial_NTDSL

    module O = struct
      include NDO_without_pow

      let ( **. ) ?label base exp = pointpow ?label ~grad_spec:Tensor.Prohibit_grad exp base
    end
  end in
  let p_t = NTDSL.number p in
  let%cd op_asn ~t ~t1 ~t2 ~projections = v =: v1 ** v2 ~projections in
  let%cd grad_asn =
    if Tensor.is_prohibit_grad grad_spec then fun ~v:_ ~g:_ ~t1:_ ~t2:_ ~projections:_ -> Asgns.Noop
    else if Float.equal p 2.0 then fun ~v:_ ~g ~t1 ~t2:_ ~projections -> g1 =+ p_t *. t1 * g
    else if Float.equal p 1.0 then fun ~v:_ ~g ~t1 ~t2:_ ~projections -> g1 =+ g
    else fun ~v:_ ~g ~t1 ~t2:_ ~projections -> g1 =+ p_t *. (t1 **. (p -. 1.)) * g
  in
  Tensor.binop ~label:("**." :: label) ~compose_op:Pointwise_bin ~op_asn ~grad_asn ~grad_spec t1 p_t

On the Tensor level, this is implemented as a binary tensor operation, but it is exposed as a unary tensor operation! To avoid the complexities of propagating gradient into the exponent, Operation.pointpow is implemented as a function of only one tensor, the exponent is a number. We hard-code the pointwise-power-of operator NTDSL.O.( **. ), resp. TDSL.O.( **. ), in the %cd and %op syntaxes, to pass the numeric value to pointpow (the second argument of **.) without converting it to a tensor first.

Intricacies of the syntax extension %cd

The syntax %cd translator needs to accomplish more than a context-free conversion of a concise notation to an Assignments.comp data-type. In particular:

Embedded nodes

In fact, the syntax %cd produces Assignments.comp values:

type comp = {
  asgns : t;
  embedded_nodes : Set.M(Tnode).t;
}

The tensor nodes that are in asgns but not in embedded_nodes, and are on-device, must already be present in contexts with which the computation is linked. Such non-embedded nodes can be seen as inputs to the computation – except that for backprop code of a tensor, they are actually the outputs! Embedded nodes are closely related to rootness – when a node has not been used in the code of another tensor, it is a root (a forward root for value nodes and a backprop root for grad nodes). embedded_nodes were roots the first time they were used in asgns. Parameters, as created by Tensor.param, are not embedded in the code that uses them and thus will not be in embedded_nodes of the forward and backprop code over the parameters; however, they will constitute the embedded_nodes of the Tensor.init_params code.