Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions itensor/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ SOURCES+= decomp.cc
SOURCES+= svd.cc
SOURCES+= hermitian.cc
SOURCES+= global.cc
SOURCES+= linsystem.cc
SOURCES+= mps/mps.cc
SOURCES+= mps/mpsalgs.cc
SOURCES+= mps/mpo.cc
Expand Down
145 changes: 145 additions & 0 deletions itensor/decomp.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,21 @@ eigen(ITensorT<I> const& T,
// ITensorT<I> & Rinv,
// Args const& args = Args::global());

//
// Linear system of equations
//
// The routine solves for X the system of linear equations A*X = B,
// where A is an n-by-n matrix, the columns of matrix B are individual right-hand sides,
// and the columns of X are the corresponding solutions.
//
// Resulting X is returned in B
//
template<class I>
void
linsystem(ITensorT<I> & A,
ITensorT<I> & B,
ITensorT<I> & X,
Args args = Args::global());

///////////////////////////
//
Expand Down Expand Up @@ -402,7 +417,137 @@ diagHermitian(ITensorT<I> const& M,
return spec;
} //diagHermitian

template<typename I>
void
lin_system(ITensorT<I> A,
ITensorT<I> & B,
ITensorT<I> & X,
Args const& args);

template<class I>
void
linsystem(ITensorT<I> & A,
ITensorT<I> & B,
ITensorT<I> & X,
Args args)
{
if(!args.defined("IndexName")) args.add("IndexName","ind_x");
auto dbg = args.getBool("dbg",false);

// TODO MORE SOPHISTICATED INDEX ANALYSIS
//
// Pick an arbitrary index and do some analysis
// on its prime level spacing
//
auto k = A.inds().front();
auto kps = stdx::reserve_vector<int>(rank(A));
for(auto& i : A.inds()) if(i.noprimeEquals(k)) kps.push_back(i.primeLevel());
if (dbg) {
std::cout<<"[linsystem] primeLevels: ";
for(unsigned int i=0; i<kps.size(); i++) std::cout<< kps[i] <<" ";
std::cout<< std::endl;
}
if (kps.size() <= 1ul || kps.size()%2 != 0ul) {
Error("Input tensor A (n x n Matrix of coeffs) to linsystem should \
have pairs of indices with equally spaced prime levels. Odd number of \
same indices (modulo primeLevel)");
}

int pdiff = -1;
if(!args.defined("plDiff")) {
auto nk = kps.size();
std::sort(kps.begin(),kps.end());
//idiff == "inner" difference between cluster of low-prime-level copies
// of k, if more than one
auto idiff = kps.at(nk/2-1)-kps.front();
//mdiff == max prime-level difference of copies of k
auto mdiff = kps.back()-kps.front();
//pdiff == spacing between lower and higher prime level index pairs
pdiff = mdiff-idiff;
if (dbg) {
std::cout<<"[linsystem] plDiff NOT provided"<< std::endl;
std::cout<<"[linsystem] idiff: "<< idiff <<" mdiff: "<< mdiff <<" pdiff: "
<< pdiff << std::endl;
}
} else {
pdiff = args.getInt("plDiff",-1);
}
if (pdiff == -1) { Error("[linsystem] Invalid plDiff value"); }

auto inds = stdx::reserve_vector<I>(rank(A)/2);
for(auto& i : A.inds())
for(auto& j : A.inds())
{
if(i.noprimeEquals(j) && i.primeLevel()+pdiff == j.primeLevel())
{
inds.push_back(i);
}
}
if(inds.empty() || rank(A)/2 != (long)inds.size())
{
Error("Input tensor to linsystem should have pairs of indices with equally spaced prime levels");
}
if (dbg) {
std::cout<<"[linsystem] inds: ";
for (unsigned int i=0; i<inds.size(); i++) std::cout << inds[i] <<" ";
std::cout<< std::endl;
}

// compare against indices of B
// B.inds() must match either inds or prime(inds,pdiff) exactly
if (rank(B) != rank(A)/2) { Error("[linsystem] rank(B) != rank(A)/2"); }
int plB_check = 0;
auto bInds = stdx::reserve_vector<I>(rank(A)/2);
for (auto& i : inds) {
for (auto& j : B.inds()) {
if (i==j)
bInds.push_back(j);
else if (prime(i,pdiff)==j) {
bInds.push_back(j);
plB_check += pdiff;
}
}
}
if (dbg) {
std::cout<<"[linsystem] bInds: ";
for (unsigned int i=0; i<bInds.size(); i++) std::cout << bInds[i] <<" ";
std::cout<< std::endl;
}
if ((bInds.size() != rank(A)/2) || !(plB_check==0 || plB_check==rank(B)*pdiff)) {
std::cout << "bInds.size() != rank(A)/2 : "<< (bInds.size() != rank(A)/2) << std::endl;
std::cout << "plB_check==0 : "<< (plB_check==0) << std::endl;
std::cout << "plB_check==rank(B)*pdiff : "<< (plB_check==rank(B)*pdiff) << std::endl;
Error("[linsystem] B.inds() are incompatible with indices of matrix A");
}

auto combA = combiner(std::move(inds), args);
auto combB = combiner(std::move(bInds),args);
auto Ac = A*combA;
auto Bc = B*combB;

auto combAP = dag(prime(combA,pdiff));
try {
Ac *= combAP;
}
catch(ITError const& e)
{
println("Diagonalize expects opposite arrow directions for primed and unprimed indices.");
throw e;
}

if (dbg) {
std::cout<<"[linsystem] ";
PrintData(Ac);
std::cout<<"[linsystem] ";
PrintData(Bc);
}
lin_system(Ac,Bc,X,args);

// restore original indices of B on X
X = combB * X;

} //linsystem

//Return value is: (trunc_error,docut)
std::tuple<Real,Real>
truncate(Vector & P,
Expand Down
92 changes: 92 additions & 0 deletions itensor/linsystem.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include <algorithm>
#include <tuple>
#include "itensor/util/stdx.h"
#include "itensor/tensor/algs.h"
#include "itensor/decomp.h"
#include "itensor/util/print_macro.h"
#include "itensor/itdata/qutil.h"

namespace itensor {

template<typename T>
void
linsystemImpl(ITensor A,
ITensor& B,
ITensor& X,
Args const& args)
{
auto dbg = args.getBool("dbg",false);

if(A.r() != 2)
{
Print(A.r());
Print(A);
Error("Rank greater than 2 in lin_system");
}

auto i1 = A.inds().front();
auto i2 = A.inds().back();
auto active = (i1.primeLevel() < i2.primeLevel()) ? i1 : i2;
auto pdiff = std::abs(i1.primeLevel()-i2.primeLevel());

auto ib = B.inds().front();
auto dummy = Index("dummy",1);

if (dbg) {
std::cout<<"[linsystemImpl] primeLevel difference on A: "<< pdiff << std::endl;
std::cout<<"[linsystemImpl] indices of A: "<< std::endl;
std::cout<<"[linsystemImpl] active: "<< active <<" other: "
<< prime(active) << std::endl;
std::cout<<"[linsystemImpl] indices of B: "<< ib << std::endl;
}

Mat<T> XX;
auto RA = toMatRefc<T>(A,active,prime(active));
auto RB = toMatRefc<T>(B,ib,dummy);

linSystem(RA,RB,XX,args);

X = ITensor({ib,dummy},Dense<T>{move(XX.storage())});
// X = ITensor({ib,dummy},Dense<T>{move(XX.storage())},A.scale());
// if(not A.scale().isTooBigForReal()) {
X *= (B.scale().real0()/A.scale().real0());
// } else {
// println("lin_systemImpl: scale too big for Real");
// }

// absorb dummy index (TODO if necessary)
auto combX = combiner(ib,dummy);
X = X*combX;
X = X*delta(commonIndex(combX,X),ib);
if (dbg) {
std::cout<<"[linsystemImpl] combX: "<< combX;
PrintData(X);
}
}

template<typename I>
void
lin_system(ITensorT<I> A,
ITensorT<I> & B,
ITensorT<I> & X,
Args const& args)
{
if(isComplex(A))
{
return linsystemImpl<Cplx>(A,B,X,args);
}
return linsystemImpl<Real>(A,B,X,args);
}
template
void
lin_system(ITensor A,
ITensor & B,
ITensor & X,
Args const& args);
// template
// void
// lin_system(IQTensor A,
// IQTensor & B,
// Args const& args);

} //namespace itensor
37 changes: 36 additions & 1 deletion itensor/tensor/algs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "itensor/util/range.h"
#include "itensor/global.h"


using std::move;
using std::sqrt;
using std::tuple;
Expand All @@ -31,6 +32,40 @@ namespace detail {
{
return zheev_wrapper(N,Udata,ddata);
}

int
dtl_linSystem(int N, int Nb, Real const * Adata, Real * Bdata, Args const& args)
{
auto arg_method = args.getString("method","LU");
auto arg_dbg = args.getBool("dbg",false);
LAPACK_INT info = 0;
if (arg_method == "CHOLESKY") {
if(arg_dbg) std::cout<<"[dtl_linSystem] calling dposv"<< std::endl;
info = dposv_wrapper(N,Nb,Adata,Bdata);
} else if (arg_method == "LU") {
if(arg_dbg) std::cout<<"[dtl_linSystem] calling dgesv"<< std::endl;
info = dgesv_wrapper(N,Nb,Adata,Bdata);
} else {
std::cout<<"[dtl_linSystem] Unsupported method: "<< arg_method << std::endl;
exit(EXIT_FAILURE);
}
return info;
}
int
dtl_linSystem(int N, int Nb, Cplx const * Adata, Cplx * Bdata, Args const& args)
{
auto arg_method = args.getString("method","LU");
LAPACK_INT info = 0;
if (arg_method == "CHOLESKY") {
info = zposv_wrapper(N,Nb,Adata,Bdata);
} else if (arg_method == "LU") {
info = zgesv_wrapper(N,Nb,Adata,Bdata);
} else {
std::cout<<"[dtl_linSystem] Unsupported method: "<< arg_method << std::endl;
exit(EXIT_FAILURE);
}
return info;
}
} //namespace detail

//void
Expand Down Expand Up @@ -574,6 +609,7 @@ SVDRefImpl(MatRefc<T> const& M,
return;
}


template<typename T>
void
SVDRef(MatRefc<T> const& M,
Expand All @@ -588,7 +624,6 @@ template void SVDRef(MatRefc<Real> const&,MatRef<Real> const&, VectorRef const&,
template void SVDRef(MatRefc<Cplx> const&,MatRef<Cplx> const&, VectorRef const&, MatRef<Cplx> const&,Real);



//void
//SVDRef(MatrixRefc const& Mre,
// MatrixRefc const& Mim,
Expand Down
Loading