/*
 * Decompiled with CFR 0.152.
 */
package org.sonar.python.checks.hotspots;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.sonar.check.Rule;
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionCheck;
import org.sonar.plugins.python.api.SubscriptionContext;
import org.sonar.plugins.python.api.symbols.ClassSymbol;
import org.sonar.plugins.python.api.symbols.Symbol;
import org.sonar.plugins.python.api.symbols.Usage;
import org.sonar.plugins.python.api.tree.AssignmentStatement;
import org.sonar.plugins.python.api.tree.CallExpression;
import org.sonar.plugins.python.api.tree.ClassDef;
import org.sonar.plugins.python.api.tree.Decorator;
import org.sonar.plugins.python.api.tree.DictionaryLiteral;
import org.sonar.plugins.python.api.tree.Expression;
import org.sonar.plugins.python.api.tree.HasSymbol;
import org.sonar.plugins.python.api.tree.KeyValuePair;
import org.sonar.plugins.python.api.tree.ListLiteral;
import org.sonar.plugins.python.api.tree.Name;
import org.sonar.plugins.python.api.tree.QualifiedExpression;
import org.sonar.plugins.python.api.tree.RegularArgument;
import org.sonar.plugins.python.api.tree.StringLiteral;
import org.sonar.plugins.python.api.tree.SubscriptionExpression;
import org.sonar.plugins.python.api.tree.Tree;
import org.sonar.python.checks.Expressions;
import org.sonar.python.tree.TreeUtils;

@Rule(key="S4502")
public class CsrfDisabledCheck
extends PythonSubscriptionCheck {
    private static final String MESSAGE = "Make sure disabling CSRF protection is safe here.";
    private static final String CSRF_VIEW_MIDDLEWARE = "django.middleware.csrf.CsrfViewMiddleware";
    private static final Set<String> DANGEROUS_DECORATORS = new HashSet<String>(Arrays.asList("django.views.decorators.csrf.csrf_exempt", "flask_wtf.csrf.CSRFProtect.exempt"));
    private static final List<Pattern> CSRF_INIT_APP_CALLEE_PATTERNS = Arrays.asList(Pattern.compile("(csrf|CSRF)"), Pattern.compile("init_app"));

    @Override
    public void initialize(SubscriptionCheck.Context context) {
        context.registerSyntaxNodeConsumer(Tree.Kind.ASSIGNMENT_STMT, CsrfDisabledCheck::djangoMiddlewareArrayCheck);
        context.registerSyntaxNodeConsumer(Tree.Kind.DECORATOR, CsrfDisabledCheck::decoratorCsrfExemptCheck);
        context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, CsrfDisabledCheck::functionCsrfExemptCheck);
        context.registerSyntaxNodeConsumer(Tree.Kind.ASSIGNMENT_STMT, CsrfDisabledCheck::flaskWtfCsrfEnabledFalseCheck);
        context.registerSyntaxNodeConsumer(Tree.Kind.CLASSDEF, CsrfDisabledCheck::metaCheck);
        context.registerSyntaxNodeConsumer(Tree.Kind.CALL_EXPR, CsrfDisabledCheck::formInstantiationCheck);
        context.registerSyntaxNodeConsumer(Tree.Kind.ASSIGNMENT_STMT, CsrfDisabledCheck::improperlyConfiguredFlaskApp);
    }

    private static void djangoMiddlewareArrayCheck(SubscriptionContext subscriptionContext) {
        boolean containsCsrfViewMiddleware;
        boolean isMiddlewareAssignment;
        if (!"settings.py".equals(subscriptionContext.pythonFile().fileName())) {
            return;
        }
        AssignmentStatement asgn = (AssignmentStatement)subscriptionContext.syntaxNode();
        boolean isLhsCalledMiddleware = CsrfDisabledCheck.isLhsCalled("MIDDLEWARE").test(asgn);
        boolean containsDjangoMiddleware = CsrfDisabledCheck.isListAnyMatch(CsrfDisabledCheck.isStringSatisfying(s -> s.startsWith("django"))).test(asgn.assignedValue());
        boolean bl = isMiddlewareAssignment = isLhsCalledMiddleware && containsDjangoMiddleware;
        if (isMiddlewareAssignment && !(containsCsrfViewMiddleware = CsrfDisabledCheck.isListAnyMatch(CsrfDisabledCheck.isStringSatisfying(CSRF_VIEW_MIDDLEWARE::equals)).test(asgn.assignedValue()))) {
            subscriptionContext.addIssue(asgn.lastToken(), MESSAGE);
        }
    }

    private static Predicate<AssignmentStatement> isLhsCalled(String lhsName) {
        return asgn -> asgn.lhsExpressions().stream().flatMap(exprList -> exprList.expressions().stream()).anyMatch(expr -> expr.is(Tree.Kind.NAME) && lhsName.equals(((Name)expr).name()));
    }

    private static Predicate<Expression> isStringSatisfying(Predicate<String> pred) {
        return expr -> expr.is(Tree.Kind.STRING_LITERAL) && pred.test(((StringLiteral)expr).trimmedQuotesValue());
    }

    private static Predicate<Expression> isListAnyMatch(Predicate<Expression> pred) {
        return expr -> Optional.ofNullable(expr).filter(e -> e.is(Tree.Kind.LIST_LITERAL)).flatMap(lst -> ((ListLiteral)lst).elements().expressions().stream().filter(pred).findFirst()).isPresent();
    }

    private static void decoratorCsrfExemptCheck(SubscriptionContext subscriptionContext) {
        boolean isDangerous;
        Decorator decorator = (Decorator)subscriptionContext.syntaxNode();
        List names = decorator.name().names().stream().map(Name::name).collect(Collectors.toList());
        boolean bl = isDangerous = names.stream().anyMatch(s -> s.toLowerCase(Locale.US).contains("csrf")) && names.stream().anyMatch(s -> s.toLowerCase(Locale.US).contains("exempt"));
        if (isDangerous) {
            subscriptionContext.addIssue(decorator.lastToken(), MESSAGE);
        }
    }

    private static void functionCsrfExemptCheck(SubscriptionContext subscriptionContext) {
        CallExpression callExpr = (CallExpression)subscriptionContext.syntaxNode();
        Optional.ofNullable(callExpr.calleeSymbol()).map(Symbol::fullyQualifiedName).filter(DANGEROUS_DECORATORS::contains).ifPresent(fqn -> subscriptionContext.addIssue(callExpr.callee().lastToken(), MESSAGE));
    }

    private static void flaskWtfCsrfEnabledFalseCheck(SubscriptionContext subscriptionContext) {
        AssignmentStatement asgn = (AssignmentStatement)subscriptionContext.syntaxNode();
        boolean isWtfCsrfEnabledSubscription = asgn.lhsExpressions().stream().flatMap(exprList -> exprList.expressions().stream()).filter(expr -> expr.is(Tree.Kind.SUBSCRIPTION)).flatMap(s -> ((SubscriptionExpression)s).subscripts().expressions().stream()).anyMatch(CsrfDisabledCheck.isStringSatisfying(s -> "WTF_CSRF_ENABLED".equals(s) || "WTF_CSRF_CHECK_DEFAULT".equals(s)));
        if (isWtfCsrfEnabledSubscription && Expressions.isFalsy(asgn.assignedValue())) {
            subscriptionContext.addIssue(asgn.assignedValue(), MESSAGE);
        }
    }

    private static void metaCheck(SubscriptionContext subscriptionContext) {
        ClassDef classDef = (ClassDef)subscriptionContext.syntaxNode();
        if (!"Meta".equals(classDef.name().name())) {
            return;
        }
        boolean isWithinFlaskForm = Optional.ofNullable(TreeUtils.firstAncestorOfKind(classDef, Tree.Kind.CLASSDEF)).map(parentClassDef -> ((ClassDef)parentClassDef).name().symbol()).filter(s -> s.is(Symbol.Kind.CLASS)).map(ClassSymbol.class::cast).filter(parentClassSymbol -> parentClassSymbol.canBeOrExtend("flask_wtf.FlaskForm")).isPresent();
        if (!isWithinFlaskForm) {
            return;
        }
        classDef.body().statements().forEach(stmt -> {
            if (stmt.is(Tree.Kind.ASSIGNMENT_STMT)) {
                AssignmentStatement asgn = (AssignmentStatement)stmt;
                if (CsrfDisabledCheck.isLhsCalled("csrf").test(asgn) && Expressions.isFalsy(asgn.assignedValue())) {
                    subscriptionContext.addIssue(asgn.assignedValue(), MESSAGE);
                }
            }
        });
    }

    private static void formInstantiationCheck(SubscriptionContext subscriptionContext) {
        CallExpression callExpr = (CallExpression)subscriptionContext.syntaxNode();
        boolean isFlaskFormInstantiation = Optional.ofNullable(callExpr.calleeSymbol()).filter(s -> s.is(Symbol.Kind.CLASS)).map(ClassSymbol.class::cast).filter(c -> c.canBeOrExtend("flask_wtf.FlaskForm")).isPresent();
        if (!isFlaskFormInstantiation) {
            return;
        }
        callExpr.arguments().forEach(arg -> {
            if (arg instanceof RegularArgument) {
                RegularArgument regArg = (RegularArgument)arg;
                CsrfDisabledCheck.searchForProblemsInFormInitializationArguments(regArg).ifPresent(badExpr -> subscriptionContext.addIssue((Tree)badExpr, MESSAGE));
            }
        });
    }

    private static Optional<Expression> searchForProblemsInFormInitializationArguments(RegularArgument regArg) {
        String name = Optional.ofNullable(regArg.keywordArgument()).map(Name::name).orElse(null);
        if ("csrf_enabled".equals(name) && Expressions.isFalsy(regArg.expression())) {
            return Optional.of(regArg.expression());
        }
        if ("meta".equals(name)) {
            return Optional.ofNullable(regArg.expression()).filter(s -> s.is(Tree.Kind.DICTIONARY_LITERAL)).map(DictionaryLiteral.class::cast).flatMap(CsrfDisabledCheck::searchForBadCsrfSettingInDictionary);
        }
        return Optional.empty();
    }

    private static Optional<Expression> searchForBadCsrfSettingInDictionary(DictionaryLiteral dict) {
        return dict.elements().stream().filter(e -> e.is(Tree.Kind.KEY_VALUE_PAIR)).map(KeyValuePair.class::cast).filter(kvp -> Optional.ofNullable(kvp.key()).filter(s -> s.is(Tree.Kind.STRING_LITERAL) && "csrf".equals(((StringLiteral)s).trimmedQuotesValue())).isPresent()).findFirst().filter(kvp -> Expressions.isFalsy(kvp.value())).map(KeyValuePair::value);
    }

    private static void improperlyConfiguredFlaskApp(SubscriptionContext subscriptionContext) {
        boolean isCsrfEnabledInThisFile;
        AssignmentStatement asgn = (AssignmentStatement)subscriptionContext.syntaxNode();
        if (CsrfDisabledCheck.isFlaskAppInstantiation(asgn.assignedValue()) && !(isCsrfEnabledInThisFile = asgn.lhsExpressions().stream().flatMap(exprList -> exprList.expressions().stream()).findFirst().filter(s -> s.is(Tree.Kind.NAME)).flatMap(app -> Optional.of((Name)app).map(HasSymbol::symbol).map(Symbol::usages).flatMap(usages -> usages.stream().filter(CsrfDisabledCheck::isWithinCsrfEnablingStatement).findFirst())).isPresent())) {
            subscriptionContext.addIssue(asgn.assignedValue(), MESSAGE);
        }
    }

    private static boolean isFlaskAppInstantiation(Expression expr) {
        if (expr.is(Tree.Kind.CALL_EXPR)) {
            Symbol cs = ((CallExpression)expr).calleeSymbol();
            return cs != null && "flask.Flask".equals(cs.fullyQualifiedName());
        }
        return false;
    }

    private static Optional<ArrayList<String>> extractQualifiedNameComponents(Expression expr) {
        if (expr.is(Tree.Kind.NAME)) {
            ArrayList<String> res = new ArrayList<String>();
            res.add(((Name)expr).name());
            return Optional.of(res);
        }
        if (expr.is(Tree.Kind.QUALIFIED_EXPR)) {
            QualifiedExpression qe = (QualifiedExpression)expr;
            return CsrfDisabledCheck.extractQualifiedNameComponents(qe.qualifier()).map(list -> {
                list.add(qe.name().name());
                return list;
            });
        }
        return Optional.empty();
    }

    private static boolean checkNestedQualifiedExpressions(List<Pattern> patternsToMatch, Expression expr) {
        Optional<ArrayList<String>> nameFragmentsOpt = CsrfDisabledCheck.extractQualifiedNameComponents(expr);
        return nameFragmentsOpt.filter(nameFragments -> {
            if (nameFragments.size() == patternsToMatch.size()) {
                for (int i = 0; i < nameFragments.size(); ++i) {
                    String s;
                    Pattern p = (Pattern)patternsToMatch.get(i);
                    if (p.matcher(s = (String)nameFragments.get(i)).matches()) continue;
                    return false;
                }
                return true;
            }
            return false;
        }).isPresent();
    }

    private static boolean isWithinCsrfEnablingStatement(Usage u) {
        Tree t = u.tree();
        return CsrfDisabledCheck.isWithinCall(new HashSet<String>(Arrays.asList("flask_wtf.csrf.CSRFProtect", "flask_wtf.csrf.CSRFProtect.init_app", "flask_wtf.CSRFProtect", "flask_wtf.CSRFProtect.init_app")), CSRF_INIT_APP_CALLEE_PATTERNS, t);
    }

    private static boolean isWithinCall(Set<String> expectedCalleeFqns, List<Pattern> fallbackCalleeRegexes, Tree t) {
        Tree callExprTree = TreeUtils.firstAncestorOfKind(t, Tree.Kind.CALL_EXPR);
        if (callExprTree != null) {
            Symbol callExprSymb = ((CallExpression)callExprTree).calleeSymbol();
            if (callExprSymb != null && expectedCalleeFqns.contains(callExprSymb.fullyQualifiedName())) {
                return true;
            }
            Expression callee = ((CallExpression)callExprTree).callee();
            return CsrfDisabledCheck.checkNestedQualifiedExpressions(fallbackCalleeRegexes, callee);
        }
        return false;
    }
}

