decons/decons.lisp

91 lines
2.6 KiB
Common Lisp

;;;; decons
(defpackage :decons
(:use :common-lisp)
(:local-nicknames (:shape :scopes/shape)
(:util :scopes/util))
(:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv
#:default-deviation #:l2-loss #:trial #:gradient
#:line
#:*obj* #:*trials* #:try))
(in-package :decons)
;;;; rapply, rreduce - recursive application of operations
(defun rapply (op arg1 &optional arg2)
(if arg2
(rcall (rcurry op arg1) arg2)
(rcall op arg1)))
(defgeneric rcall (op arg)
(:method (op arg) (funcall op arg))
(:method (op (arg list))
(mapcar (lambda (i) (rcall op i)) arg)))
(defgeneric rcurry (op arg)
(:method (op arg) (lambda (j) (funcall op arg j)))
(:method (op (arg list))
(lambda (j) (mapcar (lambda (i) (rapply op i j)) arg))))
(defun rreduce (op arg &key (initial-value 0))
(reduce op arg :initial-value initial-value
:key (lambda (v) (relement op v :initial-value initial-value))))
(defgeneric relement (op v &key initial-value)
(:method (op v &key (initial-value 0)) v)
(:method (op (v list) &key (initial-value 0))
(rreduce op v :initial-value initial-value)))
(defun radd (a b) (rapply #'+ a b))
(defun rmul (a b) (rapply #'* a b))
(defun rsub (a b) (rapply #'- a b))
(defun rdiv (a b) (rapply #'/ a b))
;;;; loss calculation, collect trial data (parameters, resulting loss)
(defun sqr (x) (* x x))
(defun sum (data) (reduce #'+ data))
(defun default-deviation (observed calculated &key (norm #'sqr))
(sum (mapcar (lambda (a b) (funcall norm (- a b)))
observed calculated)))
(defun l2-loss (target &key (deviation #'default-deviation))
(lambda (dataset) ; expectant function
(lambda (theta) ; objective function
(let* ((objective (funcall target (car dataset)))
(calculated (funcall objective theta)))
(funcall deviation (cadr dataset) calculated)))))
(defclass trial ()
((theta :reader theta :initarg :theta)
(loss :reader loss :initarg :loss)))
(defmethod print-object ((tr trial) stream)
(shape:print-slots tr stream 'theta 'loss))
(defun gradient (target dataset)
(let ((expect (funcall (l2-loss target) dataset)))
(lambda (theta)
(let* ((loss0 (funcall expect theta))
(loss1 (funcall expect (vary theta))))
(- loss0 loss1)))))
(defun vary (theta)
(mapcar (lambda (x) (- x 0.01)) theta))
;;;; parameterized target functions
(defun line (x)
#'(lambda (theta) (radd (cadr theta) (rmul (car theta) x))))
;;;; working area
(defvar *obj* nil)
(defvar *trials* nil)
(defun try (obj theta)
(push (make-instance 'trial :theta theta :loss (funcall obj theta)) *trials*))