improvements and refinements - provide separate functions for deviation and normalization
This commit is contained in:
parent
25197a0520
commit
3c3539b34b
2 changed files with 29 additions and 22 deletions
40
decons.lisp
40
decons.lisp
|
@ -6,8 +6,8 @@
|
|||
(:export #:+pi+ #:area #:circle
|
||||
#:absv #:double #:remainder
|
||||
#:scalar-p #:tensor #:at
|
||||
#:rapply #:rreduce
|
||||
#:combine #:cost #:l2-loss
|
||||
#:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv
|
||||
#:combine #:default-deviation #:l2-loss
|
||||
#:line
|
||||
#:lgx
|
||||
))
|
||||
|
@ -88,6 +88,11 @@
|
|||
(:method (op (v list) &key (initial-value 0))
|
||||
(rreduce op v :initial-value initial-value)))
|
||||
|
||||
(defun radd (a b) (rapply #'+ a b))
|
||||
(defun rmul (a b) (rapply #'* a b))
|
||||
(defun rsub (a b) (rapply #'- a b))
|
||||
(defun rdiv (a b) (rapply #'/ a b))
|
||||
|
||||
;;;; combine
|
||||
|
||||
(defun combine (l1 l2)
|
||||
|
@ -97,22 +102,23 @@
|
|||
(:method (a b) (list a b))
|
||||
(:method (a (b list)) (cons a b)))
|
||||
|
||||
;;;; cost / loss calculation
|
||||
|
||||
(defun cost (measured calculated)
|
||||
(reduce #'+
|
||||
(mapcar (lambda (p) (sqr (apply #'- p)))
|
||||
(combine measured calculated))))
|
||||
|
||||
(defun l2-loss (target &key (norm #'sqr))
|
||||
(lambda (dataset)
|
||||
(lambda (theta)
|
||||
(let ((calculated (funcall (funcall target (car dataset)) theta)))
|
||||
(reduce #'+ (mapcar (lambda (p) (funcall norm (apply #'- p)))
|
||||
(combine (cadr dataset) calculated)))))))
|
||||
;;;; loss calculation
|
||||
|
||||
(defun sqr (x) (* x x))
|
||||
|
||||
(defun sum (data) (reduce #'+ data))
|
||||
|
||||
(defun default-deviation (observed calculated &key (norm #'sqr))
|
||||
(sum (mapcar (lambda (p) (funcall norm (apply #'- p)))
|
||||
(combine observed calculated))))
|
||||
|
||||
(defun l2-loss (target &key (deviation #'default-deviation))
|
||||
(lambda (dataset) ; expectant function
|
||||
(lambda (theta) ; objective function
|
||||
(let* ((objective (funcall target (car dataset)))
|
||||
(calculated (funcall objective theta)))
|
||||
(funcall deviation (cadr dataset) calculated)))))
|
||||
|
||||
;;;; logging (printing) functions for debugging purposes
|
||||
|
||||
(defun lgx (op)
|
||||
|
@ -121,7 +127,7 @@
|
|||
(format t "~&(~a ~a) = ~a~%" op args r)
|
||||
r)))
|
||||
|
||||
;;;; parameterized functions
|
||||
;;;; parameterized target functions
|
||||
|
||||
(defun line (x)
|
||||
#'(lambda (theta) (rapply #'+ (cadr theta) (rapply #'* (car theta) x))))
|
||||
#'(lambda (theta) (radd (cadr theta) (rmul (car theta) x))))
|
||||
|
|
|
@ -40,10 +40,10 @@
|
|||
(deftest test-rapply ()
|
||||
(== (decons:rapply #'1+ 7) 8)
|
||||
(== (decons:rapply #'1+ '(2 3)) '(3 4))
|
||||
(== (decons:rapply #'+ 2 3) 5)
|
||||
(== (decons:rapply #'+ 3 '(4 5)) '(7 8))
|
||||
(== (decons:rapply #'+ '(2 3) '(4 5)) '((6 7) (7 8))) ; not '(6 8)
|
||||
(== (decons:rapply #'- '(6 7) '(4 5)) '((2 3) (1 2)))
|
||||
(== (decons:radd 2 3) 5)
|
||||
(== (decons:radd 3 '(4 5)) '(7 8))
|
||||
(== (decons:radd '(2 3) '(4 5)) '((6 7) (7 8))) ; not '(6 8)
|
||||
(== (decons:rsub '(6 7) '(4 5)) '((2 3) (1 2)))
|
||||
)
|
||||
|
||||
(deftest test-rreduce ()
|
||||
|
@ -61,7 +61,8 @@
|
|||
(setf ps1 (decons:line (car ds1)))
|
||||
(== (funcall ps1 '(0.5 2.0)) '(3.0 2.5 4.0 3.5))
|
||||
(== (funcall ps1 '(1.0 0.0)) '(2.0 1.0 4.0 3.0))
|
||||
(== (decons:cost (cadr ds1) (funcall ps1 '(1.0 0.0))) 0.20999993) ;0.899999861)
|
||||
(== (decons:default-deviation (cadr ds1) (funcall ps1 '(1.0 0.0)))
|
||||
0.20999993) ;0.899999861)
|
||||
(setf objective (funcall (decons:l2-loss #'decons:line) ds1))
|
||||
(== (funcall objective '(0.0 0.0)) 33.21)
|
||||
))
|
||||
|
|
Loading…
Add table
Reference in a new issue