Here is my take with Object Variant which fails with "Vector32" is not a
concrete type as well. Probably because Vector is also parameterized by a
static[int].
I feel like I have to recode all the base types with an unified float, vector,
matrix (and why not n-dimensional array) type to get it to work:
import linalg
type
BackPropKind = enum
## List of BackProp[T,U] types
bpFF,
bpFV,
bpVF,
bpVV,
bpFM,
bpVM,
bpMF,
bpMV,
bpMM
BackProp = ref BackPropObj
BackPropObj = object
case kind: BackPropKind
of bpFF:
ff: proc (gradient: float32): float32 {.noSideEffect.}
of bpFV:
fv: proc (gradient: float32): Vector32 {.noSideEffect.}
of bpVF:
vf: proc (gradient: Vector32): float32 {.noSideEffect.}
of bpVV:
vv: proc (gradient: Vector32): Vector32 {.noSideEffect.}
of bpFM:
fm: proc (gradient: float32): Matrix32 {.noSideEffect.}
of bpVM:
vm: proc (gradient: Vector32): Matrix32 {.noSideEffect.}
of bpMF:
mf: proc (gradient: Matrix32): float32 {.noSideEffect.}
of bpMV:
mv: proc (gradient: Matrix32): Vector32 {.noSideEffect.}
of bpMM:
mm: proc (gradient: Matrix32): Matrix32 {.noSideEffect.}
Node = object
## Represent an operation
## Stores the gradient transformation for backprop in weights
## Stores indices of parent operation in parents
weights: array[2, BackProp]
parents: array[2, int] #ref indices to parent nodes
Context* = object
## Tape / Wengert list. Contains the list of applied operations
nodes: ref seq[Node]
proc newContext*: Context {.noSideEffect.} =
## Initialize a context (Tape / Wengert list)
result.nodes = new seq[Node]
result.nodes[] = @[]
let ctx = newContext()