79 lines
2.3 KiB
Common Lisp
79 lines
2.3 KiB
Common Lisp
;;;; decons/mlx - machine learning experiments
|
|
|
|
(defpackage :decons/mlx
|
|
(:use :common-lisp)
|
|
(:local-nicknames (:r :decons/recurse)
|
|
(:util :scopes/util))
|
|
(:export #:default-deviation #:l2-loss
|
|
*revisions* *alpha* #:gradient-descent
|
|
#:nabla-xp
|
|
#:line #:quad
|
|
#:*obj*))
|
|
|
|
(in-package :decons/mlx)
|
|
|
|
;;;; loss calculation
|
|
|
|
(defun default-deviation (observed calculated &key (norm (lambda (x) (* x x))))
|
|
(reduce #'+ (mapcar (lambda (a b) (funcall norm (- a b)))
|
|
observed calculated)))
|
|
|
|
(defun l2-loss (target &key (deviation #'default-deviation))
|
|
(lambda (dataset) ; expectant function
|
|
(lambda (theta) ; objective function
|
|
(let* ((objective (funcall target (car dataset)))
|
|
(calculated (funcall objective theta)))
|
|
(funcall deviation (cadr dataset) calculated)))))
|
|
|
|
;;;; optimization by revision (= gradient descent)
|
|
|
|
(defparameter *revisions* 1000)
|
|
(defparameter *alpha* 0.01)
|
|
|
|
(defun gradient-descent (obj theta)
|
|
(flet ((try-theta (th)
|
|
(mapcar (lambda (p g) (- p (* *alpha* g)))
|
|
th (nabla-xp obj th))))
|
|
(dotimes (ix *revisions*)
|
|
(setf theta (try-theta theta)))
|
|
theta))
|
|
|
|
;;;; experimental differentiation
|
|
|
|
(defvar *diff-variation* 0.01)
|
|
|
|
(defun nabla-xp (fn args)
|
|
"Determine gradients by experiment: vary args and record changing results."
|
|
(let ((base (funcall fn args))
|
|
(vargs (apply #'vector args))
|
|
(res nil))
|
|
(dotimes (ix (length vargs))
|
|
(push (diff fn vargs ix base) res))
|
|
(reverse res)))
|
|
|
|
(defun diff (fn vargs ix base)
|
|
(let* ((vdiff *diff-variation*)
|
|
(val (svref vargs ix))
|
|
(argsx (copy-seq vargs)) r+ r-)
|
|
(setf (svref argsx ix) (+ val vdiff))
|
|
(setf r+ (/ (- (funcall fn (map 'list #'identity argsx)) base) vdiff))
|
|
(setf (svref argsx ix) (- val vdiff))
|
|
(setf r- (/ (- base (funcall fn (map 'list #'identity argsx))) vdiff))
|
|
;(util:lgi base r+ r-)
|
|
(/ (+ r+ r-) 2)))
|
|
|
|
;;;; parameterized target functions
|
|
|
|
(defun line (x)
|
|
(lambda (theta)
|
|
(r:add (r:mul (first theta) x) (second theta))))
|
|
|
|
(defun quad (x)
|
|
(lambda (theta)
|
|
(r:add (r:mul (first theta) (r:sqr x))
|
|
(r:add (r:mul (second theta) x)
|
|
(third theta)))))
|
|
|
|
;;;; working area
|
|
|
|
(defvar *obj* nil)
|