provide a trial class (together with a list of trials) to collect loss values for varied parameter sets

This commit is contained in:
Helmut Merz 2025-05-25 17:49:06 +02:00
parent 52d1381aba
commit cff6311103
2 changed files with 27 additions and 9 deletions

View file

@ -2,12 +2,14 @@
(defpackage :decons (defpackage :decons
(:use :common-lisp) (:use :common-lisp)
(:local-nicknames (:util :scopes/util)) (:local-nicknames (:shape :scopes/shape)
(:util :scopes/util))
(:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv (:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv
#:combine #:default-deviation #:l2-loss #:combine #:default-deviation #:l2-loss
#:*trials* #:trial #:try
#:line #:line
#:lgx #:lgx
#:obj #:*obj*
)) ))
(in-package :decons) (in-package :decons)
@ -69,6 +71,18 @@
(calculated (funcall objective theta))) (calculated (funcall objective theta)))
(funcall deviation (cadr dataset) calculated))))) (funcall deviation (cadr dataset) calculated)))))
(defvar *trials* nil)
(defclass trial ()
((theta :reader theta :initarg :theta)
(loss :reader loss :initarg :loss)))
(defmethod print-object ((tr trial) stream)
(shape:print-slots tr stream 'theta 'loss))
(defun try (obj theta)
(push (make-instance 'trial :theta theta :loss (funcall obj theta)) *trials*))
;;;; logging (printing) functions for debugging purposes ;;;; logging (printing) functions for debugging purposes
(defun lgx (op) (defun lgx (op)
@ -84,4 +98,4 @@
;;;; working area ;;;; working area
(defparameter obj nil) (defvar *obj* nil)

View file

@ -51,20 +51,24 @@
(== (decons:rreduce #'+ '(1 2 (3 4))) 10) (== (decons:rreduce #'+ '(1 2 (3 4))) 10)
) )
(defvar *ds1* '((2.0 1.0 4.0 3.0)
(1.8 1.2 4.2 3.3)))
(deftest test-line () (deftest test-line ()
(let ((p1 (decons:line 0.0)) (let ((p1 (decons:line 0.0))
(p2 (decons:line 1.0)) (p2 (decons:line 1.0))
(ds1 '((2.0 1.0 4.0 3.0)
(1.8 1.2 4.2 3.3)))
ps1 objective) ps1 objective)
(== (funcall p1 '(0.5 2.0)) 2.0) (== (funcall p1 '(0.5 2.0)) 2.0)
(== (funcall p2 '(0.5 2.0)) 2.5) (== (funcall p2 '(0.5 2.0)) 2.5)
(setf ps1 (decons:line (car ds1))) (setf ps1 (decons:line (car *ds1*)))
(== (funcall ps1 '(0.5 2.0)) '(3.0 2.5 4.0 3.5)) (== (funcall ps1 '(0.5 2.0)) '(3.0 2.5 4.0 3.5))
(== (funcall ps1 '(1.0 0.0)) '(2.0 1.0 4.0 3.0)) (== (funcall ps1 '(1.0 0.0)) '(2.0 1.0 4.0 3.0))
(== (decons:default-deviation (cadr ds1) (funcall ps1 '(1.0 0.0))) (== (decons:default-deviation (cadr *ds1*) (funcall ps1 '(1.0 0.0)))
0.20999993) ;0.899999861) 0.20999993) ;0.899999861)
(setf objective (funcall (decons:l2-loss #'decons:line) ds1)) (setf objective (funcall (decons:l2-loss #'decons:line) *ds1*))
(setf decons:obj objective) (setf decons:*obj* objective)
(== (funcall objective '(0.0 0.0)) 33.21) (== (funcall objective '(0.0 0.0)) 33.21)
(setf decons:*trials* nil)
(decons:try objective '(0.0 0.0))
(decons:try objective '(0.0099 0.0))
)) ))