add master password

This commit is contained in:
j3d1 2019-03-28 11:49:59 +01:00
parent 21e6c2b407
commit 2bda15fc7c
12 changed files with 1160 additions and 234 deletions

View file

@ -1,21 +1,17 @@
#include <iostream>
#include <cstring>
#include "Database.h"
using namespace std;
bool Database::open(string filename) {
if (sqlite3_open(filename.c_str(), &database) == SQLITE_OK)
return true;
return false;
}
namespace shepherd{
vector<vector<string> > Database::query2(string query) {
sqlite3_stmt *statement;
vector<vector<string> > results;
if (!strcmp(query.c_str(), ""))
return results;
if (sqlite3_prepare_v2(database, query.c_str(), -1, &statement, 0)
if (sqlite3_prepare_v2(database.handle(), query.c_str(), -1, &statement, 0)
== SQLITE_OK) {
int cols = sqlite3_column_count(statement);
int result = 0;
@ -37,13 +33,11 @@ vector<vector<string> > Database::query2(string query) {
sqlite3_finalize(statement);
}
string error = sqlite3_errmsg(database);
string error = sqlite3_errmsg(database.handle());
if (error != "not an error")
cout << query << " " << error << endl;
return results;
}
void Database::close() {
sqlite3_close(database);
}
}

View file

@ -6,64 +6,65 @@
#include <string>
#include <vector>
#include <sqlite3.h>
#include "memdb.h"
using namespace std;
class Database {
namespace shepherd {
private:
sqlite3 *database;
class Database {
class QueryStream : public std::ostream {
private:
class QueryBuf : public std::stringbuf {
private:
Database *m_db;
memdb database;
class QueryStream : public std::ostream {
private:
class QueryBuf : public std::stringbuf {
private:
Database *m_db;
public:
QueryBuf(Database *db) {
m_db = db;
}
~QueryBuf() {
pubsync();
}
int sync() {
m_db->result = m_db->query2(str());
str("");
return 0;
}
};
public:
QueryStream(Database *db) :
std::ostream(new QueryBuf(db)) {
}
~QueryStream() {
delete rdbuf();
}
};
public:
QueryBuf(Database *db) {
m_db = db;
bool open(std::string filename);
std::vector <std::vector<std::string>> query2(std::string query);
void close();
std::vector <std::vector<std::string>> result;
QueryStream query;
Database(const std::string &filename, const std::string &secret) : database(filename, secret),query(this) {
};
~Database() {
}
~QueryBuf() {
pubsync();
}
int sync() {
m_db->result = m_db->query2(str());
str("");
return 0;
}
};
public:
QueryStream(Database *db) :
std::ostream(new QueryBuf(db)) {
}
~QueryStream() {
delete rdbuf();
}
};
public:
bool open(string filename);
vector<vector<string> > query2(string query);
void close();
vector<vector<string> > result;
QueryStream query;
Database(const std::string &filename) :
query(this) {
database = NULL;
open(filename);
};
~Database() {
}
};
}
#endif

View file

@ -10,130 +10,127 @@
#include "Manager.h"
using namespace std;
Manager::Manager(string file) :
db(file) {
db.query
<< "CREATE TABLE IF NOT EXISTS passwd (type varchar(20),user varchar(255),host varchar(255),passwd varchar(512));"
<< flush;
}
Manager::~Manager() {
db.close();
}
int Manager::add(string pattern, string passwd) {
std::cmatch sm;
if (regex_match(pattern.c_str(), sm,
std::regex(
"^([a-zA-Z0-9]+):([a-zA-Z0-9/_\\.\\-]+)@([a-zA-Z0-9/_\\.\\-]+)$"))) {
db.query << "INSERT INTO passwd (type, user, host, passwd) VALUES('"
<< sm[1] << "', '" << sm[2] << "','" << sm[3] << "','" << passwd
<< "');" << flush;
} else if (regex_match(pattern.c_str(), sm,
std::regex("^([a-zA-Z0-9]+):([a-zA-Z0-9/_\\.\\-]+)$"))) {
db.query << "INSERT INTO passwd (type, user, host, passwd) VALUES('"
<< sm[1] << "', '" << sm[2] << "','" << sm[2] << "','" << passwd
<< "');" << flush;
} else {
cerr << "invalid pattern: " << pattern << endl;
}
return 0;
}
int Manager::create(string pattern) {
char charset[] = "ABCDEFGHIJKLMOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_";
int len = strlen(charset);
string secret;
for (int i = 0; i < 20; i++) {
secret += charset[rand() % len];
}
cout << secret << endl;
add(pattern, secret);
return 0;
}
int Manager::show() {
db.query << "SELECT type, user, host, passwd FROM passwd;" << flush;
for (vector<string> row : db.result) {
cout << row.at(0) << ":" << row.at(1) << "@" << row.at(2) << "\t"
<< row.at(3) << endl;
}
return 0;
}
int Manager::clear() {
db.query << "DELETE FROM passwd;" << flush;
return 0;
}
int Manager::get(string pattern) {
std::cmatch sm;
if (regex_match(pattern.c_str(), sm,
std::regex(
"^([a-zA-Z0-9%]+):([a-zA-Z0-9%/_\\.\\-]+)@([a-zA-Z0-9%/_\\.\\-]+)$"))) {
db.query << "SELECT * FROM passwd WHERE 1";
if (string("*").compare(sm[1]) != 0)
db.query << " AND type LIKE '" << sm[1] << "'";
if (string("*").compare(sm[2]) != 0)
db.query << " AND user LIKE '" << sm[2] << "'";
if (string("*").compare(sm[2]) != 0)
db.query << " AND host LIKE '" << sm[3] << "'";
db.query << ";" << flush;
for (vector<string> row : db.result) {
cout << row.at(0) << ":" << row.at(1) << "@" << row.at(2) << "\t"
<< row.at(3) << endl;
}
return 0;
} else if (regex_match(pattern.c_str(), sm,
std::regex("^([a-zA-Z0-9]+):([a-zA-Z0-9/_\\.\\-]+)$"))) {
namespace shepherd {
Manager::Manager(string file, const std::string &secret) :
db(file, secret) {
db.query
<< "SELECT * FROM passwd WHERE (type, user, host, passwd) VALUES('"
<< sm[1] << "', '" << sm[2] << "','" << sm[2] << "');" << flush;
<< "CREATE TABLE IF NOT EXISTS passwd (type varchar(20),user varchar(255),host varchar(255),passwd varchar(512));"
<< flush;
}
int Manager::add(string pattern, string passwd) {
std::cmatch sm;
if (regex_match(pattern.c_str(), sm,
std::regex(
"^([a-zA-Z0-9]+):([a-zA-Z0-9/_\\.\\-]+)@([a-zA-Z0-9/_\\.\\-]+)$"))) {
db.query << "INSERT INTO passwd (type, user, host, passwd) VALUES('"
<< sm[1] << "', '" << sm[2] << "','" << sm[3] << "','" << passwd
<< "');" << flush;
} else if (regex_match(pattern.c_str(), sm,
std::regex("^([a-zA-Z0-9]+):([a-zA-Z0-9/_\\.\\-]+)$"))) {
db.query << "INSERT INTO passwd (type, user, host, passwd) VALUES('"
<< sm[1] << "', '" << sm[2] << "','" << sm[2] << "','" << passwd
<< "');" << flush;
} else {
cerr << "invalid pattern: " << pattern << endl;
}
return 0;
}
int Manager::create(string pattern) {
char charset[] = "ABCDEFGHIJKLMOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_";
int len = strlen(charset);
string secret;
for (int i = 0; i < 20; i++) {
secret += charset[rand() % len];
}
cout << secret << endl;
add(pattern, secret);
return 0;
}
int Manager::show() {
db.query << "SELECT type, user, host, passwd FROM passwd;" << flush;
for (vector<string> row : db.result) {
cout << row.at(0) << ":" << row.at(1) << "@" << row.at(2) << "\t"
<< row.at(3) << endl;
}
return 0;
} else {
cout << "fehler: " << pattern << endl;
return 1;
}
int Manager::clear() {
db.query << "DELETE FROM passwd;" << flush;
return 0;
}
int Manager::get(string pattern) {
std::cmatch sm;
if (regex_match(pattern.c_str(), sm,
std::regex(
"^([a-zA-Z0-9%]+):([a-zA-Z0-9%/_\\.\\-]+)@([a-zA-Z0-9%/_\\.\\-]+)$"))) {
db.query << "SELECT * FROM passwd WHERE 1";
if (string("*").compare(sm[1]) != 0)
db.query << " AND type LIKE '" << sm[1] << "'";
if (string("*").compare(sm[2]) != 0)
db.query << " AND user LIKE '" << sm[2] << "'";
if (string("*").compare(sm[2]) != 0)
db.query << " AND host LIKE '" << sm[3] << "'";
db.query << ";" << flush;
for (vector<string> row : db.result) {
cout << row.at(0) << ":" << row.at(1) << "@" << row.at(2) << "\t"
<< row.at(3) << endl;
}
return 0;
} else if (regex_match(pattern.c_str(), sm,
std::regex("^([a-zA-Z0-9]+):([a-zA-Z0-9/_\\.\\-]+)$"))) {
db.query
<< "SELECT * FROM passwd WHERE (type, user, host, passwd) VALUES('"
<< sm[1] << "', '" << sm[2] << "','" << sm[2] << "');" << flush;
for (vector<string> row : db.result) {
cout << row.at(0) << ":" << row.at(1) << "@" << row.at(2) << "\t"
<< row.at(3) << endl;
}
return 0;
} else {
cout << "fehler: " << pattern << endl;
return 1;
}
}
int Manager::del(string pattern) {
std::cmatch sm;
if (regex_match(pattern.c_str(), sm,
std::regex(
"^([a-zA-Z0-9%]+):([a-zA-Z0-9%/_\\.\\-]+)@([a-zA-Z0-9%/_\\.\\-]+)$"))) {
db.query << "DELETE FROM passwd WHERE 1";
if (string("*").compare(sm[1]) != 0)
db.query << " AND type LIKE '" << sm[1] << "'";
if (string("*").compare(sm[2]) != 0)
db.query << " AND user LIKE '" << sm[2] << "'";
if (string("*").compare(sm[2]) != 0)
db.query << " AND host LIKE '" << sm[3] << "'";
db.query << ";" << flush;
for (vector<string> row : db.result) {
cout << row.at(0) << ":" << row.at(1) << "@" << row.at(2) << "\t"
<< row.at(3) << endl;
}
return 0;
} else if (regex_match(pattern.c_str(), sm,
std::regex("^([a-zA-Z0-9]+):([a-zA-Z0-9/_\\.\\-]+)$"))) {
db.query
<< "DELETE FROM passwd WHERE (type, user, host, passwd) VALUES('"
<< sm[1] << "', '" << sm[2] << "','" << sm[2] << "');" << flush;
for (vector<string> row : db.result) {
cout << row.at(0) << ":" << row.at(1) << "@" << row.at(2) << "\t"
<< row.at(3) << endl;
}
return 0;
} else {
cout << "fehler: " << pattern << endl;
return 1;
}
}
}
int Manager::del(string pattern) {
std::cmatch sm;
if (regex_match(pattern.c_str(), sm,
std::regex(
"^([a-zA-Z0-9%]+):([a-zA-Z0-9%/_\\.\\-]+)@([a-zA-Z0-9%/_\\.\\-]+)$"))) {
db.query << "DELETE FROM passwd WHERE 1";
if (string("*").compare(sm[1]) != 0)
db.query << " AND type LIKE '" << sm[1] << "'";
if (string("*").compare(sm[2]) != 0)
db.query << " AND user LIKE '" << sm[2] << "'";
if (string("*").compare(sm[2]) != 0)
db.query << " AND host LIKE '" << sm[3] << "'";
db.query << ";" << flush;
for (vector<string> row : db.result) {
cout << row.at(0) << ":" << row.at(1) << "@" << row.at(2) << "\t"
<< row.at(3) << endl;
}
return 0;
} else if (regex_match(pattern.c_str(), sm,
std::regex("^([a-zA-Z0-9]+):([a-zA-Z0-9/_\\.\\-]+)$"))) {
db.query
<< "DELETE FROM passwd WHERE (type, user, host, passwd) VALUES('"
<< sm[1] << "', '" << sm[2] << "','" << sm[2] << "');" << flush;
for (vector<string> row : db.result) {
cout << row.at(0) << ":" << row.at(1) << "@" << row.at(2) << "\t"
<< row.at(3) << endl;
}
return 0;
} else {
cout << "fehler: " << pattern << endl;
return 1;
}
}

View file

@ -12,26 +12,25 @@
#include <string>
#include "Database.h"
class Manager {
public:
Manager(std::string);
namespace shepherd {
class Manager {
public:
Manager(std::string, const std::string &secret);
~Manager();
int add(std::string pattern, std::string passwd);
int add(std::string pattern, std::string passwd);
int create(std::string pattern);
int create(std::string pattern);
int show();
int show();
int clear();
int clear();
int get(std::string pattern);
int get(std::string pattern);
int del(std::string pattern);
private:
Database db;
};
int del(std::string pattern);
private:
Database db;
};
}
#endif /* MANAGER_H_ */

191
src/crypto.cpp Normal file
View file

@ -0,0 +1,191 @@
//
// Created by jedi on 3/16/19.
//
#include <fstream>
#include <iostream>
#include "crypto.h"
#include <unistd.h>
#include <termios.h>
#define CONTEXT "SHEPHERD"
namespace shepherd::crypto {
using namespace std;
static char getch() {
char buf = 0;
struct termios old = {0};
if (tcgetattr(0, &old) < 0)
perror("tcsetattr()");
old.c_lflag &= ~ICANON;
old.c_lflag &= ~ECHO;
old.c_cc[VMIN] = 1;
old.c_cc[VTIME] = 0;
if (tcsetattr(0, TCSANOW, &old) < 0)
perror("tcsetattr ICANON");
if (read(0, &buf, 1) < 0)
perror ("read()");
old.c_lflag |= ICANON;
old.c_lflag |= ECHO;
if (tcsetattr(0, TCSADRAIN, &old) < 0)
perror ("tcsetattr ~ICANON");
return (buf);
}
std::string read_pw()
{
std::string buf;
char ch=1;
while(ch)
{
ch = getch();
// Enter
if(ch == '\n' || ch == '\r')
break;
// Backspace
else if(ch == 127 && buf.size()>0)
{
buf.pop_back();
}
// Next char
else if(isprint(ch, cout.getloc()))
{
buf.push_back(ch);
}
//else{
// cout << (int)ch << endl;
//}
}
return buf;
}
static kdf_t kdf(std::string secret, herd_file_t *file) {
kdf_t ret;
//master key
if (crypto_pwhash
(ret.master_key, sizeof ret.master_key, secret.c_str(), secret.size(), file->salt,
crypto_pwhash_OPSLIMIT_INTERACTIVE, crypto_pwhash_MEMLIMIT_INTERACTIVE,
crypto_pwhash_ALG_DEFAULT) != 0) {
/* out of memory */
std::cerr << "errör" << std::endl;
exit(1);
}
//used key
crypto_kdf_derive_from_key(ret.key, sizeof ret.key, 2, CONTEXT, ret.master_key);
/*printf("salt:\t");
for (auto e: file->salt) {
printf("%02x", e);
}
printf("\n");
printf("master key:\t");
for (auto e: ret.master_key) {
printf("%02x", e);
}
printf("\n");
printf("key:\t");
for (auto e: ret.key) {
printf("%02x", e);
}
printf("\n");*/
return ret;
}
void save(const std::string filename, bytes buf) {
std::ofstream ofs(filename, std::ios::out | std::ios::binary);
ofs.write((char *) buf.data(), buf.size());
ofs.close();
}
void save(const std::string filename, bytes buf, const std::string secret) {
herd_file_t *file;
file = (herd_file_t *) malloc(sizeof(herd_file_t) + buf.size() + crypto_aead_xchacha20poly1305_ietf_ABYTES);
file->version = 2;
randombytes_buf(file->salt, sizeof file->salt);
kdf_t pw_key = kdf(secret,file);
randombytes_buf(file->nonce, sizeof file->nonce);
crypto_aead_xchacha20poly1305_ietf_encrypt(file->ciphertext, &file->ciphertext_len,
buf.data(), buf.size(),
NULL, 0,
NULL, file->nonce, pw_key.key);
unsigned char *raw = (unsigned char *) file;
size_t len = sizeof(herd_file_t) + file->ciphertext_len;
bytes blob(raw, &raw[len]);
save(filename, blob);
free(file);
}
bytes load(const std::string filename) {
std::ifstream ifs(filename, std::ios::binary | std::ios::ate);
std::ifstream::pos_type pos = ifs.tellg();
bytes result(pos);
ifs.seekg(0, std::ios::beg);
ifs.read((char *) result.data(), pos);
return result;
}
bytes load(const std::string filename, const std::string secret) {
bytes buf = load(filename);
if (buf.size() <= sizeof(herd_file_t)) {
std::cerr << "buf.size()<= sizeof(herd_file_t)" << std::endl;
exit(1);
}
herd_file_t *file = (herd_file_t *) buf.data();
if (buf.size() < sizeof(herd_file_t) + file->ciphertext_len) {
std::cerr << "buf.size()< sizeof(herd_file_t)+file->ciphertext_len" << std::endl;
std::cerr << buf.size() << "< " << sizeof(herd_file_t) + file->ciphertext_len << std::endl;
exit(1);
}
std::cout << "version: " << file->version << std::endl;
kdf_t pw_key = kdf(secret, file);
bytes decrypted(file->ciphertext_len);
unsigned long long decrypted_len;
if (crypto_aead_xchacha20poly1305_ietf_decrypt(decrypted.data(), &decrypted_len,
NULL,
file->ciphertext, file->ciphertext_len,
NULL, 0,
file->nonce, pw_key.key) != 0) {
/* message forged! */
std::cerr << "errör" << std::endl;
exit(1);
}
decrypted.resize(decrypted_len);
return decrypted;
}
}

39
src/crypto.h Normal file
View file

@ -0,0 +1,39 @@
//
// Created by jedi on 3/16/19.
//
#include <sodium.h>
#include <vector>
#ifndef SHEPHERD_CRYPTO_H
#define SHEPHERD_CRYPTO_H
typedef std::vector<uint8_t> bytes;
namespace shepherd::crypto{
typedef struct {
uint8_t master_key[crypto_kdf_KEYBYTES];
unsigned char key[crypto_aead_xchacha20poly1305_ietf_KEYBYTES];
} kdf_t;
typedef struct{
uint32_t version;
unsigned char nonce[crypto_aead_xchacha20poly1305_ietf_NPUBBYTES];
unsigned char salt[crypto_pwhash_SALTBYTES];
unsigned long long ciphertext_len;
unsigned char ciphertext[];
} herd_file_t;
std::string read_pw();
void save(const std::string filename, bytes buf, const std::string secret);
void save(const std::string filename, bytes buf);
bytes load(const std::string filename, const std::string secret);
bytes load(const std::string filename);
}
#endif //SHEPHERD_CRYPTO_H

5
src/memdb.cpp Normal file
View file

@ -0,0 +1,5 @@
//
// Created by jedi on 3/12/19.
//
#include "memdb.h"

60
src/memdb.h Normal file
View file

@ -0,0 +1,60 @@
//
// Created by jedi on 3/12/19.
//
#include <string>
#include <sys/stat.h>
#include <iostream>
#include <filesystem>
#include "spmemvfs.h"
#include "crypto.h"
#ifndef SHEPHERD_MEMDB_H
#define SHEPHERD_MEMDB_H
namespace shepherd{
class memdb{
private:
spmemvfs_db_t db_;
spmembuffer_t * mem_;
bytes buf_;
std::string path_;
std::string secret_;
void load(){
if(std::filesystem::exists(path_)){
buf_ = crypto::load(path_,secret_);
if(buf_.size()==0){
std::cerr << "no dbfile found" << std::endl;
exit(1);
}
}else{
std::cerr << "Error: No such file" << std::endl;
buf_ = crypto::load("/home/jedi/.shepherd/passwd.db");
}
mem_->total = mem_->used = buf_.size();
mem_->data = (char*) buf_.data();
}
public:
void save(){
bytes buf(mem_->data, &mem_->data[mem_->used]);
crypto::save(path_,buf,secret_);
}
memdb(const std::string &path, const std::string &secret):path_(path),secret_(secret){
mem_ = (spmembuffer_t*)calloc( sizeof( spmembuffer_t ), 1 );
spmemvfs_env_init();
load();
spmemvfs_open_db( &db_, path.c_str(), mem_ );
}
sqlite3* handle(){
return db_.handle;
}
~memdb(){
save();
spmemvfs_close_db( &db_ );
spmemvfs_env_fini();
}
};
}
#endif //SHEPHERD_MEMDB_H

View file

@ -4,6 +4,8 @@
#include <getopt.h>
#include "Manager.h"
#include <filesystem>
#define no_argument 0
#define required_argument 1
#define optional_argument 2
@ -14,31 +16,27 @@ constexpr unsigned int arg(const char* str, int h = 0) {
return !str[h] ? 5381 : (arg(str, h + 1) * 33) ^ str[h];
}
using namespace shepherd;
int main(int argc, char *argv[]) {
srand(time(NULL));
int opt, index;
string file = getenv("HOME");
file += "/.shepherd/passwd.db";
file += "/.shepherd/passwd.herd";
const struct option longopts[] = {
{"version", no_argument, 0, 'v'},
{
"help", no_argument, 0, 'h'},
{"help", no_argument, 0, 'h'},
{"reverse", no_argument, 0, 'r'},
{
"permanent", no_argument, 0, 's'},
{"permanent", no_argument, 0, 's'},
{"debug", optional_argument, 0, 'd'},
{
"password", required_argument, 0, 'p'},
{"password", required_argument, 0, 'p'},
{"user", required_argument, 0, 'u'},
{
"interface", required_argument, 0, 'i'},
{"interface", required_argument, 0, 'i'},
{"header", required_argument, 0, 'b'},
{
"hex", required_argument, 0, 'x'},
{"hex", required_argument, 0, 'x'},
{"file", required_argument, 0, 'f'},
{
"timeout", required_argument, 0, 't'},
{"timeout", required_argument, 0, 't'},
{"wait", required_argument, 0, 'w'},
{
0, 0, 0, 0},};
@ -84,35 +82,58 @@ int main(int argc, char *argv[]) {
}
}
Manager mg(file);
switch (argc - optind) {
case 0:
mg.show();
break;
std::cout << "Password: " << std::flush;
case 1:
mg.get(argv[optind]);
break;
case 2:
if (!strcmp(argv[optind], "clear")) {
if (!strcmp(argv[optind + 1], "all")) {
mg.clear();
cerr << file << " cleared" << endl;
std::string pw = crypto::read_pw();
std::cout << std::endl;
//char c;
//while (std::cin.get(c)) // loop getting single characters
// std::cout << c;
//std::cin.close();
/*if(!std::filesystem::exists("~/.shepherd/passwd.herd")){
std::cerr << "default file not found" << std::endl;
}*/
//shepherd::memdb db(file,"foo");
//db.save();
//exit(0);
{
Manager mg(file, pw);
switch (argc - optind) {
case 0:
mg.show();
break;
case 1:
mg.get(argv[optind]);
break;
case 2:
if (!strcmp(argv[optind], "clear")) {
if (!strcmp(argv[optind + 1], "all")) {
mg.clear();
cerr << file << " cleared" << endl;
}
} else if (!strcmp(argv[optind], "gen")) {
mg.create(argv[optind + 1]);
} else if (!strcmp(argv[optind], "del")) {
mg.del(argv[optind + 1]);
} else if (!strcmp(argv[optind], "add")) {
string pw;
cout << argv[optind + 1] << ": " << flush;
cin >> pw;
mg.add(argv[optind + 1], pw);
} else {
cerr << argv[0] << " <account identifier>" << endl;
}
} else if (!strcmp(argv[optind], "gen")) {
mg.create(argv[optind + 1]);
} else if (!strcmp(argv[optind], "del")) {
mg.del(argv[optind + 1]);
} else if (!strcmp(argv[optind], "add")) {
string pw;
cout << argv[optind + 1] << ": " << flush;
cin >> pw;
mg.add(argv[optind + 1], pw);
} else {
cerr << argv[0] << " <account identifier>" << endl;
}
break;
break;
}
}
exit (EXIT_SUCCESS);
}

8
src/shepherd.h Normal file
View file

@ -0,0 +1,8 @@
//
// Created by jedi on 3/12/19.
//
#ifndef SHEPHERD_SHEPHERD_H
#define SHEPHERD_SHEPHERD_H
#endif //SHEPHERD_SHEPHERD_H

547
src/spmemvfs.c Normal file
View file

@ -0,0 +1,547 @@
/*
* BSD 2-Clause License
*
* Copyright 2009 Stephen Liu
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include "spmemvfs.h"
#include "sqlite3.h"
/* Useful macros used in several places */
#define SPMEMVFS_MIN(x,y) ((x)<(y)?(x):(y))
#define SPMEMVFS_MAX(x,y) ((x)>(y)?(x):(y))
static void spmemvfsDebug(const char *format, ...){
#if defined(SPMEMVFS_DEBUG)
char logTemp[ 1024 ] = { 0 };
va_list vaList;
va_start( vaList, format );
vsnprintf( logTemp, sizeof( logTemp ), format, vaList );
va_end ( vaList );
if( strchr( logTemp, '\n' ) ) {
printf( "%s", logTemp );
} else {
printf( "%s\n", logTemp );
}
#endif
}
//===========================================================================
typedef struct spmemfile_t {
sqlite3_file base;
char * path;
int flags;
spmembuffer_t * mem;
} spmemfile_t;
static int spmemfileClose( sqlite3_file * file );
static int spmemfileRead( sqlite3_file * file, void * buffer, int len, sqlite3_int64 offset );
static int spmemfileWrite( sqlite3_file * file, const void * buffer, int len, sqlite3_int64 offset );
static int spmemfileTruncate( sqlite3_file * file, sqlite3_int64 size );
static int spmemfileSync( sqlite3_file * file, int flags );
static int spmemfileFileSize( sqlite3_file * file, sqlite3_int64 * size );
static int spmemfileLock( sqlite3_file * file, int type );
static int spmemfileUnlock( sqlite3_file * file, int type );
static int spmemfileCheckReservedLock( sqlite3_file * file, int * result );
static int spmemfileFileControl( sqlite3_file * file, int op, void * arg );
static int spmemfileSectorSize( sqlite3_file * file );
static int spmemfileDeviceCharacteristics( sqlite3_file * file );
static sqlite3_io_methods g_spmemfile_io_memthods = {
1, /* iVersion */
spmemfileClose, /* xClose */
spmemfileRead, /* xRead */
spmemfileWrite, /* xWrite */
spmemfileTruncate, /* xTruncate */
spmemfileSync, /* xSync */
spmemfileFileSize, /* xFileSize */
spmemfileLock, /* xLock */
spmemfileUnlock, /* xUnlock */
spmemfileCheckReservedLock, /* xCheckReservedLock */
spmemfileFileControl, /* xFileControl */
spmemfileSectorSize, /* xSectorSize */
spmemfileDeviceCharacteristics /* xDeviceCharacteristics */
};
int spmemfileClose( sqlite3_file * file )
{
spmemfile_t * memfile = (spmemfile_t*)file;
spmemvfsDebug( "call %s( %p )", __func__, memfile );
if( SQLITE_OPEN_MAIN_DB & memfile->flags ) {
// noop
} else {
if( NULL != memfile->mem ) {
if( memfile->mem->data ) free( memfile->mem->data );
free( memfile->mem );
}
}
free( memfile->path );
return SQLITE_OK;
}
int spmemfileRead( sqlite3_file * file, void * buffer, int len, sqlite3_int64 offset )
{
spmemfile_t * memfile = (spmemfile_t*)file;
spmemvfsDebug( "call %s( %p, ..., %d, %lld ), len %d",
__func__, memfile, len, offset, memfile->mem->used );
if( ( offset + len ) > memfile->mem->used ) {
return SQLITE_IOERR_SHORT_READ;
}
memcpy( buffer, memfile->mem->data + offset, len );
return SQLITE_OK;
}
int spmemfileWrite( sqlite3_file * file, const void * buffer, int len, sqlite3_int64 offset )
{
spmemfile_t * memfile = (spmemfile_t*)file;
spmembuffer_t * mem = memfile->mem;
spmemvfsDebug( "call %s( %p, ..., %d, %lld ), len %d",
__func__, memfile, len, offset, mem->used );
if( ( offset + len ) > mem->total ) {
int newTotal = 2 * ( offset + len + mem->total );
char * newBuffer = (char*)realloc( mem->data, newTotal );
if( NULL == newBuffer ) {
return SQLITE_NOMEM;
}
mem->total = newTotal;
mem->data = newBuffer;
}
memcpy( mem->data + offset, buffer, len );
mem->used = SPMEMVFS_MAX( mem->used, offset + len );
return SQLITE_OK;
}
int spmemfileTruncate( sqlite3_file * file, sqlite3_int64 size )
{
spmemfile_t * memfile = (spmemfile_t*)file;
spmemvfsDebug( "call %s( %p )", __func__, memfile );
memfile->mem->used = SPMEMVFS_MIN( memfile->mem->used, size );
return SQLITE_OK;
}
int spmemfileSync( sqlite3_file * file, int flags )
{
spmemvfsDebug( "call %s( %p )", __func__, file );
return SQLITE_OK;
}
int spmemfileFileSize( sqlite3_file * file, sqlite3_int64 * size )
{
spmemfile_t * memfile = (spmemfile_t*)file;
spmemvfsDebug( "call %s( %p )", __func__, memfile );
* size = memfile->mem->used;
return SQLITE_OK;
}
int spmemfileLock( sqlite3_file * file, int type )
{
spmemvfsDebug( "call %s( %p )", __func__, file );
return SQLITE_OK;
}
int spmemfileUnlock( sqlite3_file * file, int type )
{
spmemvfsDebug( "call %s( %p )", __func__, file );
return SQLITE_OK;
}
int spmemfileCheckReservedLock( sqlite3_file * file, int * result )
{
spmemvfsDebug( "call %s( %p )", __func__, file );
*result = 0;
return SQLITE_OK;
}
int spmemfileFileControl( sqlite3_file * file, int op, void * arg )
{
spmemvfsDebug( "call %s( %p )", __func__, file );
return SQLITE_OK;
}
int spmemfileSectorSize( sqlite3_file * file )
{
spmemvfsDebug( "call %s( %p )", __func__, file );
return 0;
}
int spmemfileDeviceCharacteristics( sqlite3_file * file )
{
spmemvfsDebug( "call %s( %p )", __func__, file );
return 0;
}
//===========================================================================
typedef struct spmemvfs_cb_t {
void * arg;
spmembuffer_t * ( * load ) ( void * args, const char * path );
} spmemvfs_cb_t;
typedef struct spmemvfs_t {
sqlite3_vfs base;
spmemvfs_cb_t cb;
sqlite3_vfs * parent;
} spmemvfs_t;
static int spmemvfsOpen( sqlite3_vfs * vfs, const char * path, sqlite3_file * file, int flags, int * outflags );
static int spmemvfsDelete( sqlite3_vfs * vfs, const char * path, int syncDir );
static int spmemvfsAccess( sqlite3_vfs * vfs, const char * path, int flags, int * result );
static int spmemvfsFullPathname( sqlite3_vfs * vfs, const char * path, int len, char * fullpath );
static void * spmemvfsDlOpen( sqlite3_vfs * vfs, const char * path );
static void spmemvfsDlError( sqlite3_vfs * vfs, int len, char * errmsg );
static void ( * spmemvfsDlSym ( sqlite3_vfs * vfs, void * handle, const char * symbol ) ) ( void );
static void spmemvfsDlClose( sqlite3_vfs * vfs, void * handle );
static int spmemvfsRandomness( sqlite3_vfs * vfs, int len, char * buffer );
static int spmemvfsSleep( sqlite3_vfs * vfs, int microseconds );
static int spmemvfsCurrentTime( sqlite3_vfs * vfs, double * result );
static spmemvfs_t g_spmemvfs = {
{
1, /* iVersion */
0, /* szOsFile */
0, /* mxPathname */
0, /* pNext */
SPMEMVFS_NAME, /* zName */
0, /* pAppData */
spmemvfsOpen, /* xOpen */
spmemvfsDelete, /* xDelete */
spmemvfsAccess, /* xAccess */
spmemvfsFullPathname, /* xFullPathname */
spmemvfsDlOpen, /* xDlOpen */
spmemvfsDlError, /* xDlError */
spmemvfsDlSym, /* xDlSym */
spmemvfsDlClose, /* xDlClose */
spmemvfsRandomness, /* xRandomness */
spmemvfsSleep, /* xSleep */
spmemvfsCurrentTime /* xCurrentTime */
},
{ 0 },
0 /* pParent */
};
int spmemvfsOpen( sqlite3_vfs * vfs, const char * path, sqlite3_file * file, int flags, int * outflags )
{
spmemvfs_t * memvfs = (spmemvfs_t*)vfs;
spmemfile_t * memfile = (spmemfile_t*)file;
spmemvfsDebug( "call %s( %p(%p), %s, %p, %x, %p )\n",
__func__, vfs, &g_spmemvfs, path, file, flags, outflags );
memset( memfile, 0, sizeof( spmemfile_t ) );
memfile->base.pMethods = &g_spmemfile_io_memthods;
memfile->flags = flags;
memfile->path = strdup( path );
if( SQLITE_OPEN_MAIN_DB & memfile->flags ) {
memfile->mem = memvfs->cb.load( memvfs->cb.arg, path );
} else {
memfile->mem = (spmembuffer_t*)calloc( sizeof( spmembuffer_t ), 1 );
}
return memfile->mem ? SQLITE_OK : SQLITE_ERROR;
}
int spmemvfsDelete( sqlite3_vfs * vfs, const char * path, int syncDir )
{
spmemvfsDebug( "call %s( %p(%p), %s, %d )\n",
__func__, vfs, &g_spmemvfs, path, syncDir );
return SQLITE_OK;
}
int spmemvfsAccess( sqlite3_vfs * vfs, const char * path, int flags, int * result )
{
* result = 0;
return SQLITE_OK;
}
int spmemvfsFullPathname( sqlite3_vfs * vfs, const char * path, int len, char * fullpath )
{
strncpy( fullpath, path, len );
fullpath[ len - 1 ] = '\0';
return SQLITE_OK;
}
void * spmemvfsDlOpen( sqlite3_vfs * vfs, const char * path )
{
return NULL;
}
void spmemvfsDlError( sqlite3_vfs * vfs, int len, char * errmsg )
{
// noop
}
void ( * spmemvfsDlSym ( sqlite3_vfs * vfs, void * handle, const char * symbol ) ) ( void )
{
return NULL;
}
void spmemvfsDlClose( sqlite3_vfs * vfs, void * handle )
{
// noop
}
int spmemvfsRandomness( sqlite3_vfs * vfs, int len, char * buffer )
{
return SQLITE_OK;
}
int spmemvfsSleep( sqlite3_vfs * vfs, int microseconds )
{
return SQLITE_OK;
}
int spmemvfsCurrentTime( sqlite3_vfs * vfs, double * result )
{
return SQLITE_OK;
}
//===========================================================================
int spmemvfs_init( spmemvfs_cb_t * cb )
{
sqlite3_vfs * parent = NULL;
if( g_spmemvfs.parent ) return SQLITE_OK;
parent = sqlite3_vfs_find( 0 );
g_spmemvfs.parent = parent;
g_spmemvfs.base.mxPathname = parent->mxPathname;
g_spmemvfs.base.szOsFile = sizeof( spmemfile_t );
g_spmemvfs.cb = * cb;
return sqlite3_vfs_register( (sqlite3_vfs*)&g_spmemvfs, 0 );
}
//===========================================================================
typedef struct spmembuffer_link_t {
char * path;
spmembuffer_t * mem;
struct spmembuffer_link_t * next;
} spmembuffer_link_t;
spmembuffer_link_t * spmembuffer_link_remove( spmembuffer_link_t ** head, const char * path )
{
spmembuffer_link_t * ret = NULL;
spmembuffer_link_t ** iter = head;
for( ; NULL != *iter; ) {
spmembuffer_link_t * curr = *iter;
if( 0 == strcmp( path, curr->path ) ) {
ret = curr;
*iter = curr->next;
break;
} else {
iter = &( curr->next );
}
}
return ret;
}
void spmembuffer_link_free( spmembuffer_link_t * iter )
{
free( iter->path );
free( iter->mem->data );
free( iter->mem );
free( iter );
}
//===========================================================================
typedef struct spmemvfs_env_t {
spmembuffer_link_t * head;
sqlite3_mutex * mutex;
} spmemvfs_env_t;
static spmemvfs_env_t * g_spmemvfs_env = NULL;
static spmembuffer_t * load_cb( void * arg, const char * path )
{
spmembuffer_t * ret = NULL;
spmemvfs_env_t * env = (spmemvfs_env_t*)arg;
sqlite3_mutex_enter( env->mutex );
{
spmembuffer_link_t * toFind = spmembuffer_link_remove( &( env->head ), path );
if( NULL != toFind ) {
ret = toFind->mem;
free( toFind->path );
free( toFind );
}
}
sqlite3_mutex_leave( env->mutex );
return ret;
}
int spmemvfs_env_init()
{
int ret = 0;
if( NULL == g_spmemvfs_env ) {
spmemvfs_cb_t cb;
g_spmemvfs_env = (spmemvfs_env_t*)calloc( sizeof( spmemvfs_env_t ), 1 );
g_spmemvfs_env->mutex = sqlite3_mutex_alloc( SQLITE_MUTEX_FAST );
cb.arg = g_spmemvfs_env;
cb.load = load_cb;
ret = spmemvfs_init( &cb );
}
return ret;
}
void spmemvfs_env_fini()
{
if( NULL != g_spmemvfs_env ) {
spmembuffer_link_t * iter = NULL;
sqlite3_vfs_unregister( (sqlite3_vfs*)&g_spmemvfs );
g_spmemvfs.parent = NULL;
sqlite3_mutex_free( g_spmemvfs_env->mutex );
iter = g_spmemvfs_env->head;
for( ; NULL != iter; ) {
spmembuffer_link_t * next = iter->next;
spmembuffer_link_free( iter );
iter = next;
}
free( g_spmemvfs_env );
g_spmemvfs_env = NULL;
}
}
int spmemvfs_open_db( spmemvfs_db_t * db, const char * path, spmembuffer_t * mem )
{
int ret = 0;
spmembuffer_link_t * iter = NULL;
memset( db, 0, sizeof( spmemvfs_db_t ) );
iter = (spmembuffer_link_t*)calloc( sizeof( spmembuffer_link_t ), 1 );
iter->path = strdup( path );
iter->mem = mem;
sqlite3_mutex_enter( g_spmemvfs_env->mutex );
{
iter->next = g_spmemvfs_env->head;
g_spmemvfs_env->head = iter;
}
sqlite3_mutex_leave( g_spmemvfs_env->mutex );
ret = sqlite3_open_v2( path, &(db->handle),
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, SPMEMVFS_NAME );
if( 0 == ret ) {
db->mem = mem;
} else {
sqlite3_mutex_enter( g_spmemvfs_env->mutex );
{
iter = spmembuffer_link_remove( &(g_spmemvfs_env->head), path );
if( NULL != iter ) spmembuffer_link_free( iter );
}
sqlite3_mutex_leave( g_spmemvfs_env->mutex );
}
return ret;
}
int spmemvfs_close_db( spmemvfs_db_t * db )
{
int ret = 0;
if( NULL == db ) return 0;
if( NULL != db->handle ) {
ret = sqlite3_close( db->handle );
db->handle = NULL;
}
if( NULL != db->mem ) {
//if( NULL != db->mem->data ) free( db->mem->data );
free( db->mem );
db->mem = NULL;
}
return ret;
}

64
src/spmemvfs.h Normal file
View file

@ -0,0 +1,64 @@
/*
* BSD 2-Clause License
*
* Copyright 2009 Stephen Liu
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef __spmemvfs_h__
#define __spmemvfs_h__
#ifdef __cplusplus
extern "C" {
#endif
#include "sqlite3.h"
#define SPMEMVFS_NAME "spmemvfs"
typedef struct spmembuffer_t {
char * data;
int used;
int total;
} spmembuffer_t;
typedef struct spmemvfs_db_t {
sqlite3 * handle;
spmembuffer_t * mem;
} spmemvfs_db_t;
int spmemvfs_env_init();
void spmemvfs_env_fini();
int spmemvfs_open_db( spmemvfs_db_t * db, const char * path, spmembuffer_t * mem );
int spmemvfs_close_db( spmemvfs_db_t * db );
#ifdef __cplusplus
}
#endif
#endif