/*
-----------------------------------------------------------------------------
This source file is part of OpenSpace3D
For the latest info, see http://www.openspace3d.com

Copyright (c) 2010 I-maginer

This program is free software; you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License as published by the Free Software
Foundation; either version 2 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License along with
this program; if not, write to the Free Software Foundation, Inc., 59 Temple
Place - Suite 330, Boston, MA 02111-1307, USA, or go to
http://www.gnu.org/copyleft/lesser.txt

You may alternatively use this source under the terms of a specific version of
the OpenSpace3D Unrestricted License provided you have obtained such a license from
I-maginer.
-----------------------------------------------------------------------------
*/

/*
 Security based on beeCrypt library : beecrypt.sourceforge.net
 First version : may 2009
 Author : Bastien BOURINEAU
*/

/*! @defgroup grpsecurity Scol functions definition
 *  Scol functions definition
 *  @{
 */
/** @} */

#include <scol.h>
#include "windows.h"
#include <vector>

#include "rsa.h"
using CryptoPP::RSA;
using CryptoPP::InvertibleRSAFunction;
using CryptoPP::RSAES_OAEP_SHA_Encryptor;
using CryptoPP::RSAES_OAEP_SHA_Decryptor;

using CryptoPP::RSAES_PKCS1v15_Encryptor;
using CryptoPP::RSAES_PKCS1v15_Decryptor;

using CryptoPP::InvertibleRSAFunction;
typedef CryptoPP::RSAFunction RSAPublicKey;
typedef CryptoPP::InvertibleRSAFunction RSAPrivateKey;

#include "sha.h"
using CryptoPP::SHA1;

#include "filters.h"
using CryptoPP::StringSink;
using CryptoPP::StringSource;
using CryptoPP::PK_EncryptorFilter;
using CryptoPP::PK_DecryptorFilter;
using CryptoPP::StreamTransformationFilter;

#include "aes.h"
using CryptoPP::AES;

#include "ccm.h"
using CryptoPP::CBC_Mode;

#include "files.h"
using CryptoPP::FileSink;
using CryptoPP::FileSource;

#include "osrng.h"
using CryptoPP::AutoSeededRandomPool;

#include "SecBlock.h"
using CryptoPP::SecByteBlock;

#include "cryptlib.h"
using CryptoPP::Exception;
using CryptoPP::DecodingResult;

#include <hex.h>
using CryptoPP::HexEncoder;
using CryptoPP::HexDecoder;

#include <string>
using std::string;

#include <exception>
using std::exception;

#include <iostream>
using std::cout;
using std::cerr;
using std::endl;

#include <assert.h>

//#define		_SCOL_DEBUG_

//!Scol machine declaration for MM macros
cbmachine	ww;

/*! \mainpage Security Scol Plugin
 *
 * \section intro_sec Introduction
 * This plugin allow Scol Virtual Machine to use RSA and AES cryptography
 * 
 */


/*! @ingroup grpsecurity
* \brief _RSAgetKeyPair : This function generate the RSA private and public key with the size passed in parameter
*
* <b>Prototype:</b> fun [I] [S S]
* \param I : size of the RSA key to generate
*
* \return [S S] : the generated private and public key
*/
int _RSAgetKeyPair(mmachine m)
{
#ifdef _SCOL_DEBUG_
	MMechostr(0,"_RSAgetKeyPair\n");
#endif

	int keysize = MMget(m, 0);
	if (keysize == NIL)
  {
    MMset(m, 0, NIL);
    return 0 ;
  }

  ////////////////////////////////////////////////
  // Generate keys
  AutoSeededRandomPool rng;

  InvertibleRSAFunction parameters;
  parameters.GenerateRandomWithKeySize(rng, MTOI(keysize));

  RSA::PrivateKey privateKey(parameters);
  RSA::PublicKey publicKey(parameters);
	
  //RSAPrivateKey privateKey;
  //privateKey.Initialize(rng, keysize /*, e=17*/);
  //RSAPublicKey publicKey( privateKey );

	if (!privateKey.Validate(rng, 3) || !publicKey.Validate(rng, 3))
	{
		MMechostr(0,"_RSAgetKeyPair Privilege Error\n");
		MMset(m, 0, NIL);
		return 0;
	}
	
	string sPublicKey;
	string sPrivateKey;

	publicKey.Save(HexEncoder(new StringSink(sPublicKey)));
	privateKey.Save(HexEncoder(new StringSink(sPrivateKey)));
	
  // remove function parameter before the str push
  MMpull(m);
	Mpushstrbloc(m, (char *)sPrivateKey.c_str());
	Mpushstrbloc(m, (char *)sPublicKey.c_str());
	if(MMpush(m,2*2))
    return MERRMEM;

	if(int k=MBdeftab(m))
    return k;

#ifdef	_SCOL_DEBUG_
	MMechostr(0,"ok\n");
#endif

	return 0;
}


/*! @ingroup grpsecurity
* \brief _RSAencryptMessage : This function encrypt a message using the RSA public key
*
* <b>Prototype:</b> fun [S S] S
* \param S : the message to encrypt
* \param S : the RSA public key to use
*
* \return S : the encrypted message
*/
int _RSAencryptMessage(mmachine m)
{
#ifdef _SCOL_DEBUG_
	MMechostr(0,"_RSAcryptMessage\n");
#endif

  int key = MTOP(MMpull(m));
	int mess = MTOP(MMpull(m));
	if (mess==NIL || key==NIL)
  {
    MMpush(m, NIL);
    return 0;
  }
	
  int msize = MMsizestr(m, mess);
	byte* smess = (byte*)MMstart(m, (mess)+1);
	char* sckey = MMstartstr(m, key);

	if (smess == NULL || sckey == NULL)
  {
    MMpush(m, NIL);
    return 0;
  }
	
	string sPublicKey = string(sckey);
	string sEncrypted, sEncryptedHex;
	
	//MMechostr(0,"_RSAencryptMessage : %s \n", sMessage.c_str());

	AutoSeededRandomPool rng;

	try
	{
		RSA::PublicKey publicKey;
		HexDecoder decoder;
		decoder.Put((byte*)sPublicKey.c_str(), sPublicKey.size());
		decoder.MessageEnd();
		publicKey.Load(decoder);

		//RSAES_PKCS1v15_Encryptor encrypt( publicKey );
		RSAES_OAEP_SHA_Encryptor encrypt( publicKey );

		StringSource(smess, msize, true, new PK_EncryptorFilter(rng, encrypt, new StringSink(sEncrypted)));
		
		HexEncoder encoder;
		encoder.Attach(new StringSink(sEncryptedHex));
		encoder.Put((byte *)sEncrypted.c_str(), sEncrypted.size());
		encoder.MessageEnd();
	}
	catch (CryptoPP::Exception& encrypt)
	{
		MMechostr(0,"_RSAencryptMessage Error : %s \n", encrypt.what());
		MMpush(m, NIL);
    return 0;
	}

  int res = MMmalloc(m, STR_SIZE(sEncryptedHex.size()), TYPEBUF);
  char* BS = MMstartstr(m, res);
  memcpy(BS, sEncryptedHex.c_str(), sEncryptedHex.size());
  BS[sEncryptedHex.size()] = 0;
	MMstore(m, res, 0, sEncryptedHex.size());
	int k = MMpush(m, PTOM(res));

#ifdef	_SCOL_DEBUG_
	MMechostr(0,"ok\n");
#endif

	return k;
}


/*! @ingroup grpsecurity
* \brief _RSAdecryptMessage : This function decrypt a message using the RSA private key
*
* <b>Prototype:</b> fun [S S] S
* \param S : the message to decrypt
* \param S : the RSA private key to use
*
* \return S : the decrypted message
*/
int _RSAdecryptMessage(mmachine m)
{
#ifdef _SCOL_DEBUG_
	MMechostr(0,"_RSAdecryptMessage\n");
#endif

	int key = MTOP(MMpull(m));
	int mess = MTOP(MMpull(m));
	if (mess==NIL || key==NIL)
  {
    MMpush(m, NIL);
    return 0;
  }
	
	int msize = MMsizestr(m, mess);
	byte* smess = (byte*)MMstart(m, (mess)+1);
	char* sckey = MMstartstr(m,key);

	if (smess == NULL || sckey == NULL)
  {
    MMpush(m, NIL);
    return 0;
  }
	
	string sPrivateKey = string(sckey);
	string sDecrypted, sMessage;

	//MMechostr(0,"_RSAdecryptMessage : %s \n", sMessage.c_str());

	AutoSeededRandomPool rng;
	
	try
	{
		HexDecoder mdecoder;
		mdecoder.Attach(new StringSink(sMessage));
		mdecoder.Put(smess, msize);
		mdecoder.MessageEnd();

		RSA::PrivateKey privateKey;
		HexDecoder kdecoder;
		kdecoder.Put((byte*)sPrivateKey.c_str(), sPrivateKey.size());
		kdecoder.MessageEnd();
		privateKey.Load(kdecoder);

		//RSAES_PKCS1v15_Decryptor decrypt( privateKey );
		RSAES_OAEP_SHA_Decryptor decrypt(privateKey);

		StringSource(sMessage, true, new PK_DecryptorFilter(rng, decrypt,	new StringSink(sDecrypted)));

		//MMechostr(0,"_RSAdecryptMessage : %s \n", sDecrypted.c_str());

	}
	catch (CryptoPP::Exception& decrypt)
	{
		MMechostr(0, "_RSAdecryptMessage Error : %s \n", decrypt.what());
		return MMpush(m, NIL);
	}

  int res = MMmalloc(m, STR_SIZE(sDecrypted.size()), TYPEBUF);
  char* BS = MMstartstr(m, res);
  memcpy(BS, sDecrypted.c_str(), sDecrypted.size());
  BS[sDecrypted.size()] = 0;
	MMstore(m, res, 0, sDecrypted.size());
	int k = MMpush(m, PTOM(res));

#ifdef	_SCOL_DEBUG_
	MMechostr(0,"ok\n");
#endif

	return k;
}


/*! @ingroup grpsecurity
* \brief _AESgetKey : This function generate an AES key
*
* <b>Prototype:</b> fun [] S
*
* \return S : the new AES key
*/
// fun[] [S]
int _AESgetKey(mmachine m)
{
#ifdef _SCOL_DEBUG_
	MMechostr(0,"_AESgetKey\n");
#endif

	AutoSeededRandomPool prng;

	byte key[AES::DEFAULT_KEYLENGTH];
	prng.GenerateBlock(key, sizeof(key));

	string sKey;

	// Pretty print
	StringSource(key, sizeof(key), true, new HexEncoder(new StringSink(sKey)));
	
  int k = Mpushstrbloc(m, (char *)sKey.c_str());

#ifdef	_SCOL_DEBUG_
	MMechostr(0,"ok\n");
#endif

	return k;
}



/*! @ingroup grpsecurity
* \brief _AESencryptMessage : This function encrypt a message using the AES key
*
* <b>Prototype:</b> fun [S S] S
* \param S : the message to encrypt
* \param S : the AES key to use
*
* \return S : the encrypted message
*/
int _AESencryptMessage(mmachine m)
{
#ifdef _SCOL_DEBUG_
	MMechostr(0,"_AESencryptMessage\n");
#endif

	int key = MTOP(MMpull(m));
	int mess = MTOP(MMpull(m));	
	if (mess==NIL || key==NIL)
  {
    MMpush(m, NIL);
    return 0;
  }
	
  int msize = MMsizestr(m, mess);
	byte* smess = (byte*)MMstart(m, (mess)+1);
	char* sckey =  MMstartstr(m, key);
  
	if (smess == NULL || sckey == NULL)
  {
    MMpush(m, NIL);
    return 0;
  }

  //MMechostr(MSKFOO,"_AESencryptMessage : mess char size %i \n", msize);

	string sKeyHex = string(sckey);
	string sEncrypted, sEncryptedHex, sIv, sKey;
  
  //MMechostr(MSKFOO,"_AESencryptMessage : mess string size %i \n", sMessage.length());
	//MMechostr(0,"_AESencryptMessage : %s \n", sMessage.c_str());
	
	AutoSeededRandomPool prng;
	
	byte iv[AES::BLOCKSIZE];
	prng.GenerateBlock(iv, sizeof(iv));

  try
  {
		HexDecoder kdecoder;
		kdecoder.Attach(new StringSink(sKey));
		kdecoder.Put((byte *)sKeyHex.c_str(), sKeyHex.size());
		kdecoder.MessageEnd();
	}
	catch( CryptoPP::Exception& e)
  {
		MMechostr(0,"_AESencryptMessage Error : %s \n", e.what());
		return MMpush(m,NIL) ;
	}

	try
  {
		CBC_Mode< AES >::Encryption encrypt;
		encrypt.SetKeyWithIV((byte *)sKey.c_str(), sKey.size(), iv);

		// The StreamTransformationFilter adds padding
		//  as required. CBC Mode must be padded to the
		//  block size of the cipher.
		StringSource(smess, msize, true, new StreamTransformationFilter(encrypt, new StringSink(sEncrypted)));
	}
	catch( CryptoPP::Exception& encrypt)
  {
		MMechostr(0,"_AESencryptMessage Error : %s \n", encrypt.what());
		return MMpush(m, NIL);
	}

	try
  {
		HexEncoder encoder;
		encoder.Attach(new StringSink(sEncryptedHex));
		encoder.Put((byte *)sEncrypted.c_str(), sEncrypted.size());
		encoder.MessageEnd();
	}
	catch( CryptoPP::Exception& e)
  {
		MMechostr(0,"_AESencryptMessage Error : %s \n", e.what());
		return MMpush(m, NIL);
	}

	try
  {
		// encode iv
		StringSource( iv, sizeof(iv), true, new HexEncoder(new StringSink(sIv)));
	}
	catch( CryptoPP::Exception& e)
	{
		MMechostr(0,"_AESencryptMessage Error : %s \n", e.what());
		return MMpush(m, NIL);
	}

	// Add message after IV for CBC mode
	sIv.append(sEncryptedHex);

  int res = MMmalloc(m, STR_SIZE(sIv.size()), TYPEBUF);
  char* BS = MMstartstr(m, res);
  memcpy(BS, sIv.c_str(), sIv.size());
  BS[sIv.size()] = 0;
	MMstore(m, res, 0, sIv.size());
	int k = MMpush(m, PTOM(res));
	
#ifdef	_SCOL_DEBUG_
	MMechostr(0,"ok\n");
#endif

	return k;
}



/*! @ingroup grpsecurity
* \brief _AESdecryptMessage : This function decrypt a message using the AES key
*
* <b>Prototype:</b> fun [S S] S
* \param S : the message to decrypt
* \param S : the AES key to use
*
* \return S : the decrypted message
*/
int _AESdecryptMessage(mmachine m)
{
#ifdef _SCOL_DEBUG_
	MMechostr(0,"_AESencryptMessage\n");
#endif

	int key = MTOP(MMpull(m));
	int mess = MTOP(MMpull(m));
	if (mess==NIL || key==NIL)
  {
    MMpush(m, NIL);
    return 0;
  }
	
  int msize = MMsizestr(m, mess);
  byte* smess = (byte*)MMstart(m, (mess)+1);
	char* sckey = MMstartstr(m,key);

	if (smess == NULL || sckey == NULL)
  {
    MMpush(m, NIL);
    return 0;
  }
	
	string sKeyHex = string(sckey);
	string sMessage, sEncrypted, sEncryptedHex, sKey;

  try
  {
		HexDecoder mdecoder;
		mdecoder.Attach( new StringSink( sEncrypted ) );
		mdecoder.Put(smess, msize);
		mdecoder.MessageEnd();
	}
	catch( CryptoPP::Exception& e)
  {
		MMechostr(0,"_AESdecryptMessage Error : %s \n", e.what());
		return MMpush(m, NIL);
	}

	try
  {
		HexDecoder kdecoder;
		kdecoder.Attach( new StringSink( sKey ) );
		kdecoder.Put( (byte *)sKeyHex.c_str(), sKeyHex.size() );
		kdecoder.MessageEnd();
	}
	catch( CryptoPP::Exception& e)
  {
		MMechostr(0,"_AESdecryptMessage Error : %s \n", e.what());
		return MMpush(m, NIL);
	}

	try{
		CBC_Mode< AES >::Decryption decrypt;
		// extract the first 16 bits from the message to get IV for CBC mode
		decrypt.SetKeyWithIV((byte *)sKey.c_str(), sKey.size(), (byte *)sEncrypted.substr(0, 16).c_str());
	
    // The StreamTransformationFilter removes
    // padding as required.
		StringSource s(sEncrypted.substr(16), true, new StreamTransformationFilter(decrypt, new StringSink(sMessage)));
	}
	catch( CryptoPP::Exception& decrypt)
	{
		MMechostr(0,"_AESdecryptMessage Error : %s \n", decrypt.what());
		return MMpush(m, NIL);
	}
 
  int res = MMmalloc(m, STR_SIZE(sMessage.size()), TYPEBUF);
  char* BS = MMstartstr(m, res);
  memcpy(BS, sMessage.c_str(), sMessage.size());
  BS[sMessage.size()] = 0;
	MMstore(m, res, 0, sMessage.size());
	int k = MMpush(m, PTOM(res));
  //Mpushstrbloc(m, (char*)sMessage.c_str());

#ifdef	_SCOL_DEBUG_
	MMechostr(0,"ok\n");
#endif

	return k;
}


//! Nb of Scol functions or types
#define NbTplPKG	6


/*!
*	Scol function names
*/
char	*TplName[NbTplPKG] =
{
	"_RSAgetKeyPair",
	"_RSAencryptMessage",
	"_RSAdecryptMessage",
	"_AESgetKey",
	"_AESencryptMessage",
	"_AESdecryptMessage"
};



/*!
*	Pointers to C functions that manipulate the VM for each scol function previously defined
*/
int (*TplFunc[NbTplPKG])(mmachine m)=
{
	_RSAgetKeyPair,												// _RSAgetKeyPair
	_RSAencryptMessage,										// _RSAencryptMessage
	_RSAdecryptMessage,										// _RSAdecryptMessage
	_AESgetKey,														// _AESgetKey
	_AESencryptMessage,										// _AESencryptMessage
	_AESdecryptMessage										// _AESdecryptMessage
};



/*!
*	Nb of arguments of each scol function
*/
int TplNArg[NbTplPKG]=
{
	1,														// _RSAgetKeyPair
	2,														// _RSAencryptMessage
	2,														// _RSAdecryptMessage
	0,														// _AESgetKey
	2,														// _AESencryptMessage
	2															// _AESdecryptMessage
};



/*!
*	Prototypes of the scol functions
*/
char* TplType[NbTplPKG]=
{
	"fun [I] [S S]",										// _RSAgetKeyPair
	"fun [S S] S",											// _RSAencryptMessage
	"fun [S S] S",											// _RSAdecryptMessage
	"fun [] S",													// _AESgetKey
	"fun [S S] S",											// _AESencryptMessage
	"fun [S S] S"												// _AESdecryptMessage
};



// Everything inside _cond and _endcond is ignored by doxygen
//! \cond
/*!
* \brief Load the packages in Scol virtual machine
* \param mmachine : the scol machine
*/
int LoadSecurity(mmachine m)
{
	int k = PKhardpak(m, "SecurityEngine", NbTplPKG, TplName, TplFunc, TplNArg, TplType);
	MMechostr(MSKDEBUG,"\n" );
	return k;
}
//! \endcond


/*! 
* \brief Starting point of the DLL
*/
extern "C" __declspec (dllexport) int ScolLoadPlugin(mmachine m, cbmachine w)
{
	int k = 0;
  SCOLinitplugin(w);
	LoadSecurity(m);
	return k;
}


/*! 
* \brief Ending point of the DLL
*/
extern "C" __declspec (dllexport) int ScolUnloadPlugin()
{
	return 0;
}
