decons/decons.lisp

135 lines
4 KiB
Common Lisp

;;;; decons
(defpackage :decons
(:use :common-lisp)
(:local-nicknames (:shape :scopes/shape)
(:util :scopes/util))
(:export #:rapply #:rreduce #:rreduce-1 #:radd #:rmul #:rsub #:rdiv #:rsqr
#:default-deviation #:l2-loss
*revisions* *alpha* #:gradient-descent
#:nabla-xp
#:line #:quad
#:*obj*))
(in-package :decons)
;;;; common (basic) stuff
(defun sqr (x) (* x x))
(defun sum (data) (reduce #'+ data))
;;;; rapply, rreduce - recursive application of operations
(defun rapply (op arg1 &optional arg2)
(if arg2
;(rcall (rcurry op arg1) arg2)
(rcall2 op arg1 arg2)
(rcall op arg1)))
(defgeneric rcall (op arg)
(:method (op arg) (funcall op arg))
(:method (op (arg list))
(mapcar (lambda (i) (rcall op i)) arg)))
(defgeneric rcall2 (op a1 a2)
(:method (op a1 a2) (funcall op a1 a2))
(:method (op (a1 list) a2)
(mapcar (lambda (i) (rcall2 op i a2)) a1))
(:method (op a1 (a2 list))
(mapcar (lambda (j) (rcall2 op a1 j)) a2))
(:method (op (a1 list) (a2 list))
(mapcar (lambda (i j) (rcall2 op i j)) a1 a2)))
(defgeneric rcurry (op arg)
(:method (op arg) (lambda (j) (funcall op arg j)))
(:method (op (arg list))
(lambda (j) (mapcar (lambda (i) (rapply op i j)) arg))))
(defun rreduce (op arg &key (initial-value 0))
(reduce op arg :initial-value initial-value
:key (lambda (v) (relement op v :initial-value initial-value))))
(defgeneric relement (op v &key initial-value)
(:method (op v &key (initial-value 0)) v)
(:method (op (v list) &key (initial-value 0))
(rreduce op v :initial-value initial-value)))
(defgeneric rreduce-1 (op arg &key initial-value)
(:method (op arg &key (initial-value 0)) arg)
(:method (op (arg list) &key (initial-value 0))
(if (some #'consp arg)
(mapcar (lambda (x) (rreduce-1 op x :initial-value initial-value)) arg)
(reduce op arg :initial-value initial-value))))
(defun radd (a b) (rapply #'+ a b))
(defun rmul (a b) (rapply #'* a b))
(defun rsub (a b) (rapply #'- a b))
(defun rdiv (a b) (rapply #'/ a b))
(defun rsqr (a) (rapply #'sqr a))
;;;; loss calculation, collect trial data (parameters, resulting loss)
(defun default-deviation (observed calculated &key (norm #'sqr))
(sum (mapcar (lambda (a b) (funcall norm (- a b)))
observed calculated)))
(defun l2-loss (target &key (deviation #'default-deviation))
(lambda (dataset) ; expectant function
(lambda (theta) ; objective function
(let* ((objective (funcall target (car dataset)))
(calculated (funcall objective theta)))
(funcall deviation (cadr dataset) calculated)))))
;;;; optimization by revision (= gradient descent)
(defparameter *revisions* 1000)
(defparameter *alpha* 0.01)
(defun gradient-descent (obj theta)
(flet ((try-theta (th)
(mapcar (lambda (p g) (- p (* *alpha* g)))
th (nabla-xp obj th))))
(dotimes (ix *revisions*)
(setf theta (try-theta theta)))
theta))
;;;; experimental differentiation
(defvar *diff-variation* 0.01)
(defun nabla-xp (fn args)
"Determine gradients by experiment: vary args and record changes."
(let ((base (funcall fn args))
(vargs (apply #'vector args))
(res nil))
(dotimes (ix (length vargs))
(push (diff fn vargs ix base) res))
(reverse res)))
(defun diff (fn vargs ix base)
(let* ((vdiff *diff-variation*)
(val (svref vargs ix))
(argsx (copy-seq vargs)) r+ r-)
(setf (svref argsx ix) (+ val vdiff))
(setf r+ (/ (- (funcall fn (map 'list #'identity argsx)) base) vdiff))
(setf (svref argsx ix) (- val vdiff))
(setf r- (/ (- base (funcall fn (map 'list #'identity argsx))) vdiff))
;(util:lgi base r+ r-)
(/ (+ r+ r-) 2)))
;;;; parameterized target functions
(defun line (x)
(lambda (theta)
(radd (rmul (first theta) x) (second theta))))
(defun quad (x)
(lambda (theta)
(radd (rmul (first theta) (rsqr x))
(radd (rmul (second theta) x)
(third theta)))))
;;;; working area
(defvar *obj* nil)