diff --git a/src/helper.jl b/src/helper.jl index c97dc36..82be791 100644 --- a/src/helper.jl +++ b/src/helper.jl @@ -150,23 +150,47 @@ macro timing(cond,code) end function SparseMatrix.sparse(I,J,V, M, N;keepzeros=false) + @assert length(I) == length(J) == length(V) if(!keepzeros) return sparse(I,J,V,M,N) else - full = sparse(I,J,ones(Float64,length(I)),M,N) - actual = sparse(I,J,V,M,N) - fill!(full.nzval,0.0) + mergednnz = [0] + mergedmap = zeros(Int,length(I)) + idxmap = zeros(Int,length(I)) + mergedindices = zeros(Int,length(I)) + function combine(idx1,idx2) + # @show idx1, idx2 + idx1 = round(Int,idx1) + idx2 = round(Int,idx2) + @inbounds @assert mergedmap[idx2] == 0 && (mergedmap[idx1] == idx1 || mergedmap[idx1] == 0) + @inbounds mergednnz[1] += 1 + @inbounds mergedmap[idx1] = idx1 + @inbounds mergedmap[idx2] = idx1 + @inbounds mergedindices[mergednnz[1]] = idx2 + return idx1 + end + + full = sparse(I,J,[float(i) for i in 1:length(I)],M,N,combine) + for col in 1:N + @inbounds for pos in full.colptr[col]:(full.colptr[col+1]-1) + @inbounds row = full.rowval[pos] + @inbounds origidx = round(Int,full.nzval[pos]) # this is the original index (on JJ) of this element + @inbounds idxmap[origidx] = pos + end + end - for c = 1:N - for i=nzrange(actual,c) - r = actual.rowval[i] - v = actual.nzval[i] - if(v!=0) - full[r,c] = v - end - end - # full.nzval[crange] = actual.nzval[crange] - end + @inbounds for k in 1:mergednnz[1] + @inbounds origidx = mergedindices[k] + @inbounds mergedwith = mergedmap[origidx] + @inbounds @assert idxmap[origidx] == 0 + @inbounds @assert idxmap[mergedwith] != 0 + @inbounds idxmap[origidx] = idxmap[mergedwith] + end + + fill!(full.nzval,0.0) + for i in 1:length(I) + @inbounds full.nzval[idxmap[i]] += V[i] + end return full end end