decons/decons.lisp

106 lines
3.2 KiB
Common Lisp

;;;; decons
(defpackage :decons
(:use :common-lisp)
(:local-nicknames (:shape :scopes/shape)
(:util :scopes/util))
(:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv
#:default-deviation #:l2-loss #:trial #:gradient
#:nabla-xp #:diff
#:line
#:*obj* #:*trials* #:try))
(in-package :decons)
;;;; rapply, rreduce - recursive application of operations
(defun rapply (op arg1 &optional arg2)
(if arg2
(rcall (rcurry op arg1) arg2)
(rcall op arg1)))
(defgeneric rcall (op arg)
(:method (op arg) (funcall op arg))
(:method (op (arg list))
(mapcar (lambda (i) (rcall op i)) arg)))
(defgeneric rcurry (op arg)
(:method (op arg) (lambda (j) (funcall op arg j)))
(:method (op (arg list))
(lambda (j) (mapcar (lambda (i) (rapply op i j)) arg))))
(defun rreduce (op arg &key (initial-value 0))
(reduce op arg :initial-value initial-value
:key (lambda (v) (relement op v :initial-value initial-value))))
(defgeneric relement (op v &key initial-value)
(:method (op v &key (initial-value 0)) v)
(:method (op (v list) &key (initial-value 0))
(rreduce op v :initial-value initial-value)))
(defun radd (a b) (rapply #'+ a b))
(defun rmul (a b) (rapply #'* a b))
(defun rsub (a b) (rapply #'- a b))
(defun rdiv (a b) (rapply #'/ a b))
;;;; loss calculation, collect trial data (parameters, resulting loss)
(defun sqr (x) (* x x))
(defun sum (data) (reduce #'+ data))
(defun default-deviation (observed calculated &key (norm #'sqr))
(sum (mapcar (lambda (a b) (funcall norm (- a b)))
observed calculated)))
(defun l2-loss (target &key (deviation #'default-deviation))
(lambda (dataset) ; expectant function
(lambda (theta) ; objective function
(let* ((objective (funcall target (car dataset)))
(calculated (funcall objective theta)))
(funcall deviation (cadr dataset) calculated)))))
(defclass trial ()
((theta :reader theta :initarg :theta)
(loss :reader loss :initarg :loss)))
(defmethod print-object ((tr trial) stream)
(shape:print-slots tr stream 'theta 'loss))
(defun gradient (target dataset)
(let ((obj (funcall (l2-loss target) dataset)))
(lambda (theta)
(nabla-xp obj theta))))
(defun nabla-xp (fn args)
"Determine gradients by experiment: vary args and record changes."
(let ((base (funcall fn args))
(vargs (apply #'vector args))
(res nil))
(dotimes (ix (length vargs))
(push (diff fn vargs ix base) res))
(reverse res)))
(defun diff (fn vargs ix base)
(let* ((vx 0.01) argsx r+ r-)
(setf argsx (copy-seq vargs))
(setf (svref argsx ix) (+ (svref vargs ix) vx))
(setf r+ (/ (- (funcall fn (map 'list #'identity argsx)) base) vx))
(setf (svref argsx ix) (- (svref vargs ix) vx))
(setf r- (/ (- base (funcall fn (map 'list #'identity argsx))) vx))
;(util:lgi base r+ r-)
(/ (+ r+ r-) 2)))
;;;; parameterized target functions
(defun line (x)
#'(lambda (theta) (radd (cadr theta) (rmul (car theta) x))))
;;;; working area
(defvar *obj* nil)
(defvar *trials* nil)
(defun try (obj theta)
(push (make-instance 'trial :theta theta :loss (funcall obj theta)) *trials*))