decons/mlx.lisp

79 lines
2.3 KiB
Common Lisp

;;;; decons/mlx - machine learning experiments
(defpackage :decons/mlx
(:use :common-lisp)
(:local-nicknames (:r :decons/recurse)
(:util :scopes/util))
(:export #:default-deviation #:l2-loss
*revisions* *alpha* #:gradient-descent
#:nabla-xp
#:line #:quad
#:*obj*))
(in-package :decons/mlx)
;;;; loss calculation
(defun default-deviation (observed calculated &key (norm (lambda (x) (* x x))))
(reduce #'+ (mapcar (lambda (a b) (funcall norm (- a b)))
observed calculated)))
(defun l2-loss (target &key (deviation #'default-deviation))
(lambda (dataset) ; expectant function
(lambda (theta) ; objective function
(let* ((objective (funcall target (car dataset)))
(calculated (funcall objective theta)))
(funcall deviation (cadr dataset) calculated)))))
;;;; optimization by revision (= gradient descent)
(defparameter *revisions* 1000)
(defparameter *alpha* 0.01)
(defun gradient-descent (obj theta)
(flet ((try-theta (th)
(mapcar (lambda (p g) (- p (* *alpha* g)))
th (nabla-xp obj th))))
(dotimes (ix *revisions*)
(setf theta (try-theta theta)))
theta))
;;;; experimental differentiation
(defvar *diff-variation* 0.01)
(defun nabla-xp (fn args)
"Determine gradients by experiment: vary args and record changing results."
(let ((base (funcall fn args))
(vargs (apply #'vector args))
(res nil))
(dotimes (ix (length vargs))
(push (diff fn vargs ix base) res))
(reverse res)))
(defun diff (fn vargs ix base)
(let* ((vdiff *diff-variation*)
(val (svref vargs ix))
(argsx (copy-seq vargs)) r+ r-)
(setf (svref argsx ix) (+ val vdiff))
(setf r+ (/ (- (funcall fn (map 'list #'identity argsx)) base) vdiff))
(setf (svref argsx ix) (- val vdiff))
(setf r- (/ (- base (funcall fn (map 'list #'identity argsx))) vdiff))
;(util:lgi base r+ r-)
(/ (+ r+ r-) 2)))
;;;; parameterized target functions
(defun line (x)
(lambda (theta)
(r:add (r:mul (first theta) x) (second theta))))
(defun quad (x)
(lambda (theta)
(r:add (r:mul (first theta) (r:sqr x))
(r:add (r:mul (second theta) x)
(third theta)))))
;;;; working area
(defvar *obj* nil)