#
# a tiny backprop example using autograd
#

use strict;
use Data::Dumper;


# define "a + b" expression.
sub mysum {
  return { 
    name => "(".join("+",map {$_->{name}} @_).")",
    prnt => [@_],
    forward => sub {   # forward: out = a + b 
      my $o = shift;
      my $s=0; map { $s+=$_->{data} } @{$o->{prnt}};
      $o->{data} = $s;
    },
    backward => sub {
      my ($o) = @_;    # backward, a.grad = b.grad = out.grad
      map { $_->{grad} += 1.0*$o->{grad} } @{$o->{prnt}};
    }
  }
}

# define a * b expression
sub myprod {
  return { 
    name => "(".join("*",map {$_->{name}} @_).")",
    prnt => [@_],
    forward => sub {   # forward, out = a * b
      my $o = shift;
      my $p=1; map { $p*=$_->{data} } @{$o->{prnt}};
      $o->{data} = $p;
    },
    backward => sub {  # backward, a.grad = b * out.grad, 
      my ($o) = @_;    #           b.grad = a * out.grad
      map { $_->{grad} += ($o->{data} / ( $_->{data} ? $_->{data} : 0.000001 ) )*$o->{grad} } @{$o->{prnt}};
    }
  }
}

# topological sort; breadth-first traversal
# forward and backward evaluation must first do dependencies
sub topo {
  my @q = @_;
  my @out = ();
  my $v = $q[0]{v}+1;
  while(@q){
    my $o = shift @q;
    next if $o->{v} == $v; # already seen it.
    $o->{v} = $v;
    push @out,$o;
    push @q,@{$o->{prnt}} if $o->{prnt};
  }
  return @out;
}


sub forward {
  my ($o) = @_;
  unless($o->{topo}){   # in forward order 
    $o->{topo} = [topo($o)];
    $o->{ftopo} = [reverse @{$o->{topo}}];
  }
  map { $_->{forward}($_) } grep { $_->{prnt} } @{$o->{ftopo}}; # run each expression
}

sub backward {
  my ($o) = @_;
  map { $_->{grad} = 0 } @{$o->{topo}};
  $o->{grad} = 1;
  map { $_->{backward}($_) } grep { $_->{backward} } @{$o->{topo}}; # calc grad for each expression
}

my $n1 = {data=>-1, name=>"-1"};   # define -1


# generate "test" data (a "known" model, Y = 23 - 50*X).
my @X;
my @Y;
for(1..100){
  my $x = (rand()-0.5)*10;
  my $y = 23 + -50 * $x;
  print "data: $x,$y\n";
  push @X,$x;
  push @Y,$y;
}

# these are the weights.
my $w1 = {data=>rand(), name=>"w1"};
my $w2 = {data=>rand(), name=>"w2"};

# build the expression for the loss function
my @sum;
for(my $i=0;$i<=$#X;$i++){
  #y = w1 + w2*x
  my $y = mysum($w1, myprod($w2,{data=>$X[$i], name=>"x$i"}));
  my $e = mysum({data=>$Y[$i], name=>"y$i"}, myprod($y,$n1));
  push @sum, myprod($e,$e);
}
my $a = mysum(@sum);
print "$a->{name}\n";   # this is the long expression.

forward($a);   # run it forward.

print "sum of squares: $a->{data}\n";  # this is the output of loss function.

# do gradient descent.
my $last = 999999;
for(;;){
  forward($a);   # run model forward (calculate outputs, intermediate values).
  print "y = $w1->{data} + X*$w2->{data};   error: $a->{data}\n";
  last if ($last - $a->{data})**2 < 0.001;   # exit if stopped improving
  $last = $a->{data};

  backward($a);  # run model backwards, calculate gradient for all params.

  for($w1,$w2){  # for each parameter, adjust in direction of negative gradient
    $_->{data} += 0.001 * -$_->{grad};
  }
}


