40 #include <type_traits> 
   46 using namespace stream_write_qualified_char_as_number;
 
   48 class ExplainProvenanceImpl : 
public ExplainProvenance {
 
   52     ExplainProvenanceImpl(SouffleProgram& prog) : ExplainProvenance(prog) {
 
   56     void setup()
 override {
 
   58         for (
auto& 
rel : prog.getAllRelations()) {
 
   59             std::string name = 
rel->getName();
 
   62             if (name.find(
"@info") == std::string::npos) {
 
   68                 std::vector<std::string> bodyLiterals;
 
   73                 for (
size_t i = 1; 
i + 1 < 
rel->getArity(); 
i++) {
 
   76                     bodyLiterals.push_back(bodyLit);
 
   82                 info.insert({std::make_pair(name.substr(0, name.find(
".@info")), ruleNum), bodyLiterals});
 
   83                 rules.insert({std::make_pair(name.substr(0, name.find(
".@info")), ruleNum), 
rule});
 
   89             std::string relName, std::vector<RamDomain> tuple, 
int ruleNum, 
int levelNum, 
size_t depthLimit) {
 
   90         std::stringstream joinedArgs;
 
   91         joinedArgs << 
join(decodeArguments(relName, tuple), 
", ");
 
   92         auto joinedArgsStr = joinedArgs.str();
 
   96             return mk<LeafNode>(relName + 
"(" + joinedArgsStr + 
")");
 
   99         assert(info.find(std::make_pair(relName, ruleNum)) != info.end() && 
"invalid rule for tuple");
 
  102         if (depthLimit <= 1) {
 
  103             tuple.push_back(ruleNum);
 
  104             tuple.push_back(levelNum);
 
  108             auto it = std::find(subproofs.begin(), subproofs.end(), 
tuple);
 
  109             if (it != subproofs.end()) {
 
  110                 idx = it - subproofs.begin();
 
  112                 subproofs.push_back(tuple);
 
  113                 idx = subproofs.size() - 1;
 
  116             return mk<LeafNode>(
"subproof " + relName + 
"(" + std::to_string(idx) + 
")");
 
  119         tuple.push_back(levelNum);
 
  122                 mk<InnerNode>(relName + 
"(" + joinedArgsStr + 
")", 
"(R" + std::to_string(ruleNum) + 
")");
 
  125         std::vector<RamDomain> ret;
 
  128         prog.executeSubroutine(relName + 
"_" + std::to_string(ruleNum) + 
"_subproof", tuple, ret);
 
  131         size_t tupleCurInd = 0;
 
  132         auto bodyRelations = info[std::make_pair(relName, ruleNum)];
 
  135         for (
auto it = bodyRelations.begin() + 1; it < bodyRelations.end(); it++) {
 
  136             std::string bodyLiteral = *it;
 
  138             std::string bodyRel = 
splitString(bodyLiteral, 
',')[0];
 
  141             assert(bodyRel.size() > 0 && 
"body of a relation should have positive length");
 
  142             bool isConstraint = 
contains(constraintList, bodyRel);
 
  145             auto bodyRelAtomName = bodyRel;
 
  146             if (bodyRel[0] == 
'!' && bodyRel != 
"!=") {
 
  147                 bodyRelAtomName = bodyRel.substr(1);
 
  152             size_t auxiliaryArity;
 
  159                 arity = prog.getRelation(bodyRelAtomName)->getArity();
 
  160                 auxiliaryArity = prog.getRelation(bodyRelAtomName)->getAuxiliaryArity();
 
  162             auto tupleEnd = tupleCurInd + arity;
 
  165             std::vector<RamDomain> subproofTuple;
 
  167             for (; tupleCurInd < tupleEnd - auxiliaryArity; tupleCurInd++) {
 
  168                 subproofTuple.push_back(ret[tupleCurInd]);
 
  171             int subproofRuleNum = ret[tupleCurInd];
 
  172             int subproofLevelNum = ret[tupleCurInd + 1];
 
  177             if (bodyRel[0] == 
'!' && bodyRel != 
"!=") {
 
  178                 std::stringstream joinedTuple;
 
  179                 joinedTuple << 
join(decodeArguments(bodyRelAtomName, subproofTuple), 
", ");
 
  180                 auto joinedTupleStr = joinedTuple.str();
 
  181                 internalNode->add_child(mk<LeafNode>(bodyRel + 
"(" + joinedTupleStr + 
")"));
 
  182                 internalNode->setSize(internalNode->getSize() + 1);
 
  184             } 
else if (isConstraint) {
 
  185                 std::stringstream joinedConstraint;
 
  190                     joinedConstraint << subproofTuple[0] << 
" " << bodyRel << 
" " << subproofTuple[1];
 
  192                     joinedConstraint << bodyRel << 
"(\"" << symTable.resolve(subproofTuple[0]) << 
"\", \"" 
  193                                      << symTable.resolve(subproofTuple[1]) << 
"\")";
 
  196                 internalNode->add_child(mk<LeafNode>(joinedConstraint.str()));
 
  197                 internalNode->setSize(internalNode->getSize() + 1);
 
  201                         explain(bodyRel, subproofTuple, subproofRuleNum, subproofLevelNum, depthLimit - 1);
 
  202                 internalNode->setSize(internalNode->getSize() + child->getSize());
 
  203                 internalNode->add_child(std::move(child));
 
  206             tupleCurInd = tupleEnd;
 
  212     Own<TreeNode> 
explain(std::string relName, std::vector<std::string> args, 
size_t depthLimit)
 override {
 
  213         auto tuple = argsToNums(relName, args);
 
  215             return mk<LeafNode>(
"Relation not found");
 
  218         std::tuple<int, int> tupleInfo = findTuple(relName, 
tuple);
 
  220         int ruleNum = std::get<0>(tupleInfo);
 
  221         int levelNum = std::get<1>(tupleInfo);
 
  223         if (ruleNum < 0 || levelNum == -1) {
 
  224             return mk<LeafNode>(
"Tuple not found");
 
  227         return explain(relName, 
tuple, ruleNum, levelNum, depthLimit);
 
  230     Own<TreeNode> explainSubproof(std::string relName, 
RamDomain subproofNum, 
size_t depthLimit)
 override {
 
  231         if (subproofNum >= (
int)subproofs.size()) {
 
  232             return mk<LeafNode>(
"Subproof not found");
 
  235         auto tup = subproofs[subproofNum];
 
  237         auto rel = prog.getRelation(relName);
 
  240         ruleNum = tup[
rel->getArity() - 
rel->getAuxiliaryArity()];
 
  243         levelNum = tup[
rel->getArity() - 
rel->getAuxiliaryArity() + 1];
 
  245         tup.erase(tup.begin() + 
rel->getArity() - 
rel->getAuxiliaryArity(), tup.end());
 
  247         return explain(relName, tup, ruleNum, levelNum, depthLimit);
 
  250     std::vector<std::string> explainNegationGetVariables(
 
  251             std::string relName, std::vector<std::string> args, 
size_t ruleNum)
 override {
 
  252         std::vector<std::string> variables;
 
  255         std::tuple<int, int> foundTuple = findTuple(relName, argsToNums(relName, args));
 
  256         if (std::get<0>(foundTuple) != -1 || std::get<1>(foundTuple) != -1) {
 
  258             return std::vector<std::string>({
"@"});
 
  262         auto atoms = info[std::make_pair(relName, ruleNum)];
 
  266         if (atoms.size() <= 1) {
 
  267             return std::vector<std::string>({
"@fact"});
 
  273         auto isVariable = [&](std::string arg) {
 
  274             return !(
isNumber(arg.c_str()) || arg[0] == 
'\"' || arg == 
"_");
 
  279         std::map<std::string, std::string> headVariableMapping;
 
  280         for (
size_t i = 0; 
i < headVariables.size(); 
i++) {
 
  281             if (!isVariable(headVariables[
i])) {
 
  285             if (headVariableMapping.find(headVariables[
i]) == headVariableMapping.end()) {
 
  286                 headVariableMapping[headVariables[
i]] = args[
i];
 
  288                 if (headVariableMapping[headVariables[
i]] != args[
i]) {
 
  289                     return std::vector<std::string>({
"@non_matching"});
 
  295         std::vector<std::string> uniqueBodyVariables;
 
  296         for (
auto it = atoms.begin() + 1; it < atoms.end(); it++) {
 
  301             for (
auto atomIt = atomRepresentation.begin() + 1; atomIt < atomRepresentation.end(); atomIt++) {
 
  302                 if (!isVariable(*atomIt)) {
 
  306                 if (!
contains(uniqueBodyVariables, *atomIt) && !
contains(headVariables, *atomIt)) {
 
  307                     uniqueBodyVariables.push_back(*atomIt);
 
  312         return uniqueBodyVariables;
 
  315     Own<TreeNode> explainNegation(std::string relName, 
size_t ruleNum, 
const std::vector<std::string>& tuple,
 
  316             std::map<std::string, std::string>& bodyVariables)
 override {
 
  318         std::vector<std::string> uniqueVariables;
 
  321         std::map<std::string, char> variableTypes;
 
  324         auto atoms = info[std::make_pair(relName, ruleNum)];
 
  329         uniqueVariables.insert(uniqueVariables.end(), headVariables.begin(), headVariables.end());
 
  331         auto isVariable = [&](std::string arg) {
 
  332             return !(
isNumber(arg.c_str()) || arg[0] == 
'\"' || arg == 
"_");
 
  336         for (
auto it = atoms.begin() + 1; it < atoms.end(); it++) {
 
  341             for (
auto atomIt = atomRepresentation.begin() + 1; atomIt < atomRepresentation.end(); atomIt++) {
 
  342                 if (!
contains(uniqueVariables, *atomIt) && !
contains(headVariables, *atomIt)) {
 
  344                     if (!isVariable(*atomIt)) {
 
  348                     uniqueVariables.push_back(*atomIt);
 
  350                     if (!
contains(constraintList, atomRepresentation[0])) {
 
  352                         auto currentRel = prog.getRelation(atomRepresentation[0]);
 
  353                         assert(currentRel != 
nullptr &&
 
  354                                 (
"relation " + atomRepresentation[0] + 
" doesn't exist").c_str());
 
  355                         variableTypes[*atomIt] =
 
  356                                 *currentRel->getAttrType(atomIt - atomRepresentation.begin() - 1);
 
  357                     } 
else if (atomIt->find(
"agg_") != std::string::npos) {
 
  358                         variableTypes[*atomIt] = 
'i';
 
  364         std::vector<RamDomain> args;
 
  366         size_t varCounter = 0;
 
  372         auto tupleNums = argsToNums(relName, tuple);
 
  373         args.insert(args.end(), tupleNums.begin(), tupleNums.end());
 
  374         varCounter += tuple.size();
 
  376         while (varCounter < uniqueVariables.size()) {
 
  377             auto var = uniqueVariables[varCounter];
 
  378             auto varValue = bodyVariables[var];
 
  379             if (variableTypes[var] == 
's') {
 
  380                 if (varValue.size() >= 2 && varValue[0] == 
'"' && varValue[varValue.size() - 1] == 
'"') {
 
  381                     auto originalStr = varValue.substr(1, varValue.size() - 2);
 
  382                     args.push_back(symTable.lookup(originalStr));
 
  385                     args.push_back(symTable.lookup(varValue));
 
  388                 args.push_back(std::stoi(varValue));
 
  395         std::vector<RamDomain> ret;
 
  398         prog.executeSubroutine(relName + 
"_" + std::to_string(ruleNum) + 
"_negation_subproof", args, ret);
 
  401         assert(ret.size() == atoms.size() - 1);
 
  404         std::stringstream joinedArgsStr;
 
  405         joinedArgsStr << 
join(tuple, 
",");
 
  406         auto internalNode = mk<InnerNode>(
 
  407                 relName + 
"(" + joinedArgsStr.str() + 
")", 
"(R" + std::to_string(ruleNum) + 
")");
 
  410         for (
size_t i = 0; 
i < headVariables.size(); 
i++) {
 
  411             bodyVariables[headVariables[
i]] = tuple[
i];
 
  416         int literalCounter = 1;
 
  419             bool atomExists = 
true;
 
  420             if (returnCounter == 0) {
 
  425             auto atomRepresentation = 
splitString(atoms[literalCounter], 
',');
 
  426             std::string bodyRel = atomRepresentation[0];
 
  429             bool isConstraint = 
contains(constraintList, bodyRel);
 
  432             auto bodyRelAtomName = bodyRel;
 
  433             if (bodyRel[0] == 
'!' && bodyRel != 
"!=") {
 
  434                 bodyRelAtomName = bodyRel.substr(1);
 
  438             std::stringstream childLabel;
 
  441                 assert(atomRepresentation.size() == 3 && 
"not a binary constraint");
 
  443                 childLabel << bodyVariables[atomRepresentation[1]] << 
" " << bodyRel << 
" " 
  444                            << bodyVariables[atomRepresentation[2]];
 
  446                 childLabel << bodyRel << 
"(";
 
  447                 for (
size_t i = 1; 
i < atomRepresentation.size(); 
i++) {
 
  449                     if (!isVariable(atomRepresentation[
i])) {
 
  450                         childLabel << atomRepresentation[
i];
 
  452                         childLabel << bodyVariables[atomRepresentation[
i]];
 
  454                     if (
i < atomRepresentation.size() - 1) {
 
  463                 childLabel << 
" ✓";
 
  468             internalNode->add_child(mk<LeafNode>(childLabel.str()));
 
  469             internalNode->setSize(internalNode->getSize() + 1);
 
  477     std::string getRule(std::string relName, 
size_t ruleNum)
 override {
 
  478         auto key = make_pair(relName, ruleNum);
 
  480         auto rule = rules.find(key);
 
  481         if (
rule == rules.end()) {
 
  482             return "Rule not found";
 
  488     std::vector<std::string> getRules(
const std::string& relName)
 override {
 
  489         std::vector<std::string> relRules;
 
  491         for (
auto& 
rule : rules) {
 
  492             if (
rule.first.first == relName) {
 
  493                 relRules.push_back(
rule.second);
 
  500     std::string measureRelation(std::string relName)
 override {
 
  501         auto rel = prog.getRelation(relName);
 
  503         if (
rel == 
nullptr) {
 
  504             return "No relation found\n";
 
  508         int skip = 
size / 10;
 
  514         std::stringstream 
ss;
 
  523             if (numTuples % skip != 0) {
 
  528             std::vector<RamDomain> currentTuple;
 
  529             for (arity_type 
i = 0; 
i < 
rel->getPrimaryArity(); 
i++) {
 
  531                 if (*
rel->getAttrType(
i) == 
's') {
 
  534                     n = symTable.lookupExisting(s);
 
  535                 } 
else if (*
rel->getAttrType(
i) == 
'f') {
 
  539                 } 
else if (*
rel->getAttrType(
i) == 
'u') {
 
  547                 currentTuple.push_back(
n);
 
  556             std::cout << 
"Tuples expanded: " 
  557                       << 
explain(relName, currentTuple, ruleNum, levelNum, 10000)->getSize();
 
  564                     std::chrono::duration_cast<std::chrono::duration<double>>(tupleEnd - tupleStart);
 
  566             std::cout << 
", Time: " << tupleDuration.count() << 
"\n";
 
  570         auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(after_time - before_time);
 
  572         ss << 
"total: " << proc << 
" ";
 
  578     void printRulesJSON(std::ostream& os)
 override {
 
  579         os << 
"\"rules\": [\n";
 
  581         for (
auto const& cur : rules) {
 
  587             os << 
"\t{ \"rule-number\": \"(R" << cur.first.second << 
")\", \"rule\": \"" 
  593     void queryProcess(
const std::vector<std::pair<std::string, std::vector<std::string>>>& rels)
 override {
 
  594         std::regex varRegex(
"[a-zA-Z_][a-zA-Z_0-9]*", std::regex_constants::extended);
 
  595         std::regex symbolRegex(
"\"([^\"]*)\"", std::regex_constants::extended);
 
  596         std::regex numberRegex(
"[0-9]+", std::regex_constants::extended);
 
  598         std::smatch argsMatcher;
 
  601         std::map<std::string, Equivalence> nameToEquivalence;
 
  607         std::vector<Relation*> varRels;
 
  613         for (
const auto& 
rel : rels) {
 
  616             std::vector<RamDomain> constTuple;
 
  619                 std::cout << 
"Relation <" << 
rel.first << 
"> does not exist" << std::endl;
 
  623             if (
relation->getPrimaryArity() != 
rel.second.size()) {
 
  624                 std::cout << 
"<" + 
rel.first << 
"> has arity of " << 
relation->getPrimaryArity() << std::endl;
 
  629             bool containVar = 
false;
 
  630             for (
size_t j = 0; 
j < 
rel.second.size(); ++
j) {
 
  632                 if (std::regex_match(
rel.second[
j], argsMatcher, varRegex)) {
 
  634                     auto nameToEquivalenceIter = nameToEquivalence.find(argsMatcher[0]);
 
  637                     if (nameToEquivalenceIter == nameToEquivalence.end()) {
 
  638                         nameToEquivalence.insert(
 
  640                                                          std::make_pair(idx, 
j))});
 
  642                         nameToEquivalenceIter->second.push_back(std::make_pair(idx, 
j));
 
  650                         if (!std::regex_match(
rel.second[
j], argsMatcher, symbolRegex)) {
 
  651                             std::cout << argsMatcher.str(0) << 
" does not match type defined in relation" 
  655                         rd = prog.getSymbolTable().lookup(argsMatcher[1]);
 
  659                             std::cout << 
rel.second[
j] << 
" does not match type defined in relation" 
  667                             std::cout << 
rel.second[
j] << 
" does not match type defined in relation" 
  675                             std::cout << 
rel.second[
j] << 
" does not match type defined in relation" 
  684                 constConstraints.
push_back(std::make_pair(std::make_pair(idx, 
j), rd));
 
  686                     constTuple.push_back(rd);
 
  692                 bool tupleExist = containsTuple(
relation, constTuple);
 
  702                     std::cout << 
"false." << std::endl;
 
  703                     std::cout << 
"Tuple " << 
rel.first << 
"(";
 
  704                     for (
size_t l = 0; 
l < 
rel.second.size() - 1; ++
l) {
 
  705                         std::cout << 
rel.second[
l] << 
", ";
 
  707                     std::cout << 
rel.second.back() << 
") does not exist" << std::endl;
 
  718         if (varRels.size() == 0) {
 
  719             std::cout << 
"true." << std::endl;
 
  724         findQuerySolution(varRels, nameToEquivalence, constConstraints);
 
  728     std::map<std::pair<std::string, size_t>, std::vector<std::string>> info;
 
  729     std::map<std::pair<std::string, size_t>, std::string> rules;
 
  730     std::vector<std::vector<RamDomain>> subproofs;
 
  731     std::vector<std::string> constraintList = {
 
  732             "=", 
"!=", 
"<", 
"<=", 
">=", 
">", 
"match", 
"contains", 
"not_match", 
"not_contains"};
 
  734     std::tuple<int, int> findTuple(
const std::string& relName, std::vector<RamDomain> tup) {
 
  735         auto rel = prog.getRelation(relName);
 
  738             return std::make_tuple(-1, -1);
 
  744             std::vector<RamDomain> currentTuple;
 
  748                 if (*
rel->getAttrType(
i) == 
's') {
 
  751                     n = symTable.lookupExisting(s);
 
  752                 } 
else if (*
rel->getAttrType(
i) == 
'f') {
 
  756                 } 
else if (*
rel->getAttrType(
i) == 
'u') {
 
  764                 currentTuple.push_back(
n);
 
  779                 return std::make_tuple(ruleNum, levelNum);
 
  784         return std::make_tuple(-1, -1);
 
  794     void findQuerySolution(
const std::vector<Relation*>& varRels,
 
  795             const std::map<std::string, Equivalence>& nameToEquivalence,
 
  796             const ConstConstraint& constConstraints) {
 
  798         std::vector<Relation::iterator> varRelationIterators;
 
  803         size_t solutionCount = 0;
 
  804         std::stringstream solution;
 
  808             bool isSolution = 
true;
 
  811             std::vector<tuple> element;
 
  812             for (
auto it : varRelationIterators) {
 
  813                 element.push_back(*it);
 
  816             for (
auto var : nameToEquivalence) {
 
  817                 if (!var.second.verify(element)) {
 
  824                 isSolution = constConstraints.verify(element);
 
  828                 std::cout << solution.str();  
 
  829                 solution.str(std::string());  
 
  832                 for (
auto&& var : nameToEquivalence) {
 
  833                     auto idx = var.second.getFirstIdx();
 
  834                     auto raw = element[idx.first][idx.second];
 
  836                     solution << var.second.getSymbol() << 
" = ";
 
  837                     switch (var.second.getType()) {
 
  838                         case 'i': solution << ramBitCast<RamSigned>(raw); 
break;
 
  839                         case 'f': solution << ramBitCast<RamFloat>(raw); 
break;
 
  840                         case 'u': solution << ramBitCast<RamUnsigned>(raw); 
break;
 
  841                         case 's': solution << prog.getSymbolTable().resolve(raw); 
break;
 
  842                         default: 
fatal(
"invalid type: `%c`", var.second.getType());
 
  845                     auto sep = ++c < nameToEquivalence.size() ? 
", " : 
" ";
 
  851                 if (1 < solutionCount) {
 
  852                     for (std::string input; getline(std::cin, input);) {
 
  853                         if (input == 
";") 
break;   
 
  854                         if (input == 
".") 
return;  
 
  856                         std::cout << 
"use ; to find next solution, use . to break from current query\n";
 
  862             size_t i = varRels.size() - 1;
 
  863             bool terminate = 
true;
 
  864             for (
auto it = varRelationIterators.rbegin(); it != varRelationIterators.rend(); ++it) {
 
  865                 if ((++(*it)) != varRels[
i]->end()) {
 
  869                     (*it) = varRels[
i]->begin();
 
  876                 if (solutionCount == 0) {
 
  877                     std::cout << 
"false." << std::endl;
 
  880                     std::cout << solution.str() << 
"." << std::endl;
 
  888     bool containsTuple(Relation* 
relation, 
const std::vector<RamDomain>& constTuple) {
 
  889         bool tupleExist = 
false;
 
  892             for (
size_t j = 0; 
j < constTuple.size(); ++
j) {
 
  893                 if (constTuple[
j] != (*it)[
j]) {