如何从Julia调用Numerical Recipes svdcmp

时间:2018-02-26 18:14:27

标签: fortran julia gfortran

首先,我知道Julia确实有一个svd内在函数,但它并不完全符合我的需要。相反,来自Numerical Recipes的svdcmp会这样做。

所以,子程序是这样的:

  MODULE nrtype
  INTEGER, PARAMETER :: I4B = SELECTED_INT_KIND(9)
  INTEGER, PARAMETER :: I2B = SELECTED_INT_KIND(4)
  INTEGER, PARAMETER :: I1B = SELECTED_INT_KIND(2)
  INTEGER, PARAMETER :: SP = KIND(1.0)
  INTEGER, PARAMETER :: DP = KIND(1.0D0)
  INTEGER, PARAMETER :: SPC = KIND((1.0,1.0))
  INTEGER, PARAMETER :: DPC = KIND((1.0D0,1.0D0))
  INTEGER, PARAMETER :: LGT = KIND(.true.)
  REAL(SP), PARAMETER :: PI=3.141592653589793238462643383279502884197_sp
  REAL(SP), PARAMETER :: PIO2=1.57079632679489661923132169163975144209858_sp
  REAL(SP), PARAMETER :: TWOPI=6.283185307179586476925286766559005768394_sp
  REAL(SP), PARAMETER :: SQRT2=1.41421356237309504880168872420969807856967_sp
  REAL(SP), PARAMETER :: EULER=0.5772156649015328606065120900824024310422_sp
  REAL(DP), PARAMETER :: PI_D=3.141592653589793238462643383279502884197_dp
  REAL(DP), PARAMETER :: PIO2_D=1.57079632679489661923132169163975144209858_dp
  REAL(DP), PARAMETER :: TWOPI_D=6.283185307179586476925286766559005768394_dp
  TYPE sprs2_sp
    INTEGER(I4B) :: n,len
    REAL(SP), DIMENSION(:), POINTER :: val
    INTEGER(I4B), DIMENSION(:), POINTER :: irow
    INTEGER(I4B), DIMENSION(:), POINTER :: jcol
  END TYPE sprs2_sp
  TYPE sprs2_dp
    INTEGER(I4B) :: n,len
    REAL(DP), DIMENSION(:), POINTER :: val
    INTEGER(I4B), DIMENSION(:), POINTER :: irow
    INTEGER(I4B), DIMENSION(:), POINTER :: jcol
  END TYPE sprs2_dp
END MODULE nrtype

MODULE nrutil
    USE nrtype
    IMPLICIT NONE
    INTEGER(I4B), PARAMETER :: NPAR_ARTH=16,NPAR2_ARTH=8
    INTEGER(I4B), PARAMETER :: NPAR_GEOP=4,NPAR2_GEOP=2
    INTEGER(I4B), PARAMETER :: NPAR_CUMSUM=16
    INTEGER(I4B), PARAMETER :: NPAR_CUMPROD=8
    INTEGER(I4B), PARAMETER :: NPAR_POLY=8
    INTEGER(I4B), PARAMETER :: NPAR_POLYTERM=8

    INTERFACE assert_eq
        MODULE PROCEDURE assert_eq2,assert_eq3,assert_eq4,assert_eqn
    END INTERFACE

    INTERFACE outerprod
        MODULE PROCEDURE outerprod_r,outerprod_d
    END INTERFACE

CONTAINS
    FUNCTION assert_eq2(n1,n2,string)
    CHARACTER(LEN=*), INTENT(IN) :: string
    INTEGER, INTENT(IN) :: n1,n2
    INTEGER :: assert_eq2
    if (n1 == n2) then
        assert_eq2=n1
    else
        write (*,*) 'nrerror: an assert_eq failed with this tag:', &
            string
        STOP 'program terminated by assert_eq2'
    end if
    END FUNCTION assert_eq2
!BL
    FUNCTION assert_eq3(n1,n2,n3,string)
    CHARACTER(LEN=*), INTENT(IN) :: string
    INTEGER, INTENT(IN) :: n1,n2,n3
    INTEGER :: assert_eq3
    if (n1 == n2 .and. n2 == n3) then
        assert_eq3=n1
    else
        write (*,*) 'nrerror: an assert_eq failed with this tag:', &
            string
        STOP 'program terminated by assert_eq3'
    end if
    END FUNCTION assert_eq3
!BL
    FUNCTION assert_eq4(n1,n2,n3,n4,string)
    CHARACTER(LEN=*), INTENT(IN) :: string
    INTEGER, INTENT(IN) :: n1,n2,n3,n4
    INTEGER :: assert_eq4
    if (n1 == n2 .and. n2 == n3 .and. n3 == n4) then
        assert_eq4=n1
    else
        write (*,*) 'nrerror: an assert_eq failed with this tag:', &
            string
        STOP 'program terminated by assert_eq4'
    end if
    END FUNCTION assert_eq4
!BL
    FUNCTION assert_eqn(nn,string)
    CHARACTER(LEN=*), INTENT(IN) :: string
    INTEGER, DIMENSION(:), INTENT(IN) :: nn
    INTEGER :: assert_eqn
    if (all(nn(2:) == nn(1))) then
        assert_eqn=nn(1)
    else
        write (*,*) 'nrerror: an assert_eq failed with this tag:', &
            string
        STOP 'program terminated by assert_eqn'
    end if
    END FUNCTION assert_eqn
    !BL
    SUBROUTINE nrerror(string)
    CHARACTER(LEN=*), INTENT(IN) :: string
    write (*,*) 'nrerror: ',string
    STOP 'program terminated by nrerror'
    END SUBROUTINE nrerror
!BL
    FUNCTION outerprod_r(a,b)
    REAL(SP), DIMENSION(:), INTENT(IN) :: a,b
    REAL(SP), DIMENSION(size(a),size(b)) :: outerprod_r
    outerprod_r = spread(a,dim=2,ncopies=size(b)) * &
        spread(b,dim=1,ncopies=size(a))
    END FUNCTION outerprod_r
!BL
    FUNCTION outerprod_d(a,b)
    REAL(DP), DIMENSION(:), INTENT(IN) :: a,b
    REAL(DP), DIMENSION(size(a),size(b)) :: outerprod_d
    outerprod_d = spread(a,dim=2,ncopies=size(b)) * &
        spread(b,dim=1,ncopies=size(a))
    END FUNCTION outerprod_d
!BL
END MODULE nrutil

MODULE nr
        INTERFACE pythag
        FUNCTION pythag_dp(a,b)
        USE nrtype
        REAL(DP), INTENT(IN) :: a,b
        REAL(DP) :: pythag_dp
        END FUNCTION pythag_dp
!BL
        FUNCTION pythag_sp(a,b)
        USE nrtype
        REAL(SP), INTENT(IN) :: a,b
        REAL(SP) :: pythag_sp
        END FUNCTION pythag_sp
    END INTERFACE
END MODULE nr


SUBROUTINE svdcmp_dp(a,w,v)
    USE nrtype; USE nrutil, ONLY : assert_eq,nrerror,outerprod
    USE nr, ONLY : pythag
    IMPLICIT NONE
    REAL(DP), DIMENSION(:,:), INTENT(INOUT) :: a
    REAL(DP), DIMENSION(:), INTENT(OUT) :: w
    REAL(DP), DIMENSION(:,:), INTENT(OUT) :: v
    INTEGER(I4B) :: i,its,j,k,l,m,n,nm
    REAL(DP) :: anorm,c,f,g,h,s,scale,x,y,z
    REAL(DP), DIMENSION(size(a,1)) :: tempm
    REAL(DP), DIMENSION(size(a,2)) :: rv1,tempn
    m=size(a,1)
    write(*,*)"size(a,1)= ",size(a,1)
    write(*,*)"size(a,2)= ",size(a,2)
    write(*,*)"size(v,1)= ",size(v,1)
    write(*,*)"size(v,2)= ",size(v,2)
    write(*,*)"size(w)  = ",size(w)
    n=assert_eq(size(a,2),size(v,1),size(v,2),size(w),'svdcmp_dp')
    g=0.0
    scale=0.0
    do i=1,n
        l=i+1
        rv1(i)=scale*g
        g=0.0
        scale=0.0
        if (i <= m) then
            scale=sum(abs(a(i:m,i)))
            if (scale /= 0.0) then
                a(i:m,i)=a(i:m,i)/scale
                s=dot_product(a(i:m,i),a(i:m,i))
                f=a(i,i)
                g=-sign(sqrt(s),f)
                h=f*g-s
                a(i,i)=f-g
                tempn(l:n)=matmul(a(i:m,i),a(i:m,l:n))/h
                a(i:m,l:n)=a(i:m,l:n)+outerprod(a(i:m,i),tempn(l:n))
                a(i:m,i)=scale*a(i:m,i)
            end if
        end if
        w(i)=scale*g
        g=0.0
        scale=0.0
        if ((i <= m) .and. (i /= n)) then
            scale=sum(abs(a(i,l:n)))
            if (scale /= 0.0) then
                a(i,l:n)=a(i,l:n)/scale
                s=dot_product(a(i,l:n),a(i,l:n))
                f=a(i,l)
                g=-sign(sqrt(s),f)
                h=f*g-s
                a(i,l)=f-g
                rv1(l:n)=a(i,l:n)/h
                tempm(l:m)=matmul(a(l:m,l:n),a(i,l:n))
                a(l:m,l:n)=a(l:m,l:n)+outerprod(tempm(l:m),rv1(l:n))
                a(i,l:n)=scale*a(i,l:n)
            end if
        end if
    end do
    anorm=maxval(abs(w)+abs(rv1))
    do i=n,1,-1
        if (i < n) then
            if (g /= 0.0) then
                v(l:n,i)=(a(i,l:n)/a(i,l))/g
                tempn(l:n)=matmul(a(i,l:n),v(l:n,l:n))
                v(l:n,l:n)=v(l:n,l:n)+outerprod(v(l:n,i),tempn(l:n))
            end if
            v(i,l:n)=0.0
            v(l:n,i)=0.0
        end if
        v(i,i)=1.0
        g=rv1(i)
        l=i
    end do
    do i=min(m,n),1,-1
        l=i+1
        g=w(i)
        a(i,l:n)=0.0
        if (g /= 0.0) then
            g=1.0_dp/g
            tempn(l:n)=(matmul(a(l:m,i),a(l:m,l:n))/a(i,i))*g
            a(i:m,l:n)=a(i:m,l:n)+outerprod(a(i:m,i),tempn(l:n))
            a(i:m,i)=a(i:m,i)*g
        else
            a(i:m,i)=0.0
        end if
        a(i,i)=a(i,i)+1.0_dp
    end do
    do k=n,1,-1
        do its=1,30
            do l=k,1,-1
                nm=l-1
                if ((abs(rv1(l))+anorm) == anorm) exit
                if ((abs(w(nm))+anorm) == anorm) then
                    c=0.0
                    s=1.0
                    do i=l,k
                        f=s*rv1(i)
                        rv1(i)=c*rv1(i)
                        if ((abs(f)+anorm) == anorm) exit
                        g=w(i)
                        h=pythag(f,g)
                        w(i)=h
                        h=1.0_dp/h
                        c= (g*h)
                        s=-(f*h)
                        tempm(1:m)=a(1:m,nm)
                        a(1:m,nm)=a(1:m,nm)*c+a(1:m,i)*s
                        a(1:m,i)=-tempm(1:m)*s+a(1:m,i)*c
                    end do
                    exit
                end if
            end do
            z=w(k)
            if (l == k) then
                if (z < 0.0) then
                    w(k)=-z
                    v(1:n,k)=-v(1:n,k)
                end if
                exit
            end if
            if (its == 30) call nrerror('svdcmp_dp: no convergence in svdcmp')
            x=w(l)
            nm=k-1
            y=w(nm)
            g=rv1(nm)
            h=rv1(k)
            f=((y-z)*(y+z)+(g-h)*(g+h))/(2.0_dp*h*y)
            g=pythag(f,1.0_dp)
            f=((x-z)*(x+z)+h*((y/(f+sign(g,f)))-h))/x
            c=1.0
            s=1.0
            do j=l,nm
                i=j+1
                g=rv1(i)
                y=w(i)
                h=s*g
                g=c*g
                z=pythag(f,h)
                rv1(j)=z
                c=f/z
                s=h/z
                f= (x*c)+(g*s)
                g=-(x*s)+(g*c)
                h=y*s
                y=y*c
                tempn(1:n)=v(1:n,j)
                v(1:n,j)=v(1:n,j)*c+v(1:n,i)*s
                v(1:n,i)=-tempn(1:n)*s+v(1:n,i)*c
                z=pythag(f,h)
                w(j)=z
                if (z /= 0.0) then
                    z=1.0_dp/z
                    c=f*z
                    s=h*z
                end if
                f= (c*g)+(s*y)
                x=-(s*g)+(c*y)
                tempm(1:m)=a(1:m,j)
                a(1:m,j)=a(1:m,j)*c+a(1:m,i)*s
                a(1:m,i)=-tempm(1:m)*s+a(1:m,i)*c
            end do
            rv1(l)=0.0
            rv1(k)=f
            w(k)=x
        end do
    end do
END SUBROUTINE svdcmp_dp

请注意,我只包含我需要的模块部分(仅适用于此情况)。然后,我将其编译成共享库,如:

gfortran -shared -fPIC svdcmp_dp.f90 -o svdcmp_dp.so
到目前为止,这太好了。

我接下来要做的是朱莉娅:

julia> M=5
julia> a=rand(M,M) #just to see if it works
julia> v=zeros(M,M)
julia> w=zeros(M)
julia> t=ccall((:svdcmp_dp_, "./svdcmp_dp.so")
       , Void
       , ( Ref{Float64} # array a(mp,np)
        ,  Ref{Float64} # array w
        ,  Ref{Float64} # array v
       ) 
       ,a,w,v)

我得到了:

julia> t=ccall((:svdcmp_dp_, "./svdcmp_dp.so")
       , Void
       , ( Ref{Float64} # array a(mp,np)
        ,  Ref{Float64} # array w
        ,  Ref{Float64} # array v
       ) 
       ,a,w,v)
 size(a,1)=            0
 size(a,2)=            0
 size(v,1)=            1
 size(v,2)=            1
 size(w)  =            1
 nrerror: an assert_eq failed with this tag:svdcmp_dp
STOP program terminated by assert_eq4

所以,实际上,我的调用是正常的,但显然,Fortran 90的size内在函数并没有返回我期望的结果。我这样说是因为svdcmp_dp.f90中的第一行调用了函数assert_eq4并确定尺寸不兼容。这不应该发生,因为我选择[5 X 5],w [5],v [5,5],对吗?

我在F90中搜索了size,并找到了这个:

说明

Determine the extent of ARRAY along a specified dimension DIM, or the total number of elements in ARRAY if DIM is absent.

    Standard:
    Fortran 95 and later, with KIND argument Fortran 2003 and later

    Class:
    Inquiry function

    Syntax:
    RESULT = SIZE(ARRAY[, DIM [, KIND]])

    Arguments:
    ARRAY   Shall be an array of any type. If ARRAY is a pointer 
    it must be associated and allocatable arrays must be allocated.
    DIM (Optional) shall be a scalar of type INTEGER and its value shall 
    be in the range from 1 to n, where n equals the rank of ARRAY.
    KIND    (Optional) An INTEGER initialization expression indicating the 
    kind parameter of the result.

所以,我猜这个问题与allocable&amp;的a,v属性有关。 w。或pointer问题(使用指针零经验!)

1 个答案:

答案 0 :(得分:1)

我实际上通过替换声明解决了这个问题:

SUBROUTINE svdcmp_dp(a,w,v)
    USE nrtype; USE nrutil, ONLY : assert_eq,nrerror,outerprod
    USE nr, ONLY : pythag
    IMPLICIT NONE
    REAL(DP), DIMENSION(:,:), INTENT(INOUT) :: a
    REAL(DP), DIMENSION(:), INTENT(OUT) :: w
    REAL(DP), DIMENSION(:,:), INTENT(OUT) :: v
    INTEGER(I4B) :: i,its,j,k,l,m,n,nm
    REAL(DP) :: anorm,c,f,g,h,s,scale,x,y,z
    REAL(DP), DIMENSION(size(a,1)) :: tempm
    REAL(DP), DIMENSION(size(a,2)) :: rv1,tempn
    m=size(a,1)

到:

SUBROUTINE svdcmp_dp(Ma,Na,a,w,v)
            USE nrtype; USE nrutil, ONLY : assert_eq,nrerror,outerprod
            USE nr, ONLY : pythag
            IMPLICIT NONE
            INTEGER(I4B) :: i,its,j,k,l,Ma,Na,m,n,nm

            REAL(DP), DIMENSION(Ma,Na), INTENT(INOUT) :: a
            REAL(DP), DIMENSION(Na), INTENT(INOUT) :: w
            REAL(DP), DIMENSION(Na,Na), INTENT(INOUT) :: v

            REAL(DP) :: anorm,c,f,g,h,s,scale,x,y,z
            REAL(DP), DIMENSION(size(a,1)) :: tempm
            REAL(DP), DIMENSION(size(a,2)) :: rv1,tempn

请注意,最后一个也包含了输入数组的维数!

PD: 此外,代码需要模块(它不完整):

MODULE nr
    INTERFACE pythag
          MODULE PROCEDURE pythag_dp, pythag_sp
    END INTERFACE
CONTAINS

    FUNCTION pythag_dp(a,b)
        USE nrtype
        IMPLICIT NONE
        REAL(DP), INTENT(IN) :: a,b
        REAL(DP) :: pythag_dp
        REAL(DP) :: absa,absb
        absa=abs(a)
        absb=abs(b)
        if (absa > absb) then
            pythag_dp=absa*sqrt(1.0_dp+(absb/absa)**2)
        else
            if (absb == 0.0) then
                pythag_dp=0.0
            else
                pythag_dp=absb*sqrt(1.0_dp+(absa/absb)**2)
            end if
        end if
    END FUNCTION pythag_dp
!BL
    FUNCTION pythag_sp(a,b)
        USE nrtype
        IMPLICIT NONE
        REAL(SP), INTENT(IN) :: a,b
        REAL(SP) :: pythag_sp
        REAL(SP) :: absa,absb
        absa=abs(a)
        absb=abs(b)
            if (absa > absb) then
                pythag_sp=absa*sqrt(1.0_sp+(absb/absa)**2)
            else
                if (absb == 0.0) then
                    pythag_sp=0.0
                else
                    pythag_sp=absb*sqrt(1.0_sp+(absa/absb)**2)
                end if
            end if
    END FUNCTION pythag_sp

END MODULE nr

运行它(首先,编译为库):

julia> Na = 10;
julia> Ma = 10;
julia> w = zeros(Na);
julia> v = zeros(Na,Na);
julia> a = rand(Ma,Na);
julia> t = ccall((:svdcmp_dp_, "./svdcmp_dp.so")
              , Void
              , ( Ref{Int64}   # dim Ma
              ,   Ref{Int64}   # dim Na
              ,   Ref{Float64} # array a(Ma,Na)
               ,  Ref{Float64} # array w(Na)
               ,  Ref{Float64} # array v(Na,Na)
              ) 
              ,Ma,Na,a,w,v)
 size(a,1)=           10
 size(a,2)=           10
 size(v,1)=           10
 size(v,2)=           10
 size(w)  =           10

julia> a
10×10 Array{Float64,2}:
 -0.345725  -0.152634   -0.308378    0.16358    -0.0320809  …  -0.47387     0.429124    -0.45121   
 -0.262689   0.337605   -0.0870571   0.409442   -0.160302      -0.0551756   0.16718      0.612903  
 -0.269915   0.410518   -0.0546271  -0.251295   -0.465747       0.328763   -0.109375    -0.476041  
 -0.33862   -0.238028    0.3538     -0.110374    0.294611       0.052966    0.44796     -0.0296113 
 -0.327258  -0.432601   -0.250865    0.478916   -0.0284979      0.0839667  -0.557761    -0.0956028 
 -0.265429  -0.199584   -0.178273   -0.300575   -0.578186   …  -0.0561654   0.164844     0.35431   
 -0.333577   0.588873   -0.0587738   0.213815    0.349599       0.0573156   0.00210332  -0.0764212 
 -0.358586  -0.246824    0.211746    0.0193308   0.0844788      0.64333     0.105043     0.0645999 
 -0.340235   0.0145761  -0.344321   -0.602982    0.422866      -0.15449    -0.309766     0.220315  
 -0.301303   0.051581    0.712463   -0.0297202  -0.162096      -0.458565   -0.360566    -0.00623828

julia> w
10-element Array{Float64,1}:
 4.71084 
 1.47765 
 1.06096 
 0.911895
 0.123196
 0.235218
 0.418629
 0.611456
 0.722386
 0.688394

julia> v
10×10 Array{Float64,2}:
 -0.252394   0.128972   -0.0839656   0.6905     …   0.357651    0.0759095  -0.0858018  -0.111576 
 -0.222082  -0.202181   -0.0485353  -0.217066       0.11651    -0.223779    0.780065   -0.288588 
 -0.237793   0.109989    0.473947    0.155364       0.0821913  -0.61879     0.119753    0.33927  
 -0.343341  -0.439985   -0.459649   -0.233768       0.0948844  -0.155143   -0.233945    0.53929  
 -0.24665    0.0670331  -0.108927    0.119793      -0.520865    0.454486    0.375191    0.226854 
 -0.194316   0.301428    0.236947   -0.118114   …  -0.579563   -0.183961   -0.19942     0.0545692
 -0.349481  -0.61546     0.475366    0.227209      -0.0975147   0.274104   -0.0994582  -0.0834197
 -0.457956   0.349558    0.263727   -0.506634       0.418154    0.378996   -0.113577   -0.0262257
 -0.451763   0.0283005  -0.328583   -0.0121005     -0.219985   -0.276867   -0.269783   -0.604697 
 -0.27929    0.373724   -0.288427    0.246083       0.0529508   0.0369404   0.197368    0.265678 

喝彩!