now with function l2-loss

This commit is contained in:
Helmut Merz 2025-05-24 19:29:45 +02:00
parent 0172a2fd46
commit 25197a0520
2 changed files with 18 additions and 7 deletions

View file

@ -7,7 +7,7 @@
#:absv #:double #:remainder #:absv #:double #:remainder
#:scalar-p #:tensor #:at #:scalar-p #:tensor #:at
#:rapply #:rreduce #:rapply #:rreduce
#:combine #:cost #:combine #:cost #:l2-loss
#:line #:line
#:lgx #:lgx
)) ))
@ -97,12 +97,21 @@
(:method (a b) (list a b)) (:method (a b) (list a b))
(:method (a (b list)) (cons a b))) (:method (a (b list)) (cons a b)))
;;;; cost calculation ;;;; cost / loss calculation
(defun cost (have want) (defun cost (measured calculated)
(reduce #'+ (reduce #'+
(mapcar (lambda (p) (abs (apply #'- p))) (mapcar (lambda (p) (sqr (apply #'- p)))
(combine have want)))) (combine measured calculated))))
(defun l2-loss (target &key (norm #'sqr))
(lambda (dataset)
(lambda (theta)
(let ((calculated (funcall (funcall target (car dataset)) theta)))
(reduce #'+ (mapcar (lambda (p) (funcall norm (apply #'- p)))
(combine (cadr dataset) calculated)))))))
(defun sqr (x) (* x x))
;;;; logging (printing) functions for debugging purposes ;;;; logging (printing) functions for debugging purposes

View file

@ -55,11 +55,13 @@
(p2 (decons:line 1.0)) (p2 (decons:line 1.0))
(ds1 '((2.0 1.0 4.0 3.0) (ds1 '((2.0 1.0 4.0 3.0)
(1.8 1.2 4.2 3.3))) (1.8 1.2 4.2 3.3)))
ps1) ps1 objective)
(== (funcall p1 '(0.5 2.0)) 2.0) (== (funcall p1 '(0.5 2.0)) 2.0)
(== (funcall p2 '(0.5 2.0)) 2.5) (== (funcall p2 '(0.5 2.0)) 2.5)
(setf ps1 (decons:line (car ds1))) (setf ps1 (decons:line (car ds1)))
(== (funcall ps1 '(0.5 2.0)) '(3.0 2.5 4.0 3.5)) (== (funcall ps1 '(0.5 2.0)) '(3.0 2.5 4.0 3.5))
(== (funcall ps1 '(1.0 0.0)) '(2.0 1.0 4.0 3.0)) (== (funcall ps1 '(1.0 0.0)) '(2.0 1.0 4.0 3.0))
(== (decons:cost (funcall ps1 '(1.0 0.0)) (cadr ds1)) 0.899999861) (== (decons:cost (cadr ds1) (funcall ps1 '(1.0 0.0))) 0.20999993) ;0.899999861)
(setf objective (funcall (decons:l2-loss #'decons:line) ds1))
(== (funcall objective '(0.0 0.0)) 33.21)
)) ))