Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 47 additions & 21 deletions deps/src/fileio.f90
Original file line number Diff line number Diff line change
@@ -1,36 +1,62 @@
subroutine openfiles(printnum, sumnum, printerr, sumerr, printfile, sumfile)
module SNOPT_Julia_c

! inputs
integer, intent(in) :: printnum, sumnum
character*250, intent(in) :: printfile, sumfile
use iso_c_binding
use iso_fortran_env

! outputs
integer, intent(out) :: printerr, sumerr
implicit none

open(printnum, file=printfile, action='write', status='replace', iostat=printerr)
open(sumnum, file=sumfile, action='write', status='replace', iostat=sumerr)
private :: copy_a2s

contains

end subroutine
subroutine openfile(funit, ferror, fname, len_fname) bind(C, name = "SNOPT_openfile")

! inputs
integer(kind = c_int), intent(in) :: funit
character(kind = c_char), dimension(*), intent(in) :: fname
integer(kind = c_int), intent(in) :: len_fname
! character*250, intent(in) :: fname

subroutine closefiles(printnum, sumnum)
! outputs
integer(kind = c_int), intent(out) :: ferror

! inputs
integer, intent(in) :: printnum, sumnum
open(funit, file = copy_a2s(fname(1:len_fname)), action = 'write', status = 'replace', iostat = ferror)

close(printnum)
close(sumnum)
end subroutine

end subroutine
function get_stdout() bind(C, name = "SNOPT_get_stdout")
! output
integer(kind = c_int) :: get_stdout
get_stdout = output_unit
end function

subroutine closefile(funit) bind(C, name = "SNOPT_closefile")

subroutine flushfiles(printnum, sumnum)
! inputs
integer(kind = c_int), intent(in) :: funit

! inputs
integer, intent(in) :: printnum, sumnum
close(funit)

flush(printnum)
flush(sumnum)
end subroutine

end subroutine
subroutine flushfile(funit) bind(C, name = "SNOPT_flushfile")

! inputs
integer(kind = c_int), intent(in) :: funit

flush(funit)

end subroutine

pure function copy_a2s(a) result(s)
character(kind = c_char), intent(in) :: a(:)
character(size(a)) :: s
integer :: i
s = ''
do i = 1, size(a)
if (a(i) == c_null_char) exit
s(i:i) = a(i)
end do
end function

end module
160 changes: 120 additions & 40 deletions src/Snopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,44 @@ const codes = Dict(
142 => "System error: error in basis package"
)

const PRINTNUM = 18
const SUMNUM = 19
# wrapper for get_stdout. this returns the output unit to allow printing to the terminal
function get_stdout()
ccall((:SNOPT_get_stdout, snoptlib), Cint, ())
end

const PRINTNUM = Cint(18)
const SUMNUM = Cint(19)
const STDOUTNUM = get_stdout()

"""
haskey_caseless(dict, key)

Caseless implementation of haskey
"""

function haskey_caseless(dict::Dict, key::AbstractString)
for (k, _) in dict
if lowercase(k) == lowercase(key)
return true
end
end
return false
end

"""
get_caseless(dict, key, default = nothing)
"""

function get_caseless(dict::Dict, key::AbstractString, default = nothing)
key_lower = lowercase(key)
for (k, v) in dict
if lowercase(k) == key_lower
return v
end
end
return default
end

"""
Names(prob, xnames, fnames)

Expand Down Expand Up @@ -192,7 +226,7 @@ end


# wrapper for snInit
function sninit(nx, nf)
function sninit(nx, nf, printunit, sumunit)

# temporary working arrays
minlen = 500
Expand All @@ -204,61 +238,106 @@ function sninit(nx, nf)
ccall( (:sninit_, snoptlib), Nothing,
(Ref{Cint}, Ref{Cint}, Ptr{Cuchar}, Ref{Cint}, Ptr{Cint},
Ref{Cint}, Ptr{Cdouble}, Ref{Cint}),
PRINTNUM, SUMNUM, w.cw, w.lencw, w.iw,
printunit, sumunit, w.cw, w.lencw, w.iw,
w.leniw, w.rw, w.lenrw)

return w
end



# wrapper for openfiles. not defined with snopt, fortran file supplied in repo (from pyoptsparse)
function openfiles(printfile, sumfile)

# open files for printing (not part of snopt distribution)
printerr = Cint[0]
sumerr = Cint[0]
ccall( (:openfiles_, snoptlib), Nothing,
(Ref{Cint}, Ref{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cuchar}, Ptr{Cuchar}),
PRINTNUM, SUMNUM, printerr, sumerr, printfile, sumfile)

if printerr[1] != 0
@warn "failed to open print file"
end
if sumerr[1] != 0
@warn "failed to open summary file"
if isnothing(printfile)
printunit = Cint(0)
else
if isempty(printfile)
printunit = STDOUTNUM
else
len_printfile = Cint(length(printfile))
ccall( (:SNOPT_openfile, snoptlib), Nothing,
(Ref{Cint}, Ptr{Cint}, Ptr{Cuchar}, Ref{Cint}),
PRINTNUM, printerr, printfile, len_printfile)
if printerr[1] != Cint(0)
@warn "failed to open print file"
printunit = Cint(0)
else
printunit = PRINTNUM
end
end
end

return nothing
if isnothing(sumfile)
sumunit = Cint(0)
else

if isempty(sumfile)
sumunit = STDOUTNUM
else
len_sumfile = Cint(length(sumfile))
ccall( (:SNOPT_openfile, snoptlib), Nothing,
(Ref{Cint}, Ptr{Cint}, Ptr{Cuchar}, Ref{Cint}),
SUMNUM, sumerr, sumfile, len_sumfile)
if sumerr[1] != Cint(0)
@warn "failed to open summary file"
sumunit = Cint(0)
else
sumunit = SUMNUM
end
end
end

return (printunit, sumunit)
end

# wrapper for closefiles. not defined with snopt, fortran file supplied in repo (from pyoptsparse)
function closefiles()
function closefiles(printunit, sumunit)
# close output files
ccall( (:closefiles_, snoptlib), Nothing,
(Ref{Cint}, Ref{Cint}),
PRINTNUM, SUMNUM)
if printunit > 0
ccall( (:SNOPT_closefile, snoptlib), Nothing,
(Ref{Cint},), printunit)
end

if sumunit > 0
ccall( (:SNOPT_closefile, snoptlib), Nothing,
(Ref{Cint},), sumunit)
end

return nothing
end

# wrapper for flushfiles. not defined with snopt, fortran file supplied in repo (from pyoptsparse)
function flushfiles()
function flushfiles(printunit, sumunit)
# flush output files to see progress
ccall( (:flushfiles_, snoptlib), Nothing,
(Ref{Cint}, Ref{Cint}),
PRINTNUM, SUMNUM)
if printunit > 0
ccall( (:SNOPT_flushfile, snoptlib), Nothing,
(Ref{Cint},), printunit)
end

if sumunit > 0
ccall( (:SNOPT_flushfile, snoptlib), Nothing,
(Ref{Cint},), sumunit)
end

return nothing
end

# wrapper for snSet, snSeti, snSetr
function setoptions(options, work)
function setoptions(options, work, printunit, sumunit)

# --- set options ----
errors = Cint[0]

for key in keys(options)
value = options[key]

lower_key = lowercase(key)
value = get_caseless(options, key)

isnothing(value) && continue

buffer = string(key, repeat(" ", 55-length(key))) # buffer length is 55 so pad with space.

if length(key) > 55
Expand All @@ -274,28 +353,28 @@ function setoptions(options, work)
# for more information please see
# https://web.stanford.edu/group/SOL/guides/sndoc7.pdf page 66 (sec 7.5).

if key != "Print file" && key != "Summary file"
if lower_key != "print file" && lower_key != "summary file"
value = string(value, repeat(" ", 72-length(value)))
ccall( (:snset_, snoptlib), Nothing,
(Ptr{Cuchar}, Ref{Cint}, Ref{Cint}, Ptr{Cint},
Ptr{Cuchar}, Ref{Cint}, Ptr{Cint}, Ref{Cint}, Ptr{Cdouble}, Ref{Cint}),
value, PRINTNUM, SUMNUM, errors,
value, printunit, sumunit, errors,
work.cw, work.lencw, work.iw, work.leniw, work.rw, work.lenrw)
end
elseif isinteger(value)

ccall( (:snseti_, snoptlib), Nothing,
(Ptr{Cuchar}, Ref{Cint}, Ref{Cint}, Ref{Cint}, Ptr{Cint},
Ptr{Cuchar}, Ref{Cint}, Ptr{Cint}, Ref{Cint}, Ptr{Cdouble}, Ref{Cint}),
buffer, value, PRINTNUM, SUMNUM, errors,
buffer, value, printunit, sumunit, errors,
work.cw, work.lencw, work.iw, work.leniw, work.rw, work.lenrw)

elseif isreal(value)

ccall( (:snsetr_, snoptlib), Nothing,
(Ptr{Cuchar}, Ref{Cdouble}, Ref{Cint}, Ref{Cint}, Ptr{Cint},
Ptr{Cuchar}, Ref{Cint}, Ptr{Cint}, Ref{Cint}, Ptr{Cdouble}, Ref{Cint}),
buffer, value, PRINTNUM, SUMNUM, errors,
buffer, value, printunit, sumunit, errors,
work.cw, work.lencw, work.iw, work.leniw, work.rw, work.lenrw)
end

Expand All @@ -310,7 +389,7 @@ end


# wrapper for snMemA
function setmemory(INFO, nf, nx, nxname, nfname, neA, neG, work)
function setmemory(INFO, nf, nx, nxname, nfname, neA, neG, work, printunit, sumunit)

mincw = Cint[0]
miniw = Cint[0]
Expand Down Expand Up @@ -353,7 +432,7 @@ function setmemory(INFO, nf, nx, nxname, nfname, neA, neG, work)
ccall( (:snseti_, snoptlib), Nothing,
(Ptr{Cuchar}, Ref{Cint}, Ref{Cint}, Ref{Cint}, Ptr{Cint},
Ptr{Cuchar}, Ref{Cint}, Ptr{Cint}, Ref{Cint}, Ptr{Cdouble}, Ref{Cint}),
buffer, value, PRINTNUM, SUMNUM, errors,
buffer, value, printunit, sumunit, errors,
work.cw, work.lencw, work.iw, work.leniw, work.rw, work.lenrw)
if errors[1] > 0
@warn errors[1], " error encountered while lengths in options from memory sizing"
Expand Down Expand Up @@ -554,23 +633,24 @@ function snopta(func!, start::Start, lx, ux, lg, ug, rows, cols,
printfile = "snopt-print.out"
sumfile = "snopt-summary.out"

if haskey(options, "Print file")
printfile = options["Print file"]
if haskey_caseless(options, "print file")
printfile = get_caseless(options, "print file")
end
if haskey(options, "Summary file")
sumfile = options["Summary file"]
if haskey_caseless(options, "summary file")
sumfile = get_caseless(options, "summary file")
end
openfiles(printfile, sumfile)

printunit, sumunit = openfiles(printfile, sumfile)

# ----- initialize -------
work = sninit(nx, nf)
work = sninit(nx, nf, printunit, sumunit)

# --- set options ----
setoptions(options, work)
setoptions(options, work, printunit, sumunit)

# ---- set memory requirements ------
INFO = Cint[0]
setmemory(INFO, nf, nx, nxname, nfname, neA, neG, work)
setmemory(INFO, nf, nx, nxname, nfname, neA, neG, work, printunit, sumunit)

# --- call snopta ----
mincw = Cint[0]
Expand Down Expand Up @@ -610,7 +690,7 @@ function snopta(func!, start::Start, lx, ux, lg, ug, rows, cols,


# close output files
closefiles()
closefiles(printunit, sumunit)

# pack outputs
warm = WarmStart(ns[1], start.xstate, start.fstate, start.x, start.f,
Expand Down