 /* This program is free software. It comes without any warranty, to
  * the extent permitted by applicable law. You can redistribute it
  * and/or modify it under the terms of the Do What The Fuck You Want
  * To Public License, Version 2, as published by Sam Hocevar. See
  * http://sam.zoy.org/wtfpl/COPYING for more details. */

// TODO: need to add some comand-line switches, for action on new image (save/load), verbose level and so on...
#include <elf.h>
#include <sys/mman.h>
#include <unistd.h>
#include <sys/types.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <memory.h>
#include <errno.h>
#include <sys/syscall.h>
#define ERR_RETURN(x) do { perror(NULL); ret = x; goto out; } while (0)
#define WORD(x) ((x)>>2)

#define SHA_DIGEST_LENGTH 20

typedef struct {
    unsigned long state[5];
    unsigned long count[2];
    unsigned char buffer[64];
} SHA1_CTX;
void SHA1Init(SHA1_CTX* context);
void SHA1Update(SHA1_CTX* context, const unsigned char* data, unsigned int len);
void SHA1Final(unsigned char digest[20], SHA1_CTX* context);

typedef struct ret_description
{
	unsigned long istack;
	unsigned long iregs;
} ret_desc;

const char *excluded_sections[] = {
	".gnu.linkonce.this_module",
	".data.percpu",
	".modinfo",
	NULL
};

size_t filesize(int fd)
{
	off_t end;
	end = lseek(fd, 0, SEEK_END);
	lseek(fd, 0, SEEK_SET);
	if ((off_t)-1 == end) return 0; else return (size_t)end;
}

int find_section(const Elf32_Shdr *sections, int shnum, const char *secnames, const char *name, int look_alloc)
{
	int i;
	for (i = 1; i < shnum; ++i)
	{
		const Elf32_Shdr *sec = sections + i;
		if (((sec->sh_flags & SHF_ALLOC) || !look_alloc) && (!strcmp(secnames + sec->sh_name, name)))
			return i;
	}
	return 0;
}

/* Calculate hash of module. Can be done in a simplier way,
 * but this code is supposed to act like kernel checking routine */
int calculate_mot_hash(char hash[SHA_DIGEST_LENGTH], const Elf32_Ehdr *elf_hdr)
{
	int i, shnum;
	const Elf32_Shdr *sections;
	const char *secnames;
	SHA1_CTX ctx;
	char *hashmap;
	
	shnum = elf_hdr->e_shnum;
	hashmap = malloc(shnum*sizeof(char)); // 1 means include, 0 means exclude
	sections = (const Elf32_Shdr*)((void*)elf_hdr + elf_hdr->e_shoff);
	if (hashmap == NULL)
		return 1;
	hashmap[0] = 0;

	for (i = 1; i < shnum; ++i)
	{
		hashmap[i] = (sections[i].sh_flags & SHF_ALLOC) && (sections[i].sh_type != SHT_NOBITS);
	}

	secnames = (const char*)elf_hdr + sections[elf_hdr->e_shstrndx].sh_offset;
	const char **pname = excluded_sections;
	for (; *pname != NULL; ++pname)
	{
		int index = find_section(sections, shnum, secnames, *pname, 1);
		if (index != 0) hashmap[index] = 0;
	}
	for (i = shnum-1; i >= 1; --i)
	{
		if (sections[i].sh_type == SHT_SYMTAB)
		{
			hashmap[i] = 1;
			hashmap[sections[i].sh_link] = 1;
			break;
		}
	}
	SHA1Init(&ctx);
	for (i = 1; i < shnum; ++i)
	{
		if (hashmap[i])
		{
			SHA1Update(&ctx, (const void*)elf_hdr + sections[i].sh_offset, sections[i].sh_size);
		} // else printf("ignored: %s\n", secnames + sections[i].sh_name);
	}
	free(hashmap);
	SHA1Final((unsigned char*)hash, &ctx);
	return 0;
}

int detect_allowed_path_insn(unsigned long *pText, size_t textsize, int *rel_offset)
{
	const int offset = 0x144;
	if ((textsize >= offset + sizeof(unsigned long)) && (pText[WORD(offset)] >> 24 == 0xea)) {
		*rel_offset = pText[WORD(offset)] & 0xffffff;
		*rel_offset <<= 2;
		*rel_offset += 8;
		return offset;
	}
	return -1;
}

int find_epilogue(unsigned long *pText, int off, size_t size, ret_desc *r)
{
#define START 0
#define SP_DELTA_FOUND 1
#define DONE 2
	int state = START;
	const int max_delta = 2;
	int delta = 0;

	for (off+=4; off < size; off+=4)
	{
		if (delta >= max_delta) {
			state = START;
			delta = 0;
		}
		if (state == DONE) {
			break;
		} else {
			delta++;
		}
		switch (state)
		{
			case START: // need to find a stack sp delta
				if ((pText[WORD(off)] & 0xfff0f000) == 0xe240d000) { // sub sp, rX, N
					r->istack = pText[WORD(off)];
					state = SP_DELTA_FOUND;
					delta = 0;
				}
				break;
			case SP_DELTA_FOUND: // need to find a regs popping
				if ((pText[WORD(off)] & 0xffff0000) == 0xe89d0000) { // ldmia   sp, { ... }
					r->iregs = pText[WORD(off)];
					state = DONE;
					delta = 0;
				}
				break;
		}
	}
	printf("%08lx %08lx\n", r->istack, r->iregs);
	return (state == DONE);
}

int find_section_symbol(Elf32_Sym *syms, int count, int sect_idx)
{
	int i;
	for (i = 0; i < count; ++i)
	{
		if ((ELF32_ST_TYPE(syms[i].st_info) == STT_OBJECT) && (syms[i].st_shndx == sect_idx)) return i;
	}
	return 0;
}

#define FIXUP_OFF(w) do { if ((w) >= offset) { w += delta; printf("fixed\n"); }} while (0)
#define ELF_RELATIVE(base, off) ((void*)((void*)(base) + (off)))
void fixup_offsets(Elf32_Ehdr *elf_hdr, Elf32_Off offset, int delta)
{
	int i = 0;
	FIXUP_OFF(elf_hdr->e_phoff);
	FIXUP_OFF(elf_hdr->e_shoff);
	for (i = 0; i < elf_hdr->e_phnum; ++i) { // Fixup segments
		FIXUP_OFF((((Elf32_Phdr*)ELF_RELATIVE(elf_hdr, elf_hdr->e_phoff))+i)->p_offset);
	}
	for (i = 1; i < elf_hdr->e_shnum; ++i) { // Fixup sections
		FIXUP_OFF((((Elf32_Shdr*)ELF_RELATIVE(elf_hdr, elf_hdr->e_shoff))+i)->sh_offset);
	}
}

void patch_this_module(void *text, size_t size, unsigned int entry, unsigned long saddr, ret_desc r)
{
	unsigned long *pText = (unsigned long*)text;
	pText[WORD(0xC)] = 0x00000075; // name: "u"
	pText[WORD(entry)-3] = saddr;
	pText[WORD(entry)-2] = 0xe3a00000; // mov r0,0
	pText[WORD(entry)-1] = 0xe12fff1e; // bx lr
	pText[WORD(entry)+0] = 0xe24f0014; // sub     r0, pc, #20
	pText[WORD(entry)+1] = 0xe8900070; // ldmia   r0, {r4, r5, r6}
	pText[WORD(entry)+2] = 0xe8840060; // stmia   r4, {r5, r6}
	pText[WORD(entry)+3] = 0xe3e00001; // mvn     r0, # 1
	pText[WORD(entry)+4] = r.istack;
	pText[WORD(entry)+5] = r.iregs; 
}

#define VERBOSE_RET(msg, ret) do { printf(msg "\n"); return ret; } while (0)

Elf32_Ehdr *try_motsecurity_patch(const Elf32_Ehdr *base, size_t size, unsigned long saddr, size_t *newsize)
{
	Elf32_Ehdr *new_base;
	int rel_offset;
	int text_idx, reltext_idx, this_m_idx, symtab_idx, symindex;
	const char *secnames;
	void *t;
	const Elf32_Shdr *sections, *reltext;
	unsigned long *pText;
	size_t textsize;
	int insn_off;
	ret_desc r;

	sections = (const Elf32_Shdr*)((void*)base + base->e_shoff);
	secnames = (const char*)base + sections[base->e_shstrndx].sh_offset;

	text_idx = find_section(sections, base->e_shnum, secnames, ".text", 0);
	if (text_idx == 0)
		VERBOSE_RET("Failed to find .text section", NULL);

	pText = (unsigned long*)((const void*)base + sections[text_idx].sh_offset);
	textsize = sections[text_idx].sh_size;

	if ((insn_off = detect_allowed_path_insn(pText, textsize, &rel_offset)) == -1)
		VERBOSE_RET("Can't find insn", NULL);
	if (rel_offset != 0x28)
		VERBOSE_RET("I won't patch this because of offset..", NULL);

	printf("Seems to be useful module...\n");
	printf("Now i'll try to find function epilogue\n");

	if (find_epilogue(pText, insn_off, textsize, &r) == 0)
		VERBOSE_RET("Can't find epilogue", NULL);

	printf("It seems to be ok, let's go\n");
	symtab_idx = find_section(sections, base->e_shnum, secnames, ".symtab", 0);
	reltext_idx = find_section(sections, base->e_shnum, secnames, ".rel.text", 0);
	this_m_idx = find_section(sections, base->e_shnum, secnames, ".gnu.linkonce.this_module", 0);
	if (symtab_idx == 0)
		VERBOSE_RET("No .symtab section", NULL);
	if (reltext_idx == 0)
		VERBOSE_RET("No .rel.text section", NULL);
	if (this_m_idx == 0)
		VERBOSE_RET("No .gnu.linkonce.this_module section", NULL);

	symindex = find_section_symbol((Elf32_Sym*)((void*)base + sections[symtab_idx].sh_offset), sections[symtab_idx].sh_size / sizeof(Elf32_Sym), this_m_idx);
	if (symindex == 0)
		VERBOSE_RET("Can't find symbol for section .gnu.linkonce.this_module", NULL);

	new_base = malloc(size + sizeof(Elf32_Rel)); // need additional for bytes for relocation entry
	if (new_base == NULL)
		VERBOSE_RET("Can't alloc image", NULL);

	reltext = sections + reltext_idx;
	memcpy((void*)new_base, (void*)base, reltext->sh_offset + reltext->sh_size);
	t = (void*)new_base + reltext->sh_offset + reltext->sh_size;
	((Elf32_Rel*)t)->r_offset = insn_off;
	((Elf32_Rel*)t)->r_info = symindex << 8 | 0x1d;
	memcpy(t + sizeof(Elf32_Rel), t - (void*)new_base  + (void*)base, size - (reltext->sh_offset + reltext->sh_size));
	fixup_offsets(new_base, reltext->sh_offset, sizeof(Elf32_Rel));
	patch_this_module((void*)new_base + sections[this_m_idx].sh_offset, sections[this_m_idx].sh_size, rel_offset, saddr, r); 
	*newsize = size + sizeof(Elf32_Rel);
	return new_base;
}

const char *moderror(int err)
{
        switch (err) {
                case ENOEXEC:
                        return "Invalid module format";
                case ENOENT:
                        return "Unknown symbol in module";
                case ESRCH:
                        return "Module has wrong symbol version";
                case EINVAL:
                        return "Invalid parameters";
                default:
                        return strerror(err);
        }
}

int main(int argc, char **argv)
{
	int ret = 0, fd, i;
	size_t size, newsize;
	char old_hash[SHA_DIGEST_LENGTH], new_hash[sizeof(old_hash)];
	char *kmod_map = 0;
	const Elf32_Ehdr *elf_hdr;
	Elf32_Ehdr *new_kmod;
	const Elf32_Shdr *sec_hdr;
	unsigned long saddr;

	if (argc < 3)
	{
		printf("Usage: %s <elf file> <addr>\n", argv[0]);
		return 1;
	}
	saddr = strtoul(argv[2], 0, 16);
	printf("Address: %08lx\n", saddr);
	if ( (fd = open(argv[1], O_RDONLY)) < 0 )
		ERR_RETURN(2);
	if ( (size = filesize(fd)) == 0 )
		ERR_RETURN(3);
	if ( (kmod_map = mmap(NULL, size, PROT_READ, MAP_PRIVATE, fd, (off_t)0)) == MAP_FAILED)
		ERR_RETURN(4);

	elf_hdr = (const Elf32_Ehdr*)kmod_map;
	sec_hdr = (const Elf32_Shdr*)((void*)kmod_map + elf_hdr->e_shoff);
	calculate_mot_hash(old_hash, elf_hdr);
	for (i = 0; i < sizeof(old_hash); ++i)
	{
		printf("%02hhx", old_hash[i]);
	}
	printf("\n");
	new_kmod = try_motsecurity_patch(elf_hdr, size, saddr, &newsize);
	if (new_kmod == NULL)
	{
		fprintf(stderr, "failed to apply motsecurity fixes\n");
		ret = 7;
	} else
	{
		calculate_mot_hash(new_hash, new_kmod);
		if (!memcmp(new_hash, old_hash, sizeof(new_hash))) {
			int ret;
			printf("Hashes are the same, keep going\n");
			printf("Ok then? Trying to insert module\n");
			ret = syscall(__NR_init_module, new_kmod, newsize, strdup(""));
			if (ret != 0) {
				printf("Failed to insert module: %s\n", moderror(errno));
				ret = 5;
			} else {
				printf("Hang on, guys! Module is inserted.\n");
				ret = 0;
			}
		} else {
			printf("Hashes are not the same!\n");
			ret = 6;
		}
	}
	if (new_kmod) free(new_kmod);
out:
	if (kmod_map) munmap(kmod_map, size);
	close(fd);
	return ret;
}
