; Copyright (C) 2006 Will M. Farr ; ; This program is free software; you can redistribute it and/or modify ; it under the terms of the GNU General Public License as published by ; the Free Software Foundation; either version 2 of the License, or ; (at your option) any later version. ; ; This program is distributed in the hope that it will be useful, ; but WITHOUT ANY WARRANTY; without even the implied warranty of ; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the ; GNU General Public License for more details. ; ; You should have received a copy of the GNU General Public License along ; with this program; if not, write to the Free Software Foundation, Inc., ; 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. #| This program was recently (24/10/2006) modified as per the discussion in the three blog posts here: * http://wmfarr.blogspot.com/2006/10/automatic-differentiation-in-ocaml.html * http://wmfarr.blogspot.com/2006/10/correction-to-automatic.html * http://wmfarr.blogspot.com/2006/10/better-jacobians-in-computing.html It is also long-overdue for a few comments explaining the method used. In order to take derivatives, we define primitive mathematical operations on "differentials". Each differential is three-component structure: a tag (which, as we shall see, identifies the "variable" with respect to which the derivative is taken), an "x" value, and a "dx" value. The "x" and "dx" values may themselves be differentials. A particular differential represents an infinitesimal increment in its variable: d = x + dx. We operate on this value as if dx*dx = 0. For a function f, we have f(x+dx) = f(x) + f'(x)*dx, where the RHS is itself a differential, with the indicated "x" and "dx" components. When a differential has differential components, we maintain a heap property for the resulting tree: for each differential, all contained differentials have larger tags than it (tags are just integer values). With multiple differential tags, we are talking about multiple variables being differentiated, so each differential represents a particular set of coefficients representing an incremental value of a0 + a1*dx + a2*dx*dy + a3*dy + ... = a0 + a3*dy + (a1 + a2*dy)*dx + ... assuming that (tag dx) < (tag dy). For a set of n differential variables, there an 2^n possible terms which contian at most one multiplicative factor of each increment dx, dy, .... Not coincidentally, there are exactly 2^n slots in the tree of differential structures: tag-x / \ tagy tagy / \ / \ a0 a3 a1 a2 and so on. For binary operations, we compare tags, and compute first the derivative wrt the argument that has the smaller tag, passing the argument which has the larger tag recursively to the code for further derivative computation. |# (module derivatives "generics.ss" (require "tuples.ss") (provide D partial define-derivative define-binary-derivative) (defclass () id x dx :auto #t :print #t) (define (deriv-equals? d1 d2) (equals? (deriv-x d1) (deriv-x d2))) (add-equals?-method deriv-equals?) (defgeneric (->deriv x id)) (defmethod (->deriv x id) (make-deriv id x 1)) (defmethod (->deriv (d ) id) (with-slots d ((idd 'id) x dx) (if (id-deriv x id) dx)))) (defgeneric (extract-derivative id d)) (defmethod (extract-derivative id d) (* 0 d)) (defmethod (extract-derivative id (d )) (with-slots d ((idd 'id) x dx) (if (= idd id) dx (make :id idd :x (extract-derivative id x) :dx (extract-derivative id dx))))) (defmethod (extract-derivative id (t )) (tuple-map (lambda (i d) (extract-derivative id d)) t)) (defgeneric (strip-derivative id d)) (defmethod (strip-derivative id d) d) (defmethod (strip-derivative id (d )) (with-slots d ((idd 'id) x dx) (if (= idd id) x (make :id idd :x (strip-derivative id x) :dx (strip-derivative id dx))))) (defmethod (strip-derivative id (t )) (tuple-map (lambda (i x) (strip-derivative id x)) t)) (define new-id (let ((counter 0)) (lambda () (set! counter (+ counter 1)) counter))) (define (id-deriv x id)))))) (define (replace list i obj) (if (= i 0) (cons obj (cdr list)) (cons (car list) (replace (cdr list) (- i 1) obj)))) (define ((partial i) f) (lambda args ((D (lambda (x) (apply f (replace args i x)))) (list-ref args i)))) (define-syntax define-derivative (syntax-rules () ((_ f df) (let ((local-df df)) (add-method f (method ((d )) (with-slots d (id x dx) (make-deriv id (f x) (* (local-df x) dx))))))))) (define-syntax define-binary-derivative (syntax-rules () ((_ f df/dx df/dy) (let ((local-df/dx df/dx) (local-df/dy df/dy)) (add-method f (method ((d1 ) (d2 )) (with-slots d1 ((id1 'id) (x 'x) (dx 'dx)) (with-slots d2 ((id2 'id) (y 'x) (dy 'dx)) (cond ((id-) y) (with-slots d (id x dx) (make-deriv id (f x y) (* (local-df/dx x y) dx))))) (add-method f (method (x (d )) (with-slots d (id (y 'x) (dy 'dx)) (make-deriv id (f x y) (* (local-df/dy x y) dy))))))))) (defmethod (zero? (d )) (zero? (slot-ref d 'x))) (defmethod (greater-than (d ) x) (greater-than (deriv-x d) x)) (defmethod (greater-than x (d )) (greater-than x (deriv-x d))) (defmethod (greater-than-or-equal (d ) x) (greater-than-or-equal (deriv-x d) x)) (defmethod (greater-than-or-equal x (d )) (greater-than-or-equal x (deriv-x d))) (defmethod (less-than (d ) x) (less-than (deriv-x d) x)) (defmethod (less-than x (d )) (less-than x (deriv-x d))) (defmethod (less-than-or-equal (d ) x) (less-than-or-equal (deriv-x d) x)) (defmethod (less-than-or-equal x (d )) (less-than-or-equal x (deriv-x d))) (define-derivative plus (lambda (x) 1)) (define-binary-derivative add (lambda (x y) 1) (lambda (x y) 1)) (defmethod (add (d ) (t )) (tuple-map (lambda (i x) (+ d x)) t)) (defmethod (add (t ) (d )) (tuple-map (lambda (i x) (+ x d)) t)) (define-derivative minus (lambda (x) -1)) (define-binary-derivative sub (lambda (x y) 1) (lambda (x y) -1)) (defmethod (sub (t ) (d )) (tuple-map (lambda (i x) (- x d)) t d)) (defmethod (sub (d ) (t )) (tuple-map (lambda (i x) (- d x)) t)) (define-derivative times (lambda (x) 1)) (define-binary-derivative mul (lambda (x y) y) (lambda (x y) x)) (defmethod (mul (d ) (t )) (tuple-map (lambda (i x) (* d x)) t)) (defmethod (mul (t ) (d )) (tuple-map (lambda (i x) (* x d)))) (define-derivative invert (lambda (x) (minus (invert (mul x x))))) (define-binary-derivative div (lambda (x y) (/ y)) (lambda (x y) (- (/ x (* y y))))) (defmethod (div (d ) (t )) (tuple-map (lambda (i x) (/ d x)) t)) (defmethod (div (t ) (d )) (tuple-map (lambda (i x) (/ x d)) t)) (define-derivative exp exp) (define-derivative log /) (define-derivative sin cos) (define-derivative cos (lambda (x) (- (sin x)))) (define-derivative tan (lambda (x) (/ (* (cos x) (cos x))))) (define-derivative asin (lambda (x) (/ (sqrt (- 1 (* x x)))))) (define-derivative acos (lambda (x) (- (/ (sqrt (- 1 (* x x))))))) (define-derivative atan1 (lambda (x) (/ (+ 1 (* x x))))) (define-derivative sqrt (lambda (x) (/ 1/2 (sqrt x)))) (define-binary-derivative atan2 (lambda (x y) (- (/ y (+ (* x x) (* y y))))) (lambda (x y) (/ x (+ (* x x) (* y y))))) (define-binary-derivative expt (lambda (x y) (* (expt x (- y 1)) y)) (lambda (x y) (* (expt x y) (log x)))) (defmethod (sinc (d )) (with-slots d (id x dx) (if (= x 0) (make-deriv id 1 (* 0 dx)) (make-deriv id (sinc x) (/ (- (cos x) (sinc x)) x))))))