135 lines
4 KiB
Common Lisp
135 lines
4 KiB
Common Lisp
;;;; decons
|
|
|
|
(defpackage :decons
|
|
(:use :common-lisp)
|
|
(:local-nicknames (:shape :scopes/shape)
|
|
(:util :scopes/util))
|
|
(:export #:rapply #:rreduce #:rreduce-1 #:radd #:rmul #:rsub #:rdiv #:rsqr
|
|
#:default-deviation #:l2-loss
|
|
*revisions* *alpha* #:gradient-descent
|
|
#:nabla-xp
|
|
#:line #:quad
|
|
#:*obj*))
|
|
|
|
(in-package :decons)
|
|
|
|
;;;; common (basic) stuff
|
|
|
|
(defun sqr (x) (* x x))
|
|
|
|
(defun sum (data) (reduce #'+ data))
|
|
|
|
;;;; rapply, rreduce - recursive application of operations
|
|
|
|
(defun rapply (op arg1 &optional arg2)
|
|
(if arg2
|
|
;(rcall (rcurry op arg1) arg2)
|
|
(rcall2 op arg1 arg2)
|
|
(rcall op arg1)))
|
|
|
|
(defgeneric rcall (op arg)
|
|
(:method (op arg) (funcall op arg))
|
|
(:method (op (arg list))
|
|
(mapcar (lambda (i) (rcall op i)) arg)))
|
|
|
|
(defgeneric rcall2 (op a1 a2)
|
|
(:method (op a1 a2) (funcall op a1 a2))
|
|
(:method (op (a1 list) a2)
|
|
(mapcar (lambda (i) (rcall2 op i a2)) a1))
|
|
(:method (op a1 (a2 list))
|
|
(mapcar (lambda (j) (rcall2 op a1 j)) a2))
|
|
(:method (op (a1 list) (a2 list))
|
|
(mapcar (lambda (i j) (rcall2 op i j)) a1 a2)))
|
|
|
|
(defgeneric rcurry (op arg)
|
|
(:method (op arg) (lambda (j) (funcall op arg j)))
|
|
(:method (op (arg list))
|
|
(lambda (j) (mapcar (lambda (i) (rapply op i j)) arg))))
|
|
|
|
(defun rreduce (op arg &key (initial-value 0))
|
|
(reduce op arg :initial-value initial-value
|
|
:key (lambda (v) (relement op v :initial-value initial-value))))
|
|
|
|
(defgeneric relement (op v &key initial-value)
|
|
(:method (op v &key (initial-value 0)) v)
|
|
(:method (op (v list) &key (initial-value 0))
|
|
(rreduce op v :initial-value initial-value)))
|
|
|
|
(defgeneric rreduce-1 (op arg &key initial-value)
|
|
(:method (op arg &key (initial-value 0)) arg)
|
|
(:method (op (arg list) &key (initial-value 0))
|
|
(if (some #'consp arg)
|
|
(mapcar (lambda (x) (rreduce-1 op x :initial-value initial-value)) arg)
|
|
(reduce op arg :initial-value initial-value))))
|
|
|
|
(defun radd (a b) (rapply #'+ a b))
|
|
(defun rmul (a b) (rapply #'* a b))
|
|
(defun rsub (a b) (rapply #'- a b))
|
|
(defun rdiv (a b) (rapply #'/ a b))
|
|
(defun rsqr (a) (rapply #'sqr a))
|
|
|
|
;;;; loss calculation, collect trial data (parameters, resulting loss)
|
|
|
|
(defun default-deviation (observed calculated &key (norm #'sqr))
|
|
(sum (mapcar (lambda (a b) (funcall norm (- a b)))
|
|
observed calculated)))
|
|
|
|
(defun l2-loss (target &key (deviation #'default-deviation))
|
|
(lambda (dataset) ; expectant function
|
|
(lambda (theta) ; objective function
|
|
(let* ((objective (funcall target (car dataset)))
|
|
(calculated (funcall objective theta)))
|
|
(funcall deviation (cadr dataset) calculated)))))
|
|
|
|
;;;; optimization by revision (= gradient descent)
|
|
|
|
(defparameter *revisions* 1000)
|
|
(defparameter *alpha* 0.01)
|
|
|
|
(defun gradient-descent (obj theta)
|
|
(flet ((try-theta (th)
|
|
(mapcar (lambda (p g) (- p (* *alpha* g)))
|
|
th (nabla-xp obj th))))
|
|
(dotimes (ix *revisions*)
|
|
(setf theta (try-theta theta)))
|
|
theta))
|
|
|
|
;;;; experimental differentiation
|
|
|
|
(defvar *diff-variation* 0.01)
|
|
|
|
(defun nabla-xp (fn args)
|
|
"Determine gradients by experiment: vary args and record changes."
|
|
(let ((base (funcall fn args))
|
|
(vargs (apply #'vector args))
|
|
(res nil))
|
|
(dotimes (ix (length vargs))
|
|
(push (diff fn vargs ix base) res))
|
|
(reverse res)))
|
|
|
|
(defun diff (fn vargs ix base)
|
|
(let* ((vdiff *diff-variation*)
|
|
(val (svref vargs ix))
|
|
(argsx (copy-seq vargs)) r+ r-)
|
|
(setf (svref argsx ix) (+ val vdiff))
|
|
(setf r+ (/ (- (funcall fn (map 'list #'identity argsx)) base) vdiff))
|
|
(setf (svref argsx ix) (- val vdiff))
|
|
(setf r- (/ (- base (funcall fn (map 'list #'identity argsx))) vdiff))
|
|
;(util:lgi base r+ r-)
|
|
(/ (+ r+ r-) 2)))
|
|
|
|
;;;; parameterized target functions
|
|
|
|
(defun line (x)
|
|
(lambda (theta)
|
|
(radd (rmul (first theta) x) (second theta))))
|
|
|
|
(defun quad (x)
|
|
(lambda (theta)
|
|
(radd (rmul (first theta) (rsqr x))
|
|
(radd (rmul (second theta) x)
|
|
(third theta)))))
|
|
|
|
;;;; working area
|
|
|
|
(defvar *obj* nil)
|