start experimenting with quadratic functions

This commit is contained in:
Helmut Merz 2025-05-28 11:34:12 +02:00
parent 0768949897
commit 0617177d72
2 changed files with 42 additions and 11 deletions

View file

@ -4,20 +4,27 @@
(:use :common-lisp)
(:local-nicknames (:shape :scopes/shape)
(:util :scopes/util))
(:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv
(:export #:rapply #:rreduce #:radd #:rmul #:rsub #:rdiv #:rsqr
#:default-deviation #:l2-loss
*revisions* *alpha* #:gradient-descent
#:nabla-xp
#:line
#:line #:quad
#:*obj*))
(in-package :decons)
;;;; common (basic) stuff
(defun sqr (x) (* x x))
(defun sum (data) (reduce #'+ data))
;;;; rapply, rreduce - recursive application of operations
(defun rapply (op arg1 &optional arg2)
(if arg2
(rcall (rcurry op arg1) arg2)
;(rcall (rcurry op arg1) arg2)
(rcall2 op arg1 arg2)
(rcall op arg1)))
(defgeneric rcall (op arg)
@ -25,6 +32,15 @@
(:method (op (arg list))
(mapcar (lambda (i) (rcall op i)) arg)))
(defgeneric rcall2 (op a1 a2)
(:method (op a1 a2) (funcall op a1 a2))
(:method (op (a1 list) a2)
(mapcar (lambda (i) (rapply op i a2)) a1))
(:method (op a1 (a2 list))
(mapcar (lambda (j) (rapply op a1 j)) a2))
(:method (op (a1 list) (a2 list))
(mapcar (lambda (i j) (rapply op i j)) a1 a2)))
(defgeneric rcurry (op arg)
(:method (op arg) (lambda (j) (funcall op arg j)))
(:method (op (arg list))
@ -43,13 +59,10 @@
(defun rmul (a b) (rapply #'* a b))
(defun rsub (a b) (rapply #'- a b))
(defun rdiv (a b) (rapply #'/ a b))
(defun rsqr (a) (rapply #'sqr a))
;;;; loss calculation, collect trial data (parameters, resulting loss)
(defun sqr (x) (* x x))
(defun sum (data) (reduce #'+ data))
(defun default-deviation (observed calculated &key (norm #'sqr))
(sum (mapcar (lambda (a b) (funcall norm (- a b)))
observed calculated)))
@ -101,7 +114,14 @@
;;;; parameterized target functions
(defun line (x)
#'(lambda (theta) (radd (cadr theta) (rmul (car theta) x))))
(lambda (theta)
(radd (rmul (first theta) x) (second theta))))
(defun quad (x)
(lambda (theta)
(radd (rmul (first theta) (rsqr x))
(radd (rmul (second theta) x)
(third theta)))))
;;;; working area

View file

@ -16,6 +16,7 @@
(test-rapply)
(test-rreduce)
(test-line)
(test-quad)
(t:show-result)))
(deftest test-basic ()
@ -43,8 +44,9 @@
(== (decons:rapply #'1+ '(2 3)) '(3 4))
(== (decons:radd 2 3) 5)
(== (decons:radd 3 '(4 5)) '(7 8))
(== (decons:radd '(2 3) '(4 5)) '((6 7) (7 8))) ; not '(6 8)
(== (decons:rsub '(6 7) '(4 5)) '((2 3) (1 2)))
(== (decons:radd '(2 3) '(4 5)) '(6 8)) ; not '((6 7) (7 8))
(== (decons:rsub '(6 7) '(4 5)) '(2 2)) ; not '((2 3) (1 2))
(== (decons:rsqr '(2 3 4)) '(4 9 16))
)
(deftest test-rreduce ()
@ -66,8 +68,17 @@
(== (decons:default-deviation (cadr *ds1*) (funcall ps1 '(1.0 0.0)))
0.20999993) ;0.899999861)
(setf objective (funcall (decons:l2-loss #'decons:line) *ds1*))
(setf decons:*obj* objective)
(setf decons:*obj* objective) ; for interactive experiments
(== (funcall objective '(0.0 0.0)) 33.21)
(== (decons:nabla-xp objective '(0.0 0.0)) '(-62.999725 -21.0001))
(== (decons:gradient-descent objective '(0.0 0.0)) '(1.0499986 3.6358833e-6))
))
(defvar *ds2* '((-1.0 0.0 1.0 2.0 3.0)
(2.55 2.1 4.35 10.2 18.25)))
(deftest test-quad ()
(let (ps2 objective)
(setf ps2 (decons:quad (car *ds2*)))
(== (funcall ps2 '(1.0 0.0 0.0)) '(1.0 0.0 1.0 4.0 9.0))
))